In [32]:
import torch
print(f"GPU name: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9} GB")

GPU name: Tesla P100-PCIE-16GB
GPU memory: 17.059545088 GB


### Load and Prepare PubMedQA Dataset

In [33]:
"""The PubMedQA dataset contains biomedical research questions with yes/no/maybe answers,
collected from PubMed abstracts 310. The dataset includes questions, contexts (abstracts), and answer labels."""
from datasets import load_dataset

# Load PubMedQA dataset
dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled")

# Inspect dataset structure
print(dataset["train"][0])

{'pubid': 21645374, 'question': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?', 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.', 'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells i

In [34]:
# Convert PubMedQA format to instruction format compatible with gpt-oss

"""This function formats the PubMedQA samples into the Harmony chat format required by gpt-oss models,
which uses a system message, user message, and assistant response structure"""

def convert_to_harmony_format(example):
    """
    Convert PubMedQA example to Harmony format with analysis and final channels.
    """
    question = example["question"]
    context = example["context"]
    long_answer = example["long_answer"]
    answer = example["final_decision"]  # yes/no/maybe

    # Create reasoning analysis in biomedical style
    reasoning_text = (
        f"Analyzing biomedical research question: {question}\n"
        f"Context: {context}\n"
        f"Evidence from abstract: {long_answer}\n"
        f"Conclusion: The answer is {answer} because the evidence supports this conclusion."
    )

    # Construct Harmony format messages
    messages = [
        {"role": "developer", "content": "You are a biomedical AI assistant. "
         "Analyze research questions step-by-step based on provided abstract context."},
        {"role": "user", "content": f"Question: {question}\nContext: {context}"},
        {"role": "assistant", "content": reasoning_text, "channel": "analysis"},
        {"role": "assistant", "content": f"The answer is {answer}.", "channel": "final"}
    ]
    return {"messages": messages}

# Apply conversion to dataset
dataset = dataset.map(convert_to_harmony_format)

### Load Model and Tokenizer

In [35]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")

# Add special tokens for Harmony format if not present
special_tokens = ["<|start|>", "<|end|>", "<|return|>", "<|channel|>analysis", "<|channel|>final"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

# Set padding token
tokenizer.pad_token = tokenizer.eos_token

Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/27.9M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

In [36]:
import torch
from transformers import AutoModelForCausalLM, Mxfp4Config

# Configure quantization for A40 GPU
quantization_config = Mxfp4Config(dequantize=True)

model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    quantization_config=quantization_config,
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    use_cache=False,
    device_map="auto"
)

# Resize embedding for special tokens
model.resize_token_embeddings(len(tokenizer))

ImportError: cannot import name 'Mxfp4Config' from 'transformers' (/usr/local/lib/python3.11/dist-packages/transformers/__init__.py)

### Configure LoRA for Parameter-Efficient Fine-Tuning

In [None]:
from peft import LoraConfig

# LoRA configuration targeting attention and MLP layers
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "7.mlp.experts.gate_up_proj",
        "7.mlp.experts.down_proj",
        "15.mlp.experts.gate_up_proj",
        "15.mlp.experts.down_proj",
        "23.mlp.experts.gate_up_proj",
        "23.mlp.experts.down_proj",
        "attn.q_proj",
        "attn.k_proj",
        "attn.v_proj"
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

### Training 

In [37]:
from transformers import TrainingArguments

# Training arguments optimized for A40 GPU
training_args = TrainingArguments(
    output_dir="./gpt-oss-20b-pubmedqa",
    per_device_train_batch_size=2,  # Adjust based on VRAM
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=3,
    logging_steps=10,
    save_steps=500,
    fp16=True,
    optim="paged_adamw_8bit",
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    report_to="none",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

In [38]:
from trl import SFTTrainer

# Initialize trainer with dataset and model
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
    peft_config=peft_config,
    formatting_func=lambda example: tokenizer.apply_chat_template(
        example["messages"], tokenize=False
    ),
    max_seq_length=2048  # Reduced for A40 memory
)

RuntimeError: Failed to import trl.trainer.sft_trainer because of the following error (look up to see its traceback):
cannot import name 'add_model_info_to_auto_map' from 'transformers.utils' (/usr/local/lib/python3.11/dist-packages/transformers/utils/__init__.py)

In [39]:
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    quantization_config=quantization_config,
    device_map="auto"
)

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "./gpt-oss-20b-pubmedqa-final")

ImportError: cannot import name 'BloomPreTrainedModel' from 'transformers' (/usr/local/lib/python3.11/dist-packages/transformers/__init__.py)

In [None]:
def ask_biomedical_question(question, context):
    """Generate reasoning and answer for biomedical question."""
    messages = [
        {"role": "developer", "content": "You are a biomedical AI assistant. "
         "Analyze research questions step-by-step based on provided abstract context."},
        {"role": "user", "content": f"Question: {question}\nContext: {context}"}
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    ).to(model.device)
    
    outputs = model.generate(
        inputs,
        max_new_tokens=512,
        temperature=0.1,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=False)

# Test with example from PubMedQA
result = ask_biomedical_question(
    "Do preoperative statins reduce atrial fibrillation after coronary artery bypass grafting?",
    "Statins reduce inflammation and postoperative atrial fibrillation..."
)
print(result)

In [None]:
# Merge LoRA weights into base model
merged_model = model.merge_and_unload()

# Save full model for deployment
merged_model.save_pretrained("./gpt-oss-20b-pubmedqa-merged")
tokenizer.save_pretrained("./gpt-oss-20b-pubmedqa-merged")

In [None]:
# inference.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )
    return model, tokenizer

def answer_question(model, tokenizer, question, context):
    # ... (same as above ask_biomedical_question function)
    pass

if __name__ == "__main__":
    model, tokenizer = load_model("./gpt-oss-20b-pubmedqa-merged")
    result = answer_question(model, tokenizer, "Question here", "Context here")
    print(result)