# Actor-Critic: The Best of Both Worlds

Actor-Critic methods combine policy gradients with value functions, getting the benefits of both!

## What You'll Learn

By the end of this notebook, you'll understand:
- The acting coach analogy: why immediate feedback matters
- The actor-critic architecture (policy + value)
- TD learning for the critic (bootstrapping)
- Why actor-critic beats REINFORCE
- Implementing actor-critic from scratch
- Online vs episodic updates

**Prerequisites:** Notebook 3 (Variance Reduction)

**Time:** ~30 minutes

---
## The Big Picture: The Acting Coach Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE ACTING COACH ANALOGY                              │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine learning to act in a play...                         │
    │                                                                │
    │  REINFORCE (Monte Carlo):                                     │
    │    Perform the ENTIRE play, then get feedback:                │
    │    "The audience gave you 7/10 for the whole performance"     │
    │    Which scene was good? Which was bad? Hard to tell!         │
    │                                                                │
    │  ACTOR-CRITIC (TD Learning):                                  │
    │    A coach watches EACH scene and gives feedback:             │
    │    "That scene was better than I expected! Good job!"         │
    │    "That scene was worse than expected. Try differently."     │
    │    You improve after EVERY scene, not just at the end!        │
    │                                                                │
    │  THE TWO ROLES:                                               │
    │    ACTOR (Policy): You, the performer, deciding how to act    │
    │    CRITIC (Value): The coach, judging how good things are     │
    │                                                                │
    │  WHY THIS WORKS:                                              │
    │    • Immediate feedback → faster learning                     │
    │    • Critic reduces variance by predicting outcomes           │
    │    • No need to wait for episode to end!                      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, Rectangle

try:
    import gymnasium as gym
except ImportError:
    import gym

# Visualize Actor-Critic vs REINFORCE
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: REINFORCE (wait for episode end)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('REINFORCE (Monte Carlo)\n"Wait for the whole play"', fontsize=14, fontweight='bold', color='#d32f2f')

# Episode timeline
for i in range(5):
    x = 1.5 + i * 1.5
    box = FancyBboxPatch((x, 5.5), 1, 1.5, boxstyle="round,pad=0.1",
                          facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
    ax1.add_patch(box)
    ax1.text(x + 0.5, 6.25, f'Step {i+1}', ha='center', fontsize=9)
    if i < 4:
        ax1.annotate('', xy=(x + 1.1, 6.25), xytext=(x + 1.4, 6.25),
                     arrowprops=dict(arrowstyle='->', lw=1, color='#666'))

# Final feedback arrow
ax1.annotate('', xy=(5, 4.5), xytext=(8.5, 5.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f',
                            connectionstyle='arc3,rad=-0.3'))
ax1.text(5, 3.5, 'Total Return: G = 47', ha='center', fontsize=11, 
         color='#d32f2f', fontweight='bold')
ax1.text(5, 2.5, '"How was the whole thing?"', ha='center', fontsize=10, 
         style='italic', color='#666')

ax1.text(5, 1.5, '❌ Wait until end\n❌ High variance\n❌ Slow credit assignment', 
         ha='center', fontsize=10, color='#d32f2f')

# Right: Actor-Critic (immediate feedback)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Actor-Critic (TD Learning)\n"Feedback after each scene"', fontsize=14, fontweight='bold', color='#388e3c')

# Episode timeline with feedback arrows
for i in range(5):
    x = 1.5 + i * 1.5
    box = FancyBboxPatch((x, 5.5), 1, 1.5, boxstyle="round,pad=0.1",
                          facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
    ax2.add_patch(box)
    ax2.text(x + 0.5, 6.25, f'Step {i+1}', ha='center', fontsize=9)
    
    # Immediate feedback arrow
    ax2.annotate('', xy=(x + 0.5, 5.4), xytext=(x + 0.5, 4.2),
                 arrowprops=dict(arrowstyle='->', lw=1, color='#388e3c'))
    
    if i < 4:
        ax2.annotate('', xy=(x + 1.1, 6.25), xytext=(x + 1.4, 6.25),
                     arrowprops=dict(arrowstyle='->', lw=1, color='#666'))

# TD errors
td_errors = ['+2', '+1', '-1', '+3', '+2']
for i, td in enumerate(td_errors):
    x = 1.5 + i * 1.5
    color = '#388e3c' if '+' in td else '#d32f2f'
    ax2.text(x + 0.5, 3.8, f'δ={td}', ha='center', fontsize=9, color=color, fontweight='bold')

ax2.text(5, 2.5, '"Better than expected!" or "Worse than expected!"', 
         ha='center', fontsize=10, style='italic', color='#666')

ax2.text(5, 1.5, '✓ Immediate updates\n✓ Lower variance\n✓ Faster learning', 
         ha='center', fontsize=10, color='#388e3c')

plt.tight_layout()
plt.show()

print("\nKEY INSIGHT:")
print("  REINFORCE: Learn from total return (high variance, slow)")
print("  Actor-Critic: Learn from TD error (lower variance, fast)")
print("  TD error δ = 'How much better/worse than expected?'")

---
## The Actor-Critic Architecture

```
    ┌────────────────────────────────────────────────────────────────┐
    │              ACTOR-CRITIC ARCHITECTURE                         │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │                     ┌─────────────┐                            │
    │                     │    STATE    │                            │
    │                     │     s_t     │                            │
    │                     └──────┬──────┘                            │
    │                            │                                   │
    │                 ┌──────────┴──────────┐                        │
    │                 │                     │                        │
    │          ┌──────▼──────┐       ┌──────▼──────┐                 │
    │          │   ACTOR     │       │   CRITIC    │                 │
    │          │   π(a|s;θ)  │       │   V(s;w)    │                 │
    │          │             │       │             │                 │
    │          │ "What to    │       │ "How good   │                 │
    │          │   do?"      │       │  is this?"  │                 │
    │          └──────┬──────┘       └──────┬──────┘                 │
    │                 │                     │                        │
    │                 ▼                     ▼                        │
    │            Action a_t          Value V(s_t)                    │
    │                                                                │
    │  ACTOR updates using policy gradient + critic's feedback       │
    │  CRITIC updates using TD learning                              │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the Actor-Critic architecture in detail

fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 12)
ax.axis('off')
ax.set_title('Actor-Critic Architecture: Two Networks, One Goal', fontsize=16, fontweight='bold')

# State input
state_box = FancyBboxPatch((5.5, 10), 3, 1, boxstyle="round,pad=0.1",
                            facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=3)
ax.add_patch(state_box)
ax.text(7, 10.5, 'State s', ha='center', va='center', fontsize=12, fontweight='bold')

# Shared features (optional in some implementations)
shared_box = FancyBboxPatch((5.5, 8), 3, 1, boxstyle="round,pad=0.1",
                             facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax.add_patch(shared_box)
ax.text(7, 8.5, 'Shared Features\n(optional)', ha='center', va='center', fontsize=10)

# Actor
actor_box = FancyBboxPatch((1.5, 5), 4, 2, boxstyle="round,pad=0.1",
                            facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(actor_box)
ax.text(3.5, 6.3, 'ACTOR', ha='center', fontsize=12, fontweight='bold', color='#388e3c')
ax.text(3.5, 5.7, 'Policy π(a|s; θ)', ha='center', fontsize=10)
ax.text(3.5, 5.2, '"What action to take?"', ha='center', fontsize=9, style='italic')

# Critic
critic_box = FancyBboxPatch((8.5, 5), 4, 2, boxstyle="round,pad=0.1",
                             facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=3)
ax.add_patch(critic_box)
ax.text(10.5, 6.3, 'CRITIC', ha='center', fontsize=12, fontweight='bold', color='#7b1fa2')
ax.text(10.5, 5.7, 'Value V(s; w)', ha='center', fontsize=10)
ax.text(10.5, 5.2, '"How good is this state?"', ha='center', fontsize=9, style='italic')

# Arrows from shared to actor/critic
ax.annotate('', xy=(5.5, 8), xytext=(7, 8),
             arrowprops=dict(arrowstyle='-', lw=2, color='#666'))
ax.annotate('', xy=(3.5, 7.1), xytext=(5.5, 8),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(10.5, 7.1), xytext=(8.5, 8),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(7, 9.9), xytext=(7, 9.1),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Actor output
action_box = FancyBboxPatch((2, 2.5), 3, 1, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax.add_patch(action_box)
ax.text(3.5, 3, 'Action a', ha='center', fontsize=10, fontweight='bold')
ax.annotate('', xy=(3.5, 3.6), xytext=(3.5, 4.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#388e3c'))

# Critic output
value_box = FancyBboxPatch((9, 2.5), 3, 1, boxstyle="round,pad=0.1",
                            facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=2)
ax.add_patch(value_box)
ax.text(10.5, 3, 'V(s)', ha='center', fontsize=10, fontweight='bold')
ax.annotate('', xy=(10.5, 3.6), xytext=(10.5, 4.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#7b1fa2'))

# TD Error computation
td_box = FancyBboxPatch((5.5, 0.5), 3, 1.5, boxstyle="round,pad=0.1",
                         facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=3)
ax.add_patch(td_box)
ax.text(7, 1.5, 'TD Error δ', ha='center', fontsize=11, fontweight='bold', color='#d32f2f')
ax.text(7, 1, 'δ = r + γV(s\') - V(s)', ha='center', fontsize=9)

# Arrow from V to TD
ax.annotate('', xy=(8.4, 1.25), xytext=(8.9, 2.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#7b1fa2'))

# Feedback arrows
ax.annotate('', xy=(5.4, 1.5), xytext=(3.5, 2.4),
             arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f',
                            connectionstyle='arc3,rad=0.2'))
ax.text(3.5, 1.3, 'Updates Actor', ha='center', fontsize=9, color='#d32f2f')

ax.annotate('', xy=(10.5, 2.4), xytext=(8.6, 1.5),
             arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f',
                            connectionstyle='arc3,rad=-0.2'))
ax.text(10.5, 1.3, 'Updates Critic', ha='center', fontsize=9, color='#d32f2f')

plt.tight_layout()
plt.show()

print("\nACTOR-CRITIC SUMMARY:")
print("  Actor: Learns WHAT to do (policy)")
print("  Critic: Learns HOW GOOD things are (value function)")
print("  TD Error: Tells actor if action was better/worse than expected")

---
## The TD Error: The Critic's Feedback

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE TD ERROR (Temporal Difference)                │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  TD ERROR FORMULA:                                            │
    │    δ = r + γV(s') - V(s)                                      │
    │        └─────────┘   └──┘                                     │
    │         TD target   Current                                   │
    │         (what we    estimate                                  │
    │          got)       (what we                                  │
    │                      expected)                                │
    │                                                                │
    │  INTERPRETATION:                                              │
    │    δ > 0: "Better than expected!"  → Reinforce this action   │
    │    δ < 0: "Worse than expected!"   → Discourage this action  │
    │    δ ≈ 0: "About as expected"      → Keep doing this         │
    │                                                                │
    │  WHY TD ERROR?                                                │
    │    • Like advantage A(s,a), but computed INSTANTLY            │
    │    • No need to wait for episode to end                       │
    │    • Lower variance (bootstraps from learned V)               │
    │                                                                │
    │  USED FOR:                                                    │
    │    • Actor: -log π(a|s) × δ                                  │
    │    • Critic: (δ)² = (r + γV(s') - V(s))²                     │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate TD error computation

print("TD ERROR DEMONSTRATION")
print("="*60)

# Simulated scenario
gamma = 0.99

print("\nScenario: Agent in a game...\n")

examples = [
    # (V(s), r, V(s'), description)
    (10.0, 5.0, 8.0, "Got reward AND reached good state"),
    (10.0, 1.0, 5.0, "Got small reward, reached worse state"),
    (5.0,  0.0, 15.0, "No reward but reached great state"),
    (10.0, -2.0, 8.0, "Got penalty, state about same"),
]

for i, (V_s, r, V_next, desc) in enumerate(examples, 1):
    td_target = r + gamma * V_next
    td_error = td_target - V_s
    
    interpretation = "✓ Better than expected!" if td_error > 0.5 else \
                     "✗ Worse than expected!" if td_error < -0.5 else \
                     "≈ About as expected"
    
    print(f"Example {i}: {desc}")
    print(f"  V(s) = {V_s:.1f}, r = {r:.1f}, V(s') = {V_next:.1f}")
    print(f"  TD target = r + γV(s') = {r:.1f} + {gamma}×{V_next:.1f} = {td_target:.2f}")
    print(f"  TD error δ = {td_target:.2f} - {V_s:.1f} = {td_error:+.2f}")
    print(f"  → {interpretation}\n")

print("="*60)
print("TD error tells the actor: 'That action was better/worse than I expected!'")

In [None]:
# Visualize TD error interpretation

fig, ax = plt.subplots(figsize=(12, 6))

# TD error values
td_errors = np.linspace(-5, 5, 100)

# Color gradient based on TD error
colors = plt.cm.RdYlGn((td_errors + 5) / 10)  # Red to Green

for i in range(len(td_errors) - 1):
    ax.axvspan(td_errors[i], td_errors[i+1], 
               facecolor=colors[i], alpha=0.8)

ax.axvline(x=0, color='black', linewidth=3, linestyle='--')

# Annotations
ax.text(-2.5, 0.7, 'WORSE than expected\n"Discourage this action"', 
        ha='center', fontsize=12, fontweight='bold', color='white')
ax.text(2.5, 0.7, 'BETTER than expected\n"Reinforce this action"', 
        ha='center', fontsize=12, fontweight='bold', color='white')
ax.text(0, 0.3, 'As expected', ha='center', fontsize=11, fontweight='bold')

ax.set_xlim(-5, 5)
ax.set_ylim(0, 1)
ax.set_xlabel('TD Error δ = r + γV(s\') - V(s)', fontsize=12)
ax.set_title('TD Error: The Critic\'s Feedback to the Actor', fontsize=14, fontweight='bold')
ax.set_yticks([])

plt.tight_layout()
plt.show()

---
## Comparing Update Rules

```
    ┌────────────────────────────────────────────────────────────────┐
    │              REINFORCE vs ACTOR-CRITIC UPDATES                 │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  REINFORCE (with baseline):                                   │
    │    ∇θ J ≈ Σ_t ∇θ log π(a_t|s_t) × (G_t - V(s_t))             │
    │                                    └────────────┘              │
    │                                    Monte Carlo return          │
    │                                    (wait for episode end)      │
    │                                                                │
    │  ACTOR-CRITIC:                                                │
    │    ∇θ J ≈ ∇θ log π(a_t|s_t) × δ_t                            │
    │                                └──┘                            │
    │                                TD error                        │
    │                                (instant feedback!)             │
    │                                                                │
    │  KEY DIFFERENCE:                                              │
    │    REINFORCE: Uses G_t (actual return) - UNBIASED but HIGH VAR│
    │    Actor-Critic: Uses δ_t (TD error) - SOME BIAS but LOW VAR │
    │                                                                │
    │  ONLINE VS EPISODIC:                                          │
    │    REINFORCE: Episodic (update after each episode)            │
    │    Actor-Critic: Online (update after EACH step!)             │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the bias-variance tradeoff

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Bias-Variance Spectrum
ax1 = axes[0]

methods = ['Monte Carlo\n(REINFORCE)', 'N-step TD\n(n=10)', 'N-step TD\n(n=3)', 'TD(0)\n(Actor-Critic)']
bias = [0, 1, 2, 3]
variance = [5, 3, 2, 1]

x = np.arange(len(methods))
width = 0.35

bars1 = ax1.bar(x - width/2, bias, width, label='Bias', color='#ef5350', edgecolor='black')
bars2 = ax1.bar(x + width/2, variance, width, label='Variance', color='#42a5f5', edgecolor='black')

ax1.set_ylabel('Relative Amount', fontsize=11)
ax1.set_xticks(x)
ax1.set_xticklabels(methods)
ax1.set_title('Bias-Variance Tradeoff', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Right: Update frequency
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Update Frequency', fontsize=14, fontweight='bold')

# REINFORCE: episodic
ax2.text(1, 8.5, 'REINFORCE (Episodic):', fontsize=11, fontweight='bold')
for i in range(5):
    color = '#bbdefb' if i < 4 else '#4caf50'
    box = FancyBboxPatch((1 + i*1.5, 7), 1, 1, boxstyle="round,pad=0.05",
                          facecolor=color, edgecolor='#1976d2', linewidth=1)
    ax2.add_patch(box)
    ax2.text(1.5 + i*1.5, 7.5, f's{i+1}', ha='center', fontsize=9)
ax2.annotate('Update!', xy=(8.5, 7.5), xytext=(9.5, 7.5),
             arrowprops=dict(arrowstyle='->', lw=2, color='#4caf50'),
             fontsize=10, fontweight='bold', color='#4caf50')

# Actor-Critic: online
ax2.text(1, 4.5, 'Actor-Critic (Online):', fontsize=11, fontweight='bold')
for i in range(5):
    box = FancyBboxPatch((1 + i*1.5, 3), 1, 1, boxstyle="round,pad=0.05",
                          facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=1)
    ax2.add_patch(box)
    ax2.text(1.5 + i*1.5, 3.5, f's{i+1}', ha='center', fontsize=9)
    # Update arrow after each
    ax2.annotate('', xy=(1.5 + i*1.5, 2.9), xytext=(1.5 + i*1.5, 2.3),
                 arrowprops=dict(arrowstyle='->', lw=1, color='#4caf50'))
    ax2.text(1.5 + i*1.5, 2, '↓', ha='center', fontsize=10, color='#4caf50')

ax2.text(5, 1, '5 updates vs 1 update for same episode!', 
         ha='center', fontsize=11, fontweight='bold', color='#388e3c')

plt.tight_layout()
plt.show()

---
## Implementing Actor-Critic from Scratch

In [None]:
class ActorCritic(nn.Module):
    """
    Actor-Critic Network with shared features.
    
    Architecture:
        State → Shared Layers → Actor Head (policy)
                             → Critic Head (value)
    
    Actor: "What action should I take?" → π(a|s)
    Critic: "How good is this state?" → V(s)
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        
        # ========================================
        # SHARED FEATURE EXTRACTION
        # Both actor and critic see the same features
        # ========================================
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # ========================================
        # ACTOR HEAD: Policy π(a|s)
        # Outputs probability distribution over actions
        # ========================================
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # ========================================
        # CRITIC HEAD: Value V(s)
        # Outputs single scalar value
        # ========================================
        self.critic = nn.Linear(hidden_dim, 1)
    
    def forward(self, state):
        """
        Forward pass returns both policy and value.
        
        Returns:
            action_probs: π(a|s) - probability for each action
            value: V(s) - estimated state value
        """
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        
        features = self.shared(state)
        action_probs = self.actor(features)
        value = self.critic(features)
        
        return action_probs, value
    
    def get_action(self, state):
        """
        Sample action and return action, log_prob, and value.
        
        This is used during environment interaction.
        """
        action_probs, value = self.forward(state)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        return action.item(), log_prob, value


# Demonstrate the network
print("ACTOR-CRITIC NETWORK")
print("="*60)

state_dim = 4  # CartPole
action_dim = 2  # Left, Right
ac = ActorCritic(state_dim, action_dim)

print(ac)
print(f"\nTotal parameters: {sum(p.numel() for p in ac.parameters()):,}")

# Test forward pass
test_state = torch.randn(1, state_dim)
probs, value = ac(test_state)
action, log_prob, value = ac.get_action(test_state)

print(f"\nTest state: {test_state.numpy()[0].round(2)}")
print(f"Action probabilities: {probs.detach().numpy()[0].round(3)}")
print(f"State value V(s): {value.item():.3f}")
print(f"Sampled action: {action}")
print(f"Log probability: {log_prob.item():.3f}")
print("="*60)

In [None]:
def train_actor_critic(env_name='CartPole-v1', n_episodes=500, gamma=0.99, 
                       lr=1e-3, print_every=50):
    """
    Train Actor-Critic with ONLINE updates (update after each step!).
    
    Key differences from REINFORCE:
    1. Update after EACH step (not after episode)
    2. Use TD error instead of full return
    3. Lower variance, faster learning
    
    Args:
        env_name: Gymnasium environment
        n_episodes: Number of episodes to train
        gamma: Discount factor
        lr: Learning rate
        print_every: Print progress every N episodes
    
    Returns:
        model: Trained actor-critic model
        rewards_history: List of episode rewards
    """
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    # Create Actor-Critic network
    model = ActorCritic(state_dim, action_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    rewards_history = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        
        for _ in range(500):  # Max steps
            # ========================================
            # STEP 1: Get action from actor
            # ========================================
            state_tensor = torch.FloatTensor(state)
            action_probs, value = model(state_tensor)
            dist = torch.distributions.Categorical(action_probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            
            # ========================================
            # STEP 2: Take action in environment
            # ========================================
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            total_reward += reward
            
            # ========================================
            # STEP 3: Compute TD error (critic's feedback)
            # δ = r + γV(s') - V(s)
            # ========================================
            next_state_tensor = torch.FloatTensor(next_state)
            _, next_value = model(next_state_tensor)
            
            # TD target: r + γV(s')  (0 if terminal)
            td_target = reward + gamma * next_value * (1 - done)
            
            # TD error: δ = TD_target - V(s)
            td_error = td_target - value
            
            # ========================================
            # STEP 4: Compute losses
            # ========================================
            # Actor loss: -log π(a|s) × δ
            # We use td_error as advantage estimate
            actor_loss = -log_prob * td_error.detach()  # Don't backprop through δ for actor
            
            # Critic loss: (δ)² = (r + γV(s') - V(s))²
            critic_loss = td_error.pow(2)
            
            # Total loss (weight critic loss to balance)
            loss = actor_loss + 0.5 * critic_loss
            
            # ========================================
            # STEP 5: Update IMMEDIATELY (online!)
            # ========================================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Move to next state
            state = next_state
            if done:
                break
        
        # Track progress
        rewards_history.append(total_reward)
        
        if (episode + 1) % print_every == 0:
            avg_reward = np.mean(rewards_history[-print_every:])
            print(f"Episode {episode+1:4d} | Avg Reward: {avg_reward:6.1f}")
    
    env.close()
    return model, rewards_history

In [None]:
# Train Actor-Critic!
print("TRAINING ACTOR-CRITIC ON CARTPOLE")
print("="*60)
print("\nThis uses ONLINE updates (after each step!)\n")

model, rewards_history = train_actor_critic(
    env_name='CartPole-v1',
    n_episodes=500,
    gamma=0.99,
    lr=1e-3,
    print_every=100
)

print("\n" + "="*60)
print(f"Final average (last 50): {np.mean(rewards_history[-50:]):.1f}")
print("="*60)

In [None]:
# Visualize training
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Learning curve
ax1 = axes[0]
ax1.plot(rewards_history, alpha=0.3, color='blue', label='Episode Reward')

# Smoothed curve
window = 50
smoothed = np.convolve(rewards_history, np.ones(window)/window, mode='valid')
ax1.plot(range(window-1, len(rewards_history)), smoothed, 
         color='red', linewidth=2, label=f'{window}-Episode Average')

ax1.axhline(y=500, color='green', linestyle='--', linewidth=2, label='Max Score')
ax1.set_xlabel('Episode', fontsize=11)
ax1.set_ylabel('Reward', fontsize=11)
ax1.set_title('Actor-Critic Training on CartPole', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Reward distribution over time
ax2 = axes[1]
quarters = np.array_split(rewards_history, 4)
labels = ['0-25%', '25-50%', '50-75%', '75-100%']
positions = [1, 2, 3, 4]

bp = ax2.boxplot(quarters, positions=positions, patch_artist=True)
colors = ['#ffcdd2', '#fff3e0', '#c8e6c9', '#81c784']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

ax2.set_xticklabels(labels)
ax2.set_xlabel('Training Progress', fontsize=11)
ax2.set_ylabel('Episode Reward', fontsize=11)
ax2.set_title('Reward Distribution Over Training', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---
## Actor-Critic vs REINFORCE: Side by Side

```
    ┌────────────────────────────────────────────────────────────────┐
    │              REINFORCE vs ACTOR-CRITIC                         │
    ├──────────────────────┬─────────────────────────────────────────┤
    │    REINFORCE         │    ACTOR-CRITIC                         │
    ├──────────────────────┼─────────────────────────────────────────┤
    │ Monte Carlo          │ Temporal Difference (TD)                │
    │                      │                                         │
    │ Uses: G_t (return)   │ Uses: δ_t (TD error)                   │
    │                      │                                         │
    │ Update: End of       │ Update: After EACH step                │
    │         episode      │         (online!)                       │
    │                      │                                         │
    │ Unbiased            │ Some bias (bootstrapping)               │
    │                      │                                         │
    │ HIGH variance       │ LOWER variance                          │
    │                      │                                         │
    │ Slower learning     │ Faster learning                         │
    │                      │                                         │
    │ Needs complete      │ Works with continuing                   │
    │ episodes            │ tasks                                   │
    └──────────────────────┴─────────────────────────────────────────┘
```

In [None]:
# Compare Actor-Critic with REINFORCE

def reinforce_for_comparison(env_name='CartPole-v1', n_episodes=500, gamma=0.99, lr=1e-3):
    """REINFORCE with baseline for fair comparison."""
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    model = ActorCritic(state_dim, action_dim)  # Same architecture
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    rewards_history = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        
        log_probs = []
        values = []
        rewards = []
        
        # Collect full episode
        for _ in range(500):
            state_tensor = torch.FloatTensor(state)
            probs, value = model(state_tensor)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            
            log_probs.append(dist.log_prob(action))
            values.append(value)
            
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            rewards.append(reward)
            
            state = next_state
            if terminated or truncated:
                break
        
        # Compute returns (Monte Carlo)
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.FloatTensor(returns)
        values = torch.cat(values)
        
        # Compute advantages
        advantages = returns - values.detach()
        
        # Losses (update at end of episode)
        actor_loss = 0
        for log_prob, adv in zip(log_probs, advantages):
            actor_loss -= log_prob * adv
        critic_loss = nn.functional.mse_loss(values, returns)
        loss = actor_loss + 0.5 * critic_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        rewards_history.append(sum(rewards))
    
    env.close()
    return rewards_history

print("COMPARING REINFORCE vs ACTOR-CRITIC")
print("="*60)
print("\n1. Training Actor-Critic (online updates)...")
_, ac_rewards = train_actor_critic(n_episodes=300, print_every=100)

print("\n2. Training REINFORCE (episodic updates)...")
reinforce_rewards = reinforce_for_comparison(n_episodes=300)
print("Done!")

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

window = 30

# Left: Learning curves
ax1 = axes[0]
smoothed_ac = np.convolve(ac_rewards, np.ones(window)/window, mode='valid')
smoothed_rf = np.convolve(reinforce_rewards, np.ones(window)/window, mode='valid')

ax1.plot(range(window-1, len(ac_rewards)), smoothed_ac, 
         'g-', linewidth=2, label='Actor-Critic (online)')
ax1.plot(range(window-1, len(reinforce_rewards)), smoothed_rf, 
         'r-', linewidth=2, label='REINFORCE (episodic)')

ax1.set_xlabel('Episode', fontsize=11)
ax1.set_ylabel('Reward (smoothed)', fontsize=11)
ax1.set_title('Learning Speed Comparison', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Right: Rolling variance
ax2 = axes[1]
rolling_std_ac = [np.std(ac_rewards[max(0,i-50):i+1]) for i in range(len(ac_rewards))]
rolling_std_rf = [np.std(reinforce_rewards[max(0,i-50):i+1]) for i in range(len(reinforce_rewards))]

ax2.plot(np.convolve(rolling_std_ac, np.ones(30)/30, mode='valid'), 
         'g-', linewidth=2, label='Actor-Critic')
ax2.plot(np.convolve(rolling_std_rf, np.ones(30)/30, mode='valid'), 
         'r-', linewidth=2, label='REINFORCE')

ax2.set_xlabel('Episode', fontsize=11)
ax2.set_ylabel('Rolling Std (50 episodes)', fontsize=11)
ax2.set_title('Variance Comparison', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nFINAL STATISTICS:")
print(f"  Actor-Critic - Final avg: {np.mean(ac_rewards[-50:]):.1f}")
print(f"  REINFORCE - Final avg: {np.mean(reinforce_rewards[-50:]):.1f}")

---
## Summary: Key Takeaways

### Actor-Critic Architecture

| Component | Role | Formula |
|-----------|------|--------|
| **Actor** | Policy | π(a|s; θ) |
| **Critic** | Value function | V(s; w) |

### TD Error (Critic's Feedback)

```
δ = r + γV(s') - V(s)

δ > 0: Better than expected → Reinforce action
δ < 0: Worse than expected → Discourage action
```

### Update Rules

```
Actor:  θ ← θ + α × ∇θ log π(a|s) × δ
Critic: w ← w + β × δ × ∇w V(s)
```

### Advantages Over REINFORCE

| Aspect | REINFORCE | Actor-Critic |
|--------|-----------|-------------|
| Updates | Episodic | Online (each step) |
| Variance | High | Lower |
| Bias | Unbiased | Some bias |
| Speed | Slower | Faster |

---
## Test Your Understanding

**1. What is the actor and what is the critic?**
<details>
<summary>Click to reveal answer</summary>
The actor is the policy network π(a|s) that decides which action to take. The critic is the value network V(s) that estimates how good a state is. Together, they form a complete learning system: the critic tells the actor whether its actions are better or worse than expected.
</details>

**2. What is the TD error and why is it useful?**
<details>
<summary>Click to reveal answer</summary>
TD error δ = r + γV(s') - V(s) measures the difference between what we got (reward + estimated future value) and what we expected (current value estimate). It tells us if an action was better or worse than expected, providing immediate feedback for the actor without waiting for the episode to end.
</details>

**3. Why does actor-critic have lower variance than REINFORCE?**
<details>
<summary>Click to reveal answer</summary>
REINFORCE uses the full return G_t, which varies wildly because it includes all future rewards with their randomness. Actor-critic uses the TD error δ, which only looks one step ahead and uses the critic's estimate V(s') for future rewards. By "bootstrapping" from V, we get a more stable (lower variance) learning signal.
</details>

**4. What's the tradeoff between REINFORCE and actor-critic?**
<details>
<summary>Click to reveal answer</summary>
REINFORCE has no bias (uses true returns) but high variance (slow, noisy learning). Actor-critic has some bias (bootstrapping from learned V is imperfect) but lower variance (faster, more stable learning). In practice, the variance reduction usually outweighs the small bias, making actor-critic preferred.
</details>

**5. Why are online updates beneficial?**
<details>
<summary>Click to reveal answer</summary>
Online updates (after each step) mean:
1. More updates per episode → faster learning
2. Credit assignment happens immediately → clearer signal
3. Works with continuing tasks (no need for episode boundaries)
4. Lower memory requirements (don't store full episodes)
</details>

---
## What's Next?

Actor-critic is the foundation for modern RL algorithms. In the next notebook, we'll learn about **A2C and A3C** which scale actor-critic to multiple parallel environments!

**Continue to:** [Notebook 5: A2C and A3C](05_a2c_a3c.ipynb)

---

*Actor-Critic: The critic coaches the actor, one step at a time!*