In [2]:
import torch
print(f"GPU name: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9} GB")

GPU available: True
GPU name: Tesla P100-PCIE-16GB
GPU memory: 17.059545088 GB


### Load and Prepare PubMedQA Dataset

In [7]:
"""The PubMedQA dataset contains biomedical research questions with yes/no/maybe answers,
collected from PubMed abstracts 310. The dataset includes questions, contexts (abstracts), and answer labels."""
from datasets import load_dataset

# Load PubMedQA dataset
dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled")

# Explore dataset structure
print(dataset)
print("\nFirst sample:")
print(dataset['train'][0])

DatasetDict({
    train: Dataset({
        features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
        num_rows: 1000
    })
})

First sample:
{'pubid': 21645374, 'question': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?', 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.', 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A s

In [9]:
# Convert PubMedQA format to instruction format compatible with gpt-oss

"""This function formats the PubMedQA samples into the Harmony chat format required by gpt-oss models,
which uses a system message, user message, and assistant response structure"""

def format_pubmedqa_sample(sample):
    return {
        "messages": [
            {"role": "system", "content": "You are a helpful medical AI assistant. Answer the question based on the provided context."},
            {"role": "user", "content": f"Context: {sample['context']}\n\nQuestion: {sample['question']}\n\nAnswer (yes/no/maybe):"},
            {"role": "assistant", "content": sample['final_decision']}
        ]
    }

# Apply formatting
formatted_dataset = dataset.map(format_pubmedqa_sample, remove_columns=dataset['train'].column_names)

# Split dataset into train and validation
train_dataset = formatted_dataset['train'].train_test_split(test_size=0.1, seed=42)
train_data = train_dataset['train']
val_data = train_dataset['test']

print(train_data[0])

{'messages': [{'content': 'You are a helpful medical AI assistant. Answer the question based on the provided context.', 'role': 'system'}, {'content': "Context: {'contexts': ['In recent clinical trials (RCT) of bowel preparation, Golytely was more efficacious than MiraLAX. We hypothesised that there is a difference in adenoma detection between Golytely and MiraLAX.', 'To compare the adenoma detection rate (ADR) between these bowel preparations, and to identify independent predictors of bowel preparation quality and adenoma detection.', 'This was a post hoc analysis of an RCT that assessed efficacy and patient tolerability of Golytely vs. MiraLAX/Gatorade in average risk screening colonoscopy patients. Bowel preparation quality was measured with the Boston Bowel Preparation Scale (BBPS). An excellent/good equivalent BBPS score was defined as ≥ 7. Polyp pathology review was performed. ADR was defined as the proportion of colonoscopies with an adenoma. Univariate and multivariate analyses

### Load Model and Tokenizer

In [26]:
"""We use 4-bit quantization to reduce memory usage, allowing the 20B parameter model to fit on the A40 GPU.
The BitsAndBytesConfig configures the quantization parameters"""

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import torch

# Model configuration
model_name = "openai/gpt-oss-20b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
)

ModuleNotFoundError: Could not import module 'AutoModelForCausalLM'. Are this object's requirements defined correctly?

### Configure LoRA for Parameter-Efficient Fine-Tuning

In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "7.mlp.experts.gate_up_proj",
        "7.mlp.experts.down_proj",
        "15.mlp.experts.gate_up_proj",
        "15.mlp.experts.down_proj",
        "23.mlp.experts.gate_up_proj",
        "23.mlp.experts.down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Prepare model for PEFT
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

### Training 

In [None]:
from transformers import TrainingArguments

# Training arguments
training_args = TrainingArguments(
    output_dir="./gpt-oss-20b-pubmedqa",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    optim="paged_adamw_32bit",
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    report_to="wandb",
    fp16=False,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    remove_unused_columns=False,
)

In [None]:
from trl import SFTTrainer
from transformers import DataCollatorForLanguageModeling

# Create trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    dataset_text_field="messages",
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    max_seq_length=2048,
    packing=False,
)

# Start training
trainer.train()

In [None]:
# Save the fine-tuned model
trainer.save_model()
tokenizer.save_pretrained("./gpt-oss-20b-pubmedqa")

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

In [None]:
# Load the fine-tuned model for inference
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
fine_tuned_model = PeftModel.from_pretrained(base_model, "./gpt-oss-20b-pubmedqa")

# Create inference pipeline
pipe = pipeline(
    "text-generation",
    model=fine_tuned_model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Test with a biomedical question
test_question = {
    "messages": [
        {"role": "system", "content": "You are a helpful medical AI assistant. Answer the question based on the provided context."},
        {"role": "user", "content": "Context: This study examines the effects of aspirin on cardiovascular risk in diabetic patients.\n\nQuestion: Does aspirin reduce cardiovascular risk in diabetes?\n\nAnswer (yes/no/maybe):"}
    ]
}

formatted_prompt = tokenizer.apply_chat_template(test_question["messages"], tokenize=False)
outputs = pipe(formatted_prompt, max_new_tokens=50, temperature=0.1)
print(outputs[0]['generated_text'])