In [20]:
# For cloud gpu runpod.io

# !pip install -q transformers==4.41 datasets peft accelerate deepspeed evaluate rouge-score

In [21]:
import json, re, random, torch, os
from datasets import Dataset, DatasetDict, load_dataset
from transformers import (
        AutoTokenizer, AutoModelForCausalLM,
        TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, TaskType
import evaluate
from transformers import pipeline

# Reproducibility
random.seed(11)
torch.manual_seed(11)

ModuleNotFoundError: No module named 'torch'

### Build train / validation JSON

In [23]:
data_path = Path("../pubmedqa/data/ori_pqal.json")
with open(data_path, "r") as f:
    pubmedqa = json.load(f)

def make_record(item):
    prompt = (
        f"Question: {item['QUESTION']}\n"
        f"Context:  {' '.join(item['CONTEXTS'])}\n"
        f"Answer:"
    )
    completion = " " + item["final_decision"]   # leading space important
    return {"prompt": prompt, "completion": completion}

records = [make_record(v) for v in pubmedqa.values()]
random.shuffle(records)

split_idx = int(0.9 * len(records))
train_recs, val_recs = records[:split_idx], records[split_idx:]

for split, recs in (("train", train_recs), ("validation", val_recs)):
    with open(f"{split}.jsonl", "w") as f:
        for r in recs:
            f.write(json.dumps(r) + "\n")

print(f"Train: {len(train_recs)}  |  Validation: {len(val_recs)}")

Train: 900  |  Validation: 100


### Load & tokenise

In [None]:
base_model = "openai/gpt-oss-20b"
tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)
tok.pad_token = tok.eos_token

raw_ds = load_dataset("json", data_files={"train":"train.jsonl",
                                          "validation":"validation.jsonl"})

block_size = 1024
def tokenize(example):
    full = example["prompt"] + example["completion"]
    out = tok(full, truncation=True, max_length=block_size)
    out["labels"] = out["input_ids"].copy()
    return out

tokenised = raw_ds.map(tokenize, remove_columns=raw_ds["train"].column_names, num_proc=4)

### LoRA

In [None]:
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["c_attn", "c_proj", "c_fc"],
)


model = AutoModelForCausalLM.from_pretrained(
            base_model,
            torch_dtype=torch.bfloat16,
            device_map="auto")
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.print_trainable_parameters()

### Training

In [None]:
training_args = TrainingArguments(
    output_dir="gpt-oss-20b-pubmedqa-lora",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,   # effective batch 16
    num_train_epochs=1,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=25,
    save_steps=250,
    eval_steps=250,
    evaluation_strategy="steps",
    bf16=True,
    dataloader_pin_memory=False,
    report_to="none",
    # deepspeed="ds_config_zero2.json"  # switch on for multi-GPU
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenised["train"],
    eval_dataset=tokenised["validation"],
    data_collator=DataCollatorForLanguageModeling(tok, mlm=False),
)
trainer.train()
trainer.save_model("gpt-oss-20b-pubmedqa-lora/final")
tok.save_pretrained("gpt-oss-20b-pubmedqa-lora/final")

### Merge adapter → full checkpoint

In [None]:
del model  # free memory
base = AutoModelForCausalLM.from_pretrained(
        base_model, torch_dtype=torch.bfloat16, device_map="auto")
lora = PeftModel.from_pretrained(base, "gpt-oss-20b-pubmedqa-lora/final")
merged = lora.merge_and_unload()
merged.save_pretrained("gpt-oss-20b-pubmed-lora")
tok.save_pretrained("gpt-oss-20b-pubmed-lora")

### Evaluate

In [None]:
pipe = pipeline(
        "text-generation",
        model="gpt-oss-20b-pubmed-lora",
        tokenizer="gpt-oss-20b-pubmed-lora",
        torch_dtype="auto",
        device_map="auto")

test_path = Path("../pubmedqa/data/test_ground_truth.json")
test_data = json.load(open(test_path))

def extract_ans(text):
    m = re.findall(r"Answer:\s*(yes|no|maybe)", text.lower())
    return m[-1] if m else "maybe"

refs, preds = [], []
for rec in test_data.values():
    prompt = (
        f"Question: {rec['QUESTION']}\n"
        f"Context:  {' '.join(rec['CONTEXTS'])}\n"
        f"Answer:"
    )
    gen = pipe(prompt, max_new_tokens=5, do_sample=False)[0]["generated_text"]
    preds.append(extract_ans(gen))
    refs.append(rec["final_decision"])

accuracy_metric = evaluate.load("accuracy")
acc = accuracy_metric.compute(predictions=preds, references=refs)["accuracy"]
print("Official test accuracy:", round(acc, 4))