# Full RLHF Pipeline: From Base Model to Aligned Assistant

Let's put everything together and build a complete LLM alignment pipeline!

## What You'll Learn

By the end of this notebook, you'll understand:
- The factory assembly line analogy: complete RLHF workflow
- End-to-end pipeline: SFT → Reward Model → PPO (or DPO)
- Practical implementation with real code
- Evaluation metrics: how to measure success
- Common pitfalls and how to avoid them
- Tips for production deployment

**Prerequisites:** Notebooks 1-5 (all previous RLHF notebooks)

**Time:** ~40 minutes

---
## The Big Picture: The Factory Assembly Line

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE FACTORY ASSEMBLY LINE ANALOGY                     │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Building an aligned AI assistant is like a factory...        │
    │                                                                │
    │  RAW MATERIALS:                                               │
    │    ┌─────────────────────────────────────────┐               │
    │    │  Base LLM (pre-trained on internet)     │               │
    │    │  Instruction Data (Q&A pairs)           │               │
    │    │  Preference Data (good vs bad)          │               │
    │    └─────────────────────────────────────────┘               │
    │                      ↓                                        │
    │  ASSEMBLY LINE:                                               │
    │    ┌─────────┐    ┌─────────┐    ┌─────────┐                │
    │    │ Station │ → │ Station │ → │ Station │                │
    │    │   1     │    │   2     │    │   3     │                │
    │    │  SFT    │    │ Reward  │    │  PPO    │                │
    │    └─────────┘    └─────────┘    └─────────┘                │
    │                      ↓                                        │
    │  QUALITY CONTROL:                                            │
    │    ┌─────────────────────────────────────────┐               │
    │    │  Evaluation: Win rate, Human eval, KL   │               │
    │    └─────────────────────────────────────────┘               │
    │                      ↓                                        │
    │  FINISHED PRODUCT:                                           │
    │    ┌─────────────────────────────────────────┐               │
    │    │  Helpful, Harmless, Honest AI Assistant │               │
    │    └─────────────────────────────────────────┘               │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, FancyArrowPatch
import warnings
warnings.filterwarnings('ignore')

# Check available libraries
print("LIBRARY CHECK")
print("="*60)

libraries = {
    'torch': 'PyTorch',
    'transformers': 'Transformers',
    'trl': 'TRL',
    'peft': 'PEFT',
    'datasets': 'Datasets',
}

available = {}
for module, name in libraries.items():
    try:
        lib = __import__(module)
        version = getattr(lib, '__version__', 'unknown')
        available[module] = True
        print(f"✓ {name}: {version}")
    except ImportError:
        available[module] = False
        print(f"✗ {name}: Not installed")

print("="*60)

In [None]:
# Visualize the complete RLHF pipeline

fig, ax = plt.subplots(figsize=(16, 10))
ax.set_xlim(0, 16)
ax.set_ylim(0, 12)
ax.axis('off')
ax.set_title('Complete RLHF Pipeline: From Base Model to Aligned Assistant', 
             fontsize=16, fontweight='bold')

# Stage 0: Input
input_box = FancyBboxPatch((0.5, 8.5), 3, 2.5, boxstyle="round,pad=0.1",
                            facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=2)
ax.add_patch(input_box)
ax.text(2, 10.5, 'INPUT', ha='center', fontsize=10, fontweight='bold', color='#1976d2')
ax.text(2, 9.8, 'Base LLM', ha='center', fontsize=9)
ax.text(2, 9.2, '(Llama, GPT-2, etc.)', ha='center', fontsize=8, color='#666')

# Stage 1: SFT
sft_box = FancyBboxPatch((4.5, 8.5), 3, 2.5, boxstyle="round,pad=0.1",
                          facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(sft_box)
ax.text(6, 10.5, 'STAGE 1: SFT', ha='center', fontsize=10, fontweight='bold', color='#388e3c')
ax.text(6, 9.6, 'Supervised', ha='center', fontsize=9)
ax.text(6, 9.1, 'Fine-Tuning', ha='center', fontsize=9)

# Data for SFT
sft_data = FancyBboxPatch((4.5, 6.5), 3, 1.5, boxstyle="round,pad=0.1",
                           facecolor='#fff3e0', edgecolor='#f57c00', linewidth=1)
ax.add_patch(sft_data)
ax.text(6, 7.5, 'Instruction Data', ha='center', fontsize=8, fontweight='bold')
ax.text(6, 6.9, '(Q&A pairs)', ha='center', fontsize=7, color='#666')

# Stage 2a: Reward Model (for PPO path)
rm_box = FancyBboxPatch((8.5, 9), 3, 2, boxstyle="round,pad=0.1",
                         facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=2)
ax.add_patch(rm_box)
ax.text(10, 10.5, 'STAGE 2a', ha='center', fontsize=9, fontweight='bold', color='#7b1fa2')
ax.text(10, 9.9, 'Reward Model', ha='center', fontsize=9)
ax.text(10, 9.3, '(for PPO)', ha='center', fontsize=8, color='#666')

# Stage 3a: PPO
ppo_box = FancyBboxPatch((12.5, 9), 3, 2, boxstyle="round,pad=0.1",
                          facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax.add_patch(ppo_box)
ax.text(14, 10.5, 'STAGE 3a', ha='center', fontsize=9, fontweight='bold', color='#f57c00')
ax.text(14, 9.9, 'PPO Training', ha='center', fontsize=9)
ax.text(14, 9.3, '(RL optimization)', ha='center', fontsize=8, color='#666')

# Stage 2b: DPO (simpler path)
dpo_box = FancyBboxPatch((10.5, 5.5), 3, 2, boxstyle="round,pad=0.1",
                          facecolor='#bbdefb', edgecolor='#1976d2', linewidth=3)
ax.add_patch(dpo_box)
ax.text(12, 7, 'STAGE 2b', ha='center', fontsize=9, fontweight='bold', color='#1976d2')
ax.text(12, 6.4, 'DPO Training', ha='center', fontsize=9)
ax.text(12, 5.8, '(Simpler!)', ha='center', fontsize=8, color='#666')

# Preference data (shared)
pref_data = FancyBboxPatch((8.5, 3.5), 3, 1.5, boxstyle="round,pad=0.1",
                            facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=1)
ax.add_patch(pref_data)
ax.text(10, 4.5, 'Preference Data', ha='center', fontsize=8, fontweight='bold')
ax.text(10, 3.9, '(chosen vs rejected)', ha='center', fontsize=7, color='#666')

# Output: Aligned model
output_box = FancyBboxPatch((12.5, 1), 3, 2, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(output_box)
ax.text(14, 2.5, 'OUTPUT', ha='center', fontsize=10, fontweight='bold', color='#388e3c')
ax.text(14, 1.8, 'Aligned Model', ha='center', fontsize=9)
ax.text(14, 1.3, '✓', ha='center', fontsize=14, color='#388e3c')

# Arrows
# Input → SFT
ax.annotate('', xy=(4.4, 9.75), xytext=(3.6, 9.75),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# SFT → RM
ax.annotate('', xy=(8.4, 10), xytext=(7.6, 9.75),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# RM → PPO
ax.annotate('', xy=(12.4, 10), xytext=(11.6, 10),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# SFT → DPO
ax.annotate('', xy=(10.4, 6.5), xytext=(7.6, 8.4),
            arrowprops=dict(arrowstyle='->', lw=2, color='#1976d2'))

# PPO → Output
ax.annotate('', xy=(14, 3.1), xytext=(14, 8.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# DPO → Output
ax.annotate('', xy=(13, 3.1), xytext=(12, 5.4),
            arrowprops=dict(arrowstyle='->', lw=2, color='#1976d2'))

# Data arrows
ax.annotate('', xy=(6, 8.4), xytext=(6, 8.1),
            arrowprops=dict(arrowstyle='->', lw=1, color='#f57c00'))
ax.annotate('', xy=(10, 8.9), xytext=(10, 5.1),
            arrowprops=dict(arrowstyle='->', lw=1, color='#d32f2f'))
ax.annotate('', xy=(12, 5.4), xytext=(11, 5.1),
            arrowprops=dict(arrowstyle='->', lw=1, color='#d32f2f'))

# Labels
ax.text(9, 7.5, 'PPO Path\n(Complex)', ha='center', fontsize=8, color='#f57c00')
ax.text(9, 6, 'DPO Path\n(Simple)', ha='center', fontsize=8, color='#1976d2', fontweight='bold')

plt.tight_layout()
plt.show()

print("\nTWO PATHS TO ALIGNMENT:")
print("  1. PPO Path: SFT → Reward Model → PPO (more complex, more control)")
print("  2. DPO Path: SFT → DPO (simpler, often just as good!)")

---
## Complete Implementation: The DPO Path (Recommended)

```
    ┌────────────────────────────────────────────────────────────────┐
    │              RECOMMENDED PIPELINE: SFT + DPO                   │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  For most projects, we recommend the DPO path:                │
    │                                                                │
    │    Step 1: SFT (Supervised Fine-Tuning)                       │
    │      • Teaches model to follow instructions                  │
    │      • Input: Base model + instruction dataset               │
    │      • Output: Instruction-following model                   │
    │                                                                │
    │    Step 2: DPO (Direct Preference Optimization)               │
    │      • Aligns model with human preferences                   │
    │      • Input: SFT model + preference dataset                 │
    │      • Output: Aligned model                                 │
    │                                                                │
    │  WHY DPO?                                                     │
    │    • Simpler (no reward model training)                      │
    │    • Faster (2 stages instead of 3)                          │
    │    • Often just as effective as PPO                          │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Complete RLHF Pipeline Code

print("COMPLETE RLHF PIPELINE (DPO PATH)")
print("="*70)

complete_code = '''
# ============================================================
# COMPLETE RLHF PIPELINE: SFT + DPO
# ============================================================

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig

# ============================================================
# CONFIGURATION
# ============================================================

BASE_MODEL = "meta-llama/Llama-2-7b-hf"  # Or any base model
SFT_OUTPUT = "./sft_model"
DPO_OUTPUT = "./dpo_aligned_model"

# LoRA configuration (for memory efficiency)
LORA_CONFIG = LoraConfig(
    r=16,                         # Rank
    lora_alpha=32,                # Scaling
    lora_dropout=0.05,            # Regularization
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

# ============================================================
# STEP 1: SUPERVISED FINE-TUNING (SFT)
# ============================================================

print("Step 1: Supervised Fine-Tuning")
print("-" * 40)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token

# Load instruction dataset
# Popular choices: openassistant-guanaco, dolly-15k, alpaca
sft_dataset = load_dataset(
    "timdettmers/openassistant-guanaco",
    split="train"
)
print(f"SFT dataset size: {len(sft_dataset)} examples")

# Configure SFT training
sft_config = SFTConfig(
    output_dir=SFT_OUTPUT,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    max_seq_length=512,
    packing=True,
    learning_rate=2e-4,
    logging_steps=25,
    save_steps=500,
)

# Create SFT trainer
sft_trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=sft_dataset,
    tokenizer=tokenizer,
    peft_config=LORA_CONFIG,
)

# Train!
print("Training SFT model...")
sft_trainer.train()
sft_trainer.save_model(SFT_OUTPUT)
print(f"SFT model saved to {SFT_OUTPUT}")

# ============================================================
# STEP 2: DIRECT PREFERENCE OPTIMIZATION (DPO)
# ============================================================

print("\nStep 2: Direct Preference Optimization")
print("-" * 40)

# Load SFT model
model = AutoModelForCausalLM.from_pretrained(
    SFT_OUTPUT,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(SFT_OUTPUT)

# Load preference dataset
# Popular choices: hh-rlhf, ultrafeedback
preference_dataset = load_dataset(
    "Anthropic/hh-rlhf",
    split="train[:5000]"  # Start small!
)
print(f"Preference dataset size: {len(preference_dataset)} examples")

# Process preference data to DPO format
def format_for_dpo(example):
    """Convert hh-rlhf format to DPO format."""
    chosen = example["chosen"]
    rejected = example["rejected"]
    
    # Extract prompt (Human turn)
    if "\\nAssistant:" in chosen:
        prompt = chosen.split("\\nAssistant:")[0].replace("Human: ", "")
        chosen_response = chosen.split("\\nAssistant:")[-1]
        rejected_response = rejected.split("\\nAssistant:")[-1] if "\\nAssistant:" in rejected else rejected
    else:
        prompt = chosen[:100]
        chosen_response = chosen
        rejected_response = rejected
    
    return {
        "prompt": prompt.strip(),
        "chosen": chosen_response.strip(),
        "rejected": rejected_response.strip(),
    }

dpo_dataset = preference_dataset.map(format_for_dpo)

# Configure DPO training
dpo_config = DPOConfig(
    output_dir=DPO_OUTPUT,
    beta=0.1,                      # KL penalty coefficient
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    max_length=512,
    max_prompt_length=256,
    learning_rate=5e-5,
    logging_steps=25,
    save_steps=500,
)

# Create DPO trainer
dpo_trainer = DPOTrainer(
    model=model,
    args=dpo_config,
    train_dataset=dpo_dataset,
    tokenizer=tokenizer,
    peft_config=LORA_CONFIG,  # Continue with LoRA
)

# Train!
print("Training DPO model...")
dpo_trainer.train()
dpo_trainer.save_model(DPO_OUTPUT)
print(f"Aligned model saved to {DPO_OUTPUT}")

# ============================================================
# STEP 3: EVALUATION
# ============================================================

print("\nStep 3: Evaluation")
print("-" * 40)

# Load aligned model
aligned_model = AutoModelForCausalLM.from_pretrained(DPO_OUTPUT)
tokenizer = AutoTokenizer.from_pretrained(DPO_OUTPUT)

# Test generation
test_prompts = [
    "What is the capital of France?",
    "How do I learn programming?",
    "Explain machine learning to a 5-year-old.",
]

for prompt in test_prompts:
    inputs = tokenizer(f"Human: {prompt}\\nAssistant:", return_tensors="pt")
    outputs = aligned_model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True,
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"\\nPrompt: {prompt}")
    print(f"Response: {response.split(\'Assistant:\')[-1].strip()}")

print("\\n" + "=" * 70)
print("PIPELINE COMPLETE!")
print("=" * 70)
'''

print(complete_code)
print("="*70)

---
## Evaluation Metrics: Measuring Alignment Success

```
    ┌────────────────────────────────────────────────────────────────┐
    │              EVALUATION METRICS                                │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  1. WIN RATE (Automated)                                      │
    │     Compare aligned model vs baseline on same prompts         │
    │     Use reward model or GPT-4 as judge                        │
    │     Target: > 50% (better than baseline)                      │
    │                                                                │
    │  2. REWARD SCORE (Automated)                                  │
    │     Average reward model score on held-out prompts            │
    │     Higher = better aligned with preferences                  │
    │     Watch for reward hacking!                                 │
    │                                                                │
    │  3. KL DIVERGENCE (Automated)                                 │
    │     How different from reference model?                       │
    │     Target: 5-15 (enough change but not too much)            │
    │                                                                │
    │  4. PERPLEXITY (Automated)                                    │
    │     Language quality on held-out text                        │
    │     Should not increase significantly                        │
    │                                                                │
    │  5. HUMAN EVALUATION (Gold Standard!)                         │
    │     Real humans rate responses                               │
    │     Most reliable but expensive                              │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Evaluation utilities

def compute_win_rate(model_a_scores, model_b_scores):
    """
    Compute win rate of model A over model B.
    
    Args:
        model_a_scores: Scores for model A responses
        model_b_scores: Scores for model B responses
    
    Returns:
        Win rate (0-1), tie rate, lose rate
    """
    wins = sum(a > b for a, b in zip(model_a_scores, model_b_scores))
    ties = sum(a == b for a, b in zip(model_a_scores, model_b_scores))
    losses = sum(a < b for a, b in zip(model_a_scores, model_b_scores))
    total = len(model_a_scores)
    
    return wins/total, ties/total, losses/total


def compute_kl_divergence(log_probs_policy, log_probs_ref):
    """
    Compute KL divergence between policy and reference.
    
    KL(π || π_ref) = E_π[log π - log π_ref]
    """
    kl = (log_probs_policy - log_probs_ref).mean()
    return kl.item()


# Simulate evaluation results
print("EVALUATION EXAMPLE")
print("="*60)

np.random.seed(42)
n_eval = 100

# Simulated scores
base_scores = np.random.normal(5, 1.5, n_eval)  # Base model: mean 5
sft_scores = np.random.normal(6, 1.5, n_eval)   # SFT: mean 6
aligned_scores = np.random.normal(7.5, 1.2, n_eval)  # Aligned: mean 7.5

# Compute win rates
win_sft, tie_sft, lose_sft = compute_win_rate(sft_scores, base_scores)
win_aligned, tie_aligned, lose_aligned = compute_win_rate(aligned_scores, sft_scores)

print("\nWin Rate Results:")
print(f"  SFT vs Base: {win_sft:.1%} wins, {tie_sft:.1%} ties, {lose_sft:.1%} losses")
print(f"  Aligned vs SFT: {win_aligned:.1%} wins, {tie_aligned:.1%} ties, {lose_aligned:.1%} losses")

print("\nAverage Reward Scores:")
print(f"  Base Model: {base_scores.mean():.2f} ± {base_scores.std():.2f}")
print(f"  SFT Model: {sft_scores.mean():.2f} ± {sft_scores.std():.2f}")
print(f"  Aligned Model: {aligned_scores.mean():.2f} ± {aligned_scores.std():.2f}")

# Simulated KL divergence
kl_sft = 3.2  # After SFT
kl_aligned = 8.5  # After alignment

print("\nKL Divergence from Base:")
print(f"  SFT Model: {kl_sft:.1f}")
print(f"  Aligned Model: {kl_aligned:.1f}")
print(f"  Target range: 5-15 (enough change but stable)")

print("\n" + "="*60)

In [None]:
# Visualize evaluation results

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top-left: Reward score distribution
ax1 = axes[0, 0]
ax1.hist(base_scores, bins=20, alpha=0.6, label='Base', color='#ef5350')
ax1.hist(sft_scores, bins=20, alpha=0.6, label='SFT', color='#ff9800')
ax1.hist(aligned_scores, bins=20, alpha=0.6, label='Aligned', color='#4caf50')
ax1.set_xlabel('Reward Score', fontsize=11)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Reward Score Distribution', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Top-right: Win rate comparison
ax2 = axes[0, 1]
comparisons = ['SFT vs Base', 'Aligned vs SFT']
wins = [win_sft * 100, win_aligned * 100]
ties = [tie_sft * 100, tie_aligned * 100]
losses = [lose_sft * 100, lose_aligned * 100]

x = np.arange(len(comparisons))
width = 0.25

ax2.bar(x - width, wins, width, label='Wins', color='#4caf50')
ax2.bar(x, ties, width, label='Ties', color='#ff9800')
ax2.bar(x + width, losses, width, label='Losses', color='#ef5350')

ax2.set_ylabel('Percentage', fontsize=11)
ax2.set_title('Win Rate Analysis', fontsize=12, fontweight='bold')
ax2.set_xticks(x)
ax2.set_xticklabels(comparisons)
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')

# Bottom-left: Training progress
ax3 = axes[1, 0]
steps = np.arange(0, 1001, 100)
reward_progress = 5 + 2 * (1 - np.exp(-steps/300)) + np.random.randn(len(steps)) * 0.2
kl_progress = 0.5 * steps / 100 + np.random.randn(len(steps)) * 0.5

ax3.plot(steps, reward_progress, 'g-', linewidth=2, label='Reward Score')
ax3.set_xlabel('Training Steps', fontsize=11)
ax3.set_ylabel('Reward Score', fontsize=11, color='g')
ax3.tick_params(axis='y', labelcolor='g')

ax3_twin = ax3.twinx()
ax3_twin.plot(steps, kl_progress, 'b--', linewidth=2, label='KL Divergence')
ax3_twin.set_ylabel('KL Divergence', fontsize=11, color='b')
ax3_twin.tick_params(axis='y', labelcolor='b')

ax3.set_title('Training Progress', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3)

# Bottom-right: Quality metrics
ax4 = axes[1, 1]
metrics = ['Helpfulness', 'Harmlessness', 'Honesty', 'Coherence']
base_metrics = [5, 5, 5, 7]
aligned_metrics = [8, 7, 8, 8]

x = np.arange(len(metrics))
width = 0.35

ax4.bar(x - width/2, base_metrics, width, label='Base', color='#ef5350', alpha=0.8)
ax4.bar(x + width/2, aligned_metrics, width, label='Aligned', color='#4caf50', alpha=0.8)

ax4.set_ylabel('Score (1-10)', fontsize=11)
ax4.set_title('Quality Metrics (Human Evaluation)', fontsize=12, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels(metrics)
ax4.legend()
ax4.set_ylim(0, 10)
ax4.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---
## Common Pitfalls and Solutions

```
    ┌────────────────────────────────────────────────────────────────┐
    │              COMMON PITFALLS                                   │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  1. REWARD HACKING                                            │
    │     Problem: Model finds exploits in reward model            │
    │     Symptoms: High reward but bad outputs                    │
    │     Solution: KL penalty, diverse eval, human checks         │
    │                                                                │
    │  2. MODE COLLAPSE                                             │
    │     Problem: Model gives same response to everything         │
    │     Symptoms: Very low diversity in outputs                  │
    │     Solution: Increase KL penalty, reduce learning rate      │
    │                                                                │
    │  3. CATASTROPHIC FORGETTING                                   │
    │     Problem: Model forgets base knowledge                    │
    │     Symptoms: Can't answer basic questions anymore           │
    │     Solution: Use LoRA, lower learning rate, more KL         │
    │                                                                │
    │  4. OVERFITTING TO PREFERENCES                                │
    │     Problem: Model memorizes preference data                 │
    │     Symptoms: Perfect on train, bad on test                  │
    │     Solution: Early stopping, regularization, more data      │
    │                                                                │
    │  5. LENGTH BIAS                                               │
    │     Problem: Model learns "longer = better"                  │
    │     Symptoms: Verbose, repetitive outputs                    │
    │     Solution: Length normalization in reward/loss            │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Pitfall detection utilities

def check_for_pitfalls(training_stats):
    """
    Analyze training stats for common pitfalls.
    
    Args:
        training_stats: Dict with 'rewards', 'kl', 'loss', 'response_lengths'
    
    Returns:
        List of warnings
    """
    warnings = []
    
    # Check for reward hacking
    if training_stats['rewards'][-1] > 10:
        warnings.append("⚠️ REWARD HACKING: Suspiciously high reward scores")
    
    # Check for high KL (mode collapse risk)
    if training_stats['kl'][-1] > 15:
        warnings.append("⚠️ HIGH KL: Risk of catastrophic forgetting")
    
    if training_stats['kl'][-1] < 1:
        warnings.append("⚠️ LOW KL: Model may not be learning enough")
    
    # Check for length bias
    avg_length_start = np.mean(training_stats['response_lengths'][:10])
    avg_length_end = np.mean(training_stats['response_lengths'][-10:])
    if avg_length_end > avg_length_start * 1.5:
        warnings.append("⚠️ LENGTH BIAS: Responses getting longer")
    
    # Check for loss spike
    if max(training_stats['loss']) > 3 * training_stats['loss'][0]:
        warnings.append("⚠️ LOSS SPIKE: Training may be unstable")
    
    return warnings


# Example: Good training run
print("PITFALL ANALYSIS EXAMPLES")
print("="*60)

# Good training
good_stats = {
    'rewards': [5, 5.5, 6, 6.3, 6.5, 6.8, 7, 7.1, 7.2, 7.3],
    'kl': [0, 1, 2, 3, 4, 5, 6, 7, 8, 8.5],
    'loss': [0.7, 0.6, 0.55, 0.5, 0.48, 0.46, 0.45, 0.44, 0.43, 0.42],
    'response_lengths': [100, 105, 108, 110, 112, 115, 118, 120, 122, 125],
}

print("\nGood Training Run:")
warnings = check_for_pitfalls(good_stats)
if not warnings:
    print("  ✓ No issues detected!")
else:
    for w in warnings:
        print(f"  {w}")

# Bad training (reward hacking)
bad_stats = {
    'rewards': [5, 6, 8, 10, 12, 14, 16, 18, 20, 25],  # Too high!
    'kl': [0, 2, 5, 10, 15, 20, 25, 30, 35, 40],  # Too high!
    'loss': [0.7, 0.5, 0.3, 0.5, 1.0, 0.3, 2.0, 0.4, 0.3, 0.2],  # Spiky
    'response_lengths': [100, 120, 150, 200, 280, 350, 450, 550, 700, 900],  # Growing!
}

print("\nProblematic Training Run:")
warnings = check_for_pitfalls(bad_stats)
for w in warnings:
    print(f"  {w}")

print("\n" + "="*60)

---
## Production Tips

```
    ┌────────────────────────────────────────────────────────────────┐
    │              PRODUCTION TIPS                                   │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  DATA QUALITY:                                                │
    │    • Quality >> Quantity (1K good > 100K bad)                │
    │    • Clean preference data manually if possible              │
    │    • Remove low-quality/contradictory examples               │
    │                                                                │
    │  TRAINING:                                                    │
    │    • Start with smaller model to debug pipeline              │
    │    • Use LoRA for memory efficiency                          │
    │    • Save checkpoints frequently                             │
    │    • Monitor KL divergence closely                           │
    │                                                                │
    │  EVALUATION:                                                  │
    │    • Always test on held-out data                           │
    │    • Use multiple metrics (not just reward)                 │
    │    • Include human evaluation when possible                 │
    │    • Test for harmful outputs explicitly                    │
    │                                                                │
    │  DEPLOYMENT:                                                                                                                                                                                                               │
    │    • A/B test against baseline                              │
    │    • Monitor outputs in production                          │
    │    • Have fallback mechanisms                               │
    │    • Plan for iterative improvement                         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Production checklist

print("PRODUCTION CHECKLIST")
print("="*60)

checklist = [
    ("Data", [
        "☐ Clean instruction dataset (remove duplicates, errors)",
        "☐ Clean preference dataset (remove contradictions)",
        "☐ Create held-out test sets",
        "☐ Check for data leakage",
    ]),
    ("SFT Stage", [
        "☐ Validate base model loads correctly",
        "☐ Test tokenizer (padding, special tokens)",
        "☐ Run small training to verify pipeline",
        "☐ Save intermediate checkpoints",
    ]),
    ("Alignment Stage", [
        "☐ Choose DPO (simple) or PPO (flexible)",
        "☐ Set appropriate beta/KL coefficient",
        "☐ Monitor training metrics (reward, KL, loss)",
        "☐ Watch for reward hacking",
    ]),
    ("Evaluation", [
        "☐ Test on held-out prompts",
        "☐ Compare to baseline (SFT model)",
        "☐ Check for harmful outputs",
        "☐ Run human evaluation if possible",
    ]),
    ("Deployment", [
        "☐ Merge LoRA weights (if using LoRA)",
        "☐ Optimize for inference (quantization)",
        "☐ Set up monitoring",
        "☐ Plan iteration cycle",
    ]),
]

for category, items in checklist:
    print(f"\n{category}:")
    for item in items:
        print(f"  {item}")

print("\n" + "="*60)

---
## Summary: RLHF Key Takeaways

### The Complete Pipeline

```
Base LLM → SFT → (Reward Model →) PPO/DPO → Aligned Model
```

### Two Paths to Alignment

| Path | Stages | Complexity | When to Use |
|------|--------|------------|-------------|
| **DPO** | SFT → DPO | Simple | Most projects |
| **PPO** | SFT → RM → PPO | Complex | Online learning, fine control |

### Key Metrics

| Metric | Purpose | Target |
|--------|---------|--------|
| Win Rate | Compare to baseline | > 50% |
| Reward Score | Preference alignment | Higher is better |
| KL Divergence | Model stability | 5-15 |
| Human Eval | True quality | Gold standard |

### Common Pitfalls

| Pitfall | Symptom | Solution |
|---------|---------|----------|
| Reward Hacking | High reward, bad output | More KL penalty |
| Mode Collapse | Repetitive outputs | Lower learning rate |
| Forgetting | Lost base knowledge | Use LoRA |
| Length Bias | Verbose responses | Length normalization |

---
## Test Your Understanding

**1. Why is DPO recommended over PPO for most projects?**
<details>
<summary>Click to reveal answer</summary>
DPO is simpler (no reward model needed), faster (2 stages vs 3), and often just as effective. PPO is more complex and can be unstable. DPO directly optimizes on preferences without the intermediate reward model step.
</details>

**2. What does KL divergence measure and why is it important?**
<details>
<summary>Click to reveal answer</summary>
KL divergence measures how different the aligned model is from the reference (SFT) model. It's important because:
- Too low: Model isn't learning the preferences
- Too high: Risk of catastrophic forgetting or reward hacking
- Target range: 5-15 is typically good
</details>

**3. What is reward hacking and how do you detect it?**
<details>
<summary>Click to reveal answer</summary>
Reward hacking is when the model exploits weaknesses in the reward model to get high scores without actually being helpful. Signs include:
- Suspiciously high reward scores (> 10)
- Outputs that look good to RM but are actually bad
- Repetitive or formulaic responses

Prevention: Strong KL penalty, diverse evaluation, human checks.
</details>

**4. Why should you start with a smaller model when developing?**
<details>
<summary>Click to reveal answer</summary>
Starting with a smaller model (e.g., GPT-2 instead of Llama-7B) lets you:
- Debug the pipeline quickly
- Iterate faster on data processing
- Test hyperparameters cheaply
- Identify issues before expensive training

Once the pipeline works, scale up to the larger model.
</details>

**5. What's the most reliable evaluation method for alignment?**
<details>
<summary>Click to reveal answer</summary>
Human evaluation is the gold standard! While automated metrics (reward score, win rate, KL) are useful for development, only humans can truly judge:
- Whether responses are actually helpful
- Subtle issues like tone, appropriateness
- Edge cases and safety concerns

Use automated metrics for iteration, human eval for final assessment.
</details>

---
## Congratulations!

You've completed the **RLHF** section! You now understand:

- ✅ What RLHF is and why it matters
- ✅ How reward models learn human preferences
- ✅ How PPO optimizes language models
- ✅ DPO as a simpler alternative
- ✅ Using the TRL library
- ✅ Building complete alignment pipelines

**Next Steps:**

Move on to **[Applications](../applications/)** to see RL in action across different domains - games, robotics, recommendations, and more!

---

*RLHF: "Teaching AI to be helpful, harmless, and honest - one preference at a time!"*