# Prefix Tuning on Subset of USMLE

## Setup

In [1]:
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    default_data_collator, 
    get_linear_schedule_with_warmup
)
from peft import (
    get_peft_model, 
    PrefixTuningConfig, 
    TaskType
)
from accelerate import dispatch_model, infer_auto_device_map
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import os
import sys

#  GPU Verification
assert torch.cuda.is_available(), "GPU not detected!"
print(f"CUDA version: {torch.version.cuda}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
torch.cuda.empty_cache()

# --- 3. Model and Tokenizer Setup ---
model_name = "meta-llama/Llama-2-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # For causal LM padding

# --- 4. Hyperparameters ---
max_length = 128
lr = 1e-2
num_epochs = 3
batch_size = 2

CUDA version: 12.1
VRAM: 23.57GB


## Load Data

In [2]:
sys.path.append(os.path.abspath(os.path.join('..'))) 
from src.helper_functions import format_mcf_finetuning


# Load model
usml_raw = load_dataset("GBaker/MedQA-USMLE-4-options")
usml_train = usml_raw['train']
sample_train = usml_train.shuffle(seed=42).select(range(100))
print(sample_train)

Dataset({
    features: ['question', 'answer', 'options', 'meta_info', 'answer_idx', 'metamap_phrases'],
    num_rows: 100
})


In [3]:
formatted_train_subset = sample_train.map(
    format_mcf_finetuning,
    remove_columns=sample_train.column_names
)

print(formatted_train_subset[0]['prompt'])
print(formatted_train_subset[0]['completion'])

Question: A 35-year-old woman comes to your office with a variety of complaints. As part of her evaluation, she undergoes laboratory testing which reveals the presence of anti-centromere antibodies. All of the following symptoms and signs would be expected to be present EXCEPT:
A. Pallor, cyanosis, and erythema of the hands
B. Blanching vascular abnormalities
C. Hypercoagulable state
D. Heartburn and regurgitation
Answer:
Hypercoagulable state


## Tokenize Data

In [4]:
# 5. Tokenization Function
def tokenize_function(examples):
    texts = [p + c for p, c in zip(examples['prompt'], examples['completion'])]
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=512,
        padding="max_length",
        return_tensors="pt"
    )
    
    # Create labels (mask prompt tokens)
    prompt_lens = [len(tokenizer(p)['input_ids']) for p in examples['prompt']]
    labels = tokenized["input_ids"].clone()
    for i, plen in enumerate(prompt_lens):
        labels[i, :plen] = -100
    
    tokenized["labels"] = labels
    return tokenized

# 6. Apply tokenization
tokenized_dataset = formatted_train_subset.map(
    tokenize_function,
    batched=True,
    remove_columns=['prompt', 'completion'],
    batch_size=8  # Smaller batches for tokenization
)

## Data Loader

In [5]:
train_dataloader = DataLoader(
    tokenized_dataset,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=default_data_collator
)

## Prefix Tuning

In [6]:
peft_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,  # Correct for LLaMA
    inference_mode=False,
    num_virtual_tokens=20
)

# Load LLaMA-2
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,  # For speed and lower memory
    token=True
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 5,242,880 || all params: 6,743,658,496 || trainable%: 0.0777


## Optimizer & Scheduler

In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=10,
    num_training_steps=len(train_dataloader) * num_epochs,
)

## Training

In [11]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

# 1. Training arguments - configure as needed
training_args = TrainingArguments(
    output_dir="./llama7b-prefix",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    learning_rate=lr, 
    logging_steps=10,
    save_strategy="epoch",
    fp16=True,
    report_to="none",
)

# 2. Data collator to handle padding and labels
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # for causal LM (e.g. LLaMA, GPT)
)

# 3. Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,  # your tokenized train dataset
    data_collator=data_collator,
    optimizers=(optimizer, lr_scheduler)
)

# 4. Start training
trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
10,8.0366
20,5.7689
30,4.2056
40,3.4808
50,2.8479
60,2.4562
70,2.3215
80,2.2384
90,2.2085
100,2.1981


TrainOutput(global_step=150, training_loss=2.9849599329630534, metrics={'train_runtime': 112.705, 'train_samples_per_second': 2.662, 'train_steps_per_second': 1.331, 'total_flos': 6089327876505600.0, 'train_loss': 2.9849599329630534, 'epoch': 3.0})

## Save adapters - to not exclude full model

In [None]:
model.save_pretrained(
    "llama7b-prefix-subset",
    safe_serialization=True,  # Uses modern .safetensors format
    max_shard_size="200MB"  # Optional: splits large adapters
)

## Upload model to Huggingface

!pip install huggingface_hub transformers

In [None]:
from huggingface_hub import login
# Place Login Token
login(token="LOGIN_TOKEN")

In [None]:
# 1. Define your custom model name
MODEL_NAME = "llama7b-prefix-subset"  
USERNAME = "pippalap"  # Your Hugging Face username

model.push_to_hub("pippalap/llama7b-prefix-subset")

In [None]:
# CommitInfo(commit_url='https://huggingface.co/pippalap/llama7b-prefix-subset/commit/ee52723c2b44d04181b4f5be0bcaeb10fda69172', commit_message='Upload model', 
# commit_description='', oid='ee52723c2b44d04181b4f5be0bcaeb10fda69172', pr_url=None, repo_url=RepoUrl('https://huggingface.co/pippalap/llama7b-prefix-subset', 
# endpoint='https://huggingface.co', repo_type='model', repo_id='pippalap/llama7b-prefix-subset'), pr_revision=None, pr_num=None)