# What is RLHF? The Secret Behind ChatGPT

Welcome to RLHF - the breakthrough technique that transformed language models into helpful AI assistants like ChatGPT!

## What You'll Learn

By the end of this notebook, you'll understand:
- Why raw language models need alignment (with a talented-but-clueless student analogy!)
- The three-stage RLHF pipeline, step by step
- How reward models learn human preferences
- Why PPO is used for fine-tuning
- The history from InstructGPT to ChatGPT
- Code demonstrations of each concept!

**Prerequisites:** Notebooks in `advanced-algorithms/` (PPO)

**Time:** ~45 minutes

---
## The Big Picture: The Brilliant But Clueless Student

Imagine a student who has read EVERY book ever written:

```
    ┌────────────────────────────────────────────────────────────────┐
    │           THE ALIGNMENT PROBLEM: A STUDENT ANALOGY             │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  PRETRAINED LLM = Brilliant Student Who Read Everything       │
    │                                                                │
    │  KNOWS:                        DOESN'T KNOW:                  │
    │  • All of Wikipedia            • What answer YOU want         │
    │  • Every textbook              • How to be helpful            │
    │  • All of Reddit               • When to refuse requests      │
    │  • Code from GitHub            • How to be safe               │
    │                                                                │
    │  You ask: "What's the capital of France?"                     │
    │                                                                │
    │  HELPFUL response:   "The capital of France is Paris."        │
    │                                                                │
    │  UNHELPFUL (but valid text completion):                       │
    │    • "What's the capital of Germany? Berlin..." (off-topic)   │
    │    • "I'll tell you a story about France..." (rambling)       │
    │    • "The capital of France is Paris. But did you know..."    │
    │      (overly verbose)                                         │
    │                                                                │
    │  THE PROBLEM: Knowing everything ≠ Knowing how to help        │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

**RLHF is like hiring a tutor to teach this brilliant student HOW to use their knowledge helpfully!**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch, Circle, Rectangle
from matplotlib.colors import LinearSegmentedColormap
import torch
import torch.nn as nn
import torch.nn.functional as F

# Visualize the alignment problem
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Before RLHF
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('BEFORE RLHF\n(Pretrained LLM)', fontsize=14, fontweight='bold', color='#d32f2f')

# LLM box
llm_box = FancyBboxPatch((2, 3), 6, 4, boxstyle="round,pad=0.1",
                          facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=3)
ax1.add_patch(llm_box)
ax1.text(5, 5.5, 'Pretrained LLM', ha='center', fontsize=14, fontweight='bold')
ax1.text(5, 4.5, '"I know EVERYTHING!"', ha='center', fontsize=11, style='italic')
ax1.text(5, 3.5, '"...but what do you want?"', ha='center', fontsize=10, color='#666')

# Outputs
outputs = [
    ('Helpful', '#c8e6c9', 0.2),
    ('Harmful', '#ffcdd2', 0.3),
    ('Off-topic', '#fff3e0', 0.3),
    ('Rambling', '#e3f2fd', 0.2)
]

ax1.text(5, 8.5, 'Possible Outputs:', ha='center', fontsize=11)
start_x = 1
for label, color, width in outputs:
    box = FancyBboxPatch((start_x, 7.5), width*10, 0.7, boxstyle="round,pad=0.05",
                          facecolor=color, edgecolor='gray', linewidth=1)
    ax1.add_patch(box)
    ax1.text(start_x + width*5, 7.85, label, ha='center', fontsize=8)
    start_x += width * 10 + 0.2

# Right: After RLHF
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('AFTER RLHF\n(Aligned LLM)', fontsize=14, fontweight='bold', color='#388e3c')

# LLM box
llm_box2 = FancyBboxPatch((2, 3), 6, 4, boxstyle="round,pad=0.1",
                           facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax2.add_patch(llm_box2)
ax2.text(5, 5.5, 'RLHF-Trained LLM', ha='center', fontsize=14, fontweight='bold')
ax2.text(5, 4.5, '"I know EVERYTHING!"', ha='center', fontsize=11, style='italic')
ax2.text(5, 3.5, '"AND I know how to help!"', ha='center', fontsize=10, color='#388e3c', fontweight='bold')

# Outputs - mostly helpful now!
outputs2 = [
    ('Helpful', '#c8e6c9', 0.85),
    ('Other', '#e0e0e0', 0.15)
]

ax2.text(5, 8.5, 'Possible Outputs:', ha='center', fontsize=11)
start_x = 1
for label, color, width in outputs2:
    box = FancyBboxPatch((start_x, 7.5), width*10, 0.7, boxstyle="round,pad=0.05",
                          facecolor=color, edgecolor='gray', linewidth=1)
    ax2.add_patch(box)
    ax2.text(start_x + width*5, 7.85, label, ha='center', fontsize=9, fontweight='bold')
    start_x += width * 10 + 0.2

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("THE ALIGNMENT PROBLEM")
print("="*70)
print("""
Pretrained LLMs are trained on ONE objective:
  → Predict the next token (word)

But we ACTUALLY want:
  → Helpful responses
  → Honest information
  → Safe behavior
  → Follow instructions

RLHF bridges this gap by teaching the model human preferences!
""")
print("="*70)

---
## The RLHF Pipeline: Three Stages

RLHF transforms a pretrained LLM into a helpful assistant through three stages:

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    THE RLHF PIPELINE                           │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │   ┌─────────────────┐                                         │
    │   │  PRETRAINED LLM │  GPT, LLaMA, etc.                       │
    │   └────────┬────────┘                                         │
    │            │                                                   │
    │            ▼                                                   │
    │   ┌─────────────────┐                                         │
    │   │  STAGE 1: SFT   │  Supervised Fine-Tuning                 │
    │   │                 │  "Learn from example conversations"     │
    │   └────────┬────────┘                                         │
    │            │                                                   │
    │            ▼                                                   │
    │   ┌─────────────────┐                                         │
    │   │  STAGE 2: RM    │  Reward Model Training                  │
    │   │                 │  "Learn what humans prefer"             │
    │   └────────┬────────┘                                         │
    │            │                                                   │
    │            ▼                                                   │
    │   ┌─────────────────┐                                         │
    │   │  STAGE 3: PPO   │  Reinforcement Learning                 │
    │   │                 │  "Optimize for human preferences"       │
    │   └────────┬────────┘                                         │
    │            │                                                   │
    │            ▼                                                   │
    │   ┌─────────────────┐                                         │
    │   │  ALIGNED LLM    │  ChatGPT, Claude, etc.                  │
    │   └─────────────────┘                                         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

Let's understand each stage in detail!

In [None]:
# Visualize the three stages
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 12)
ax.axis('off')

# Title
ax.text(7, 11.5, 'The RLHF Training Pipeline', ha='center', fontsize=18, fontweight='bold')

stages = [
    {
        'y': 9,
        'title': 'Stage 1: Supervised Fine-Tuning (SFT)',
        'color': '#bbdefb',
        'edge': '#1976d2',
        'icon': '1',
        'description': 'Train on high-quality demonstrations',
        'analogy': 'Like a student learning from example answers'
    },
    {
        'y': 6,
        'title': 'Stage 2: Reward Model (RM)',
        'color': '#fff3e0',
        'edge': '#f57c00',
        'icon': '2',
        'description': 'Train a model to predict human preferences',
        'analogy': 'Like training a judge to grade essays'
    },
    {
        'y': 3,
        'title': 'Stage 3: PPO Fine-Tuning',
        'color': '#c8e6c9',
        'edge': '#388e3c',
        'icon': '3',
        'description': 'Use RL to maximize reward model scores',
        'analogy': 'Like a student improving based on grades'
    }
]

for stage in stages:
    y = stage['y']
    
    # Stage box
    box = FancyBboxPatch((1, y-1), 12, 2, boxstyle="round,pad=0.1",
                          facecolor=stage['color'], edgecolor=stage['edge'], linewidth=3)
    ax.add_patch(box)
    
    # Stage number
    circle = Circle((2, y), 0.4, facecolor=stage['edge'], edgecolor='white', linewidth=2)
    ax.add_patch(circle)
    ax.text(2, y, stage['icon'], ha='center', va='center', fontsize=14, 
            fontweight='bold', color='white')
    
    # Title and description
    ax.text(3, y+0.4, stage['title'], fontsize=13, fontweight='bold', color=stage['edge'])
    ax.text(3, y-0.1, stage['description'], fontsize=11)
    ax.text(3, y-0.6, stage['analogy'], fontsize=10, style='italic', color='#666')

# Arrows between stages
for i in range(len(stages) - 1):
    y1 = stages[i]['y'] - 1
    y2 = stages[i+1]['y'] + 1
    ax.annotate('', xy=(7, y2), xytext=(7, y1),
                arrowprops=dict(arrowstyle='->', lw=3, color='#666'))

# Result
result_box = FancyBboxPatch((4, 0.2), 6, 1.2, boxstyle="round,pad=0.1",
                             facecolor='#f3e5f5', edgecolor='#7b1fa2', linewidth=3)
ax.add_patch(result_box)
ax.text(7, 0.8, 'Aligned, Helpful AI Assistant!', ha='center', fontsize=13, 
        fontweight='bold', color='#7b1fa2')

ax.annotate('', xy=(7, 1.4), xytext=(7, 2),
            arrowprops=dict(arrowstyle='->', lw=3, color='#7b1fa2'))

plt.tight_layout()
plt.show()

---
## Stage 1: Supervised Fine-Tuning (SFT)

The first stage teaches the model what good responses look like:

```
    ┌────────────────────────────────────────────────────────────────┐
    │            STAGE 1: SUPERVISED FINE-TUNING (SFT)               │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  WHAT: Train the LLM on high-quality demonstration data       │
    │                                                                │
    │  HOW:                                                          │
    │    1. Collect prompts from users                              │
    │    2. Human experts write ideal responses                     │
    │    3. Fine-tune the LLM on these (prompt, response) pairs     │
    │                                                                │
    │  EXAMPLE:                                                      │
    │    Prompt: "Explain quantum computing to a 5-year-old"        │
    │    Response: "Imagine you have a magic coin that can be       │
    │              both heads AND tails at the same time..."        │
    │                                                                │
    │  ANALOGY: Learning by Imitation                               │
    │    Like a student reading model essay answers before          │
    │    writing their own. They learn the format, tone, and        │
    │    style that's expected.                                     │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate SFT data format

print("STAGE 1: SUPERVISED FINE-TUNING (SFT)")
print("="*70)

# Example SFT training data
sft_examples = [
    {
        "prompt": "What is the capital of France?",
        "response": "The capital of France is Paris. It's known as the 'City of Light' and is home to famous landmarks like the Eiffel Tower and the Louvre Museum."
    },
    {
        "prompt": "Write a haiku about programming.",
        "response": "Lines of code flow free\nDebugging through the night sky\nProgram comes alive"
    },
    {
        "prompt": "How do I make a peanut butter sandwich?",
        "response": "Here's how to make a peanut butter sandwich:\n1. Get two slices of bread\n2. Spread peanut butter on one slice\n3. Optionally add jelly on the other slice\n4. Press the slices together\n5. Enjoy!"
    }
]

print("\nExample SFT Training Data:")
print("-"*70)

for i, example in enumerate(sft_examples, 1):
    print(f"\n[Example {i}]")
    print(f"PROMPT: {example['prompt']}")
    print(f"RESPONSE: {example['response'][:100]}..." if len(example['response']) > 100 else f"RESPONSE: {example['response']}")

print("\n" + "="*70)
print("SFT TRAINING OBJECTIVE")
print("="*70)
print("""
The model learns to:
  P(response | prompt) → Maximize likelihood of human-written responses

This is standard language model fine-tuning!
  Loss = -log P(response | prompt)

After SFT, the model:
  ✓ Knows the format of helpful responses
  ✓ Can follow basic instructions
  ✗ Still doesn't know WHICH response is BEST
""")
print("="*70)

In [None]:
# Simple simulation of SFT

class SimpleSFTModel:
    """
    A simplified demonstration of SFT.
    
    In reality, this would be a transformer LLM.
    Here we just simulate the training process.
    """
    
    def __init__(self):
        self.training_data = []
        self.trained = False
    
    def add_demonstration(self, prompt, response):
        """Add a human demonstration to training data."""
        self.training_data.append({
            'prompt': prompt,
            'response': response
        })
    
    def train(self, epochs=3):
        """Simulate training on demonstrations."""
        print(f"Training on {len(self.training_data)} demonstrations...")
        
        for epoch in range(epochs):
            # Simulate loss decreasing
            loss = 2.5 - (epoch * 0.7)
            print(f"  Epoch {epoch+1}: Loss = {loss:.2f}")
        
        self.trained = True
        print("Training complete!")
    
    def generate(self, prompt):
        """Generate a response (simplified)."""
        if not self.trained:
            return "[Model not trained yet]"
        
        # In reality, this would be transformer generation
        # Here we just show the model learned the pattern
        return f"[SFT Model Response to: '{prompt[:30]}...']"


# Demonstrate SFT
print("\nSIMULATED SFT TRAINING")
print("-"*50)

sft_model = SimpleSFTModel()

# Add demonstrations
for example in sft_examples:
    sft_model.add_demonstration(example['prompt'], example['response'])

# Train
sft_model.train(epochs=3)

print("\n" + "-"*50)
print("After SFT, the model can generate helpful-LOOKING responses.")
print("But it doesn't know which response is BEST!")
print("That's where Stage 2 comes in...")

---
## Stage 2: Reward Model Training

The reward model learns to predict which responses humans prefer:

```
    ┌────────────────────────────────────────────────────────────────┐
    │               STAGE 2: REWARD MODEL (RM)                       │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  WHAT: Train a model to score responses based on human prefs  │
    │                                                                │
    │  HOW:                                                          │
    │    1. Generate multiple responses to the same prompt          │
    │    2. Humans rank them (e.g., "A is better than B")           │
    │    3. Train RM to predict these rankings                      │
    │                                                                │
    │  EXAMPLE:                                                      │
    │    Prompt: "Tell me a joke"                                   │
    │    Response A: "Why did the chicken cross the road? To get    │
    │                 to the other side!" → Score: 7.2              │
    │    Response B: "Knock knock... [unfunny joke]" → Score: 3.1   │
    │                                                                │
    │    Human ranked: A > B                                        │
    │    RM learns: score(A) > score(B)                             │
    │                                                                │
    │  ANALOGY: The Essay Grader                                    │
    │    Like training a teaching assistant to grade essays the     │
    │    same way the professor would. Once trained, they can       │
    │    grade thousands of essays automatically!                   │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the reward model concept

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Comparison data collection
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Step 1: Collect Human Preferences', fontsize=14, fontweight='bold')

# Prompt
prompt_box = FancyBboxPatch((1, 7), 8, 1.5, boxstyle="round,pad=0.1",
                             facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=2)
ax1.add_patch(prompt_box)
ax1.text(5, 7.9, 'Prompt: "How do I learn Python?"', ha='center', fontsize=11, fontweight='bold')

# Response A
resp_a = FancyBboxPatch((0.5, 3.5), 4, 3, boxstyle="round,pad=0.1",
                         facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax1.add_patch(resp_a)
ax1.text(2.5, 5.8, 'Response A', ha='center', fontsize=11, fontweight='bold')
ax1.text(2.5, 5.0, '"Start with basics like\nvariables and loops.\nTry Codecademy..."', 
         ha='center', fontsize=9)
ax1.text(2.5, 3.8, 'WINNER', ha='center', fontsize=10, color='#388e3c', fontweight='bold')

# Response B
resp_b = FancyBboxPatch((5.5, 3.5), 4, 3, boxstyle="round,pad=0.1",
                         facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax1.add_patch(resp_b)
ax1.text(7.5, 5.8, 'Response B', ha='center', fontsize=11, fontweight='bold')
ax1.text(7.5, 5.0, '"Python is a\nprogramming language\nused for..."', 
         ha='center', fontsize=9)

# Human annotator
ax1.text(5, 1.5, 'Human: "Response A is better!"', ha='center', fontsize=12, 
         style='italic', color='#666')
ax1.text(5, 0.7, '(More helpful, actionable advice)', ha='center', fontsize=10, color='#888')

# Arrows
ax1.annotate('', xy=(2.5, 6.5), xytext=(3.5, 7),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax1.annotate('', xy=(7.5, 6.5), xytext=(6.5, 7),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Right: Reward Model Training
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Step 2: Train Reward Model', fontsize=14, fontweight='bold')

# RM box
rm_box = FancyBboxPatch((2, 4), 6, 3, boxstyle="round,pad=0.1",
                         facecolor='#fff3e0', edgecolor='#f57c00', linewidth=3)
ax2.add_patch(rm_box)
ax2.text(5, 6.2, 'REWARD MODEL', ha='center', fontsize=14, fontweight='bold', color='#f57c00')
ax2.text(5, 5.2, 'RM(prompt, response) → score', ha='center', fontsize=11)
ax2.text(5, 4.5, '(Higher score = Better response)', ha='center', fontsize=10, color='#666')

# Input/Output
ax2.text(5, 8.5, 'Input: (prompt, response) pair', ha='center', fontsize=11)
ax2.annotate('', xy=(5, 7), xytext=(5, 8),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

ax2.text(5, 2, 'Output: Score (e.g., 7.2)', ha='center', fontsize=11)
ax2.annotate('', xy=(5, 2.5), xytext=(5, 4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#f57c00'))

# Training objective
ax2.text(5, 0.8, 'Training: Learn that score(A) > score(B)', ha='center', 
         fontsize=11, style='italic', color='#f57c00')

plt.tight_layout()
plt.show()

In [None]:
# Implement a simple Reward Model

class SimpleRewardModel(nn.Module):
    """
    A simplified Reward Model.
    
    In reality, this would be a transformer that takes
    (prompt, response) and outputs a scalar score.
    
    Here we simulate with a simple network.
    """
    
    def __init__(self, input_dim=10):
        super().__init__()
        
        # Simple network that outputs a score
        self.network = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1)  # Output: scalar score
        )
    
    def forward(self, x):
        """Return a scalar reward score."""
        return self.network(x)


def reward_model_loss(rm, chosen, rejected):
    """
    Reward Model Loss: Bradley-Terry model.
    
    We want: score(chosen) > score(rejected)
    
    Loss = -log(sigmoid(score_chosen - score_rejected))
    
    This is the key insight: we train on PREFERENCES, not absolute scores!
    """
    score_chosen = rm(chosen)
    score_rejected = rm(rejected)
    
    # We want chosen to have higher score
    loss = -F.logsigmoid(score_chosen - score_rejected).mean()
    
    return loss, score_chosen.mean().item(), score_rejected.mean().item()


# Demonstrate RM training
print("REWARD MODEL TRAINING DEMONSTRATION")
print("="*60)

# Create RM
rm = SimpleRewardModel(input_dim=10)
optimizer = torch.optim.Adam(rm.parameters(), lr=0.01)

# Simulate preference data
# In reality, this would be encoded (prompt, response) pairs
# Here we use random vectors where "chosen" has a specific pattern
n_pairs = 100

# Chosen responses have slightly higher "quality signal" in dimension 0
chosen_data = torch.randn(n_pairs, 10)
chosen_data[:, 0] += 1  # Add quality signal

rejected_data = torch.randn(n_pairs, 10)
rejected_data[:, 0] -= 1  # Lower quality signal

print("\nTraining RM on preference pairs...")
print("-"*60)

losses = []
for epoch in range(100):
    optimizer.zero_grad()
    loss, score_c, score_r = reward_model_loss(rm, chosen_data, rejected_data)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    
    if (epoch + 1) % 20 == 0:
        accuracy = (rm(chosen_data) > rm(rejected_data)).float().mean().item()
        print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, "
              f"Chosen: {score_c:.2f}, Rejected: {score_r:.2f}, "
              f"Accuracy: {accuracy:.1%}")

print("\n" + "="*60)
print("The RM learned to give HIGHER scores to CHOSEN responses!")
print("="*60)

In [None]:
# Visualize RM training
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Training loss
ax1 = axes[0]
ax1.plot(losses, color='#f57c00', linewidth=2)
ax1.set_xlabel('Training Step', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Reward Model Training Loss', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Right: Score distribution
ax2 = axes[1]

with torch.no_grad():
    chosen_scores = rm(chosen_data).numpy().flatten()
    rejected_scores = rm(rejected_data).numpy().flatten()

ax2.hist(chosen_scores, bins=20, alpha=0.7, label='Chosen (Preferred)', color='#4caf50')
ax2.hist(rejected_scores, bins=20, alpha=0.7, label='Rejected', color='#f44336')
ax2.axvline(x=np.mean(chosen_scores), color='#388e3c', linestyle='--', linewidth=2, label=f'Chosen mean: {np.mean(chosen_scores):.2f}')
ax2.axvline(x=np.mean(rejected_scores), color='#d32f2f', linestyle='--', linewidth=2, label=f'Rejected mean: {np.mean(rejected_scores):.2f}')
ax2.set_xlabel('Reward Score', fontsize=12)
ax2.set_ylabel('Count', fontsize=12)
ax2.set_title('Learned Score Distribution', fontsize=14, fontweight='bold')
ax2.legend()

plt.tight_layout()
plt.show()

print("\nThe Reward Model has learned to distinguish good from bad responses!")
print("Chosen responses get higher scores (green) than rejected ones (red).")

---
## Stage 3: PPO Fine-Tuning

Now we use reinforcement learning to optimize the LLM!

```
    ┌────────────────────────────────────────────────────────────────┐
    │               STAGE 3: PPO FINE-TUNING                         │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  WHAT: Use RL to make the LLM generate high-reward responses  │
    │                                                                │
    │  THE RL SETUP:                                                │
    │    • State: The prompt                                        │
    │    • Action: Each token the LLM generates                     │
    │    • Reward: Score from the Reward Model                      │
    │    • Policy: The LLM itself!                                  │
    │                                                                │
    │  PROCESS:                                                      │
    │    1. LLM generates a response                                │
    │    2. Reward Model scores the response                        │
    │    3. PPO updates LLM to generate better responses            │
    │    4. Repeat!                                                 │
    │                                                                │
    │  THE KL PENALTY (Critical!):                                  │
    │    Reward = RM_score - β × KL(policy || reference)            │
    │                                                                │
    │    Without KL penalty: LLM might "hack" the reward model     │
    │    With KL penalty: LLM stays close to the SFT model         │
    │                                                                │
    │  ANALOGY: Studying with a Tutor                               │
    │    - LLM writes essay (generates response)                    │
    │    - Tutor grades it (reward model)                          │
    │    - LLM learns from feedback (PPO update)                   │
    │    - KL penalty: "Stay true to what you learned in class!"   │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the PPO training loop for LLMs

fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 12)
ax.axis('off')

ax.text(7, 11.5, 'PPO Fine-Tuning Loop for LLMs', ha='center', fontsize=18, fontweight='bold')

# 1. Prompt
prompt_box = FancyBboxPatch((0.5, 8.5), 3, 1.5, boxstyle="round,pad=0.1",
                             facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=2)
ax.add_patch(prompt_box)
ax.text(2, 9.5, '1. PROMPT', ha='center', fontsize=11, fontweight='bold')
ax.text(2, 9, '"How do I..."', ha='center', fontsize=10, style='italic')

# 2. LLM Policy
llm_box = FancyBboxPatch((5, 8.5), 4, 1.5, boxstyle="round,pad=0.1",
                          facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax.add_patch(llm_box)
ax.text(7, 9.5, '2. LLM (Policy)', ha='center', fontsize=11, fontweight='bold')
ax.text(7, 9, 'Generates response', ha='center', fontsize=10)

# 3. Response
resp_box = FancyBboxPatch((10.5, 8.5), 3, 1.5, boxstyle="round,pad=0.1",
                           facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax.add_patch(resp_box)
ax.text(12, 9.5, '3. RESPONSE', ha='center', fontsize=11, fontweight='bold')
ax.text(12, 9, '"You should..."', ha='center', fontsize=10, style='italic')

# 4. Reward Model
rm_box = FancyBboxPatch((8.5, 5.5), 4, 2, boxstyle="round,pad=0.1",
                         facecolor='#f3e5f5', edgecolor='#7b1fa2', linewidth=2)
ax.add_patch(rm_box)
ax.text(10.5, 6.8, '4. REWARD MODEL', ha='center', fontsize=11, fontweight='bold')
ax.text(10.5, 6.2, 'Scores response', ha='center', fontsize=10)
ax.text(10.5, 5.8, 'Score: 7.3', ha='center', fontsize=10, color='#7b1fa2', fontweight='bold')

# 5. KL Penalty
kl_box = FancyBboxPatch((3.5, 5.5), 4, 2, boxstyle="round,pad=0.1",
                         facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax.add_patch(kl_box)
ax.text(5.5, 6.8, '5. KL PENALTY', ha='center', fontsize=11, fontweight='bold')
ax.text(5.5, 6.2, 'Stay close to SFT', ha='center', fontsize=10)
ax.text(5.5, 5.8, 'Penalty: -0.5', ha='center', fontsize=10, color='#d32f2f', fontweight='bold')

# 6. Total Reward
reward_box = FancyBboxPatch((5.5, 3), 3, 1.5, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(reward_box)
ax.text(7, 4, '6. TOTAL REWARD', ha='center', fontsize=11, fontweight='bold')
ax.text(7, 3.4, '7.3 - 0.5 = 6.8', ha='center', fontsize=10, fontweight='bold', color='#388e3c')

# 7. PPO Update
ppo_box = FancyBboxPatch((5.5, 0.5), 3, 1.5, boxstyle="round,pad=0.1",
                          facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
ax.add_patch(ppo_box)
ax.text(7, 1.5, '7. PPO UPDATE', ha='center', fontsize=11, fontweight='bold')
ax.text(7, 1, 'Improve policy', ha='center', fontsize=10)

# Arrows
ax.annotate('', xy=(4.9, 9.25), xytext=(3.6, 9.25),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(10.4, 9.25), xytext=(9.1, 9.25),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(10.5, 7.5), xytext=(10.5, 8.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(5.5, 7.5), xytext=(7, 8.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(7, 4.5), xytext=(5.5, 5.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(7, 4.5), xytext=(8.5, 5.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(7, 2), xytext=(7, 2.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Loop back arrow
ax.annotate('', xy=(5.4, 8.5), xytext=(5.4, 1.25),
             arrowprops=dict(arrowstyle='->', lw=2, color='#1976d2',
                            connectionstyle='arc3,rad=-0.3'))
ax.text(3.5, 5, 'REPEAT', fontsize=10, color='#1976d2', fontweight='bold', rotation=90)

plt.tight_layout()
plt.show()

In [None]:
# Demonstrate the RLHF reward calculation

print("RLHF REWARD CALCULATION")
print("="*60)

def calculate_rlhf_reward(rm_score, kl_divergence, beta=0.1):
    """
    Calculate the RLHF reward.
    
    Reward = RM_score - beta * KL(policy || reference)
    
    Args:
        rm_score: Score from reward model
        kl_divergence: KL divergence from reference (SFT) model
        beta: Weight of KL penalty
    
    Returns:
        Total RLHF reward
    """
    kl_penalty = beta * kl_divergence
    total_reward = rm_score - kl_penalty
    return total_reward, kl_penalty

# Example scenarios
scenarios = [
    {"name": "Good response, normal", "rm_score": 8.0, "kl": 0.5},
    {"name": "Great response, normal", "rm_score": 9.5, "kl": 0.8},
    {"name": "Hacked response (high RM, high KL)", "rm_score": 10.0, "kl": 5.0},
    {"name": "Safe response (moderate RM, low KL)", "rm_score": 7.0, "kl": 0.2},
]

print("\nScenarios with beta = 0.1:")
print("-"*60)
print(f"{'Scenario':<35} {'RM':<6} {'KL':<6} {'Penalty':<8} {'Total':<8}")
print("-"*60)

for s in scenarios:
    total, penalty = calculate_rlhf_reward(s["rm_score"], s["kl"], beta=0.1)
    print(f"{s['name']:<35} {s['rm_score']:<6.1f} {s['kl']:<6.1f} {penalty:<8.1f} {total:<8.1f}")

print("\n" + "="*60)
print("KEY INSIGHT: The KL penalty")
print("="*60)
print("""
The 'Hacked response' gets the HIGHEST RM score (10.0),
but the KL penalty (-0.5) brings it down.

This prevents the LLM from:
  • Exploiting loopholes in the reward model
  • Generating nonsensical but high-scoring text
  • Drifting too far from coherent language

The KL penalty says: "Stay close to what you learned!"
""")
print("="*60)

---
## Why RLHF Works: The Key Insights

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    WHY RLHF WORKS                              │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  1. PREFERENCES ARE EASIER THAN DEMONSTRATIONS                 │
    │     Telling which response is BETTER is easier than            │
    │     writing the PERFECT response yourself.                     │
    │                                                                │
    │  2. PREFERENCES SCALE BETTER                                   │
    │     Non-experts can compare responses.                         │
    │     Writing expert demonstrations is expensive!                │
    │                                                                │
    │  3. PREFERENCES CAPTURE SUBTLE QUALITIES                       │
    │     "Helpful", "harmless", "honest" are hard to define,      │
    │     but humans can recognize them when they see them!          │
    │                                                                │
    │  4. RL OPTIMIZES FOR WHAT HUMANS ACTUALLY WANT                 │
    │     Instead of just imitating (SFT), RL actively seeks         │
    │     to maximize the qualities humans prefer.                   │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the comparison: Demonstrations vs Preferences

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Demonstrations (SFT)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Demonstrations (for SFT)\n"Write the perfect answer"', fontsize=14, fontweight='bold')

# Expert writing
expert_box = FancyBboxPatch((2, 5), 6, 3, boxstyle="round,pad=0.1",
                             facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax1.add_patch(expert_box)
ax1.text(5, 7.2, 'Expert writes response', ha='center', fontsize=12, fontweight='bold')
ax1.text(5, 6.5, 'Time: 5-10 minutes per response', ha='center', fontsize=10, color='#d32f2f')
ax1.text(5, 5.8, 'Cost: $$$ (expert time)', ha='center', fontsize=10, color='#d32f2f')
ax1.text(5, 5.2, 'Scale: 10,000 examples', ha='center', fontsize=10)

ax1.text(5, 3, 'Requires domain expertise', ha='center', fontsize=11, style='italic', color='#666')
ax1.text(5, 2, 'Hard to scale', ha='center', fontsize=11, style='italic', color='#666')

# Right: Preferences (RM)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Preferences (for RM)\n"Which response is better?"', fontsize=14, fontweight='bold')

# Human comparing
pref_box = FancyBboxPatch((2, 5), 6, 3, boxstyle="round,pad=0.1",
                           facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax2.add_patch(pref_box)
ax2.text(5, 7.2, 'Human compares responses', ha='center', fontsize=12, fontweight='bold')
ax2.text(5, 6.5, 'Time: 30 seconds per comparison', ha='center', fontsize=10, color='#388e3c')
ax2.text(5, 5.8, 'Cost: $ (crowdsourcing)', ha='center', fontsize=10, color='#388e3c')
ax2.text(5, 5.2, 'Scale: 100,000+ comparisons', ha='center', fontsize=10)

ax2.text(5, 3, 'Anyone can judge "which is better"', ha='center', fontsize=11, style='italic', color='#388e3c')
ax2.text(5, 2, 'Scales easily!', ha='center', fontsize=11, style='italic', color='#388e3c', fontweight='bold')

plt.tight_layout()
plt.show()

print("\nKEY INSIGHT:")
print("="*60)
print("It's MUCH easier to say 'A is better than B' than to write A!")
print("This is why preference data scales 10x better than demonstrations.")
print("="*60)

---
## The History of RLHF

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    RLHF TIMELINE                               │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  2017  Deep RL from Human Preferences                         │
    │        ├── First RLHF paper (Christiano et al.)               │
    │        └── Applied to Atari games, not language               │
    │                                                                │
    │  2020  Learning to Summarize with Human Feedback              │
    │        ├── OpenAI applies RLHF to text summarization          │
    │        └── Shows RLHF > SFT for summarization quality         │
    │                                                                │
    │  2022  InstructGPT Paper                                      │
    │        ├── Full RLHF pipeline for general instruction         │
    │        ├── Establishes the 3-stage process                    │
    │        └── Foundation for ChatGPT                             │
    │                                                                │
    │  Nov   ChatGPT Released                                       │
    │  2022  ├── RLHF at scale                                      │
    │        ├── 100M+ users in months                              │
    │        └── Proves RLHF works in production                    │
    │                                                                │
    │  2023  Alternatives Emerge                                    │
    │        ├── DPO (Direct Preference Optimization)               │
    │        ├── RLAIF (RL from AI Feedback)                        │
    │        └── Constitutional AI                                  │
    │                                                                │
    │  2024  RLHF becomes standard for LLM alignment                │
    │        └── Used by OpenAI, Anthropic, Google, Meta, etc.      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the timeline

fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(2016, 2025)
ax.set_ylim(0, 10)

# Timeline
ax.axhline(y=5, color='#333', linewidth=3)

# Events
events = [
    {'year': 2017, 'y': 7, 'title': 'First RLHF Paper', 'desc': 'Applied to Atari', 'color': '#1976d2'},
    {'year': 2020, 'y': 3, 'title': 'Summarization', 'desc': 'RLHF for text', 'color': '#388e3c'},
    {'year': 2022, 'y': 7.5, 'title': 'InstructGPT', 'desc': '3-stage pipeline', 'color': '#f57c00'},
    {'year': 2022.8, 'y': 2.5, 'title': 'ChatGPT', 'desc': 'RLHF at scale', 'color': '#d32f2f'},
    {'year': 2023.5, 'y': 7, 'title': 'DPO & Alternatives', 'desc': 'Simpler methods', 'color': '#7b1fa2'},
    {'year': 2024, 'y': 3, 'title': 'Industry Standard', 'desc': 'Used everywhere', 'color': '#00838f'},
]

for event in events:
    # Dot on timeline
    ax.scatter(event['year'], 5, s=200, c=event['color'], zorder=5, edgecolors='white', linewidths=2)
    
    # Line to label
    ax.plot([event['year'], event['year']], [5, event['y']], color=event['color'], linewidth=2, linestyle='--')
    
    # Label box
    box = FancyBboxPatch((event['year']-0.4, event['y']-0.5 if event['y'] > 5 else event['y']-0.8), 
                          1.5, 1.3, boxstyle="round,pad=0.05",
                          facecolor=event['color'], edgecolor='white', linewidth=2, alpha=0.9)
    ax.add_patch(box)
    ax.text(event['year']+0.35, event['y']+0.3 if event['y'] > 5 else event['y'], 
            event['title'], ha='center', fontsize=9, fontweight='bold', color='white')
    ax.text(event['year']+0.35, event['y']-0.2 if event['y'] > 5 else event['y']-0.5, 
            event['desc'], ha='center', fontsize=8, color='white')

ax.set_xlabel('Year', fontsize=12)
ax.set_title('The Rise of RLHF', fontsize=16, fontweight='bold')
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()

---
## Summary: Key Takeaways

### What is RLHF?

RLHF = Teaching LLMs human preferences through reinforcement learning

### The Three Stages

| Stage | Name | What It Does | Analogy |
|-------|------|--------------|----------|
| 1 | SFT | Learn from example responses | Student reading model essays |
| 2 | RM | Learn what humans prefer | Training a grading assistant |
| 3 | PPO | Optimize for preferences | Student improving from grades |

### Key Concepts

| Concept | Description |
|---------|-------------|
| **Reward Model** | Predicts human preference scores |
| **KL Penalty** | Prevents model from drifting too far from SFT |
| **Preference Data** | Pairs of (chosen, rejected) responses |

### Why RLHF Works

1. Preferences are easier to provide than demonstrations
2. Scales better (crowdsourcing)
3. Captures subtle qualities
4. Optimizes for actual human wants

---
## Test Your Understanding

**1. What problem does RLHF solve?**
<details>
<summary>Click to reveal answer</summary>
RLHF solves the alignment problem: pretrained LLMs are good at predicting text but don't know what responses humans actually want. They might be unhelpful, harmful, or off-topic. RLHF teaches them human preferences.
</details>

**2. What are the three stages of RLHF?**
<details>
<summary>Click to reveal answer</summary>
1. Supervised Fine-Tuning (SFT): Train on human-written demonstrations
2. Reward Model Training: Train a model to predict human preferences
3. PPO Fine-Tuning: Use RL to maximize reward model scores
</details>

**3. Why is preference data easier to collect than demonstration data?**
<details>
<summary>Click to reveal answer</summary>
It's much easier to say "A is better than B" than to write the perfect answer yourself. Anyone can compare responses (30 seconds), but writing expert demonstrations requires domain knowledge and time (5-10 minutes). This makes preference data 10x more scalable.
</details>

**4. What is the KL penalty and why is it important?**
<details>
<summary>Click to reveal answer</summary>
The KL penalty = β × KL(policy || reference) penalizes the model for straying too far from the SFT model. Without it, the LLM might "hack" the reward model by generating nonsensical but high-scoring text. It keeps the model coherent and prevents reward hacking.
</details>

**5. How does the reward model learn?**
<details>
<summary>Click to reveal answer</summary>
The reward model learns from comparison data: pairs of (chosen, rejected) responses where humans indicated which is better. It's trained so that score(chosen) > score(rejected). This is called the Bradley-Terry model.
</details>

---
## What's Next?

Excellent work! You now understand the RLHF pipeline that powers ChatGPT!

In the next notebooks, we'll dive deeper into each component:
- **Reward Modeling** in detail
- **PPO for Language Models**
- **DPO** (a simpler alternative to PPO)
- **Practical RLHF** with the TRL library

**Continue to:** [Notebook 2: Reward Modeling](02_reward_modeling.ipynb)

---

*You now understand the secret sauce behind ChatGPT and other AI assistants!*