# Prefix Tuning Full USMLE on Letter Answer Choice

## Setup

In [3]:
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    default_data_collator, 
    get_linear_schedule_with_warmup,
    BitsAndBytesConfig
)
from peft import (
    get_peft_model, 
    PrefixTuningConfig, 
    TaskType
)
from accelerate import dispatch_model, infer_auto_device_map
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import os
import sys

# --- GPU Verification ---
assert torch.cuda.is_available(), "GPU not detected!"
print(f"CUDA version: {torch.version.cuda}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
torch.cuda.empty_cache()

# --- Model and Tokenizer Setup ---
model_name = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # For causal LM padding

# --- Hyperparameters ---
# Adjust hyperparameters for full dataset and larger model
max_length = 512  # Increased for longer sequences
lr = 5e-3  # Slightly lower learning rate for stability
num_epochs = 1  # Reduced epochs for full dataset
batch_size = 1  # Reduced batch size due to larger model
gradient_accumulation_steps = 8  # Accumulate gradients to simulate larger batch
eval_steps = 500  # Evaluate every 500 steps
save_steps = 1000  # Save checkpoints every 1000 steps
logging_steps = 50  # Log more frequently
max_grad_norm = 1.0  # Gradient clipping

CUDA version: 12.6
VRAM: 23.57GB


## Load Data

In [9]:
sys.path.append(os.path.abspath(os.path.join('..'))) 
from src.helper_functions import format_letter_finetuning

usml_raw = load_dataset("GBaker/MedQA-USMLE-4-options")
usml_train = usml_raw['train']
print(usml_train)

Dataset({
    features: ['question', 'answer', 'options', 'meta_info', 'answer_idx', 'metamap_phrases'],
    num_rows: 10178
})


## Pre-process the Dataset (Letter Strategy)

In [None]:
formatted_train = usml_train.map(
    format_letter_finetuning,
    remove_columns=usml_train.column_names
)

print(formatted_train[0]['prompt'])
print(formatted_train[0]['completion'])

## Create train/validation split for monitoring

In [11]:
train_size = int(0.95 * len(full_train))
val_size = len(full_train) - train_size
train_dataset = full_train.select(range(train_size))
val_dataset = full_train.select(range(train_size, train_size + val_size))

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Format Data
formatted_train = train_dataset.map(
    format_mcf_finetuning,
    remove_columns=train_dataset.column_names,
    num_proc=4
)

formatted_val = val_dataset.map(
    format_mcf_finetuning,
    remove_columns=val_dataset.column_names,
    num_proc=4
)

Training samples: 9669
Validation samples: 509


## Tokenize Data

In [13]:
# --- Tokenization Function ---
def tokenize_function(examples):
    texts = [p + c for p, c in zip(examples['prompt'], examples['completion'])]
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=max_length,
        padding="max_length",
        return_tensors="pt"
    )
    
    # Create labels (mask prompt tokens)
    prompt_lens = [len(tokenizer(p)['input_ids']) for p in examples['prompt']]
    labels = tokenized["input_ids"].clone()
    for i, plen in enumerate(prompt_lens):
        labels[i, :plen] = -100
    
    tokenized["labels"] = labels
    return tokenized

# --- Apply tokenization ---
print("Tokenizing training data...")
tokenized_train = formatted_train.map(
    tokenize_function,
    batched=True,
    remove_columns=['prompt', 'completion'],
    batch_size=8,
    num_proc=4
)

print("Tokenizing validation data...")
tokenized_val = formatted_val.map(
    tokenize_function,
    batched=True,
    remove_columns=['prompt', 'completion'],
    batch_size=8,
    num_proc=4
)

Tokenizing training data...


Map (num_proc=4):   0%|          | 0/9669 [00:00<?, ? examples/s]

Tokenizing validation data...


Map (num_proc=4):   0%|          | 0/509 [00:00<?, ? examples/s]

## Prefix Tuning Configuration

In [14]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

# Adjust prefix tuning config for Llama 3
peft_config = PrefixTuningConfig(
    peft_type="PREFIX_TUNING",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    num_virtual_tokens=30,
    prefix_projection=False,
)
    
# Load Llama 3 8b with optimized settings
print("Loading Llama 3 8b model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16,
    token=True,
    trust_remote_code=True,
    use_cache=False

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Loading Llama 3 8b model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 1,966,080 || all params: 8,032,227,328 || trainable%: 0.0245


## Training Configuration

In [19]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir="./llama3-8b-usmle-prefix-letters-v1",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=lr,
    weight_decay=0.01,
    warmup_steps=100,
    logging_steps=logging_steps,
    eval_steps=eval_steps,
    save_steps=save_steps,
    save_strategy="steps",
    eval_strategy="no",
    load_best_model_at_end=False,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=True,
    dataloader_drop_last=True,
    remove_unused_columns=False,
    report_to="none",
    max_grad_norm=max_grad_norm,
    push_to_hub=False,
    dataloader_num_workers=2,
    dataloader_pin_memory=True,
    label_names=["labels"],
)

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    data_collator=data_collator
)

trainer.train()

Step,Training Loss
50,8.8859
100,6.7613
150,5.8201
200,4.9743
250,3.8625
300,2.9114
350,2.4497
400,2.2152
450,2.0476
500,1.9558


TrainOutput(global_step=1209, training_loss=2.715461001305466, metrics={'train_runtime': 5778.5979, 'train_samples_per_second': 1.673, 'train_steps_per_second': 0.209, 'total_flos': 2.229200383597609e+17, 'train_loss': 2.715461001305466, 'epoch': 1.0})

## Save adapters - to not exclude full model

In [21]:
model.save_pretrained(
    "llama3-8b-usmle-prefix-letters",
    safe_serialization=True,  # Uses modern .safetensors format
    max_shard_size="200MB"  # Optional: splits large adapters
)

In [23]:
# 1. Define your custom model name
MODEL_NAME = "llama3-8b-usmle-prefix-letters"  
USERNAME = "pippalap"  # Your Hugging Face username

model.push_to_hub("pippalap/llama3-8b-usmle-prefix-letters")


Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/config.json.
Your request to access model meta-llama/Meta-Llama-3-8B has been rejected by the repo's authors. - silently ignoring the lookup for the file config.json in meta-llama/Meta-Llama-3-8B.


Uploading...:   0%|          | 0.00/7.86M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/pippalap/llama3-8b-usmle-prefix-letters/commit/4f2b62b6fc452b297c3ecee4ee582edfdd78b5f1', commit_message='Upload model', commit_description='', oid='4f2b62b6fc452b297c3ecee4ee582edfdd78b5f1', pr_url=None, repo_url=RepoUrl('https://huggingface.co/pippalap/llama3-8b-usmle-prefix-letters', endpoint='https://huggingface.co', repo_type='model', repo_id='pippalap/llama3-8b-usmle-prefix-letters'), pr_revision=None, pr_num=None)