# Policy Gradient Intuition: Learning Actions Directly

Welcome to a fundamentally different approach to RL! Instead of learning values, we'll learn policies directly.

## What You'll Learn

By the end of this notebook, you'll understand:
- The dance instructor analogy: why learning actions directly makes sense
- Limitations of value-based methods (DQN's struggles)
- Stochastic vs deterministic policies
- The policy gradient theorem (intuition, not math terror!)
- When to use policy gradients vs value methods

**Prerequisites:** Deep RL section (DQN basics)

**Time:** ~25 minutes

---
## The Big Picture: The Dance Instructor Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE DANCE INSTRUCTOR ANALOGY                          │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine learning to dance...                                 │
    │                                                                │
    │  VALUE-BASED (DQN approach):                                  │
    │    Instructor rates every possible move:                      │
    │      "Step left: 7 points, spin: 9 points, jump: 3 points"   │
    │    You calculate which move has highest points               │
    │    Then do that move                                         │
    │                                                                │
    │    Problems:                                                  │
    │    • What if you need smooth moves, not discrete steps?      │
    │    • What if calculating max is expensive?                   │
    │    • Some moves work better with randomness (jazz hands!)    │
    │                                                                │
    │  POLICY GRADIENT (Direct approach):                          │
    │    Instructor says: "Just dance! I'll tell you if it's good" │
    │    You try moves, get feedback                               │
    │    Do MORE of what worked, LESS of what didn't              │
    │                                                                │
    │    Benefits:                                                  │
    │    • Works with smooth, continuous movements                 │
    │    • Natural randomness for creativity                       │
    │    • Simple: just learn what to do!                         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, Rectangle, FancyArrowPatch

# Visualize Value-Based vs Policy-Based
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Value-Based (DQN)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('VALUE-BASED (DQN)\n"Learn values, derive actions"', fontsize=14, fontweight='bold', color='#1976d2')

# State input
state_box1 = FancyBboxPatch((1, 6), 2, 2, boxstyle="round,pad=0.1",
                             facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
ax1.add_patch(state_box1)
ax1.text(2, 7, 'State', ha='center', va='center', fontsize=11, fontweight='bold')

# Q-Network
q_box = FancyBboxPatch((4, 5.5), 2, 3, boxstyle="round,pad=0.1",
                        facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax1.add_patch(q_box)
ax1.text(5, 7.8, 'Q-Network', ha='center', fontsize=10, fontweight='bold')
ax1.text(5, 7, 'Q(s, a₁)=5.2', ha='center', fontsize=9)
ax1.text(5, 6.4, 'Q(s, a₂)=3.1', ha='center', fontsize=9)
ax1.text(5, 5.8, 'Q(s, a₃)=7.4', ha='center', fontsize=9, color='#d32f2f', fontweight='bold')

# argmax operation
argmax_box = FancyBboxPatch((7, 6), 2, 2, boxstyle="round,pad=0.1",
                             facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=2)
ax1.add_patch(argmax_box)
ax1.text(8, 7.3, 'argmax', ha='center', fontsize=10, fontweight='bold')
ax1.text(8, 6.7, '→ a₃', ha='center', fontsize=11)

# Arrows
ax1.annotate('', xy=(3.9, 7), xytext=(3.1, 7),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax1.annotate('', xy=(6.9, 7), xytext=(6.1, 7),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Problems
ax1.text(5, 4.5, '❌ Needs discrete actions', ha='center', fontsize=10, color='#d32f2f')
ax1.text(5, 3.8, '❌ argmax is not differentiable', ha='center', fontsize=10, color='#d32f2f')
ax1.text(5, 3.1, '❌ Deterministic (no exploration)', ha='center', fontsize=10, color='#d32f2f')

# Right: Policy-Based
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('POLICY-BASED\n"Learn actions directly"', fontsize=14, fontweight='bold', color='#388e3c')

# State input
state_box2 = FancyBboxPatch((1, 6), 2, 2, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax2.add_patch(state_box2)
ax2.text(2, 7, 'State', ha='center', va='center', fontsize=11, fontweight='bold')

# Policy Network
policy_box = FancyBboxPatch((4.5, 5.5), 3, 3, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax2.add_patch(policy_box)
ax2.text(6, 7.8, 'Policy Network', ha='center', fontsize=10, fontweight='bold')
ax2.text(6, 7, 'π(a₁|s)=0.15', ha='center', fontsize=9)
ax2.text(6, 6.4, 'π(a₂|s)=0.10', ha='center', fontsize=9)
ax2.text(6, 5.8, 'π(a₃|s)=0.75', ha='center', fontsize=9, color='#388e3c', fontweight='bold')

# Sample arrow
ax2.annotate('', xy=(3.9, 7), xytext=(3.1, 7),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax2.annotate('', xy=(8.5, 7), xytext=(7.6, 7),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax2.text(9, 7.3, 'Sample', ha='center', fontsize=10)
ax2.text(9, 6.7, '→ Action!', ha='center', fontsize=10, fontweight='bold')

# Benefits
ax2.text(5.5, 4.5, '✓ Works with continuous actions', ha='center', fontsize=10, color='#388e3c')
ax2.text(5.5, 3.8, '✓ Fully differentiable', ha='center', fontsize=10, color='#388e3c')
ax2.text(5.5, 3.1, '✓ Natural exploration (stochastic)', ha='center', fontsize=10, color='#388e3c')

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("THE FUNDAMENTAL DIFFERENCE")
print("="*70)
print("""
VALUE-BASED (DQN):
  1. Learn: "How good is each action?" → Q(s, a)
  2. Then: Take the best action → argmax Q(s, a)
  
  Problem: You need to compare ALL actions to pick the best!

POLICY-BASED:
  1. Learn: "What should I do?" → π(a|s) directly!
  2. No argmax needed - just sample from the policy
  
  Simpler, and works with infinite action spaces!
""")
print("="*70)

---
## Why DQN Struggles: The Continuous Action Problem

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE CONTINUOUS ACTION PROBLEM                     │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  DISCRETE ACTIONS (DQN is happy):                             │
    │    CartPole: [LEFT, RIGHT] → 2 options, easy argmax!          │
    │    Atari: [UP, DOWN, LEFT, RIGHT, FIRE...] → ~18 options      │
    │                                                                │
    │  CONTINUOUS ACTIONS (DQN struggles):                          │
    │    Robot arm: "How much torque?" → Any value from -10 to +10 │
    │    Car steering: "What angle?" → Any value from -30° to +30° │
    │    Drone: "What thrust?" → Any value from 0 to 100%          │
    │                                                                │
    │  WHY DQN FAILS:                                               │
    │    DQN needs: argmax_a Q(s, a)                                │
    │    With continuous a, there are INFINITE options!            │
    │    Can't check Q-value for every possible action!            │
    │                                                                │
    │  POLICY GRADIENT SOLUTION:                                    │
    │    Output a probability distribution over actions            │
    │    For continuous: output mean μ and std σ of Gaussian       │
    │    Sample action from N(μ, σ)                                │
    │    No argmax needed!                                         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize discrete vs continuous action spaces
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Left: Discrete actions (DQN-friendly)
ax1 = axes[0]
actions = ['LEFT', 'RIGHT']
q_values = [3.5, 7.2]
colors = ['#90caf9', '#81c784']
bars = ax1.bar(actions, q_values, color=colors, edgecolor='black', linewidth=2)
bars[1].set_color('#4caf50')  # Highlight max
ax1.set_ylabel('Q-value', fontsize=11)
ax1.set_title('Discrete Actions\n(DQN works great!)', fontsize=12, fontweight='bold', color='#388e3c')
ax1.text(1, 7.5, '← argmax!', fontsize=10, fontweight='bold', color='#388e3c')
ax1.grid(True, alpha=0.3, axis='y')

# Middle: Continuous action - the problem
ax2 = axes[1]
# Show Q-function over continuous action
a = np.linspace(-2, 2, 100)
q = -0.5 * (a - 0.7)**2 + 2  # Quadratic Q-function
ax2.plot(a, q, 'b-', linewidth=2)
ax2.fill_between(a, q, alpha=0.2, color='blue')
ax2.axvline(x=0.7, color='red', linestyle='--', linewidth=2)
ax2.scatter([0.7], [2], c='red', s=100, zorder=5, label='Max (but how to find?)')
ax2.set_xlabel('Action (continuous)', fontsize=11)
ax2.set_ylabel('Q(s, a)', fontsize=11)
ax2.set_title('Continuous Actions\n(DQN struggles!)', fontsize=12, fontweight='bold', color='#d32f2f')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.text(0, 0.5, '∞ options to check!', fontsize=10, color='#d32f2f', ha='center')

# Right: Policy gradient solution
ax3 = axes[2]
# Show Gaussian policy
a = np.linspace(-2, 2, 100)
mu, sigma = 0.7, 0.3
policy = 1/(sigma*np.sqrt(2*np.pi)) * np.exp(-0.5*((a-mu)/sigma)**2)
ax3.plot(a, policy, 'g-', linewidth=2, label=f'π(a|s) ~ N({mu}, {sigma}²)')
ax3.fill_between(a, policy, alpha=0.3, color='green')
ax3.axvline(x=mu, color='green', linestyle='--', linewidth=2)

# Show samples
np.random.seed(42)
samples = np.random.normal(mu, sigma, 10)
ax3.scatter(samples, np.zeros(10), c='green', s=50, zorder=5, marker='|', label='Sampled actions')

ax3.set_xlabel('Action (continuous)', fontsize=11)
ax3.set_ylabel('π(a|s)', fontsize=11)
ax3.set_title('Policy Gradient Solution\n(Just sample!)', fontsize=12, fontweight='bold', color='#388e3c')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nPOLICY GRADIENT FOR CONTINUOUS ACTIONS:")
print("  Instead of Q-values, output distribution parameters:")
print("  • μ (mean): Where to center actions")
print("  • σ (std): How much to explore")
print("  Then simply sample: a ~ N(μ, σ²)")

---
## Stochastic vs Deterministic Policies

```
    ┌────────────────────────────────────────────────────────────────┐
    │              STOCHASTIC vs DETERMINISTIC                       │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  DETERMINISTIC POLICY:                                        │
    │    π(s) = a                                                   │
    │    "In state s, always do action a"                          │
    │    Example: DQN's argmax policy                              │
    │                                                                │
    │    Problem: No exploration! Gets stuck in local optima.       │
    │                                                                │
    │  STOCHASTIC POLICY:                                           │
    │    π(a|s) = probability of action a in state s               │
    │    "In state s, here's how likely each action is"            │
    │                                                                │
    │    Benefits:                                                  │
    │    • Built-in exploration (samples vary)                     │
    │    • Can represent mixed strategies                          │
    │    • Smooth optimization (no discontinuities)                │
    │                                                                │
    │  ANALOGY - Rock Paper Scissors:                               │
    │    Deterministic: "Always play Rock" → Easy to beat!         │
    │    Stochastic: "33% each" → Optimal mixed strategy!         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate stochastic vs deterministic policies

class DeterministicPolicy(nn.Module):
    """Always outputs the same action for a given state."""
    
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
    
    def forward(self, state):
        # Returns Q-values, then we take argmax
        return self.net(state)
    
    def get_action(self, state):
        with torch.no_grad():
            q_values = self.forward(state)
            return q_values.argmax().item()  # Always same action!


class StochasticPolicy(nn.Module):
    """Outputs action probabilities - samples are different each time!"""
    
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)  # Output probabilities!
        )
    
    def forward(self, state):
        return self.net(state)
    
    def get_action(self, state):
        with torch.no_grad():
            probs = self.forward(state)
            dist = torch.distributions.Categorical(probs)
            return dist.sample().item()  # Different each time!


# Compare behavior
print("COMPARING DETERMINISTIC vs STOCHASTIC POLICIES")
print("="*60)

state_dim, action_dim = 4, 3
det_policy = DeterministicPolicy(state_dim, action_dim)
sto_policy = StochasticPolicy(state_dim, action_dim)

# Same state
test_state = torch.randn(1, state_dim)

print("\nFor the SAME state, sampling 10 actions:")
print("\nDeterministic policy:")
det_actions = [det_policy.get_action(test_state) for _ in range(10)]
print(f"  Actions: {det_actions}")
print(f"  Always the same! ← No exploration!")

print("\nStochastic policy:")
with torch.no_grad():
    probs = sto_policy(test_state)
    print(f"  Probabilities: {probs.numpy()[0].round(3)}")

sto_actions = [sto_policy.get_action(test_state) for _ in range(10)]
print(f"  Actions: {sto_actions}")
print(f"  Varies! ← Natural exploration!")
print("="*60)

In [None]:
# Visualize the difference

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Deterministic
ax1 = axes[0]
actions = ['A1', 'A2', 'A3']
heights = [0, 1, 0]  # Always A2
ax1.bar(actions, heights, color=['#90caf9', '#4caf50', '#90caf9'], edgecolor='black', linewidth=2)
ax1.set_ylabel('Probability', fontsize=11)
ax1.set_ylim(0, 1.2)
ax1.set_title('Deterministic Policy\nπ(s) = A2 (always)', fontsize=12, fontweight='bold')
ax1.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
ax1.text(1, 1.05, '100%', ha='center', fontsize=10, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='y')

# Simulate multiple episodes
ax1_inset = ax1.inset_axes([0.6, 0.5, 0.35, 0.35])
det_samples = [1] * 20  # Always action 1
ax1_inset.hist(det_samples, bins=[0, 1, 2, 3], color='#4caf50', edgecolor='black', rwidth=0.8)
ax1_inset.set_title('20 Samples', fontsize=9)
ax1_inset.set_xticks([0.5, 1.5, 2.5])
ax1_inset.set_xticklabels(['A1', 'A2', 'A3'], fontsize=8)

# Right: Stochastic
ax2 = axes[1]
probs = [0.2, 0.5, 0.3]
ax2.bar(actions, probs, color=['#90caf9', '#81c784', '#ffcc80'], edgecolor='black', linewidth=2)
ax2.set_ylabel('Probability', fontsize=11)
ax2.set_ylim(0, 1.2)
ax2.set_title('Stochastic Policy\nπ(a|s) = [0.2, 0.5, 0.3]', fontsize=12, fontweight='bold')
ax2.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
for i, p in enumerate(probs):
    ax2.text(i, p + 0.03, f'{int(p*100)}%', ha='center', fontsize=10)
ax2.grid(True, alpha=0.3, axis='y')

# Simulate samples
ax2_inset = ax2.inset_axes([0.6, 0.5, 0.35, 0.35])
np.random.seed(42)
sto_samples = np.random.choice([0, 1, 2], 20, p=probs)
ax2_inset.hist(sto_samples, bins=[0, 1, 2, 3], color=['#90caf9', '#81c784', '#ffcc80'][1], 
               edgecolor='black', rwidth=0.8)
ax2_inset.set_title('20 Samples', fontsize=9)
ax2_inset.set_xticks([0.5, 1.5, 2.5])
ax2_inset.set_xticklabels(['A1', 'A2', 'A3'], fontsize=8)

plt.tight_layout()
plt.show()

print("\nSTOCHASTIC POLICIES ARE BETTER FOR:")
print("  1. Exploration - naturally try different actions")
print("  2. Mixed strategies - optimal in competitive games")
print("  3. Smooth optimization - gradients flow nicely")

---
## The Policy Gradient Theorem: The Key Insight

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE POLICY GRADIENT THEOREM                       │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  GOAL: Maximize expected return J(θ)                          │
    │        J(θ) = E[total rewards when following π(a|s; θ)]       │
    │                                                                │
    │  PROBLEM: How do we compute ∇_θ J(θ)?                         │
    │           Rewards depend on environment (not differentiable!) │
    │                                                                │
    │  THE MAGIC THEOREM:                                           │
    │                                                                │
    │        ∇_θ J(θ) = E[ ∇_θ log π(a|s; θ) × R ]                  │
    │                        └─────────────────┘   └─┘              │
    │                         Score function    Return              │
    │                                                                │
    │  INTUITION (Very Important!):                                 │
    │    • If action a got HIGH return R:                          │
    │      → Increase log π(a|s) → Make a MORE likely              │
    │    • If action a got LOW return R:                           │
    │      → Decrease log π(a|s) → Make a LESS likely              │
    │                                                                │
    │  ANALOGY: Like a coach giving feedback                       │
    │    "That move scored! Do more of that!"                      │
    │    "That move failed. Do less of that."                      │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the policy gradient intuition
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Initial policy
ax1 = axes[0]
actions = ['A1', 'A2', 'A3']
initial_probs = [0.33, 0.33, 0.34]
ax1.bar(actions, initial_probs, color='#90caf9', edgecolor='black', linewidth=2)
ax1.set_ylabel('π(a|s)', fontsize=11)
ax1.set_ylim(0, 0.8)
ax1.set_title('Before Training\n(Uniform Policy)', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='y')

# Experience: A2 got high return, A3 got low return
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Experience\n(Trial and Feedback)', fontsize=12, fontweight='bold')

# Show experiences
experiences = [
    ('A1', 'R = +5', '#fff3e0', 'Medium'),
    ('A2', 'R = +50', '#c8e6c9', 'HIGH! ✓'),
    ('A3', 'R = -10', '#ffcdd2', 'Bad ✗'),
]

for i, (action, ret, color, comment) in enumerate(experiences):
    y = 7 - i * 2.5
    box = FancyBboxPatch((1, y), 7, 1.5, boxstyle="round,pad=0.1",
                          facecolor=color, edgecolor='black', linewidth=2)
    ax2.add_patch(box)
    ax2.text(2.5, y + 0.75, f'Action {action}:', fontsize=11, fontweight='bold', va='center')
    ax2.text(5.5, y + 0.75, ret, fontsize=11, va='center')
    ax2.text(8.5, y + 0.75, comment, fontsize=10, va='center')

ax2.text(5, 0.5, 'Policy gradient uses this\nfeedback to update!', ha='center', fontsize=10, style='italic')

# After update
ax3 = axes[2]
updated_probs = [0.25, 0.60, 0.15]  # A2 increased, A3 decreased
colors = ['#90caf9', '#4caf50', '#ef9a9a']
bars = ax3.bar(actions, updated_probs, color=colors, edgecolor='black', linewidth=2)
ax3.set_ylabel('π(a|s)', fontsize=11)
ax3.set_ylim(0, 0.8)
ax3.set_title('After Training\n(Learned Policy)', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3, axis='y')

# Annotations
ax3.annotate('↑ More likely\n(worked well!)', xy=(1, 0.60), xytext=(1.5, 0.72),
             fontsize=9, color='#388e3c', ha='center')
ax3.annotate('↓ Less likely\n(failed)', xy=(2, 0.15), xytext=(2.5, 0.30),
             fontsize=9, color='#d32f2f', ha='center')

plt.tight_layout()
plt.show()

print("\nPOLICY GRADIENT UPDATE RULE:")
print("  If return R > 0: Increase π(a|s) for that action")
print("  If return R < 0: Decrease π(a|s) for that action")
print("\n  It's reinforcing good behavior and weakening bad behavior!")

In [None]:
# Demonstrate the math: log π trick

print("THE LOG PROBABILITY TRICK")
print("="*60)

# Simple policy
policy = StochasticPolicy(4, 3)
state = torch.randn(1, 4)

# Get probabilities and log probabilities
probs = policy(state)
log_probs = torch.log(probs)

print(f"\nAction probabilities: π(a|s) = {probs.detach().numpy()[0].round(3)}")
print(f"Log probabilities: log π(a|s) = {log_probs.detach().numpy()[0].round(3)}")

# The gradient magic
print("\n" + "-"*60)
print("WHY LOG PROBABILITIES?")
print("-"*60)
print("""
We want: ∇_θ E[R × π(a|s; θ)]
Problem: R and π are tangled together!

Solution: Use the log-derivative trick:
  ∇_θ π(a|s; θ) = π(a|s; θ) × ∇_θ log π(a|s; θ)
  
This lets us write:
  ∇_θ E[R × π] = E[R × ∇_θ log π]
  
Now we can estimate this from samples!
  ≈ (1/N) Σ R_i × ∇_θ log π(a_i|s_i; θ)
""")
print("="*60)

# Practical gradient computation
print("\nPRACTICAL IMPLEMENTATION:")
action = torch.tensor([1])  # Suppose we took action 1
ret = 10.0  # And got return 10

# The loss we minimize (negative because we want gradient ASCENT)
dist = torch.distributions.Categorical(probs)
log_prob = dist.log_prob(action)
loss = -log_prob * ret  # Negative for gradient ascent

print(f"  Action taken: {action.item()}")
print(f"  Log probability: {log_prob.item():.4f}")
print(f"  Return: {ret}")
print(f"  Loss: {loss.item():.4f}")
print("\n  Gradient descent on this loss → Makes action more likely!")

---
## Building a Simple Policy Network

```
    ┌────────────────────────────────────────────────────────────────┐
    │              POLICY NETWORK ARCHITECTURE                       │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  DISCRETE ACTIONS (e.g., CartPole):                           │
    │                                                                │
    │    State [4] → Hidden [64] → Hidden [64] → Softmax → [2]      │
    │                                                                │
    │    Output: [π(LEFT|s), π(RIGHT|s)]                            │
    │    Sample action from this distribution                       │
    │                                                                │
    │  CONTINUOUS ACTIONS (e.g., Pendulum):                         │
    │                                                                │
    │    State [3] → Hidden [64] → Hidden [64] → [μ, log σ]        │
    │                                                                │
    │    Output: Parameters of Gaussian N(μ, σ²)                    │
    │    Sample action from Gaussian                                │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
class DiscretePolicy(nn.Module):
    """
    Policy network for DISCRETE action spaces.
    
    Outputs probability distribution over actions.
    Use for: CartPole, Atari, etc.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)  # Outputs sum to 1
        )
    
    def forward(self, state):
        """Returns action probabilities."""
        return self.network(state)
    
    def get_action(self, state):
        """Sample action from policy."""
        probs = self.forward(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)


class ContinuousPolicy(nn.Module):
    """
    Policy network for CONTINUOUS action spaces.
    
    Outputs parameters of Gaussian distribution.
    Use for: Pendulum, MuJoCo, robotics, etc.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Output mean of Gaussian
        self.mean = nn.Linear(hidden_dim, action_dim)
        
        # Output log std (learnable, state-independent for simplicity)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
    
    def forward(self, state):
        """Returns mean and std of Gaussian."""
        features = self.shared(state)
        mean = self.mean(features)
        std = torch.exp(self.log_std)  # Ensure positive
        return mean, std
    
    def get_action(self, state):
        """Sample continuous action from Gaussian."""
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum()  # Sum over action dims
        return action.numpy(), log_prob


# Demonstrate both
print("POLICY NETWORK EXAMPLES")
print("="*60)

# Discrete policy (CartPole-like)
print("\n1. DISCRETE POLICY (CartPole):")
discrete_policy = DiscretePolicy(state_dim=4, action_dim=2)
state = torch.randn(1, 4)
probs = discrete_policy(state)
action, log_prob = discrete_policy.get_action(state)
print(f"   State: {state.numpy()[0].round(2)}")
print(f"   Action probabilities: {probs.detach().numpy()[0].round(3)}")
print(f"   Sampled action: {action}")
print(f"   Log probability: {log_prob.item():.4f}")

# Continuous policy (Pendulum-like)
print("\n2. CONTINUOUS POLICY (Pendulum):")
continuous_policy = ContinuousPolicy(state_dim=3, action_dim=1)
state = torch.randn(1, 3)
mean, std = continuous_policy(state)
action, log_prob = continuous_policy.get_action(state)
print(f"   State: {state.numpy()[0].round(2)}")
print(f"   Gaussian mean: {mean.detach().numpy()[0].round(3)}")
print(f"   Gaussian std: {std.detach().numpy().round(3)}")
print(f"   Sampled action: {action.round(3)}")
print(f"   Log probability: {log_prob.item():.4f}")

print("\n" + "="*60)

---
## When to Use Policy Gradients vs Value Methods

```
    ┌────────────────────────────────────────────────────────────────┐
    │              WHEN TO USE WHAT                                  │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  USE VALUE-BASED (DQN) WHEN:                                  │
    │    ✓ Discrete, small action space                            │
    │    ✓ Sample efficiency is crucial                            │
    │    ✓ You have lots of experience replay data                 │
    │    Example: Atari games, board games                         │
    │                                                                │
    │  USE POLICY GRADIENTS WHEN:                                   │
    │    ✓ Continuous action space                                 │
    │    ✓ High-dimensional action space                           │
    │    ✓ Stochastic policy is beneficial                         │
    │    ✓ You want simpler, more direct optimization              │
    │    Example: Robotics, control, games with mixed strategies   │
    │                                                                │
    │  USE ACTOR-CRITIC (Both!) WHEN:                               │
    │    ✓ You want the best of both worlds                        │
    │    ✓ Continuous actions + sample efficiency                  │
    │    Example: PPO, SAC (state-of-the-art for most tasks)       │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Summary comparison visualization
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

ax.text(7, 9.5, 'Value-Based vs Policy-Based Methods', ha='center', fontsize=16, fontweight='bold')

# Value-based column
value_box = FancyBboxPatch((0.5, 2), 6, 6.5, boxstyle="round,pad=0.1",
                            facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=3)
ax.add_patch(value_box)
ax.text(3.5, 8, 'VALUE-BASED', ha='center', fontsize=14, fontweight='bold', color='#1976d2')
ax.text(3.5, 7.3, '(DQN, Double DQN, Dueling)', ha='center', fontsize=10, color='#666')

value_points = [
    '✓ Sample efficient',
    '✓ Off-policy learning',
    '✓ Experience replay',
    '✗ Discrete actions only',
    '✗ Deterministic policy',
    '✗ argmax not differentiable'
]
for i, point in enumerate(value_points):
    color = '#388e3c' if point.startswith('✓') else '#d32f2f'
    ax.text(1, 6.2 - i*0.8, point, fontsize=10, color=color)

# Policy-based column
policy_box = FancyBboxPatch((7.5, 2), 6, 6.5, boxstyle="round,pad=0.1",
                             facecolor='#e8f5e9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(policy_box)
ax.text(10.5, 8, 'POLICY-BASED', ha='center', fontsize=14, fontweight='bold', color='#388e3c')
ax.text(10.5, 7.3, '(REINFORCE, A2C, PPO)', ha='center', fontsize=10, color='#666')

policy_points = [
    '✓ Continuous actions',
    '✓ Stochastic policies',
    '✓ Differentiable end-to-end',
    '✓ Simpler architecture',
    '✗ Higher variance',
    '✗ On-policy (less efficient)'
]
for i, point in enumerate(policy_points):
    color = '#388e3c' if point.startswith('✓') else '#d32f2f'
    ax.text(8, 6.2 - i*0.8, point, fontsize=10, color=color)

# Best of both
ax.text(7, 0.8, '⭐ ACTOR-CRITIC combines both: Policy + Value function ⭐', 
        ha='center', fontsize=12, fontweight='bold', color='#7b1fa2')

plt.tight_layout()
plt.show()

---
## Summary: Key Takeaways

### The Two Approaches

| Aspect | Value-Based (DQN) | Policy-Based |
|--------|-------------------|-------------|
| **Learns** | Q(s, a) values | π(a|s) directly |
| **Action Selection** | argmax Q(s, a) | Sample from π(a|s) |
| **Action Space** | Discrete only | Any (discrete or continuous) |
| **Policy Type** | Deterministic | Stochastic |
| **Exploration** | ε-greedy (added) | Built-in (sampling) |

### Policy Gradient Theorem

```
∇_θ J(θ) = E[ ∇_θ log π(a|s; θ) × R ]

Intuition:
  High return → Increase action probability
  Low return → Decrease action probability
```

### When to Use

| Scenario | Recommendation |
|----------|---------------|
| Discrete, small action space | DQN |
| Continuous actions | Policy Gradient |
| Best performance | Actor-Critic (PPO) |

---
## Test Your Understanding

**1. Why can't DQN handle continuous action spaces?**
<details>
<summary>Click to reveal answer</summary>
DQN requires computing argmax Q(s, a) to select actions. With continuous actions, there are infinite possible actions to check. You can't evaluate Q(s, a) for every possible real-valued action. Policy gradient methods avoid this by directly outputting the action (or its distribution parameters).
</details>

**2. What's the difference between deterministic and stochastic policies?**
<details>
<summary>Click to reveal answer</summary>
Deterministic: π(s) = a, always outputs the same action for a given state.
Stochastic: π(a|s), outputs a probability distribution over actions.

Stochastic policies are better for exploration (naturally vary actions) and can represent mixed strategies (important in competitive games like Rock-Paper-Scissors).
</details>

**3. What does the policy gradient theorem tell us intuitively?**
<details>
<summary>Click to reveal answer</summary>
The policy gradient theorem says:
- If an action led to a HIGH return, increase its probability
- If an action led to a LOW return, decrease its probability

It's like a coach reinforcing good moves and discouraging bad ones!
</details>

**4. Why do we use log probabilities in the policy gradient?**
<details>
<summary>Click to reveal answer</summary>
The log-derivative trick allows us to rewrite:
∇_θ π(a|s; θ) = π(a|s; θ) × ∇_θ log π(a|s; θ)

This lets us estimate the gradient from samples by multiplying log π by returns. Without this trick, we couldn't easily compute gradients through the environment's dynamics.
</details>

**5. When would you choose policy gradient over DQN?**
<details>
<summary>Click to reveal answer</summary>
Choose policy gradient when:
- Action space is continuous (robotics, control)
- Action space is large or high-dimensional
- You want stochastic behavior (exploration, mixed strategies)
- You prefer simpler, end-to-end differentiable optimization
</details>

---
## What's Next?

You now understand the **why** of policy gradients. In the next notebook, we'll implement the simplest policy gradient algorithm: **REINFORCE**!

**Continue to:** [Notebook 2: REINFORCE Algorithm](02_reinforce_algorithm.ipynb)

---

*Policy gradients: "Instead of asking 'how good is this action?', just learn what to do!"*