# Chapters 4 & 5: Policy Gradients and GRPO

**Goal:** Understand how GRPO (Group Relative Policy Optimization) uses reward signals to update model parameters, and why it eliminates the need for a critic network.

This notebook covers both Chapter 4 (REINFORCE theory) and Chapter 5 (GRPO implementation), since `grpo.py` implements the ideas from both chapters.

---

## The Central Problem

We have:
- A model that generates text (Chapter 2)
- A reward function that scores responses (Chapter 3)

**Question:** How do we use the scalar reward to compute gradients and update the model weights?

**The difficulty:** Sampling is **non-differentiable**. When we call `model.generate()`, the function `torch.multinomial()` selects discrete tokens. We can't backpropagate through a discrete selection.

---

## Part 1: The Log-Derivative Trick (REINFORCE)

### Chapter 4's Key Insight

We can't differentiate through sampling, but we CAN differentiate through **log-probabilities**.

The REINFORCE algorithm:
1. **Sample** a response (no gradients needed)
2. **Compute the log-probability** of that response under the current model (this IS differentiable)
3. **Multiply** by the reward
4. **Backpropagate**

### The Math

Our objective is to maximize expected reward:

$$J(\theta) = \mathbb{E}_{a \sim \pi_\theta}[R(s, a)]$$

Using the log-derivative trick:

$$\nabla_\theta J(\theta) = \mathbb{E}_{a \sim \pi_\theta}[R(s, a) \cdot \nabla_\theta \log \pi_\theta(a|s)]$$

In practice, we approximate this expectation with a single sample:

$$\nabla_\theta J \approx R(s, a) \cdot \nabla_\theta \log \pi_\theta(a|s)$$

And since PyTorch minimizes loss:

$$\text{loss} = -R(s, a) \cdot \log \pi_\theta(a|s)$$

In [None]:
# Intuitive demonstration of why this works

import torch
import torch.nn.functional as F

# Imagine a simple model with 3 possible tokens
logits = torch.tensor([1.0, 0.5, -0.5], requires_grad=True)
probs = F.softmax(logits, dim=0)

print("Initial token probabilities:")
for i, p in enumerate(probs):
    print(f"  Token {i}: {p.item():.4f}")

# Suppose we sampled Token 0 and got reward +1.0
sampled_token = 0
reward = 1.0

# REINFORCE loss
log_prob = F.log_softmax(logits, dim=0)[sampled_token]
loss = -reward * log_prob

print(f"\nSampled token: {sampled_token}")
print(f"Reward: {reward}")
print(f"Log prob of sampled token: {log_prob.item():.4f}")
print(f"Loss: {loss.item():.4f}")

# Compute gradients
loss.backward()

print(f"\nGradients on logits: {logits.grad.tolist()}")
print("\nInterpretation:")
print("  Negative gradient on Token 0 --> optimizer will INCREASE its logit")
print("  Positive gradients on Tokens 1,2 --> optimizer will DECREASE their logits")
print("  Net effect: Token 0 becomes MORE likely (reward was positive)")

In [None]:
# Now with a NEGATIVE reward
logits2 = torch.tensor([1.0, 0.5, -0.5], requires_grad=True)

# Suppose we sampled Token 0 but got reward -1.0
sampled_token = 0
reward = -1.0

log_prob = F.log_softmax(logits2, dim=0)[sampled_token]
loss = -reward * log_prob
loss.backward()

print("With negative reward:")
print(f"  Gradients on logits: {logits2.grad.tolist()}")
print("  Positive gradient on Token 0 --> optimizer will DECREASE its logit")
print("  Net effect: Token 0 becomes LESS likely (reward was negative)")
print()
print("This is the core of RL: positive reward = do more of this,")
print("negative reward = do less of this.")

---

## Part 2: The Variance Problem & Baselines

### Why raw REINFORCE has high variance

If all rewards are positive (say, ranging from 0.1 to 1.0), then EVERY response gets "encouraged" - even the worst ones. The gradient signal is noisy.

### The Solution: Subtract a Baseline

Instead of using raw rewards, use **advantages**:

$$A(s, a_i) = R(s, a_i) - b(s)$$

Where $b(s)$ is a baseline. Responses better than the baseline get positive advantages (encouraged), those worse get negative advantages (discouraged).

**Mathematical guarantee:** Subtracting a constant baseline does NOT change the expected gradient direction - it only reduces variance.

In [None]:
# Demonstrating the baseline's effect

# Scenario: 4 responses with all-positive rewards
rewards_raw = [0.8, 0.3, 0.9, 0.2]

# Without baseline: ALL get encouraged (all positive)
print("WITHOUT baseline:")
for i, r in enumerate(rewards_raw):
    direction = "encourage" if r > 0 else "discourage"
    print(f"  Response {i}: reward={r:+.1f} --> {direction}")
print("  Problem: Even the worst response (0.2) gets encouraged!")

print()

# With baseline: only above-average get encouraged
baseline = sum(rewards_raw) / len(rewards_raw)
advantages = [r - baseline for r in rewards_raw]

print(f"WITH baseline (mean = {baseline:.2f}):")
for i, (r, adv) in enumerate(zip(rewards_raw, advantages)):
    direction = "ENCOURAGE" if adv > 0 else "DISCOURAGE"
    print(f"  Response {i}: reward={r:+.1f}, advantage={adv:+.2f} --> {direction}")
print("  Now only above-average responses get encouraged!")

---

## Part 3: From REINFORCE to GRPO

### The Critic Problem

**PPO** (used for ChatGPT) maintains a **critic network** $V(s)$ that estimates the expected reward for each state. This serves as the baseline. But:

- The critic needs to be as large as the policy (~8B parameters)
- It needs its own optimizer and learning rate schedule  
- Total memory: ~48 GB (impossible on a 24 GB GPU)

### The GRPO Insight

**Instead of *predicting* the expected reward with a critic, *calculate* it directly!**

For each prompt:
1. Generate a **group** of G responses
2. Score all G responses with the reward function
3. Use the **group mean** as the baseline

This is a Monte Carlo estimate of the expected reward - no neural network needed!

### The GRPO Loss

$$\mathcal{L}_{GRPO}(\theta) = -\frac{1}{G} \sum_{i=1}^{G} \hat{A}(s, a_i) \cdot \log \pi_\theta(a_i | s)$$

Where:
$$\hat{A}(s, a_i) = R(s, a_i) - \frac{1}{G} \sum_{j=1}^{G} R(s, a_j)$$

---

## Part 4: `compute_log_probs()` - Line by Line

This function computes $\log \pi_\theta(a|s)$ - the log-probability of a response under the current model.

In [None]:
import torch
import torch.nn.functional as F


def compute_log_probs(model, tokenizer, prompt: str, response: str) -> torch.Tensor:
    """Compute the total log-probability of a response given a prompt.
    
    This is the DIFFERENTIABLE computation that allows gradient flow.
    Even though we generated the response with torch.no_grad(),
    we can compute its probability WITH gradients.
    """
    # STEP 1: Concatenate prompt + response into one sequence
    full_text = prompt + response
    inputs = tokenizer(full_text, return_tensors="pt").to(model.device)
    
    # STEP 2: Find where the prompt ends and response begins
    prompt_len = len(tokenizer(prompt, return_tensors="pt").input_ids[0])
    
    # STEP 3: Forward pass through the model (WITH gradient tracking!)
    # Note: NO torch.no_grad() here - we need gradients for learning
    with torch.no_grad():
        outputs = model(**inputs)
    
    # STEP 4: Extract logits for the response tokens only
    # outputs.logits shape: [1, seq_len, vocab_size]
    # We want logits that PREDICT response tokens:
    #   logits[prompt_len-1] predicts response[0]
    #   logits[prompt_len]   predicts response[1]
    #   ...
    #   logits[-2]           predicts response[-1]
    logits = outputs.logits[0, prompt_len-1:-1]
    
    # STEP 5: Get the actual response token IDs
    target_ids = inputs.input_ids[0, prompt_len:]
    
    # STEP 6: Compute log probabilities
    log_probs = F.log_softmax(logits, dim=-1)
    
    # STEP 7: Gather the log-prob of each actual response token
    # This picks out the probability the model assigned to the tokens
    # that were actually generated
    token_log_probs = log_probs.gather(1, target_ids.unsqueeze(1)).squeeze(1)
    
    # STEP 8: Sum to get total sequence log-probability
    # log P(response|prompt) = sum of log P(token_i | prompt, token_1..i-1)
    return token_log_probs.sum()


print("compute_log_probs() defined.")
print("This function computes log P(response | prompt) under the current model.")

### Understanding the logit indexing

This is the trickiest part. Let's visualize:

```
Full sequence:  [P0, P1, P2, R0, R1, R2]
                 |---prompt--|  |--response-|
                 prompt_len=3

Logits:         [L0, L1, L2, L3, L4, L5]
                  |    |    |    |    |    |
                  v    v    v    v    v    v
Predicts:       [P1,  P2,  R0,  R1,  R2,  <next>]

We want logits that predict R0, R1, R2:
  L2 predicts R0  (index: prompt_len-1 = 2)
  L3 predicts R1
  L4 predicts R2  (index: -2, which is len-2)

So: logits[prompt_len-1 : -1] = logits[2:5] = [L2, L3, L4]
And: target_ids = input_ids[prompt_len:] = [R0, R1, R2]
```

In [None]:
# Visual demonstration of the indexing logic

# Simulating a sequence of 6 tokens: 3 prompt + 3 response
prompt_tokens = ["What", "is", "2+2"]
response_tokens = ["The", "answer", "4"]
full_sequence = prompt_tokens + response_tokens

prompt_len = len(prompt_tokens)  # 3

print("Full sequence:")
for i, token in enumerate(full_sequence):
    label = "prompt" if i < prompt_len else "response"
    print(f"  Position {i}: '{token}' ({label})")

print(f"\nprompt_len = {prompt_len}")
print()

print("Logit-to-prediction mapping:")
for i in range(len(full_sequence) - 1):
    predicted = full_sequence[i + 1]
    in_range = prompt_len - 1 <= i < len(full_sequence) - 1
    marker = " <-- USED" if in_range else ""
    print(f"  Logit[{i}] (at '{full_sequence[i]}') predicts '{predicted}'{marker}")

print(f"\nSlice: logits[{prompt_len-1}:-1] = logits[{prompt_len-1}:{len(full_sequence)-1}]")
print(f"Targets: input_ids[{prompt_len}:] = {response_tokens}")

### Understanding `gather()`

The `gather` operation picks specific values from a tensor along a dimension:

In [None]:
# Demonstrating gather

# Imagine log_probs for 3 tokens, vocabulary size 5
vocab = ["the", "answer", "is", "4", "hello"]
log_probs = torch.tensor([
    [-2.0, -0.5, -3.0, -4.0, -5.0],  # Position 0: "answer" most likely
    [-3.0, -4.0, -0.3, -2.0, -5.0],  # Position 1: "is" most likely
    [-4.0, -5.0, -3.0, -0.1, -6.0],  # Position 2: "4" most likely
])

# The actual response tokens (as indices into vocab)
target_ids = torch.tensor([1, 2, 3])  # "answer", "is", "4"

# gather picks out the log-prob of the actual token at each position
token_log_probs = log_probs.gather(1, target_ids.unsqueeze(1)).squeeze(1)

print("Log-probability table (rows=position, cols=vocab):")
print(f"{'':>12}", end="")
for v in vocab:
    print(f"{v:>8}", end="")
print()

for i in range(3):
    print(f"Position {i}:", end="")
    for j in range(5):
        marker = " *" if j == target_ids[i] else "  "
        print(f"{log_probs[i][j].item():>6.1f}{marker}", end="")
    print()

print(f"\nGathered log-probs (* marks): {token_log_probs.tolist()}")
print(f"Sum (total sequence log-prob): {token_log_probs.sum().item():.1f}")
print(f"\nThis means the model assigned probability exp({token_log_probs.sum().item():.1f})")
print(f"= {torch.exp(token_log_probs.sum()).item():.6f} to the response 'answer is 4'")

---

## Part 5: `grpo_step()` - The Complete Algorithm

This is the heart of the system. Let's walk through it piece by piece.

In [None]:
from collections.abc import Callable


def grpo_step(
    model,
    tokenizer,
    prompt: str,
    ground_truth: str,
    reward_fn: Callable[[str, str], float],
    G: int = 4,                 # Group size
    temperature: float = 0.7,
    max_new_tokens: int = 256,
) -> tuple[float, float, list[str]]:
    """
    One step of GRPO:
    1. Generate G responses for a single prompt
    2. Score all G responses
    3. Compute advantages using group mean as baseline
    4. Compute policy gradient loss
    5. Backpropagate
    
    Returns: (loss_value, mean_reward, responses)
    """
    model.train()  # Enable dropout etc.
    
    # ============================================================
    # PHASE 1: Generate G responses (INFERENCE - no gradients)
    # ============================================================
    messages = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    responses = []
    for _ in range(G):
        with torch.no_grad():  # No gradients during generation!
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
            )
        response = tokenizer.decode(
            output_ids[0][inputs.input_ids.shape[1]:],
            skip_special_tokens=True,
        )
        responses.append(response)
    
    # ============================================================
    # PHASE 2: Score all responses and compute advantages
    # ============================================================
    rewards = torch.tensor(
        [reward_fn(r, ground_truth) for r in responses],
        dtype=torch.float32,
        device=model.device,
    )
    
    # The GRPO baseline: group mean
    baseline = rewards.mean()
    advantages = rewards - baseline
    
    # ============================================================
    # PHASE 3: Compute GRPO loss (WITH gradients)
    # ============================================================
    loss = torch.tensor(0.0, device=model.device, requires_grad=True)
    
    for response, adv in zip(responses, advantages):
        # Skip if advantage is ~0 (all responses got same reward)
        # Nothing to learn when there's no contrast
        if adv.abs() < 1e-8:
            continue
        
        # Compute log P(response | prompt) - THIS IS DIFFERENTIABLE
        log_prob = compute_log_probs(model, tokenizer, formatted_prompt, response)
        
        # Accumulate: loss = -(1/G) * sum(advantage * log_prob)
        loss = loss - adv * log_prob / G
    
    # ============================================================
    # PHASE 4: Backpropagate
    # ============================================================
    if loss.requires_grad:
        loss.backward()  # Computes gradients on all LoRA parameters
    
    return loss.item(), rewards.mean().item(), responses


print("grpo_step() defined.")
print("This function performs one complete GRPO update step.")

### The Three Phases Visualized

```
PHASE 1: GENERATION (torch.no_grad)
  Prompt: "What is 15 * 23?"
  |
  +--> Response 1: "The answer is 345"     (correct)
  +--> Response 2: "I think it's 355"      (wrong)  
  +--> Response 3: "345"                   (correct)
  +--> Response 4: "Hmm, not sure"         (no answer)

PHASE 2: SCORING & ADVANTAGES
  Rewards:    [+1.0,  -0.5,  +1.0,  -1.0]    (from reward_fn)
  Baseline:   +0.125                           (mean)
  Advantages: [+0.875, -0.625, +0.875, -1.125] (rewards - baseline)

PHASE 3: LOSS COMPUTATION (with gradients!)
  For each (response, advantage) pair:
    loss -= advantage * log_prob / G
  
  Positive advantage --> loss becomes more negative --> gradient pushes UP log_prob
  Negative advantage --> loss becomes more positive --> gradient pushes DOWN log_prob
```

---

## Part 6: Simulating GRPO

Let's trace through the algorithm with concrete numbers (no model needed):

In [None]:
import torch

# Simulate a GRPO step with G=4
G = 4

# Simulated rewards for 4 responses
rewards = torch.tensor([1.0, -0.5, 1.0, -1.0])
response_descriptions = [
    "'The answer is 345'      (correct)",
    "'I think it's 355'       (wrong answer)",
    "'345'                    (correct)",
    "'Hmm, not sure'          (no answer)",
]

# Simulated log-probabilities (as if computed by compute_log_probs)
log_probs = torch.tensor([-15.2, -18.7, -12.3, -22.1])

print("=" * 70)
print("GRPO Step Simulation")
print("=" * 70)
print(f"\nPrompt: 'What is 15 * 23?'")
print(f"Ground truth: 345")
print(f"Group size G: {G}")

# Phase 1: Show responses and rewards
print(f"\n--- Phase 1: Responses & Rewards ---")
for i in range(G):
    print(f"  Response {i+1}: {response_descriptions[i]}")
    print(f"    Reward: {rewards[i]:+.1f}")

# Phase 2: Compute advantages
baseline = rewards.mean()
advantages = rewards - baseline

print(f"\n--- Phase 2: Advantages ---")
print(f"  Baseline (mean reward): {baseline:.4f}")
for i in range(G):
    sign = "+" if advantages[i] > 0 else ""
    print(f"  Response {i+1}: advantage = {rewards[i]:+.1f} - {baseline:.4f} = {sign}{advantages[i]:.4f}")

# Phase 3: Compute loss
loss = 0.0
print(f"\n--- Phase 3: Loss Computation ---")
for i in range(G):
    if abs(advantages[i]) < 1e-8:
        print(f"  Response {i+1}: advantage ~0, SKIPPED")
        continue
    
    contribution = -advantages[i].item() * log_probs[i].item() / G
    loss += contribution
    print(f"  Response {i+1}: -({advantages[i]:+.4f}) * ({log_probs[i]:.1f}) / {G} = {contribution:+.4f}")

print(f"\n  Total loss: {loss:.4f}")

print(f"\n--- Phase 4: Gradient Direction ---")
for i in range(G):
    if advantages[i] > 0:
        print(f"  Response {i+1}: Positive advantage --> INCREASE probability")
    elif advantages[i] < 0:
        print(f"  Response {i+1}: Negative advantage --> DECREASE probability")

---

## Part 7: Why GRPO Works on a Single GPU

### Memory Comparison

| | PPO | GRPO |
|---|---|---|
| Policy model | 8B params (4-bit) | 8B params (4-bit) |
| Critic model | 8B params (16-bit, ~16 GB) | **None** |
| Critic optimizer | ~32 GB | **None** |
| Total extra cost | ~48 GB | 0 GB |
| Fits on 24 GB GPU? | No | **Yes** |

### Trade-offs

| Advantage | Disadvantage |
|---|---|
| No critic = no extra memory | Higher variance baseline (small G) |
| Simpler code (no second optimizer) | More generation cost (G forward passes) |
| Fresh baseline every step | New hyperparameter G to tune |
| Stable training | Less sample-efficient than PPO |

---

## Part 8: The Skip Condition

```python
if adv.abs() < 1e-8:
    continue
```

### When does this trigger?

When **all G responses get the same reward**. For example, if all 4 responses are correct (+1.0 each), then:

```
baseline = (1.0 + 1.0 + 1.0 + 1.0) / 4 = 1.0
advantage_i = 1.0 - 1.0 = 0.0 for all i
```

**There's nothing to learn** when all responses are equally good (or equally bad). The group provides no relative signal.

In [None]:
# Example: when skip triggers

scenarios = [
    ("All correct",    [1.0, 1.0, 1.0, 1.0]),
    ("All wrong",      [-0.5, -0.5, -0.5, -0.5]),
    ("Mixed (useful)", [1.0, -0.5, 1.0, -1.0]),
    ("3 correct, 1 wrong", [1.0, 1.0, 1.0, -0.5]),
]

print("When does learning happen?")
print("=" * 60)
for label, rewards in scenarios:
    mean_r = sum(rewards) / len(rewards)
    advantages = [r - mean_r for r in rewards]
    has_nonzero = any(abs(a) > 1e-8 for a in advantages)
    
    status = "LEARNING" if has_nonzero else "SKIP (all same)"
    print(f"\n{label}: rewards={rewards}")
    print(f"  Baseline: {mean_r:.4f}")
    print(f"  Advantages: {[f'{a:+.4f}' for a in advantages]}")
    print(f"  --> {status}")

---

## Part 9: Choosing the Group Size G

G is a critical hyperparameter that balances several factors:

In [None]:
import random

# Simulate how G affects baseline quality
# True expected reward for this prompt: 0.3
true_expected_reward = 0.3

# Simulate many trials for each G value
random.seed(42)
n_trials = 1000

print("Effect of G on baseline estimation quality")
print("(True expected reward = 0.3)")
print(f"\n{'G':<5} {'Mean baseline':<18} {'Std of baseline':<18} {'Variance reduction'}")
print("-" * 65)

for G in [2, 4, 8, 16, 32]:
    baselines = []
    for _ in range(n_trials):
        # Simulate G rewards drawn from some distribution
        rewards = [random.choice([1.0, -0.5, -1.0]) for _ in range(G)]
        baselines.append(sum(rewards) / len(rewards))
    
    import statistics
    mean_bl = statistics.mean(baselines)
    std_bl = statistics.stdev(baselines)
    print(f"{G:<5} {mean_bl:<18.4f} {std_bl:<18.4f} {'Low' if G <= 2 else 'Medium' if G <= 8 else 'High'}")

print("\nThe book uses G=4 as a sweet spot:")
print("  - Enough diversity for meaningful advantages")
print("  - Not too many forward passes (4 generations per step)")
print("  - Fits in 24 GB GPU memory with Qwen3-8B")

---

## Exercises

### Exercise 1: Advantage Normalization
The current code uses raw advantages (`rewards - mean`). Many implementations also **normalize** advantages by dividing by the standard deviation: `(rewards - mean) / (std + epsilon)`. Implement this and think about when it would help vs hurt.

### Exercise 2: KL Divergence Penalty
The full GRPO paper includes a KL divergence penalty to prevent the model from drifting too far from the original pretrained model. This looks like: `loss += beta * KL(pi_current || pi_reference)`. Why might this be important? What happens without it?

### Exercise 3: Varying Group Size
Modify the simulation to show what happens with G=2 (very noisy baseline) vs G=16 (accurate but slow). At what point does increasing G have diminishing returns?

### Exercise 4: Understanding the Loss Sign
Trace through the math: if advantage is +0.5 and log_prob is -15.0, what is the contribution to the loss? What direction will the optimizer move? What about advantage -0.5 and log_prob -15.0?

---

## Key Takeaways

1. **The log-derivative trick** lets us compute gradients even though sampling is non-differentiable. We compute the log-probability of the sampled response and multiply by the reward.

2. **Baselines reduce variance** by turning raw rewards into advantages. Only responses better than average get encouraged; worse than average get discouraged.

3. **GRPO uses the group mean as the baseline**, eliminating the need for a critic network. This saves ~48 GB of memory.

4. **The algorithm has two gradient modes**: Generation (no_grad, inference only) and loss computation (with grad, differentiable).

5. **G=4 is the sweet spot** for single-GPU training: enough diversity for meaningful learning, not too expensive in forward passes.

6. **When all responses get the same reward**, there's nothing to learn - the step is skipped.

---

**Previous:** [Chapter 3 - Reward Signals](../ch03_rewards/learn_rewards.ipynb)  
**Next:** [Chapter 7 - Training Loop](../ch07_training/learn_training.ipynb) - How do we put all the pieces together into a complete training script?