# PPO for Language Models: Training AI to Follow Instructions

PPO is the workhorse of RLHF - it's how we actually optimize the language model!

## What You'll Learn

By the end of this notebook, you'll understand:
- The dog training analogy: how PPO guides LLM behavior
- Why standard PPO needs modifications for LLMs
- The KL penalty: keeping the model grounded
- The value head: predicting future rewards
- Implementing PPO for text generation from scratch
- Using TRL's PPOTrainer

**Prerequisites:** Notebooks 1-2 (RLHF intro, Reward Modeling)

**Time:** ~35 minutes

---
## The Big Picture: The Dog Training Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE DOG TRAINING ANALOGY                              │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine training a dog to perform tricks...                  │
    │                                                                │
    │  THE DOG (Language Model):                                    │
    │    Already knows how to move, bark, sit (pre-trained)         │
    │    But doesn't know WHAT behaviors you want                   │
    │                                                                │
    │  THE TRAINER'S CLICKER (Reward Model):                        │
    │    Click! = Good behavior                                     │
    │    Silence = Not quite right                                  │
    │                                                                │
    │  THE TREAT (PPO Gradient):                                    │
    │    Higher reward → More of that behavior                      │
    │    Lower reward → Less of that behavior                       │
    │                                                                │
    │  THE LEASH (KL Penalty):                                      │
    │    Don't stray TOO far from original behavior!                │
    │    A dog that only does tricks isn't a dog anymore...        │
    │    We want to REFINE behavior, not replace it!                │
    │                                                                │
    │  RLHF with PPO = Training with clicker + treats + leash      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, FancyArrowPatch
import warnings
warnings.filterwarnings('ignore')

# Visualize the PPO for LLMs concept
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 12)
ax.axis('off')
ax.set_title('PPO for Language Models: The Training Loop', fontsize=16, fontweight='bold')

# Input: Prompt
prompt_box = FancyBboxPatch((0.5, 9), 2.5, 2, boxstyle="round,pad=0.1",
                             facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=3)
ax.add_patch(prompt_box)
ax.text(1.75, 10.3, 'PROMPT', ha='center', fontsize=10, fontweight='bold', color='#1976d2')
ax.text(1.75, 9.5, '"How do I\nlearn Python?"', ha='center', fontsize=9, style='italic')

# Policy (LLM)
llm_box = FancyBboxPatch((4, 8.5), 3, 3, boxstyle="round,pad=0.1",
                          facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(llm_box)
ax.text(5.5, 10.8, 'POLICY', ha='center', fontsize=11, fontweight='bold', color='#388e3c')
ax.text(5.5, 10.1, '(Language Model)', ha='center', fontsize=9)
ax.text(5.5, 9.4, 'π_θ(token|context)', ha='center', fontsize=10)
ax.text(5.5, 8.8, '+Value Head', ha='center', fontsize=9, style='italic')

# Response
response_box = FancyBboxPatch((8, 9), 3, 2, boxstyle="round,pad=0.1",
                               facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax.add_patch(response_box)
ax.text(9.5, 10.3, 'RESPONSE', ha='center', fontsize=10, fontweight='bold', color='#f57c00')
ax.text(9.5, 9.5, '"Start with the\nbasics..."', ha='center', fontsize=9, style='italic')

# Reward Model
rm_box = FancyBboxPatch((9, 5.5), 3.5, 2.5, boxstyle="round,pad=0.1",
                         facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=3)
ax.add_patch(rm_box)
ax.text(10.75, 7.3, 'REWARD MODEL', ha='center', fontsize=10, fontweight='bold', color='#7b1fa2')
ax.text(10.75, 6.5, 'Score: 7.3', ha='center', fontsize=11, fontweight='bold')
ax.text(10.75, 5.9, '(The "Clicker")', ha='center', fontsize=9, style='italic')

# KL Penalty
kl_box = FancyBboxPatch((5, 5.5), 3, 2.5, boxstyle="round,pad=0.1",
                         facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax.add_patch(kl_box)
ax.text(6.5, 7.3, 'KL PENALTY', ha='center', fontsize=10, fontweight='bold', color='#d32f2f')
ax.text(6.5, 6.5, 'β × KL(π || π_ref)', ha='center', fontsize=10)
ax.text(6.5, 5.9, '(The "Leash")', ha='center', fontsize=9, style='italic')

# Combined Reward
combined_box = FancyBboxPatch((6.5, 2.5), 4, 2, boxstyle="round,pad=0.1",
                               facecolor='#b2dfdb', edgecolor='#00796b', linewidth=3)
ax.add_patch(combined_box)
ax.text(8.5, 3.8, 'TOTAL REWARD', ha='center', fontsize=10, fontweight='bold', color='#00796b')
ax.text(8.5, 3, 'R = RM(x,y) - β×KL', ha='center', fontsize=10)

# PPO Update
ppo_box = FancyBboxPatch((1.5, 2.5), 3.5, 2, boxstyle="round,pad=0.1",
                          facecolor='#fff9c4', edgecolor='#fbc02d', linewidth=3)
ax.add_patch(ppo_box)
ax.text(3.25, 3.8, 'PPO UPDATE', ha='center', fontsize=10, fontweight='bold', color='#f57f17')
ax.text(3.25, 3, '(The "Treat")', ha='center', fontsize=9, style='italic')

# Arrows
ax.annotate('', xy=(3.9, 10), xytext=(3.1, 10),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(7.9, 10), xytext=(7.1, 10),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(10.5, 8), xytext=(10.5, 8.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(6.5, 8), xytext=(5.8, 8.4),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(8.5, 4.6), xytext=(8.5, 5.4),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(5, 3.5), xytext=(6.4, 3.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(3.25, 4.6), xytext=(3.25, 8.4),
            arrowprops=dict(arrowstyle='->', lw=2, color='#fbc02d', connectionstyle='arc3,rad=-0.3'))

ax.text(2, 6.5, 'Update\nPolicy', ha='center', fontsize=9, color='#f57f17')

plt.tight_layout()
plt.show()

print("\nPPO FOR LANGUAGE MODELS:")
print("  1. Generate response with current policy")
print("  2. Score with reward model")
print("  3. Compute KL penalty (don't drift too far!)")
print("  4. PPO update to maximize reward - KL penalty")

---
## Why Standard PPO Needs Modifications

```
    ┌────────────────────────────────────────────────────────────────┐
    │              LLM-SPECIFIC CHALLENGES                           │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  1. HUGE ACTION SPACE                                         │
    │     Standard RL: Maybe 4-10 actions (up, down, left, right)  │
    │     LLM: 50,000+ actions (every token in vocabulary!)        │
    │     → Need efficient probability computation                 │
    │                                                                │
    │  2. SEQUENTIAL DECISIONS                                      │
    │     Each token depends on all previous tokens                │
    │     Response = [token1, token2, ..., tokenN]                 │
    │     → Credit assignment is tricky                            │
    │                                                                │
    │  3. REWARD AT THE END                                         │
    │     Reward model scores COMPLETE response                    │
    │     Individual tokens don't get direct feedback              │
    │     → Need value function to distribute credit               │
    │                                                                │
    │  4. MUST STAY COHERENT                                        │
    │     Can't just maximize reward at any cost                   │
    │     Model might produce gibberish that fools RM              │
    │     → KL penalty keeps language quality                      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate the action space difference

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Standard RL action space
ax1 = axes[0]
actions = ['Up', 'Down', 'Left', 'Right']
probs = [0.3, 0.1, 0.2, 0.4]
colors = ['#64b5f6', '#81c784', '#ffb74d', '#ef5350']

bars = ax1.bar(actions, probs, color=colors, edgecolor='black', linewidth=2)
ax1.set_ylabel('Probability', fontsize=11)
ax1.set_title('Standard RL: 4 Actions', fontsize=14, fontweight='bold')
ax1.set_ylim(0, 0.6)
ax1.grid(True, alpha=0.3, axis='y')

for bar, prob in zip(bars, probs):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
             f'{prob:.0%}', ha='center', fontsize=10, fontweight='bold')

# Right: LLM action space (tokens)
ax2 = axes[1]
# Show top tokens and "...50,000 more"
tokens = ['the', 'a', 'to', 'of', 'and', '...', '50K+\nmore']
probs2 = [0.15, 0.08, 0.06, 0.05, 0.04, 0, 0]
colors2 = ['#64b5f6']*5 + ['white', '#e0e0e0']

bars = ax2.bar(tokens, probs2, color=colors2, edgecolor='black', linewidth=2)
bars[-1].set_edgecolor('#999')
bars[-1].set_linewidth(1)
bars[-1].set_linestyle('--')

ax2.set_ylabel('Probability', fontsize=11)
ax2.set_title('LLM: 50,000+ Actions (Tokens)!', fontsize=14, fontweight='bold', color='#d32f2f')
ax2.set_ylim(0, 0.25)
ax2.grid(True, alpha=0.3, axis='y')

# Add long tail annotation
ax2.annotate('Long tail of\nrare tokens', xy=(5.5, 0.02), fontsize=9, ha='center', color='#666')

plt.tight_layout()
plt.show()

print("\nACTION SPACE COMPARISON:")
print("  Standard RL: ~4-10 discrete actions")
print("  Language Models: ~50,000 tokens in vocabulary!")
print("  This is a 10,000x increase in complexity!")

---
## The RLHF Objective: Reward Minus KL Penalty

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE RLHF OBJECTIVE                                │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  NAIVE OBJECTIVE (doesn't work!):                             │
    │    maximize E[RM(prompt, response)]                           │
    │                                                                │
    │    Problem: Model finds exploits!                             │
    │    - Repeats phrases that fool RM                            │
    │    - Produces gibberish with high RM scores                  │
    │    - Loses ability to be a general language model            │
    │                                                                │
    │  RLHF OBJECTIVE (with KL penalty):                            │
    │                                                                │
    │    maximize E[RM(x,y) - β × KL(π_θ(y|x) || π_ref(y|x))]      │
    │                         └───────────────────────────┘         │
    │                           Don't drift too far!                │
    │                                                                │
    │  WHERE:                                                       │
    │    π_θ    = Current policy (being trained)                   │
    │    π_ref  = Reference policy (SFT model, frozen)             │
    │    β      = KL penalty coefficient (0.01 - 0.2)              │
    │                                                                │
    │  INTUITION:                                                   │
    │    "Be helpful (high RM) but stay sane (low KL)"             │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
def compute_kl_penalty(log_probs_policy, log_probs_ref):
    """
    Compute KL divergence penalty between policy and reference.
    
    KL(π || π_ref) = E_π[log π(a|s) - log π_ref(a|s)]
    
    For autoregressive models, this is sum over tokens.
    """
    kl = log_probs_policy - log_probs_ref
    return kl


def compute_rlhf_reward(rm_score, log_probs_policy, log_probs_ref, beta=0.1):
    """
    Compute the RLHF reward with KL penalty.
    
    R = RM(x, y) - β × KL(π || π_ref)
    
    Args:
        rm_score: Reward model score for the response
        log_probs_policy: Log probs from current policy
        log_probs_ref: Log probs from reference model
        beta: KL penalty coefficient
    
    Returns:
        Total reward (scalar)
    """
    kl = compute_kl_penalty(log_probs_policy, log_probs_ref)
    total_reward = rm_score - beta * kl.sum()
    return total_reward, kl


# Demonstrate the tradeoff
print("RLHF REWARD CALCULATION")
print("="*60)

# Example: 5 tokens in response
torch.manual_seed(42)
rm_score = 7.5  # Good response
log_probs_policy = torch.tensor([-2.0, -1.5, -3.0, -2.5, -1.8])  # Current policy
log_probs_ref = torch.tensor([-2.1, -1.6, -2.8, -2.6, -1.9])     # Reference (similar)

print(f"\nReward Model Score: {rm_score}")
print(f"Policy log probs: {log_probs_policy.numpy()}")
print(f"Reference log probs: {log_probs_ref.numpy()}")

betas = [0.0, 0.05, 0.1, 0.2, 0.5]
print(f"\n{'β':<8} {'KL Penalty':<15} {'Total Reward':<15}")
print("-"*40)

for beta in betas:
    total_reward, kl = compute_rlhf_reward(rm_score, log_probs_policy, log_probs_ref, beta)
    kl_sum = kl.sum().item()
    print(f"{beta:<8} {beta * kl_sum:<15.3f} {total_reward.item():<15.3f}")

print("\n" + "="*60)
print("Higher β → More penalty for diverging from reference!")

In [None]:
# Visualize the KL penalty effect

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Reward vs KL tradeoff
ax1 = axes[0]
kl_values = np.linspace(0, 10, 100)
rm_score = 8.0

for beta in [0.05, 0.1, 0.2, 0.5]:
    total_reward = rm_score - beta * kl_values
    ax1.plot(kl_values, total_reward, linewidth=2, label=f'β = {beta}')

ax1.axhline(y=0, color='red', linestyle='--', alpha=0.5)
ax1.fill_between(kl_values, -5, 0, alpha=0.1, color='red', label='Negative reward zone')

ax1.set_xlabel('KL Divergence from Reference', fontsize=11)
ax1.set_ylabel('Total RLHF Reward', fontsize=11)
ax1.set_title('KL Penalty Effect\n(Higher β = Stronger Leash)', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(-5, 10)

# Right: What happens without KL penalty
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Why KL Penalty Matters', fontsize=12, fontweight='bold')

# Without KL
ax2.text(2, 8.5, 'Without KL Penalty:', fontsize=11, fontweight='bold', color='#d32f2f')
problems = [
    '❌ Model finds reward hacks',
    '❌ Outputs become repetitive',
    '❌ Loses general language ability',
    '❌ Gibberish that fools RM',
]
for i, problem in enumerate(problems):
    ax2.text(2, 7.5 - i*0.8, problem, fontsize=10, color='#d32f2f')

# With KL
ax2.text(2, 4, 'With KL Penalty:', fontsize=11, fontweight='bold', color='#388e3c')
benefits = [
    '✓ Stays close to SFT model',
    '✓ Maintains language quality',
    '✓ Learns genuine improvements',
    '✓ Can still be used as LLM',
]
for i, benefit in enumerate(benefits):
    ax2.text(2, 3 - i*0.8, benefit, fontsize=10, color='#388e3c')

plt.tight_layout()
plt.show()

---
## The Value Head: Predicting Future Rewards

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE VALUE HEAD                                    │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  PROBLEM: Reward comes at the END                             │
    │    Response: [t1, t2, t3, ..., t100] → RM Score = 7.5        │
    │    How do we know which tokens were good?                     │
    │                                                                │
    │  SOLUTION: Add a VALUE HEAD to the LLM                        │
    │                                                                │
    │    ┌─────────────────────┐                                    │
    │    │  Transformer LLM    │                                    │
    │    │  (frozen or fine-   │                                    │
    │    │   tuned backbone)   │                                    │
    │    └─────────┬───────────┘                                    │
    │              │ hidden_state                                   │
    │        ┌─────┴─────┐                                          │
    │        │           │                                          │
    │        ▼           ▼                                          │
    │    ┌──────┐    ┌──────┐                                       │
    │    │ LM   │    │Value │                                       │
    │    │ Head │    │ Head │  ← NEW!                               │
    │    └──┬───┘    └──┬───┘                                       │
    │       │           │                                           │
    │       ▼           ▼                                           │
    │    Vocab        Scalar                                        │
    │    Probs        Value                                         │
    │                                                                │
    │  THE VALUE HEAD predicts: "How much reward will we get       │
    │  from this point onwards?"                                   │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
class SimpleTransformerWithValueHead(nn.Module):
    """
    Simplified model showing the value head architecture.
    
    In practice, this would be a full transformer.
    Here we show the key concept: LM head + Value head.
    """
    
    def __init__(self, vocab_size=1000, hidden_dim=256):
        super().__init__()
        
        # Simplified "backbone" (in reality: full transformer)
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.backbone = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
        )
        
        # LM Head: Predicts next token probabilities
        self.lm_head = nn.Linear(hidden_dim, vocab_size)
        
        # Value Head: Predicts expected future reward (NEW!)
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
        )
    
    def forward(self, input_ids):
        """
        Forward pass returning both token probs and values.
        
        Args:
            input_ids: Token indices (batch_size, seq_len)
            
        Returns:
            logits: Token logits (batch_size, seq_len, vocab_size)
            values: Value predictions (batch_size, seq_len, 1)
        """
        # Get embeddings
        x = self.embedding(input_ids)
        
        # Pass through backbone
        hidden_states = self.backbone(x)
        
        # Two outputs!
        logits = self.lm_head(hidden_states)  # For next token
        values = self.value_head(hidden_states)  # For PPO
        
        return logits, values
    
    def generate(self, input_ids, max_length=20):
        """Simple greedy generation (for demonstration)."""
        for _ in range(max_length):
            logits, _ = self.forward(input_ids)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        return input_ids


# Create and inspect model
print("MODEL WITH VALUE HEAD")
print("="*60)

model = SimpleTransformerWithValueHead(vocab_size=1000, hidden_dim=256)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
lm_head_params = sum(p.numel() for p in model.lm_head.parameters())
value_head_params = sum(p.numel() for p in model.value_head.parameters())

print(f"\nTotal parameters: {total_params:,}")
print(f"LM Head parameters: {lm_head_params:,}")
print(f"Value Head parameters: {value_head_params:,}")
print(f"Value Head adds only {value_head_params/total_params*100:.1f}% extra parameters!")

# Test forward pass
test_input = torch.randint(0, 1000, (2, 10))  # Batch of 2, length 10
logits, values = model(test_input)

print(f"\nInput shape: {test_input.shape}")
print(f"Logits shape: {logits.shape} (for next token prediction)")
print(f"Values shape: {values.shape} (for PPO)")
print("="*60)

In [None]:
# Visualize value head architecture

fig, ax = plt.subplots(figsize=(12, 10))
ax.set_xlim(0, 12)
ax.set_ylim(0, 12)
ax.axis('off')
ax.set_title('LLM Architecture with Value Head', fontsize=16, fontweight='bold')

# Input tokens
input_box = FancyBboxPatch((4, 10), 4, 1, boxstyle="round,pad=0.1",
                            facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=2)
ax.add_patch(input_box)
ax.text(6, 10.5, 'Input Tokens', ha='center', fontsize=10)

# Embedding
emb_box = FancyBboxPatch((4, 8.2), 4, 1.2, boxstyle="round,pad=0.1",
                          facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax.add_patch(emb_box)
ax.text(6, 8.8, 'Embedding Layer', ha='center', fontsize=10)
ax.annotate('', xy=(6, 9.5), xytext=(6, 9.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Transformer backbone
trans_box = FancyBboxPatch((3.5, 5), 5, 2.5, boxstyle="round,pad=0.1",
                            facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(trans_box)
ax.text(6, 6.7, 'Transformer Backbone', ha='center', fontsize=11, fontweight='bold')
ax.text(6, 6, '(Self-Attention + FFN)', ha='center', fontsize=9)
ax.text(6, 5.4, 'Hidden States', ha='center', fontsize=9, style='italic')
ax.annotate('', xy=(6, 7.6), xytext=(6, 8.1),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Split into two heads
ax.annotate('', xy=(3.5, 4), xytext=(5.5, 4.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(8.5, 4), xytext=(6.5, 4.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# LM Head (left)
lm_box = FancyBboxPatch((1.5, 2), 3.5, 1.8, boxstyle="round,pad=0.1",
                         facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=2)
ax.add_patch(lm_box)
ax.text(3.25, 3.3, 'LM Head', ha='center', fontsize=10, fontweight='bold', color='#7b1fa2')
ax.text(3.25, 2.6, 'Linear(hidden → vocab)', ha='center', fontsize=9)

# LM output
ax.text(3.25, 1.2, 'Token Probabilities', ha='center', fontsize=10, fontweight='bold')
ax.text(3.25, 0.6, 'P(next_token | context)', ha='center', fontsize=9, style='italic')
ax.annotate('', xy=(3.25, 1.5), xytext=(3.25, 1.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#7b1fa2'))

# Value Head (right)
value_box = FancyBboxPatch((7, 2), 3.5, 1.8, boxstyle="round,pad=0.1",
                            facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax.add_patch(value_box)
ax.text(8.75, 3.3, 'Value Head', ha='center', fontsize=10, fontweight='bold', color='#d32f2f')
ax.text(8.75, 2.6, 'Linear(hidden → 1)', ha='center', fontsize=9)
ax.text(9.5, 2.1, 'NEW!', fontsize=8, color='#d32f2f', fontweight='bold')

# Value output
ax.text(8.75, 1.2, 'Expected Reward', ha='center', fontsize=10, fontweight='bold')
ax.text(8.75, 0.6, 'V(state) = future rewards', ha='center', fontsize=9, style='italic')
ax.annotate('', xy=(8.75, 1.5), xytext=(8.75, 1.9),
            arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f'))

plt.tight_layout()
plt.show()

---
## PPO Update for Language Models

```
    ┌────────────────────────────────────────────────────────────────┐
    │              PPO FOR LANGUAGE MODELS                           │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  STANDARD PPO OBJECTIVE (from advanced-algorithms):           │
    │                                                                │
    │    L_PPO = E[ min(r(θ)×Â, clip(r(θ), 1-ε, 1+ε)×Â) ]         │
    │                                                                │
    │    where r(θ) = π_θ(a|s) / π_old(a|s)                        │
    │                                                                │
    │  FOR LANGUAGE MODELS:                                         │
    │                                                                │
    │    • State s = (prompt, tokens_so_far)                       │
    │    • Action a = next_token                                   │
    │    • π_θ(a|s) = softmax(logits)[token_id]                   │
    │                                                                │
    │  ADVANTAGE ESTIMATION:                                        │
    │                                                                │
    │    For each token position t:                                │
    │    Â_t = R_total - V(state_t)                               │
    │                                                                │
    │    Or using GAE for better variance reduction:               │
    │    Â_t = Σ (γλ)^l × δ_{t+l}                                 │
    │                                                                │
    │  VALUE LOSS:                                                  │
    │    L_V = (V(s) - R_target)²                                  │
    │                                                                │
    │  TOTAL LOSS:                                                  │
    │    L = -L_PPO + c₁×L_V - c₂×Entropy                         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
def compute_advantages(rewards, values, gamma=0.99, lam=0.95):
    """
    Compute Generalized Advantage Estimation (GAE).
    
    For language models, this is computed per-token.
    
    Args:
        rewards: Per-token rewards (mostly 0, final token has RM score)
        values: Value predictions at each position
        gamma: Discount factor
        lam: GAE lambda
    
    Returns:
        advantages: GAE advantages for each position
        returns: Discounted returns for value loss
    """
    seq_len = len(rewards)
    advantages = torch.zeros(seq_len)
    returns = torch.zeros(seq_len)
    
    # Work backwards
    last_gae = 0
    last_return = 0
    
    for t in reversed(range(seq_len)):
        if t == seq_len - 1:
            next_value = 0  # Terminal state
        else:
            next_value = values[t + 1]
        
        # TD error
        delta = rewards[t] + gamma * next_value - values[t]
        
        # GAE
        advantages[t] = delta + gamma * lam * last_gae
        last_gae = advantages[t]
        
        # Returns for value loss
        returns[t] = rewards[t] + gamma * last_return
        last_return = returns[t]
    
    return advantages, returns


def ppo_loss_for_lm(log_probs_new, log_probs_old, advantages, epsilon=0.2):
    """
    Compute PPO clipped objective for language model.
    
    Args:
        log_probs_new: Log probs under current policy
        log_probs_old: Log probs under old policy (from generation)
        advantages: GAE advantages
        epsilon: Clipping range
    
    Returns:
        PPO loss (to maximize)
    """
    # Probability ratio
    ratio = torch.exp(log_probs_new - log_probs_old)
    
    # Clipped and unclipped objectives
    obj_unclipped = ratio * advantages
    obj_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
    
    # PPO objective: take minimum (pessimistic)
    ppo_obj = torch.min(obj_unclipped, obj_clipped)
    
    return ppo_obj.mean()


# Demonstrate
print("PPO FOR LANGUAGE MODELS: EXAMPLE")
print("="*60)

# Example: 10 token response
seq_len = 10
torch.manual_seed(42)

# Rewards: 0 for all but last token (which has RM score minus KL)
rewards = torch.zeros(seq_len)
rewards[-1] = 5.0  # Final reward from RM - KL penalty

# Value predictions
values = torch.tensor([2.0, 2.5, 3.0, 3.5, 4.0, 4.2, 4.5, 4.8, 5.0, 5.0])

# Compute advantages
advantages, returns = compute_advantages(rewards, values)

print(f"\nPer-token rewards: {rewards.numpy()}")
print(f"Value predictions: {values.numpy()}")
print(f"GAE advantages: {advantages.numpy().round(3)}")
print(f"Returns: {returns.numpy().round(3)}")

# Compute PPO loss
log_probs_old = torch.randn(seq_len) * 0.5 - 2  # Simulated old log probs
log_probs_new = log_probs_old + torch.randn(seq_len) * 0.1  # Slight change

ppo_obj = ppo_loss_for_lm(log_probs_new, log_probs_old, advantages)
print(f"\nPPO Objective: {ppo_obj.item():.4f}")
print("(Positive = policy improving, Negative = policy worsening)")
print("="*60)

In [None]:
# Visualize the PPO update for a single response

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Top left: Rewards per token
ax1 = axes[0, 0]
positions = np.arange(seq_len)
ax1.bar(positions, rewards.numpy(), color='#64b5f6', edgecolor='black')
ax1.set_xlabel('Token Position', fontsize=11)
ax1.set_ylabel('Reward', fontsize=11)
ax1.set_title('Per-Token Rewards\n(Only final token has RM score)', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='y')
ax1.annotate('RM Score!', xy=(seq_len-1, rewards[-1].item()), xytext=(seq_len-2, rewards[-1].item()+1),
             arrowprops=dict(arrowstyle='->', color='red'), fontsize=10, color='red')

# Top right: Value predictions vs returns
ax2 = axes[0, 1]
ax2.plot(positions, values.numpy(), 'go-', linewidth=2, label='V(s) predictions')
ax2.plot(positions, returns.numpy(), 'rs-', linewidth=2, label='Actual returns')
ax2.set_xlabel('Token Position', fontsize=11)
ax2.set_ylabel('Value', fontsize=11)
ax2.set_title('Value Predictions vs Actual Returns', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Bottom left: Advantages
ax3 = axes[1, 0]
colors = ['#4caf50' if a > 0 else '#f44336' for a in advantages.numpy()]
ax3.bar(positions, advantages.numpy(), color=colors, edgecolor='black')
ax3.axhline(y=0, color='black', linewidth=1)
ax3.set_xlabel('Token Position', fontsize=11)
ax3.set_ylabel('Advantage', fontsize=11)
ax3.set_title('GAE Advantages\n(Green = good tokens, Red = bad tokens)', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3, axis='y')

# Bottom right: Probability ratio effect
ax4 = axes[1, 1]
ratio = torch.exp(log_probs_new - log_probs_old).numpy()
ax4.bar(positions, ratio, color='#9575cd', edgecolor='black')
ax4.axhline(y=1, color='black', linewidth=1, linestyle='--')
ax4.axhline(y=1.2, color='red', linewidth=1, linestyle=':', label='Clip bounds (1±0.2)')
ax4.axhline(y=0.8, color='red', linewidth=1, linestyle=':')
ax4.set_xlabel('Token Position', fontsize=11)
ax4.set_ylabel('Probability Ratio', fontsize=11)
ax4.set_title('Probability Ratio r(θ) = π_new/π_old', fontsize=12, fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---
## Using TRL's PPOTrainer

```
    ┌────────────────────────────────────────────────────────────────┐
    │              TRL PPOTrainer                                    │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  TRL (Transformer Reinforcement Learning) from Hugging Face   │
    │  provides production-ready RLHF components.                   │
    │                                                                │
    │  KEY COMPONENTS:                                              │
    │                                                                │
    │  1. AutoModelForCausalLMWithValueHead                        │
    │     • Wraps any HF model with value head                     │
    │     • Automatically handles architecture                     │
    │                                                                │
    │  2. PPOConfig                                                 │
    │     • learning_rate, batch_size, mini_batch_size             │
    │     • ppo_epochs, kl_penalty                                 │
    │                                                                │
    │  3. PPOTrainer                                                │
    │     • .generate() - Generate responses                       │
    │     • .step() - PPO update step                              │
    │     • Handles KL penalty automatically                       │
    │                                                                │
    │  WORKFLOW:                                                    │
    │    for batch in dataloader:                                  │
    │        responses = trainer.generate(prompts)                 │
    │        rewards = reward_model(prompts, responses)            │
    │        trainer.step(prompts, responses, rewards)             │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Check TRL availability
try:
    from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
    from transformers import AutoTokenizer
    TRL_AVAILABLE = True
    print("✓ TRL is installed!")
except ImportError:
    TRL_AVAILABLE = False
    print("✗ TRL not installed.")
    print("  Install with: pip install trl transformers")

# Show example code
print("\n" + "="*60)
print("TRL PPOTrainer EXAMPLE CODE")
print("="*60)

example_code = '''
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer

# 1. Load model WITH value head
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. Configure PPO
config = PPOConfig(
    learning_rate=1.41e-5,
    batch_size=16,
    mini_batch_size=4,
    ppo_epochs=4,
    kl_penalty="kl",     # Use KL divergence penalty
    init_kl_coef=0.2,    # Initial KL coefficient (β)
    target_kl=6.0,       # Target KL divergence
)

# 3. Create trainer
ppo_trainer = PPOTrainer(
    config=config,
    model=model,
    ref_model=ref_model,  # For KL penalty
    tokenizer=tokenizer,
)

# 4. Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        # Get prompts
        query_tensors = [tokenizer.encode(q) for q in batch["prompts"]]
        
        # Generate responses
        response_tensors = ppo_trainer.generate(
            query_tensors,
            max_length=100,
            do_sample=True,
            temperature=0.7,
        )
        
        # Get rewards from reward model
        texts = [tokenizer.decode(r) for r in response_tensors]
        rewards = [reward_model.score(text) for text in texts]
        rewards = [torch.tensor(r) for r in rewards]
        
        # PPO step - this does everything!
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        
        # Log stats
        print(f"Mean reward: {stats['ppo/mean_scores']:.3f}")
        print(f"KL divergence: {stats['objective/kl']:.3f}")
'''

print(example_code)
print("="*60)

In [None]:
# Visualize the TRL workflow

fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')
ax.set_title('TRL PPOTrainer Workflow', fontsize=16, fontweight='bold')

# Steps
steps = [
    (1, 7.5, 'Load Models', 'AutoModelFor...WithValueHead\n+ Reference model', '#e3f2fd', '#1976d2'),
    (5, 7.5, 'Configure', 'PPOConfig(...)\nbatch_size, kl_coef, etc.', '#fff3e0', '#f57c00'),
    (9, 7.5, 'Create Trainer', 'PPOTrainer(\n  model, ref_model,\n  tokenizer)', '#c8e6c9', '#388e3c'),
    (3, 4.5, 'Generate', 'trainer.generate(\n  prompts)', '#e1bee7', '#7b1fa2'),
    (7, 4.5, 'Get Rewards', 'reward_model(\n  prompts, responses)', '#ffcdd2', '#d32f2f'),
    (11, 4.5, 'PPO Step', 'trainer.step(\n  queries, responses,\n  rewards)', '#b2dfdb', '#00796b'),
]

for x, y, title, desc, fcolor, ecolor in steps:
    box = FancyBboxPatch((x, y), 3.5, 2, boxstyle="round,pad=0.1",
                          facecolor=fcolor, edgecolor=ecolor, linewidth=2)
    ax.add_patch(box)
    ax.text(x + 1.75, y + 1.6, title, ha='center', fontsize=10, fontweight='bold', color=ecolor)
    ax.text(x + 1.75, y + 0.7, desc, ha='center', fontsize=8)

# Arrows for top row
ax.annotate('', xy=(4.9, 8.5), xytext=(4.6, 8.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(8.9, 8.5), xytext=(8.6, 8.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Arrow down to loop
ax.annotate('', xy=(10.75, 6.5), xytext=(10.75, 7.4),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Arrows for bottom row
ax.annotate('', xy=(6.9, 5.5), xytext=(6.6, 5.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(10.9, 5.5), xytext=(10.6, 5.5),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Loop back arrow
ax.annotate('', xy=(3, 6.5), xytext=(12, 4.3),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666',
                            connectionstyle='arc3,rad=0.3'))
ax.text(7, 2.5, 'Repeat for many batches...', ha='center', fontsize=10, style='italic')

# Labels
ax.text(7, 9.5, 'Setup (once)', ha='center', fontsize=11, fontweight='bold', color='#666')
ax.text(7, 6.8, 'Training Loop (repeat)', ha='center', fontsize=11, fontweight='bold', color='#666')

plt.tight_layout()
plt.show()

---
## Summary: Key Takeaways

### PPO for LLMs: Key Modifications

| Standard PPO | PPO for LLMs |
|--------------|-------------|
| Small action space | 50K+ tokens |
| Dense rewards | Sparse (end only) |
| Simple states | Sequential context |
| No reference | KL penalty |

### The RLHF Objective

```
R = RM(prompt, response) - β × KL(π_θ || π_ref)
```

### TRL Components

| Component | Purpose |
|-----------|---------|
| `AutoModelForCausalLMWithValueHead` | LLM + Value head |
| `PPOConfig` | Hyperparameters |
| `PPOTrainer.generate()` | Sample responses |
| `PPOTrainer.step()` | PPO update |

---
## Test Your Understanding

**1. Why do we need a KL penalty in RLHF?**
<details>
<summary>Click to reveal answer</summary>
Without KL penalty, the model can "reward hack" - find outputs that score high with the reward model but are actually gibberish or repetitive. The KL penalty keeps the model close to the original SFT model, preserving language quality and preventing degenerate solutions.
</details>

**2. What is the value head and why is it needed?**
<details>
<summary>Click to reveal answer</summary>
The value head is a neural network layer added to the LLM that predicts expected future rewards from any position. It's needed because:
- Rewards only come at the END of generation
- We need to estimate which tokens contributed to the final reward
- It enables computing advantages for each token position
</details>

**3. Why is the action space so much larger for LLMs?**
<details>
<summary>Click to reveal answer</summary>
In standard RL, actions might be discrete choices like up/down/left/right (4 actions). For LLMs, each action is choosing the next token from the entire vocabulary - typically 32K to 100K+ tokens! This makes the policy network (softmax over vocab) much larger and requires efficient implementations.
</details>

**4. What does β control in the RLHF objective?**
<details>
<summary>Click to reveal answer</summary>
β is the KL penalty coefficient:
- β = 0: No penalty, maximize RM score only (dangerous!)
- β small (0.01-0.05): Light constraint, more reward optimization
- β large (0.2-0.5): Strong constraint, stay very close to reference

Typical values: 0.05-0.2. Often tuned based on observed KL divergence.
</details>

**5. How does TRL's PPOTrainer simplify RLHF?**
<details>
<summary>Click to reveal answer</summary>
TRL's PPOTrainer handles:
- Automatic value head wrapping
- KL penalty computation vs reference model
- Advantage estimation (GAE)
- PPO clipped objective
- Batching and optimization

You just need to provide: prompts, generated responses, and rewards. The trainer does the rest!
</details>

---
## What's Next?

PPO is powerful but complex. In the next notebook, we'll explore **DPO (Direct Preference Optimization)** - a simpler alternative that skips the reward model entirely!

**Continue to:** [Notebook 4: DPO and Alternatives](04_dpo_and_alternatives.ipynb)

---

*PPO for LLMs: Like training a dog - reward good behavior, keep the leash tight!*