# Variance Reduction: Making Policy Gradients Practical

REINFORCE is elegant but has high variance. This notebook covers the key techniques that make policy gradients work in practice!

## What You'll Learn

By the end of this notebook, you'll understand:
- The exam grading analogy: why baselines help
- Baseline subtraction: the variance reduction trick
- Why V(s) is the optimal baseline
- The advantage function A(s, a)
- Implementing baselines in PyTorch
- Comparing REINFORCE with and without baseline

**Prerequisites:** Notebook 2 (REINFORCE Algorithm)

**Time:** ~25 minutes

---
## The Big Picture: The Exam Grading Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE EXAM GRADING ANALOGY                              │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine you're a student getting exam scores...              │
    │                                                                │
    │  WITHOUT BASELINE (Raw Scores):                               │
    │    "You scored 75/100 on the exam"                           │
    │    Is this good? Bad? Hard to tell!                          │
    │    • Easy exam: 75 is bad (everyone got 90+)                 │
    │    • Hard exam: 75 is great (average was 50)                 │
    │                                                                │
    │  WITH BASELINE (Curve Grading):                               │
    │    "You scored 75, class average was 60"                     │
    │    Advantage = 75 - 60 = +15                                 │
    │    This is CLEARLY good! You beat the baseline!              │
    │                                                                │
    │  HOW THIS HELPS:                                              │
    │    • If return > baseline: Action was BETTER than average    │
    │    • If return < baseline: Action was WORSE than average     │
    │    • Much clearer signal for learning!                       │
    │                                                                │
    │  THE MAGIC: Subtracting baseline doesn't change expected     │
    │  gradient, but REDUCES VARIANCE!                             │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle

try:
    import gymnasium as gym
except ImportError:
    import gym

# Visualize why baselines help
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Without baseline (high variance)
ax1 = axes[0]
np.random.seed(42)
returns = np.random.normal(100, 30, 50)  # Mean 100, std 30

ax1.bar(range(len(returns)), returns, color='#64b5f6', edgecolor='black', linewidth=0.5)
ax1.axhline(y=0, color='black', linewidth=2)
ax1.axhline(y=100, color='red', linewidth=2, linestyle='--', label='Mean: 100')
ax1.set_xlabel('Episode', fontsize=11)
ax1.set_ylabel('Return G', fontsize=11)
ax1.set_title('Without Baseline\n(All positive → Always increase!)', fontsize=12, fontweight='bold', color='#d32f2f')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')
ax1.text(25, 20, 'Problem: ALL returns are positive!\nEven bad actions get reinforced!', 
         ha='center', fontsize=10, color='#d32f2f')

# Right: With baseline (centered)
ax2 = axes[1]
baseline = 100
advantages = returns - baseline  # Subtract baseline

colors = ['#4caf50' if a > 0 else '#f44336' for a in advantages]
ax2.bar(range(len(advantages)), advantages, color=colors, edgecolor='black', linewidth=0.5)
ax2.axhline(y=0, color='black', linewidth=2, label='Baseline')
ax2.set_xlabel('Episode', fontsize=11)
ax2.set_ylabel('Advantage (G - baseline)', fontsize=11)
ax2.set_title('With Baseline\n(Clear good vs bad signal!)', fontsize=12, fontweight='bold', color='#388e3c')
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')
ax2.text(25, -50, 'Better: Green = above average (reinforce!)\nRed = below average (discourage!)', 
         ha='center', fontsize=10, color='#388e3c')

plt.tight_layout()
plt.show()

print("\nWHY BASELINES HELP:")
print(f"  Without baseline: All returns positive (mean {np.mean(returns):.0f})")
print(f"  With baseline: Centered around 0 (mean {np.mean(advantages):.0f})")
print(f"\n  Variance unchanged: {np.std(returns):.1f} vs {np.std(advantages):.1f}")
print(f"  But now we can clearly distinguish good from bad!")

---
## The Math: Why Baselines Don't Add Bias

```
    ┌────────────────────────────────────────────────────────────────┐
    │              BASELINE SUBTRACTION MATH                         │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  ORIGINAL GRADIENT:                                           │
    │    ∇_θ J = E[ ∇_θ log π(a|s) × G ]                            │
    │                                                                │
    │  WITH BASELINE b(s):                                          │
    │    ∇_θ J = E[ ∇_θ log π(a|s) × (G - b(s)) ]                   │
    │                                                                │
    │  WHY DOESN'T THIS ADD BIAS?                                   │
    │                                                                │
    │    E[ ∇_θ log π(a|s) × b(s) ]                                 │
    │    = b(s) × E[ ∇_θ log π(a|s) ]  (b doesn't depend on a)     │
    │    = b(s) × ∇_θ Σ_a π(a|s)       (by definition of log grad) │
    │    = b(s) × ∇_θ 1               (probabilities sum to 1)     │
    │    = 0                                                        │
    │                                                                │
    │  KEY INSIGHT: Any baseline that doesn't depend on a can be   │
    │  subtracted without changing the expected gradient!          │
    │                                                                │
    │  OPTIMAL BASELINE:                                            │
    │    b*(s) = E[G|s] = V(s)                                     │
    │    The value function minimizes variance!                    │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate variance reduction mathematically

print("VARIANCE REDUCTION DEMONSTRATION")
print("="*60)

# Simulate returns from a state
np.random.seed(42)
n_samples = 1000

# True expected return from this state
true_value = 50

# Actual returns have variance
returns = np.random.normal(true_value, 20, n_samples)

# Simulate log probabilities (doesn't matter what they are for this demo)
log_probs = np.random.randn(n_samples)

# Gradient estimates without baseline
grad_no_baseline = log_probs * returns

# Gradient estimates with baseline (using true V(s))
baseline = true_value
grad_with_baseline = log_probs * (returns - baseline)

print(f"\nTrue V(s) = {true_value}")
print(f"\nGradient WITHOUT baseline:")
print(f"  Mean: {np.mean(grad_no_baseline):.2f}")
print(f"  Variance: {np.var(grad_no_baseline):.2f}")
print(f"  Std: {np.std(grad_no_baseline):.2f}")

print(f"\nGradient WITH baseline (V(s)):")
print(f"  Mean: {np.mean(grad_with_baseline):.2f}")
print(f"  Variance: {np.var(grad_with_baseline):.2f}")
print(f"  Std: {np.std(grad_with_baseline):.2f}")

variance_reduction = (np.var(grad_no_baseline) - np.var(grad_with_baseline)) / np.var(grad_no_baseline)
print(f"\n  Variance reduced by: {variance_reduction*100:.1f}%!")
print(f"  Same expected value, much lower variance!")
print("="*60)

In [None]:
# Visualize the variance reduction
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Distribution of gradient estimates
ax1 = axes[0]
ax1.hist(grad_no_baseline, bins=50, alpha=0.5, color='red', label='No baseline', density=True)
ax1.hist(grad_with_baseline, bins=50, alpha=0.5, color='green', label='With baseline', density=True)
ax1.axvline(x=0, color='black', linewidth=2, linestyle='--')
ax1.set_xlabel('Gradient Estimate', fontsize=11)
ax1.set_ylabel('Density', fontsize=11)
ax1.set_title('Distribution of Gradient Estimates', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Effect of baseline choice
ax2 = axes[1]
baselines = np.linspace(0, 100, 50)
variances = []

for b in baselines:
    grad = log_probs * (returns - b)
    variances.append(np.var(grad))

ax2.plot(baselines, variances, 'b-', linewidth=2)
ax2.axvline(x=true_value, color='red', linewidth=2, linestyle='--', label=f'V(s) = {true_value}')
ax2.scatter([true_value], [min(variances)], c='red', s=100, zorder=5, label='Minimum variance')
ax2.set_xlabel('Baseline Value', fontsize=11)
ax2.set_ylabel('Gradient Variance', fontsize=11)
ax2.set_title('V(s) is the Optimal Baseline!', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKEY INSIGHT:")
print("  The value function V(s) is the optimal baseline!")
print("  It minimizes gradient variance while keeping it unbiased.")

---
## The Advantage Function

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE ADVANTAGE FUNCTION                            │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  DEFINITION:                                                  │
    │    A(s, a) = Q(s, a) - V(s)                                   │
    │                                                                │
    │  INTUITION:                                                   │
    │    • Q(s, a): Expected return from taking a in state s       │
    │    • V(s): Expected return from state s (under policy)       │
    │    • A(s, a): How much BETTER is action a than average?      │
    │                                                                │
    │  PROPERTIES:                                                  │
    │    • A(s, a) > 0: Action a is BETTER than average            │
    │    • A(s, a) < 0: Action a is WORSE than average             │
    │    • E[A(s, a)] = 0: Advantages sum to zero!                 │
    │                                                                │
    │  IN PRACTICE:                                                 │
    │    We estimate A using returns and learned V(s):             │
    │    A_t ≈ G_t - V(s_t)                                        │
    │                                                                │
    │  THE POLICY GRADIENT BECOMES:                                 │
    │    ∇_θ J = E[ ∇_θ log π(a|s) × A(s, a) ]                     │
    │                                                                │
    │  This is what Actor-Critic methods use!                      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the advantage function
fig, ax = plt.subplots(figsize=(12, 6))

# Example: 3 actions from some state
actions = ['Action A\n(Good)', 'Action B\n(Average)', 'Action C\n(Bad)']
q_values = [80, 50, 20]  # Q(s, a)
v_value = 50  # V(s) = average over all actions
advantages = [q - v_value for q in q_values]  # A(s, a) = Q(s, a) - V(s)

x = np.arange(len(actions))
width = 0.35

# Bar chart
bars_q = ax.bar(x - width/2, q_values, width, label='Q(s, a)', color='#64b5f6', edgecolor='black')
colors_a = ['#4caf50' if a > 0 else '#f44336' for a in advantages]
bars_a = ax.bar(x + width/2, advantages, width, label='A(s, a)', color=colors_a, edgecolor='black')

ax.axhline(y=v_value, color='orange', linewidth=2, linestyle='--', label=f'V(s) = {v_value}')
ax.axhline(y=0, color='black', linewidth=1)

ax.set_ylabel('Value', fontsize=11)
ax.set_title('Q-Values vs Advantages\nA(s,a) = Q(s,a) - V(s)', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(actions)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Annotations
for i, (q, a) in enumerate(zip(q_values, advantages)):
    sign = '+' if a > 0 else ''
    ax.annotate(f'A={sign}{a}', xy=(i + width/2, a), xytext=(i + width/2 + 0.1, a + 5),
                fontsize=10, fontweight='bold', color=colors_a[i])

plt.tight_layout()
plt.show()

print("\nADVANTAGE INTERPRETATION:")
print(f"  V(s) = {v_value} (expected return from this state)")
for action, q, a in zip(actions, q_values, advantages):
    sign = '+' if a > 0 else ''
    verdict = "REINFORCE" if a > 0 else "DISCOURAGE"
    action_name = action.split('\n')[0]
    print(f"  {action_name}: Q={q}, A={sign}{a} → {verdict}")

---
## Implementing REINFORCE with Baseline

In [None]:
class PolicyWithBaseline(nn.Module):
    """
    Policy network with value function baseline.
    
    Two heads on shared features:
    - Policy head: π(a|s) for action selection
    - Value head: V(s) for baseline
    
    This architecture is used in Actor-Critic!
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        
        # Shared feature extraction
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Policy head: outputs action probabilities
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Value head: outputs V(s) estimate
        self.value_head = nn.Linear(hidden_dim, 1)
    
    def forward(self, state):
        """Returns both policy and value."""
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        
        features = self.shared(state)
        action_probs = self.policy_head(features)
        value = self.value_head(features)
        
        return action_probs, value
    
    def get_action(self, state):
        """Sample action and return action, log_prob, and value."""
        action_probs, value = self.forward(state)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob, value


# Demonstrate
policy_baseline = PolicyWithBaseline(state_dim=4, action_dim=2)
print("POLICY WITH BASELINE")
print("="*60)
print(policy_baseline)

# Test forward pass
test_state = torch.randn(1, 4)
probs, value = policy_baseline(test_state)
print(f"\nTest state: {test_state.numpy()[0].round(2)}")
print(f"Action probs: {probs.detach().numpy()[0].round(3)}")
print(f"V(s): {value.item():.3f}")
print("="*60)

In [None]:
def reinforce_with_baseline(env_name='CartPole-v1', n_episodes=500, gamma=0.99, 
                            lr=1e-2, print_every=50):
    """
    REINFORCE with learned value function baseline.
    
    Key differences from vanilla REINFORCE:
    1. Learn V(s) alongside π(a|s)
    2. Use advantage A = G - V(s) instead of raw return
    3. Two losses: policy loss + value loss
    """
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    policy = PolicyWithBaseline(state_dim, action_dim)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    
    rewards_history = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        
        log_probs = []
        values = []
        rewards = []
        
        # ----------------------------------------
        # Collect episode
        # ----------------------------------------
        for _ in range(500):
            action, log_prob, value = policy.get_action(state)
            
            log_probs.append(log_prob)
            values.append(value)
            
            next_state, reward, terminated, truncated, _ = env.step(action)
            rewards.append(reward)
            
            state = next_state
            if terminated or truncated:
                break
        
        # ----------------------------------------
        # Compute returns and advantages
        # ----------------------------------------
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns)
        
        values = torch.cat(values)
        
        # Advantage = Return - Baseline
        advantages = returns - values.detach()  # detach to not backprop through baseline
        
        # ----------------------------------------
        # Compute losses
        # ----------------------------------------
        # Policy loss: -log π × A
        policy_loss = 0
        for log_prob, adv in zip(log_probs, advantages):
            policy_loss -= log_prob * adv
        
        # Value loss: (V(s) - G)²
        value_loss = nn.functional.mse_loss(values, returns)
        
        # Total loss
        loss = policy_loss + 0.5 * value_loss  # 0.5 weight for value loss
        
        # ----------------------------------------
        # Update
        # ----------------------------------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track progress
        episode_reward = sum(rewards)
        rewards_history.append(episode_reward)
        
        if (episode + 1) % print_every == 0:
            avg_reward = np.mean(rewards_history[-print_every:])
            print(f"Episode {episode+1:4d} | Avg Reward: {avg_reward:6.1f}")
    
    env.close()
    return policy, rewards_history

In [None]:
# Train and compare
print("COMPARING VANILLA REINFORCE vs REINFORCE WITH BASELINE")
print("="*70)

print("\n1. Training REINFORCE with Baseline...\n")
policy_with, rewards_with = reinforce_with_baseline(
    n_episodes=500, print_every=100
)

In [None]:
# Also train vanilla REINFORCE for comparison
def reinforce_vanilla(env_name='CartPole-v1', n_episodes=500, gamma=0.99, lr=1e-2):
    """Vanilla REINFORCE without baseline."""
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    policy = nn.Sequential(
        nn.Linear(state_dim, 128), nn.ReLU(),
        nn.Linear(128, 128), nn.ReLU(),
        nn.Linear(128, action_dim), nn.Softmax(dim=-1)
    )
    optimizer = optim.Adam(policy.parameters(), lr=lr)
    
    rewards_history = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        log_probs = []
        rewards = []
        
        for _ in range(500):
            state_tensor = torch.FloatTensor(state)
            probs = policy(state_tensor)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            log_probs.append(dist.log_prob(action))
            
            state, reward, terminated, truncated, _ = env.step(action.item())
            rewards.append(reward)
            
            if terminated or truncated:
                break
        
        # Compute returns
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns)
        
        # Vanilla REINFORCE loss (no baseline!)
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss -= log_prob * G
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        rewards_history.append(sum(rewards))
    
    env.close()
    return rewards_history

print("\n2. Training Vanilla REINFORCE (no baseline)...")
rewards_vanilla = reinforce_vanilla(n_episodes=500)
print("Done!")

In [None]:
# Compare results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Learning curves
ax1 = axes[0]
window = 50

smoothed_with = np.convolve(rewards_with, np.ones(window)/window, mode='valid')
smoothed_vanilla = np.convolve(rewards_vanilla, np.ones(window)/window, mode='valid')

ax1.plot(range(window-1, len(rewards_vanilla)), smoothed_vanilla, 
         color='red', linewidth=2, label='Vanilla REINFORCE')
ax1.plot(range(window-1, len(rewards_with)), smoothed_with, 
         color='green', linewidth=2, label='With Baseline')
ax1.axhline(y=500, color='gray', linestyle='--', linewidth=1, label='Max Score')

ax1.set_xlabel('Episode', fontsize=11)
ax1.set_ylabel('Reward (50-ep moving avg)', fontsize=11)
ax1.set_title('Learning Speed: Baseline Helps!', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Variance comparison
ax2 = axes[1]

# Compute rolling std
rolling_std_vanilla = [np.std(rewards_vanilla[max(0,i-50):i+1]) for i in range(len(rewards_vanilla))]
rolling_std_with = [np.std(rewards_with[max(0,i-50):i+1]) for i in range(len(rewards_with))]

ax2.plot(rolling_std_vanilla, color='red', alpha=0.5, linewidth=1, label='Vanilla')
ax2.plot(rolling_std_with, color='green', alpha=0.5, linewidth=1, label='With Baseline')

# Smoothed std
ax2.plot(np.convolve(rolling_std_vanilla, np.ones(30)/30, mode='valid'), 
         color='red', linewidth=2)
ax2.plot(np.convolve(rolling_std_with, np.ones(30)/30, mode='valid'), 
         color='green', linewidth=2)

ax2.set_xlabel('Episode', fontsize=11)
ax2.set_ylabel('Rolling Std (50 episodes)', fontsize=11)
ax2.set_title('Variance: Baseline Reduces It!', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nFINAL STATISTICS:")
print(f"  Vanilla REINFORCE - Final avg: {np.mean(rewards_vanilla[-50:]):.1f}")
print(f"  With Baseline - Final avg: {np.mean(rewards_with[-50:]):.1f}")

---
## Summary: Key Takeaways

### Baseline Subtraction

| Concept | Description |
|---------|-------------|
| **Baseline** | Value to subtract from returns |
| **Purpose** | Reduce variance without adding bias |
| **Optimal** | V(s) = expected return from state |

### The Advantage Function

```
A(s, a) = Q(s, a) - V(s)
        = "How much better is this action than average?"

A > 0: Action is better than average → Reinforce!
A < 0: Action is worse than average → Discourage!
```

### Policy Gradient with Baseline

```
∇_θ J = E[ ∇_θ log π(a|s) × A(s, a) ]
      = E[ ∇_θ log π(a|s) × (G - V(s)) ]
```

---
## Test Your Understanding

**1. Why does subtracting a baseline not add bias?**
<details>
<summary>Click to reveal answer</summary>
Because E[∇log π(a|s) × b(s)] = 0 when b doesn't depend on the action a. The math shows that ∇Σπ(a|s) = ∇1 = 0 since probabilities sum to 1. So we can subtract any state-dependent baseline without changing the expected gradient!
</details>

**2. Why is V(s) the optimal baseline?**
<details>
<summary>Click to reveal answer</summary>
V(s) minimizes the variance of the gradient estimate. Intuitively, it's the expected return from state s, so (G - V(s)) measures how much better or worse this specific trajectory was compared to average. This gives a clearer signal than raw returns.
</details>

**3. What is the advantage function?**
<details>
<summary>Click to reveal answer</summary>
A(s, a) = Q(s, a) - V(s) measures how much better action a is compared to the average action in state s. Positive advantage means the action is better than average; negative means worse. This is what Actor-Critic methods optimize!
</details>

**4. Why do we have two losses (policy and value)?**
<details>
<summary>Click to reveal answer</summary>
- Policy loss trains π(a|s) to take actions with high advantage
- Value loss trains V(s) to accurately predict returns (better baseline)

Both are needed: without good V(s) estimates, the baseline is useless. Without policy loss, we never learn which actions are good.
</details>

**5. How much does baseline help in practice?**
<details>
<summary>Click to reveal answer</summary>
A LOT! Without baseline, all returns are typically positive, so all actions get reinforced (just by different amounts). With baseline, we clearly distinguish better-than-average (reinforce) from worse-than-average (discourage). This dramatically speeds up learning and reduces variance.
</details>

---
## What's Next?

We've now seen how to reduce variance with baselines. But we still wait for full episodes! In the next notebook, we'll learn **Actor-Critic** methods that update after every step!

**Continue to:** [Notebook 4: Actor-Critic Methods](04_actor_critic.ipynb)

---

*Baselines: Simple idea, huge impact!*