# A2C and A3C: Scaling Actor-Critic

A2C (Advantage Actor-Critic) and A3C (Asynchronous A3C) scale actor-critic to multiple parallel environments for much faster training!

## What You'll Learn

By the end of this notebook, you'll understand:
- The parallel training analogy: why more environments help
- A2C: synchronous parallel training
- A3C: asynchronous parallel training (historical)
- Generalized Advantage Estimation (GAE)
- Implementing A2C from scratch
- Using Stable-Baselines3 for production A2C

**Prerequisites:** Notebook 4 (Actor-Critic)

**Time:** ~30 minutes

---
## The Big Picture: The Parallel Training Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE PARALLEL TRAINING ANALOGY                         │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine training for a marathon...                           │
    │                                                                │
    │  SINGLE ENVIRONMENT (Basic Actor-Critic):                     │
    │    You run one route, analyze, adjust, repeat.                │
    │    Slow! You only see one type of terrain at a time.          │
    │                                                                │
    │  PARALLEL ENVIRONMENTS (A2C):                                 │
    │    4 clones of you run 4 different routes simultaneously!     │
    │    Every 10 minutes, all report back:                         │
    │      Clone 1: "Hills were tough"                             │
    │      Clone 2: "Flat terrain was easy"                        │
    │      Clone 3: "Rain made it slippery"                        │
    │      Clone 4: "Crowds helped motivation"                     │
    │    You learn from ALL experiences at once!                    │
    │                                                                │
    │  WHY THIS HELPS:                                              │
    │    • 4x more experience per update → faster learning          │
    │    • Diverse experiences → more stable gradients              │
    │    • Better exploration → see more situations                 │
    │                                                                │
    │  A2C: All clones finish, then update (SYNCHRONOUS)           │
    │  A3C: Clones update as soon as they finish (ASYNCHRONOUS)    │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, Rectangle, FancyArrowPatch

try:
    import gymnasium as gym
except ImportError:
    import gym

# Visualize parallel environments
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Single environment
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Single Environment\n(Basic Actor-Critic)', fontsize=14, fontweight='bold', color='#d32f2f')

# Timeline
for i in range(4):
    x = 1 + i * 2
    box = FancyBboxPatch((x, 5), 1.5, 1.5, boxstyle="round,pad=0.1",
                          facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
    ax1.add_patch(box)
    ax1.text(x + 0.75, 5.75, f'Exp {i+1}', ha='center', fontsize=9)
    if i < 3:
        ax1.annotate('', xy=(x + 1.6, 5.75), xytext=(x + 1.9, 5.75),
                     arrowprops=dict(arrowstyle='->', lw=1, color='#666'))

ax1.text(5, 3.5, 'Sequential: One experience at a time', ha='center', fontsize=11, color='#d32f2f')
ax1.text(5, 2.5, 'Time to collect 4 experiences: ████████', ha='center', fontsize=10)

# Right: Parallel environments (A2C)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Parallel Environments\n(A2C)', fontsize=14, fontweight='bold', color='#388e3c')

# 4 parallel environments
for i in range(4):
    y = 7.5 - i * 1.5
    box = FancyBboxPatch((3, y), 1.5, 1, boxstyle="round,pad=0.1",
                          facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
    ax2.add_patch(box)
    ax2.text(3.75, y + 0.5, f'Env {i+1}', ha='center', fontsize=9)
    ax2.text(2.5, y + 0.5, f'→', ha='center', fontsize=12)

# Arrows converging to update
update_box = FancyBboxPatch((6, 4), 2.5, 2, boxstyle="round,pad=0.1",
                             facecolor='#fff3e0', edgecolor='#f57c00', linewidth=3)
ax2.add_patch(update_box)
ax2.text(7.25, 5.3, 'Batch', ha='center', fontsize=10, fontweight='bold')
ax2.text(7.25, 4.7, 'Update', ha='center', fontsize=10)

for i in range(4):
    y = 7.5 - i * 1.5 + 0.5
    ax2.annotate('', xy=(5.9, 5), xytext=(4.6, y),
                 arrowprops=dict(arrowstyle='->', lw=1, color='#388e3c'))

ax2.text(5, 2, 'Parallel: All experiences simultaneously!', ha='center', fontsize=11, color='#388e3c')
ax2.text(5, 1, 'Time to collect 4 experiences: ██', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

print("\nPARALLEL ENVIRONMENTS BENEFITS:")
print("  1. 4x faster data collection (4 envs = 4x experience)")
print("  2. More diverse gradients (different situations)")
print("  3. Lower variance (batch of experiences averages out noise)")

---
## A2C vs A3C: Synchronous vs Asynchronous

```
    ┌────────────────────────────────────────────────────────────────┐
    │              A2C vs A3C COMPARISON                             │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  A3C (Asynchronous - 2016):                                   │
    │    ┌────┐ ┌────┐ ┌────┐ ┌────┐                               │
    │    │Env1│ │Env2│ │Env3│ │Env4│  Each has own copy of policy  │
    │    └──┬─┘ └─┬──┘ └──┬─┘ └─┬──┘                               │
    │       ↓     ↓      ↓     ↓     Update whenever ready         │
    │    ┌────────────────────────┐                                │
    │    │   Shared Parameters    │  But may use stale params!     │
    │    └────────────────────────┘                                │
    │                                                                │
    │  A2C (Synchronous - simpler):                                 │
    │    ┌────┐ ┌────┐ ┌────┐ ┌────┐                               │
    │    │Env1│ │Env2│ │Env3│ │Env4│  All use SAME policy         │
    │    └──┬─┘ └─┬──┘ └──┬─┘ └─┬──┘                               │
    │       └─────┴───┬───┴─────┘     Wait for all, then update    │
    │                 ↓                                            │
    │    ┌────────────────────────┐                                │
    │    │   Batch Update Once    │  All params fresh!             │
    │    └────────────────────────┘                                │
    │                                                                │
    │  MODERN PREFERENCE: A2C (simpler, just as effective on GPU)  │
    │    A3C was designed for CPU parallelism                      │
    │    A2C works better with GPU batch processing                │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize A2C vs A3C architecture

fig, axes = plt.subplots(1, 2, figsize=(14, 7))

# Left: A3C (Asynchronous)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('A3C: Asynchronous\n(Historical, CPU-focused)', fontsize=14, fontweight='bold', color='#7b1fa2')

# Global parameters
global_box = FancyBboxPatch((3.5, 7.5), 3, 1.5, boxstyle="round,pad=0.1",
                             facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=3)
ax1.add_patch(global_box)
ax1.text(5, 8.25, 'Global Parameters', ha='center', fontsize=11, fontweight='bold')

# Workers
colors = ['#ffcdd2', '#c8e6c9', '#bbdefb', '#fff3e0']
for i in range(4):
    x = 1.5 + i * 2
    worker_box = FancyBboxPatch((x, 4), 1.5, 2, boxstyle="round,pad=0.1",
                                 facecolor=colors[i], edgecolor='#333', linewidth=2)
    ax1.add_patch(worker_box)
    ax1.text(x + 0.75, 5.3, f'Worker {i+1}', ha='center', fontsize=9, fontweight='bold')
    ax1.text(x + 0.75, 4.7, 'Local copy', ha='center', fontsize=8)
    ax1.text(x + 0.75, 4.3, '+ Env', ha='center', fontsize=8)
    
    # Arrows (different heights to show async)
    offset = [0, 0.3, -0.2, 0.1][i]
    ax1.annotate('', xy=(x + 0.75, 6.1 + offset), xytext=(x + 0.75, 7.4),
                 arrowprops=dict(arrowstyle='<->', lw=1, color='#7b1fa2'))

ax1.text(5, 2.5, 'Updates happen independently!', ha='center', fontsize=10, color='#7b1fa2')
ax1.text(5, 1.8, 'Problem: May use stale parameters', ha='center', fontsize=9, color='#d32f2f')

# Right: A2C (Synchronous)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('A2C: Synchronous\n(Modern, GPU-friendly)', fontsize=14, fontweight='bold', color='#388e3c')

# Single policy
policy_box = FancyBboxPatch((3.5, 7.5), 3, 1.5, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax2.add_patch(policy_box)
ax2.text(5, 8.25, 'Single Policy', ha='center', fontsize=11, fontweight='bold')

# Vectorized environments
env_box = FancyBboxPatch((2, 4), 6, 2, boxstyle="round,pad=0.1",
                          facecolor='#fff3e0', edgecolor='#f57c00', linewidth=3)
ax2.add_patch(env_box)
ax2.text(5, 5.5, 'Vectorized Environments', ha='center', fontsize=10, fontweight='bold')

for i in range(4):
    x = 2.5 + i * 1.3
    mini_box = FancyBboxPatch((x, 4.3), 1, 0.8, boxstyle="round,pad=0.05",
                               facecolor='#ffe0b2', edgecolor='#f57c00', linewidth=1)
    ax2.add_patch(mini_box)
    ax2.text(x + 0.5, 4.7, f'E{i+1}', ha='center', fontsize=9)

# Single synchronized arrow
ax2.annotate('', xy=(5, 6.1), xytext=(5, 7.4),
             arrowprops=dict(arrowstyle='<->', lw=3, color='#388e3c'))

ax2.text(5, 2.5, 'All envs step together, then batch update', ha='center', fontsize=10, color='#388e3c')
ax2.text(5, 1.8, 'Simpler, better GPU utilization!', ha='center', fontsize=9, color='#388e3c')

plt.tight_layout()
plt.show()

print("\nWHY A2C IS PREFERRED TODAY:")
print("  • Simpler implementation (no async complications)")
print("  • Better GPU utilization (batch processing)")
print("  • Equally effective in practice")
print("  • A3C's async was designed for CPU parallelism")

---
## Generalized Advantage Estimation (GAE)

```
    ┌────────────────────────────────────────────────────────────────┐
    │              GENERALIZED ADVANTAGE ESTIMATION                  │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  RECALL THE BIAS-VARIANCE TRADEOFF:                           │
    │    TD(0): A_t = r_t + γV(s_{t+1}) - V(s_t)                    │
    │           Low variance, but biased (1-step lookahead)         │
    │                                                                │
    │    MC:    A_t = G_t - V(s_t)                                  │
    │           Unbiased, but high variance (full return)           │
    │                                                                │
    │  GAE: GET THE BEST OF BOTH!                                   │
    │    A_t^GAE = Σ (γλ)^l × δ_{t+l}                               │
    │                                                                │
    │    where δ_t = r_t + γV(s_{t+1}) - V(s_t) (TD error)         │
    │                                                                │
    │  THE λ PARAMETER:                                             │
    │    λ = 0: Pure TD(0) - low variance, more bias                │
    │    λ = 1: Pure MC - high variance, no bias                    │
    │    λ = 0.95: Good balance (commonly used)                     │
    │                                                                │
    │  INTUITION:                                                   │
    │    GAE is like a weighted average of n-step returns:          │
    │    "Look ahead many steps, but trust closer steps more"       │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    """
    Compute Generalized Advantage Estimation (GAE).
    
    GAE provides a balance between bias and variance:
    A_t^GAE = Σ_{l=0}^{∞} (γλ)^l × δ_{t+l}
    
    where δ_t = r_t + γV(s_{t+1}) - V(s_t)
    
    Args:
        rewards: List/array of rewards
        values: List/array of value estimates V(s)
        dones: List/array of done flags
        gamma: Discount factor
        lam: GAE lambda (0 = TD(0), 1 = MC)
    
    Returns:
        advantages: GAE advantage estimates
        returns: Estimated returns (advantages + values)
    """
    n_steps = len(rewards)
    advantages = np.zeros(n_steps)
    
    # Compute GAE backwards
    gae = 0
    for t in reversed(range(n_steps)):
        # For the last step, next_value is 0
        if t == n_steps - 1:
            next_value = 0
        else:
            next_value = values[t + 1]
        
        # TD error: δ_t = r_t + γV(s_{t+1}) - V(s_t)
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        
        # GAE: A_t = δ_t + γλ × A_{t+1}
        gae = delta + gamma * lam * (1 - dones[t]) * gae
        advantages[t] = gae
    
    # Returns = advantages + values
    returns = advantages + np.array(values)
    
    return advantages, returns


# Demonstrate GAE
print("GAE DEMONSTRATION")
print("="*60)

# Example episode
rewards = [1, 1, 1, 1, 10]  # Big reward at end
values = [5, 5.5, 6, 6.5, 7]  # V(s) estimates
dones = [0, 0, 0, 0, 1]  # Episode ends at step 5

print(f"\nRewards: {rewards}")
print(f"Values:  {values}")

print("\nGAE with different λ values:")
for lam in [0.0, 0.5, 0.95, 1.0]:
    adv, ret = compute_gae(rewards, values, dones, gamma=0.99, lam=lam)
    print(f"  λ={lam:.2f}: advantages = {adv.round(2)}")

print("\n" + "="*60)
print("λ=0: Only looks 1 step ahead (low variance, more bias)")
print("λ=1: Uses full return (high variance, no bias)")
print("λ=0.95: Good balance (commonly used in PPO)")

In [None]:
# Visualize the effect of lambda

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Lambda spectrum
ax1 = axes[0]
lambdas = [0.0, 0.5, 0.9, 0.95, 1.0]
labels = ['TD(0)', 'Mixed', 'GAE', 'Standard', 'MC']
bias = [3, 2, 1, 0.5, 0]
variance = [0, 1, 2, 2.5, 4]

x = np.arange(len(lambdas))
width = 0.35

bars1 = ax1.bar(x - width/2, bias, width, label='Bias', color='#ef5350', edgecolor='black')
bars2 = ax1.bar(x + width/2, variance, width, label='Variance', color='#42a5f5', edgecolor='black')

ax1.set_ylabel('Relative Amount', fontsize=11)
ax1.set_xticks(x)
ax1.set_xticklabels([f'λ={l}\n({lab})' for l, lab in zip(lambdas, labels)])
ax1.set_title('GAE: Bias-Variance Tradeoff via λ', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Highlight optimal region
ax1.axvspan(2.5, 3.5, alpha=0.2, color='green')
ax1.text(3, 3.5, 'Sweet spot!', ha='center', fontsize=10, color='green', fontweight='bold')

# Right: Weighting visualization
ax2 = axes[1]
steps = np.arange(10)

gamma, lam = 0.99, 0.95
weights = [(gamma * lam) ** i for i in steps]

ax2.bar(steps, weights, color='#64b5f6', edgecolor='black')
ax2.set_xlabel('Steps into Future', fontsize=11)
ax2.set_ylabel('Weight (γλ)^t', fontsize=11)
ax2.set_title(f'GAE Weights (γ={gamma}, λ={lam})', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

ax2.text(5, 0.5, 'Closer steps get more weight!\nThis reduces variance.', 
         ha='center', fontsize=10, color='#1976d2')

plt.tight_layout()
plt.show()

---
## Implementing A2C from Scratch

In [None]:
class ActorCritic(nn.Module):
    """
    Actor-Critic network for A2C.
    Same architecture as before, but used with vectorized environments.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh()
        )
        
        self.actor = nn.Linear(hidden_dim, action_dim)
        self.critic = nn.Linear(hidden_dim, 1)
    
    def forward(self, state):
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state)
        
        features = self.shared(state)
        action_logits = self.actor(features)
        value = self.critic(features)
        
        return action_logits, value
    
    def get_action_and_value(self, state):
        """Sample action and return everything needed for training."""
        action_logits, value = self.forward(state)
        dist = torch.distributions.Categorical(logits=action_logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()
        
        return action, log_prob, entropy, value

In [None]:
class VectorizedEnv:
    """
    Simple vectorized environment wrapper.
    Runs multiple environments in parallel (synchronously).
    
    This is a simplified version - use SubprocVecEnv for real parallelism!
    """
    
    def __init__(self, env_name, n_envs):
        self.envs = [gym.make(env_name) for _ in range(n_envs)]
        self.n_envs = n_envs
        
        # Get environment info
        self.observation_space = self.envs[0].observation_space
        self.action_space = self.envs[0].action_space
    
    def reset(self):
        """Reset all environments and return stacked observations."""
        observations = []
        for env in self.envs:
            obs, _ = env.reset()
            observations.append(obs)
        return np.array(observations)
    
    def step(self, actions):
        """Step all environments with given actions."""
        observations = []
        rewards = []
        dones = []
        
        for env, action in zip(self.envs, actions):
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Auto-reset on done
            if done:
                obs, _ = env.reset()
            
            observations.append(obs)
            rewards.append(reward)
            dones.append(done)
        
        return np.array(observations), np.array(rewards), np.array(dones)
    
    def close(self):
        for env in self.envs:
            env.close()


# Demonstrate vectorized environments
print("VECTORIZED ENVIRONMENTS")
print("="*60)

n_envs = 4
vec_env = VectorizedEnv('CartPole-v1', n_envs)

# Reset all environments
observations = vec_env.reset()
print(f"\nNumber of parallel environments: {n_envs}")
print(f"Observation shape: {observations.shape}")
print(f"  → {n_envs} environments × {observations.shape[1]} state dims")

# Take actions in all environments
actions = np.random.randint(0, 2, size=n_envs)
next_obs, rewards, dones = vec_env.step(actions)

print(f"\nActions taken: {actions}")
print(f"Rewards: {rewards}")
print(f"Dones: {dones}")

vec_env.close()
print("\n" + "="*60)

In [None]:
def train_a2c(env_name='CartPole-v1', n_envs=4, n_steps=5, n_updates=1000,
              gamma=0.99, lam=0.95, lr=7e-4, ent_coef=0.01, vf_coef=0.5,
              print_every=100):
    """
    Train A2C (Advantage Actor-Critic) with vectorized environments.
    
    Key components:
    1. Vectorized environments: Multiple envs running in parallel
    2. N-step returns: Collect n_steps before updating
    3. GAE: Generalized Advantage Estimation for better advantages
    4. Entropy bonus: Encourage exploration
    
    Args:
        env_name: Environment name
        n_envs: Number of parallel environments
        n_steps: Steps to collect before each update
        n_updates: Number of update iterations
        gamma: Discount factor
        lam: GAE lambda
        lr: Learning rate
        ent_coef: Entropy coefficient (exploration bonus)
        vf_coef: Value function coefficient
        print_every: Print interval
    """
    # ========================================
    # SETUP
    # ========================================
    vec_env = VectorizedEnv(env_name, n_envs)
    state_dim = vec_env.observation_space.shape[0]
    action_dim = vec_env.action_space.n
    
    model = ActorCritic(state_dim, action_dim)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Initialize
    observations = vec_env.reset()
    rewards_history = []
    episode_rewards = np.zeros(n_envs)  # Track ongoing episode rewards
    
    # ========================================
    # TRAINING LOOP
    # ========================================
    for update in range(n_updates):
        # Storage for this rollout
        mb_obs = []
        mb_actions = []
        mb_log_probs = []
        mb_values = []
        mb_rewards = []
        mb_dones = []
        
        # ----------------------------------------
        # STEP 1: Collect n_steps of experience
        # ----------------------------------------
        for step in range(n_steps):
            with torch.no_grad():
                obs_tensor = torch.FloatTensor(observations)
                actions, log_probs, _, values = model.get_action_and_value(obs_tensor)
            
            # Store
            mb_obs.append(observations.copy())
            mb_actions.append(actions.numpy())
            mb_log_probs.append(log_probs.numpy())
            mb_values.append(values.squeeze().numpy())
            
            # Step environments
            observations, rewards, dones = vec_env.step(actions.numpy())
            
            mb_rewards.append(rewards)
            mb_dones.append(dones)
            
            # Track episode rewards
            episode_rewards += rewards
            for i, done in enumerate(dones):
                if done:
                    rewards_history.append(episode_rewards[i])
                    episode_rewards[i] = 0
        
        # Convert to arrays: shape (n_steps, n_envs)
        mb_obs = np.array(mb_obs)
        mb_actions = np.array(mb_actions)
        mb_log_probs = np.array(mb_log_probs)
        mb_values = np.array(mb_values)
        mb_rewards = np.array(mb_rewards)
        mb_dones = np.array(mb_dones)
        
        # ----------------------------------------
        # STEP 2: Compute GAE advantages
        # ----------------------------------------
        # Get value of last state for bootstrapping
        with torch.no_grad():
            _, last_values = model(torch.FloatTensor(observations))
            last_values = last_values.squeeze().numpy()
        
        # Compute advantages for each environment
        mb_advantages = np.zeros_like(mb_rewards)
        mb_returns = np.zeros_like(mb_rewards)
        
        for env_idx in range(n_envs):
            # Append last value for GAE computation
            values_with_last = np.append(mb_values[:, env_idx], last_values[env_idx])
            
            gae = 0
            for t in reversed(range(n_steps)):
                delta = (mb_rewards[t, env_idx] + 
                         gamma * values_with_last[t+1] * (1 - mb_dones[t, env_idx]) -
                         values_with_last[t])
                gae = delta + gamma * lam * (1 - mb_dones[t, env_idx]) * gae
                mb_advantages[t, env_idx] = gae
            
            mb_returns[:, env_idx] = mb_advantages[:, env_idx] + mb_values[:, env_idx]
        
        # ----------------------------------------
        # STEP 3: Flatten and compute loss
        # ----------------------------------------
        # Flatten: (n_steps, n_envs) → (n_steps * n_envs,)
        batch_obs = mb_obs.reshape(-1, state_dim)
        batch_actions = mb_actions.flatten()
        batch_log_probs_old = mb_log_probs.flatten()
        batch_advantages = mb_advantages.flatten()
        batch_returns = mb_returns.flatten()
        
        # Normalize advantages
        batch_advantages = (batch_advantages - batch_advantages.mean()) / (batch_advantages.std() + 1e-8)
        
        # Forward pass
        obs_tensor = torch.FloatTensor(batch_obs)
        actions_tensor = torch.LongTensor(batch_actions)
        
        action_logits, values = model(obs_tensor)
        dist = torch.distributions.Categorical(logits=action_logits)
        log_probs = dist.log_prob(actions_tensor)
        entropy = dist.entropy().mean()
        
        # ----------------------------------------
        # STEP 4: Compute A2C losses
        # ----------------------------------------
        advantages_tensor = torch.FloatTensor(batch_advantages)
        returns_tensor = torch.FloatTensor(batch_returns)
        
        # Policy loss (actor): -log π × A
        policy_loss = -(log_probs * advantages_tensor).mean()
        
        # Value loss (critic): (V - R)²
        value_loss = ((values.squeeze() - returns_tensor) ** 2).mean()
        
        # Entropy loss (exploration bonus): -H[π]
        entropy_loss = -entropy
        
        # Total loss
        loss = policy_loss + vf_coef * value_loss + ent_coef * entropy_loss
        
        # ----------------------------------------
        # STEP 5: Update
        # ----------------------------------------
        optimizer.zero_grad()
        loss.backward()
        # Gradient clipping for stability
        nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        # Print progress
        if (update + 1) % print_every == 0 and len(rewards_history) > 0:
            avg_reward = np.mean(rewards_history[-100:]) if len(rewards_history) >= 100 else np.mean(rewards_history)
            print(f"Update {update+1:4d} | Episodes: {len(rewards_history):4d} | Avg Reward: {avg_reward:6.1f}")
    
    vec_env.close()
    return model, rewards_history

In [None]:
# Train A2C!
print("TRAINING A2C ON CARTPOLE")
print("="*60)
print("\nUsing 4 parallel environments with GAE\n")

model, rewards_history = train_a2c(
    env_name='CartPole-v1',
    n_envs=4,
    n_steps=5,
    n_updates=500,
    gamma=0.99,
    lam=0.95,
    lr=7e-4,
    print_every=100
)

print("\n" + "="*60)
if len(rewards_history) > 0:
    print(f"Total episodes completed: {len(rewards_history)}")
    print(f"Final average (last 100): {np.mean(rewards_history[-100:]):.1f}")
print("="*60)

In [None]:
# Visualize training
if len(rewards_history) > 10:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Left: Learning curve
    ax1 = axes[0]
    ax1.plot(rewards_history, alpha=0.3, color='blue', label='Episode Reward')
    
    window = min(50, len(rewards_history) // 3)
    if window > 1:
        smoothed = np.convolve(rewards_history, np.ones(window)/window, mode='valid')
        ax1.plot(range(window-1, len(rewards_history)), smoothed, 
                 color='red', linewidth=2, label=f'{window}-Episode Average')
    
    ax1.axhline(y=500, color='green', linestyle='--', linewidth=2, label='Max Score')
    ax1.set_xlabel('Episode', fontsize=11)
    ax1.set_ylabel('Reward', fontsize=11)
    ax1.set_title('A2C Training on CartPole', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Right: Episode completion rate
    ax2 = axes[1]
    
    # Bin episodes into groups
    n_bins = min(10, len(rewards_history) // 5)
    if n_bins > 1:
        bins = np.array_split(rewards_history, n_bins)
        bin_means = [np.mean(b) for b in bins]
        bin_stds = [np.std(b) for b in bins]
        
        x = range(n_bins)
        ax2.bar(x, bin_means, yerr=bin_stds, capsize=5, 
                color='#64b5f6', edgecolor='black', alpha=0.8)
        ax2.set_xlabel('Training Progress (bins)', fontsize=11)
        ax2.set_ylabel('Average Reward', fontsize=11)
        ax2.set_title('Reward Improvement Over Training', fontsize=14, fontweight='bold')
        ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
else:
    print("Not enough episodes to plot. Try increasing n_updates.")

---
## Using Stable-Baselines3 for Production A2C

For real projects, use a well-tested library!

In [None]:
# Check if Stable-Baselines3 is available
try:
    from stable_baselines3 import A2C
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.evaluation import evaluate_policy
    SB3_AVAILABLE = True
    print("✓ Stable-Baselines3 is installed!")
except ImportError:
    SB3_AVAILABLE = False
    print("Stable-Baselines3 not installed.")
    print("\nTo install, run:")
    print("  pip install stable-baselines3")

In [None]:
if SB3_AVAILABLE:
    print("TRAINING A2C WITH STABLE-BASELINES3")
    print("="*60)
    
    # Create vectorized environment
    vec_env = make_vec_env('CartPole-v1', n_envs=4)
    
    # Create A2C agent
    model = A2C(
        'MlpPolicy',      # Use MLP policy network
        vec_env,          # Vectorized environment
        verbose=1,        # Print training info
        learning_rate=7e-4,
        n_steps=5,        # Steps before each update
        gamma=0.99,
        gae_lambda=0.95,  # GAE lambda
        ent_coef=0.01,    # Entropy coefficient
        vf_coef=0.5,      # Value function coefficient
    )
    
    print("\nTraining for 25,000 steps...")
    model.learn(total_timesteps=25000)
    
    # Evaluate
    print("\nEvaluating trained agent...")
    mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=10)
    print(f"Mean reward: {mean_reward:.1f} ± {std_reward:.1f}")
    
    vec_env.close()
    print("\n" + "="*60)
else:
    print("\n(Showing example code - install SB3 to run)")
    example_code = '''
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env

# Create 4 parallel environments
vec_env = make_vec_env('CartPole-v1', n_envs=4)

# Create and train A2C agent
model = A2C('MlpPolicy', vec_env, verbose=1)
model.learn(total_timesteps=100000)

# Save the model
model.save('a2c_cartpole')

# Load and use later
model = A2C.load('a2c_cartpole')
'''
    print(example_code)

---
## Summary: Key Takeaways

### A2C vs A3C

| Aspect | A2C | A3C |
|--------|-----|-----|
| **Updates** | Synchronous | Asynchronous |
| **Parameters** | Single copy | Local copies |
| **Hardware** | GPU-friendly | CPU-focused |
| **Modern Use** | ✓ Preferred | Historical |

### GAE (Generalized Advantage Estimation)

```
A_t^GAE = Σ (γλ)^l × δ_{t+l}

λ = 0: TD(0) - low variance, more bias
λ = 1: MC - high variance, no bias  
λ = 0.95: Sweet spot!
```

### A2C Components

| Component | Purpose |
|-----------|--------|
| Vectorized Envs | Parallel experience collection |
| N-step Returns | Batch updates |
| GAE | Better advantage estimates |
| Entropy Bonus | Encourage exploration |

---
## Test Your Understanding

**1. Why use parallel environments?**
<details>
<summary>Click to reveal answer</summary>
Parallel environments provide:
1. N× faster data collection (N envs = N× experience)
2. More diverse experiences in each batch
3. Lower variance gradients (batch averages out noise)
4. Better exploration (different situations simultaneously)
</details>

**2. What's the difference between A2C and A3C?**
<details>
<summary>Click to reveal answer</summary>
A3C (Asynchronous): Workers update independently, may use stale parameters. Designed for CPU parallelism.

A2C (Synchronous): All workers wait, then batch update with fresh parameters. Simpler and better for GPUs.

A2C is preferred today because GPUs are common and it's simpler while being equally effective.
</details>

**3. What does GAE lambda control?**
<details>
<summary>Click to reveal answer</summary>
GAE lambda (λ) controls the bias-variance tradeoff:
- λ = 0: TD(0), only look 1 step ahead (low variance, more bias)
- λ = 1: Monte Carlo, use full return (no bias, high variance)
- λ = 0.95: Good balance (commonly used)

It's a weighted average of n-step returns, trusting closer steps more.
</details>

**4. Why add an entropy bonus?**
<details>
<summary>Click to reveal answer</summary>
The entropy bonus encourages exploration by preventing the policy from becoming too deterministic too quickly. Higher entropy = more random actions = more exploration. The coefficient (ent_coef) balances exploration vs exploitation.
</details>

**5. When should you use A2C vs other algorithms?**
<details>
<summary>Click to reveal answer</summary>
A2C is good for:
- Simple environments where sample efficiency isn't critical
- When you want a baseline policy gradient method
- Educational purposes (simpler than PPO)

For most real tasks, PPO is preferred (next section!) because it's more stable and sample efficient.
</details>

---
## Congratulations!

You've completed the **Policy Gradient** section! You now understand:

- ✅ Why policy gradients exist (continuous actions, stochastic policies)
- ✅ REINFORCE algorithm (Monte Carlo policy gradient)
- ✅ Variance reduction with baselines
- ✅ Actor-Critic methods (TD learning + policy gradient)
- ✅ A2C/A3C (parallel environments, GAE)

**Next Steps:**

Move on to **[Advanced Algorithms](../advanced-algorithms/)** to learn about:
- Trust Region Methods (TRPO)
- PPO (Proximal Policy Optimization) - the most popular algorithm!
- SAC (Soft Actor-Critic) for continuous control

---

*A2C: "Why train one agent when you can train four at once?"*