# Medical Question Pairs Classification with GPT-2 and PEFT

This notebook implements a medical question similarity classification model using GPT-2 and Parameter-Efficient Fine-Tuning (PEFT).

## Import Required Libraries

In [1]:
import os
import torch
import pandas as pd
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    set_seed,
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
)
import wandb

## Global Configurations

In [2]:
# Global configurations
SEED = 42
MODEL_NAME = "gpt2"
ID2LABEL = {0: "not_similar", 1: "similar"}
LABEL2ID = {"not_similar": 0, "similar": 1}

## Helper Functions

### Seed Setting and Weights & Biases Setup

In [3]:
def set_seeds():
    torch.manual_seed(SEED)
    if torch.mps.is_available():
        torch.mps.manual_seed(SEED)
    set_seed(SEED)

def setup_wandb():
    os.environ["WANDB_WATCH"] = "false"
    config = {
        "model_name": MODEL_NAME,
        "learning_rate": 2e-5,
        "epochs": 3,
        "batch_size": 8,
        "lora_r": 8,
        "lora_alpha": 32,
        "lora_dropout": 0.1,
        "seed": SEED,
    }
    wandb.init(
        project="medical-qa-peft",
        name="gpt2-lora-experiment",
        config=config,
        settings=wandb.Settings(console="off"),
    )
    return config

### Dataset Loading and Preparation

In [4]:
def load_and_prepare_dataset():
    dataset = load_dataset("medical_questions_pairs")
    if "validation" not in dataset:
        train_testvalid = dataset["train"].train_test_split(test_size=0.2, seed=SEED)
        test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=SEED)
        dataset = {
            "train": train_testvalid["train"],
            "validation": test_valid["train"],
            "test": test_valid["test"],
        }
    return dataset

### Metrics Computation and Training Arguments

In [5]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="binary"
    )
    acc = accuracy_score(labels, preds)
    metrics = {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
    wandb.log(metrics)
    return metrics

def get_training_args(config):
    return TrainingArguments(
        output_dir="./checkpoints",
        learning_rate=config["learning_rate"],
        per_device_train_batch_size=config["batch_size"],
        per_device_eval_batch_size=config["batch_size"],
        num_train_epochs=config["epochs"],
        weight_decay=0.01,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        report_to="wandb",
        logging_dir="./logs",
        logging_steps=10,
    )

### Data Preprocessing and Model Evaluation

In [6]:
def preprocess_data(dataset, tokenizer):
    def preprocess_function(examples):
        tokenized = tokenizer(
            examples["question_1"],
            examples["question_2"],
            truncation=True,
            max_length=128,
            padding="max_length",
        )
        tokenized["labels"] = examples["label"]
        return tokenized

    return {
        split: dataset[split].map(
            preprocess_function,
            batched=True,
            remove_columns=dataset[split].column_names,
        )
        for split in dataset.keys()
    }

def evaluate_model(model, trainer, model_name=""):
    """Evaluate model performance"""
    print(f"\nEvaluating {model_name}...")
    metrics = trainer.evaluate()
    print(f"{model_name} metrics:", metrics)
    return metrics

## Main Training and Evaluation Process

In [7]:
# Initialize settings
set_seeds()
config = setup_wandb()

# Load dataset
dataset = load_and_prepare_dataset()

# Set up tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Preprocess data
encoded_dataset = preprocess_data(dataset, tokenizer)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtim_lin[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Map:   0%|          | 0/2438 [00:00<?, ? examples/s]

Map:   0%|          | 0/305 [00:00<?, ? examples/s]

Map:   0%|          | 0/305 [00:00<?, ? examples/s]

### Base Model Setup and Evaluation

In [8]:
# Load base model
base_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    pad_token_id=tokenizer.eos_token_id,
    id2label=ID2LABEL,
    label2id=LABEL2ID,
)

# Get training arguments
training_args = get_training_args(wandb.config)

# Create trainer for base model evaluation
base_trainer = Trainer(
    model=base_model,
    args=training_args,
    eval_dataset=encoded_dataset["validation"],
    data_collator=DataCollatorWithPadding(tokenizer),
    compute_metrics=compute_metrics,
)

# Evaluate base model
base_metrics = evaluate_model(base_model, base_trainer, "Base Model")

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Evaluating Base Model...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Base Model metrics: {'eval_loss': 7.587254047393799, 'eval_model_preparation_time': 0.0009, 'eval_accuracy': 0.45901639344262296, 'eval_f1': 0.0, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_runtime': 2.019, 'eval_samples_per_second': 151.066, 'eval_steps_per_second': 19.317}


### PEFT Model Setup and Training

In [9]:
# Configure LoRA
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=config["lora_r"],
    lora_alpha=config["lora_alpha"],
    lora_dropout=config["lora_dropout"],
    target_modules=["c_attn", "c_proj"],
    fan_in_fan_out=True,
)

# Create PEFT model
peft_model = get_peft_model(base_model, peft_config)
print("Trainable parameters:", peft_model.print_trainable_parameters())

# Create trainer for PEFT model
peft_trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=DataCollatorWithPadding(tokenizer),
    compute_metrics=compute_metrics,
)

# Train PEFT model
print("Training PEFT model...")
peft_trainer.train()

No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


trainable params: 812,544 || all params: 125,253,888 || trainable%: 0.6487
Trainable parameters: None
Training PEFT model...


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.7215,0.766343,0.593443,0.560284,0.675214,0.478788
2,0.7197,0.677142,0.629508,0.563707,0.776596,0.442424
3,0.6679,0.658817,0.636066,0.571429,0.787234,0.448485


TrainOutput(global_step=915, training_loss=1.289413628812696, metrics={'train_runtime': 135.3262, 'train_samples_per_second': 54.047, 'train_steps_per_second': 6.761, 'total_flos': 482345291612160.0, 'train_loss': 1.289413628812696, 'epoch': 3.0})

### Model Evaluation and Results Analysis

In [10]:
# Evaluate PEFT model
peft_metrics = evaluate_model(peft_model, peft_trainer, "PEFT Model")

# Compare and record performance differences
print("\nPerformance Comparison:")
print(f"Base Model Accuracy: {base_metrics['eval_accuracy']:.4f}")
print(f"PEFT Model Accuracy: {peft_metrics['eval_accuracy']:.4f}")
print(f"Improvement: {(peft_metrics['eval_accuracy'] - base_metrics['eval_accuracy'])*100:.2f}%")

# Log comparison results to wandb
comparison_data = [
    ["Base Model", base_metrics["eval_accuracy"]],
    ["PEFT Model", peft_metrics["eval_accuracy"]]
]
wandb.log({
    "model_comparison": wandb.plot.bar(
        wandb.Table(data=comparison_data, columns=["Model Type", "Accuracy"]),
        "Model Type",
        "Accuracy",
        title="Model Accuracy Comparison"
    )
})


Evaluating PEFT Model...


PEFT Model metrics: {'eval_loss': 0.6588172912597656, 'eval_accuracy': 0.6360655737704918, 'eval_f1': 0.5714285714285714, 'eval_precision': 0.7872340425531915, 'eval_recall': 0.4484848484848485, 'eval_runtime': 2.0476, 'eval_samples_per_second': 148.957, 'eval_steps_per_second': 19.047, 'epoch': 3.0}

Performance Comparison:
Base Model Accuracy: 0.4590
PEFT Model Accuracy: 0.6361
Improvement: 17.70%


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Save Model and Generate Test Results

In [11]:
# Save model
peft_model.save_pretrained("./peft_model")

# Test set prediction and results saving
test_results = peft_trainer.predict(encoded_dataset["test"])

test_df = pd.DataFrame({
    "question_1": [item["question_1"] for item in dataset["test"]],
    "question_2": [item["question_2"] for item in dataset["test"]],
    "predictions": test_results.predictions.argmax(axis=1),
    "true_labels": test_results.label_ids,
})

test_df["prediction_text"] = test_df["predictions"].map(ID2LABEL)
test_df["true_label_text"] = test_df["true_labels"].map(ID2LABEL)
test_df["is_correct"] = test_df["predictions"] == test_df["true_labels"]

os.makedirs("test_results", exist_ok=True)
test_df.to_csv("test_results/predictions.csv", index=False)
print("\nTest results saved to predictions.csv")

wandb.finish()


Test results saved to predictions.csv


0,1
accuracy,▁▆███▇
eval/accuracy,▁▆███
eval/f1,▁████
eval/loss,█▁▁▁▁
eval/model_preparation_time,▁
eval/precision,▁▇███
eval/recall,▁█▇██
eval/runtime,▇▆▅▁█
eval/samples_per_second,▂▃▄█▁
eval/steps_per_second,▂▃▄█▁

0,1
accuracy,0.61967
eval/accuracy,0.63607
eval/f1,0.57143
eval/loss,0.65882
eval/model_preparation_time,0.0009
eval/precision,0.78723
eval/recall,0.44848
eval/runtime,2.0476
eval/samples_per_second,148.957
eval/steps_per_second,19.047
