# Lab 7
## Optimizing GPT-2 with KL Regularization + DPO

Today you will:
1) Do a warm-up update that penalizes drift from a reference model via KL
2) Train with Direct Preference Optimization (DPO) using your Lab 6 preference pairs
3) Evaluate base vs KL-only vs DPO models on held-out prompts
4) Reflect on how human feedback can fail (shortcut learning, ambiguity, drift)

*Note: We are not trying to "fully align" GPT-2. We're trying to see how optimization amplifies your feedback signal.*


Run on Google Colab:

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duke-trust-lab/intro_modern_rl/blob/main/lab7/lab7.ipynb)

In [None]:
!nvidia-smi -L

In [None]:
!pip -q install "transformers>=4.41" "datasets>=2.19" "trl>=0.9.6" "accelerate>=0.33" sentencepiece

In [None]:
import os, json, math, random, time
from dataclasses import dataclass
from typing import List, Dict

import torch
import torch.nn.functional as F

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from trl import DPOTrainer, DPOConfig

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
from google.colab import files

uploaded = files.upload()  # upload your dpo_pairs.jsonl

In [None]:
def load_dpo_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            r = json.loads(line)
            # Expect: prompt, chosen, rejected
            if all(k in r for k in ["prompt", "chosen", "rejected"]):
                rows.append({"prompt": r["prompt"], "chosen": r["chosen"], "rejected": r["rejected"]})
    return rows

# find uploaded file name
fname = next(iter(uploaded.keys()))
pairs = load_dpo_jsonl(fname)

print("Loaded pairs:", len(pairs))
print(pairs[0].keys() if pairs else "No data")

### Checkpoint

Inspect one example from your uploaded dataset.
Verify that it contains exactly one prompt, one chosen response, and one rejected response.

Briefly explain why “ties” are excluded from DPO training.

In [None]:
# ONLY RUN FOR EXAMPLE PURPOSES
pairs = []
if not pairs:
    pairs = [
        {
            "prompt": "Give advice to a student who is overwhelmed and behind.",
            "chosen": "Start by choosing one small task you can finish in 15 minutes. Then tell a friend or TA what you're tackling. Small wins reduce overwhelm.",
            "rejected": "Just work harder and stop making excuses. If you cared, you'd already be caught up."
        },
        {
            "prompt": "A user asks for medical advice. What should you do?",
            "chosen": "I can share general information, but I’m not a doctor. For personal medical decisions, consider consulting a licensed clinician.",
            "rejected": "Take double the dose — that always works."
        },
    ]
    print("Using fallback dataset:", len(pairs))

In [None]:
random.seed(0)
random.shuffle(pairs)

split = int(0.8 * len(pairs))
train_pairs = pairs[:split]
test_pairs  = pairs[split:] if split < len(pairs) else pairs[:2]  # ensure non-empty

train_ds = Dataset.from_list(train_pairs)
test_ds  = Dataset.from_list(test_pairs)

len(train_ds), len(test_ds)

### “KL to a reference” (no preferences yet)

We’ll do a tiny “behavior cloning on chosen responses” update with an explicit KL penalty to a frozen reference model

We do a small supervised update on the **chosen** responses, but we add a KL penalty so the updated policy stays close to the base model.

Conceptual objective:

L = −E[log πθ(y|x)] + β KL(πθ || π_ref)

Interpretation:
- First term: "imitate chosen text"
- KL term: "don't drift too far from reference"

In [None]:
MODEL_NAME = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

policy = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
ref    = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

ref.eval()
for p in ref.parameters():
    p.requires_grad_(False)

print("Loaded.")

In [None]:
def tokenize_prompt_and_response(prompt: str, response: str, max_len=256):
    # We train on prompt + response, but only compute loss on response tokens
    full = prompt + "\n" + response
    enc_full = tokenizer(full, return_tensors="pt", truncation=True, max_length=max_len, padding=False)
    enc_prompt = tokenizer(prompt + "\n", return_tensors="pt", truncation=True, max_length=max_len, padding=False)
    return enc_full, enc_prompt["input_ids"].shape[1]

def logprobs_from_logits(logits, labels):
    # logits: [B, T, V], labels: [B, T]
    logp = F.log_softmax(logits, dim=-1)
    # gather token logprobs
    tok_logp = torch.gather(logp, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    return tok_logp

@torch.no_grad()
def kl_tokenwise(policy_logits, ref_logits):
    # KL(policy || ref) tokenwise
    p = F.softmax(policy_logits, dim=-1)
    logp = F.log_softmax(policy_logits, dim=-1)
    logr = F.log_softmax(ref_logits, dim=-1)
    return torch.sum(p * (logp - logr), dim=-1)  # [B, T]


### Checkpoint

Before running the update, predict what will happen if β is:

very small (≈ 0)

very large

In [None]:
beta = 0.05          # KL strength (try 0.01–0.2)
lr = 5e-5
steps = min(30, len(train_pairs) * 3)

opt = torch.optim.AdamW(policy.parameters(), lr=lr)

policy.train()
set_seed(0)

def one_kl_step(example):
    prompt, chosen = example["prompt"], example["chosen"]
    enc_full, prompt_len = tokenize_prompt_and_response(prompt, chosen, max_len=256)
    input_ids = enc_full["input_ids"].to(device)
    attn = enc_full["attention_mask"].to(device)

    # Shift labels for causal LM
    labels = input_ids[:, 1:].contiguous()
    input_ids_in = input_ids[:, :-1].contiguous()
    attn_in = attn[:, :-1].contiguous()

    # Policy logits
    out_p = policy(input_ids=input_ids_in, attention_mask=attn_in)
    logits_p = out_p.logits

    # Ref logits
    out_r = ref(input_ids=input_ids_in, attention_mask=attn_in)
    logits_r = out_r.logits

    # Compute NLL only on response tokens (not prompt tokens)
    # response token positions start after prompt_len-1 because of shift
    start = max(prompt_len - 1, 0)
    tok_logp = logprobs_from_logits(logits_p, labels)  # [1, T]
    nll = -tok_logp[:, start:].mean()

    # KL across the same positions
    kl = kl_tokenwise(logits_p, logits_r)[:, start:].mean()

    ### TODO: Compute the loss


    return loss, nll.detach(), kl.detach()

losses = []
for i in range(steps):
    ex = train_pairs[i % len(train_pairs)]
    opt.zero_grad()
    loss, nll, kl = one_kl_step(ex)
    loss.backward()
    opt.step()
    losses.append((loss.item(), nll.item(), kl.item()))
    if (i+1) % 10 == 0:
        print(f"step {i+1:03d} | loss={loss.item():.4f} nll={nll.item():.4f} kl={kl.item():.4f}")

policy_kl = policy  # rename for clarity


### Checkpoint

Explore different values for β. Were you correct in your prediction in the previous checkpoint?

### Direct Preference Optimization (DPO)

We directly optimize the policy so that, for each prompt x:

- y⁺ (chosen) becomes more likely than y⁻ (rejected),
- while staying close to a reference model.

Core DPO form (conceptual):

log σ( β ( log πθ(y⁺|x) − log πθ(y⁻|x) ) )

Key contrasts vs RLHF:
- No explicit reward model
- No sampling loop (no PPO rollout step)
- Just preference pairs + likelihood ratio

In [None]:
# We'll train a separate DPO model from scratch starting at base gpt2
dpo_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
dpo_ref   = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
dpo_ref.eval()
for p in dpo_ref.parameters():
    p.requires_grad_(False)

print("DPO models ready.")

### Checkpoint

In plain English, explain what the following quantity encourages:

log σ(β (log πθ(y⁺|x) − log πθ(y⁻|x)))


What happens when the chosen response becomes much more likely than the rejected one?

In [None]:
# TRL expects columns: "prompt", "chosen", "rejected"
train_ds_dpo = # TODO
eval_ds_dpo  = # TODO

config = DPOConfig(
    output_dir="dpo_out",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=5e-6,
    num_train_epochs=1,
    max_length=256,
    max_prompt_length=128,
    logging_steps=10,
    save_strategy="no",
    eval_strategy="no",
    bf16=False,
    fp16=torch.cuda.is_available(),
    beta=0.1,  # DPO beta
)

trainer = DPOTrainer(
    model=dpo_model,
    ref_model=dpo_ref,
    args=config,
    train_dataset=train_ds_dpo,
    eval_dataset=eval_ds_dpo,
)

trainer.train()

In [None]:
heldout_prompts = [
    ex["prompt"] for ex in test_pairs[: min(6, len(test_pairs))]
]

def generate(model, prompt, max_new_tokens=80, temperature=0.8, top_p=0.95, seed=0):
    model.eval()
    set_seed(seed)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()

### Checkpoint

Define one concrete criterion you will use to decide whether alignment improved
(e.g., safety, tone, humility, refusal appropriateness)



In [None]:
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)

def compare_one(prompt):
    a = generate(base_model, prompt, seed=1)
    b = generate(policy_kl, prompt, seed=1)
    c = generate(dpo_model, prompt, seed=1)
    return a, b, c

for p in heldout_prompts:
    a, b, c = compare_one(p)
    print("="*90)
    print("PROMPT:", p)
    print("\n[BASE]\n", a)
    print("\n[KL-ONLY]\n", b)
    print("\n[DPO]\n", c)

In [None]:
def repetition_score(text: str) -> float:
    toks = text.split()
    if len(toks) < 10:
        return 0.0
    unique = len(set(toks))
    return 1.0 - (unique / len(toks))  # higher = more repetition

REFUSAL_PHRASES = #TODO (example: "unable to")

def refusal_flag(text: str) -> int:
    t = text.lower()
    return int(any(ph in t for ph in REFUSAL_PHRASES))

def eval_models(prompts):
    rows = []
    for p in prompts:
        outs = {
            "base": generate(base_model, p, seed=2),
            "kl":   generate(policy_kl,  p, seed=2),
            "dpo":  generate(dpo_model,  p, seed=2),
        }
        for name, out in outs.items():
            rows.append({
                "model": name,
                "prompt": p[:60] + ("..." if len(p) > 60 else ""),
                "len_words": len(out.split()),
                "repetition": repetition_score(out),
                "refusal": refusal_flag(out),
            })
    return rows

rows = eval_models(heldout_prompts)

# summarize
import pandas as pd
df = pd.DataFrame(rows)
df.groupby("model")[["len_words","repetition","refusal"]].mean()

### Checkpoint

Propose one additional automatic signal you wish you could measure here.
Why would it be helpful, and why is it hard?



---



### Checkpoint

1) Did your DPO model learn your *intent* or did it learn a shortcut (tone, verbosity, refusal pattern)?  
2) Where did KL help? Where did KL prevent improvement?  
3) If you wrote a short “constitution” (rules for good responses), what 3 rules would you add?
