# DQN Improvements: Supercharging the Original

The original DQN was revolutionary, but it had problems. This notebook covers the key improvements that made DQN even better!

## What You'll Learn

By the end of this notebook, you'll understand:
- The overestimation problem (with a restaurant analogy!)
- Double DQN: fixing the optimism bias
- Dueling DQN: separating "where" from "what"
- Other improvements: N-step, Noisy Nets
- Rainbow DQN: combining everything
- Implementing all improvements in PyTorch

**Prerequisites:** Notebooks 1-4 (DQN fundamentals)

**Time:** ~30 minutes

---
## The Big Picture: The Restaurant Review Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          WHY DQN NEEDS IMPROVEMENTS: THE RESTAURANT           │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine you're choosing a restaurant based on reviews...     │
    │                                                                │
    │  THE OVERESTIMATION PROBLEM (Fixed by Double DQN):            │
    │    You always pick the restaurant with the HIGHEST rating.    │
    │    But ratings are noisy! That "5-star" place might be:       │
    │    - Actually 4.5 stars + lucky reviews                       │
    │    - Meanwhile a 4.8 star place got unlucky reviews           │
    │    Always picking max picks the luckiest, not the best!       │
    │                                                                │
    │  THE EFFICIENCY PROBLEM (Fixed by Dueling DQN):               │
    │    Some neighborhoods are just BETTER (safer, nicer).         │
    │    Any restaurant there is probably good!                     │
    │    You can learn: "Good neighborhood" + "Good menu"           │
    │    separately, instead of rating each restaurant from scratch.│
    │                                                                │
    │  THE EXPLORATION PROBLEM (Fixed by Noisy Nets):               │
    │    With ε-greedy, you randomly try ANY restaurant.            │
    │    Wouldn't it be smarter to be uncertain about similar      │
    │    restaurants, not completely random?                        │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle, Circle, FancyArrowPatch

# Visualize DQN improvements overview
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

ax.text(7, 9.5, 'Evolution of DQN Improvements', ha='center', fontsize=18, fontweight='bold')

# Original DQN
dqn_box = FancyBboxPatch((0.5, 6), 3, 2.5, boxstyle="round,pad=0.1",
                          facecolor='#bbdefb', edgecolor='#1976d2', linewidth=3)
ax.add_patch(dqn_box)
ax.text(2, 7.8, 'Original DQN', ha='center', fontsize=11, fontweight='bold', color='#1976d2')
ax.text(2, 7.2, '(2015)', ha='center', fontsize=9)
ax.text(2, 6.5, '• Experience Replay\n• Target Network', ha='center', fontsize=9)

# Double DQN
ddqn_box = FancyBboxPatch((4.5, 6), 3, 2.5, boxstyle="round,pad=0.1",
                           facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax.add_patch(ddqn_box)
ax.text(6, 7.8, 'Double DQN', ha='center', fontsize=11, fontweight='bold', color='#388e3c')
ax.text(6, 7.2, '(2015)', ha='center', fontsize=9)
ax.text(6, 6.5, 'Fix: Overestimation', ha='center', fontsize=9)

# Dueling DQN
duel_box = FancyBboxPatch((8.5, 6), 3, 2.5, boxstyle="round,pad=0.1",
                           facecolor='#fff3e0', edgecolor='#f57c00', linewidth=3)
ax.add_patch(duel_box)
ax.text(10, 7.8, 'Dueling DQN', ha='center', fontsize=11, fontweight='bold', color='#f57c00')
ax.text(10, 7.2, '(2016)', ha='center', fontsize=9)
ax.text(10, 6.5, 'Fix: V/A Separation', ha='center', fontsize=9)

# Arrows
ax.annotate('', xy=(4.4, 7.25), xytext=(3.6, 7.25),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax.annotate('', xy=(8.4, 7.25), xytext=(7.6, 7.25),
            arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Other improvements
others = [
    ('Prioritized\nReplay', 0.5, '#e1bee7', '#7b1fa2'),
    ('N-step\nReturns', 3.5, '#b2dfdb', '#00796b'),
    ('Noisy\nNets', 6.5, '#ffe0b2', '#e65100'),
    ('Distributional\nRL', 9.5, '#f8bbd9', '#c2185b'),
]

for name, x, color, edge in others:
    box = FancyBboxPatch((x, 2.5), 2.5, 2, boxstyle="round,pad=0.1",
                          facecolor=color, edgecolor=edge, linewidth=2)
    ax.add_patch(box)
    ax.text(x + 1.25, 3.5, name, ha='center', va='center', fontsize=9, fontweight='bold')

# Rainbow
rainbow_box = FancyBboxPatch((4, 0), 6, 1.8, boxstyle="round,pad=0.1",
                              facecolor='#fff', edgecolor='#333', linewidth=3)
ax.add_patch(rainbow_box)

# Rainbow gradient text
colors = ['#e57373', '#ffb74d', '#fff176', '#81c784', '#64b5f6', '#9575cd']
for i, (c, letter) in enumerate(zip(colors, 'RAINBOW')):
    ax.text(5.5 + i*0.5, 1.1, letter, fontsize=14, fontweight='bold', color=c)
ax.text(7, 0.5, '(2017) - Combines ALL improvements!', ha='center', fontsize=10)

# Arrows to Rainbow
for x in [1.75, 4.75, 7.75, 10.75]:
    ax.annotate('', xy=(7, 1.9), xytext=(x, 2.4),
                arrowprops=dict(arrowstyle='->', lw=1, color='#666', connectionstyle='arc3,rad=0'))

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("DQN IMPROVEMENTS TIMELINE")
print("="*70)
print("""
2015: Original DQN - First deep RL to play Atari
2015: Double DQN - Fixed overestimation bias
2015: Prioritized Replay - Learn from important experiences
2016: Dueling DQN - Separated value and advantage
2017: Rainbow - Combined everything for 2x performance!
""")
print("="*70)

---
## Problem 1: Overestimation Bias

DQN has a fundamental problem: it **overestimates** Q-values!

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE OVERESTIMATION PROBLEM                        │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  DQN TARGET: y = r + γ × max_a' Q(s', a')                     │
    │                          └────────────┘                        │
    │                          This is the problem!                  │
    │                                                                │
    │  WHY MAX CAUSES OVERESTIMATION:                               │
    │                                                                │
    │  Q-values have noise (estimation errors):                     │
    │    True Q:  [3.0, 2.8, 2.9, 3.1]   (action 4 is actually best)│
    │    Noisy Q: [3.5, 2.6, 3.2, 2.9]   (with estimation errors)   │
    │                ↑                                              │
    │    max picks: 3.5 (WRONG! This is noise, not skill!)          │
    │                                                                │
    │  RESULT: We consistently overestimate because max always     │
    │  picks the action with the highest NOISE + value!            │
    │                                                                │
    │  E[max(X + noise)] > max(X) when there's noise!              │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Demonstrate overestimation bias
np.random.seed(42)

# True Q-values for 4 actions
true_q = np.array([2.0, 2.2, 2.1, 2.3])  # Action 4 is truly best (2.3)
n_samples = 1000

# Simulate noisy Q-estimates
noise_std = 0.5
max_estimates = []
selected_actions = []

for _ in range(n_samples):
    noisy_q = true_q + np.random.randn(4) * noise_std
    max_estimates.append(np.max(noisy_q))
    selected_actions.append(np.argmax(noisy_q))

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Left: Distribution of max estimates
ax1 = axes[0]
ax1.hist(max_estimates, bins=50, color='#ef5350', alpha=0.7, edgecolor='black')
ax1.axvline(np.max(true_q), color='#388e3c', linewidth=3, linestyle='--', 
            label=f'True max: {np.max(true_q):.2f}')
ax1.axvline(np.mean(max_estimates), color='#d32f2f', linewidth=3, 
            label=f'Avg estimate: {np.mean(max_estimates):.2f}')
ax1.set_xlabel('Max Q-value Estimate', fontsize=11)
ax1.set_ylabel('Count', fontsize=11)
ax1.set_title('Overestimation Bias\n(max picks highest noise!)', fontsize=12, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Middle: Which action is selected?
ax2 = axes[1]
action_counts = [selected_actions.count(i) for i in range(4)]
colors = ['#90caf9', '#90caf9', '#90caf9', '#81c784']  # Highlight true best
bars = ax2.bar(['A1', 'A2', 'A3', 'A4 (best)'], action_counts, color=colors, edgecolor='black')
ax2.set_ylabel('Times Selected', fontsize=11)
ax2.set_title('Action Selection with Noise\n(Often picks wrong action!)', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

# Right: Overestimation grows with more actions
ax3 = axes[2]
n_actions_list = [2, 4, 8, 16, 32]
overestimation = []

for n_actions in n_actions_list:
    true_max = 2.5  # Assume all true values are similar
    estimates = []
    for _ in range(500):
        noisy = np.random.randn(n_actions) * noise_std + true_max - 0.2
        estimates.append(np.max(noisy))
    overestimation.append(np.mean(estimates) - true_max)

ax3.plot(n_actions_list, overestimation, 'ro-', linewidth=2, markersize=10)
ax3.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax3.set_xlabel('Number of Actions', fontsize=11)
ax3.set_ylabel('Overestimation', fontsize=11)
ax3.set_title('More Actions = More Overestimation', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nOVERESTIMATION ANALYSIS:")
print(f"  True max Q-value: {np.max(true_q):.2f}")
print(f"  Average max estimate: {np.mean(max_estimates):.2f}")
print(f"  Overestimation: +{np.mean(max_estimates) - np.max(true_q):.2f}")
print(f"\n  Times correct action selected: {selected_actions.count(3)/n_samples*100:.1f}%")
print(f"  (Should be higher, but noise causes wrong selections!)")

---
## Solution 1: Double DQN

The elegant fix: **Decouple selection from evaluation!**

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    DOUBLE DQN                                  │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  ORIGINAL DQN:                                                │
    │    target = r + γ × max_a' Q(s', a'; θ⁻)                      │
    │                    └─────────────────┘                         │
    │               Same network selects AND evaluates              │
    │                                                                │
    │  DOUBLE DQN:                                                  │
    │    Step 1: a* = argmax_a' Q(s', a'; θ)   ← Q-net SELECTS     │
    │    Step 2: target = r + γ × Q(s', a*; θ⁻) ← Target EVALUATES │
    │                                                                │
    │  WHY THIS WORKS:                                              │
    │    • Even if Q-net's noise picks wrong action...             │
    │    • Target net's evaluation of that action is unbiased!     │
    │    • The two networks have DIFFERENT noise!                  │
    │                                                                │
    │  ANALOGY:                                                     │
    │    • Friend 1 (Q-net): "Let's go to Restaurant A!"           │
    │    • Friend 2 (Target): "Hmm, Restaurant A is actually 3.5⭐"│
    │    • Two opinions reduce the bias!                           │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Implement Double DQN vs Standard DQN

class QNetwork(nn.Module):
    """Simple Q-Network."""
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, x):
        return self.network(x)


def compute_dqn_target(reward, next_state, done, target_net, gamma):
    """
    Standard DQN target.
    
    Uses same network for both selection AND evaluation.
    target = r + γ × max_a' Q(s', a'; θ⁻)
    """
    with torch.no_grad():
        # Target network does both: select AND evaluate
        next_q = target_net(next_state)
        max_next_q = next_q.max(dim=1)[0]
        target = reward + gamma * max_next_q * (1 - done)
    return target


def compute_double_dqn_target(reward, next_state, done, q_net, target_net, gamma):
    """
    Double DQN target - reduces overestimation!
    
    Step 1: Q-network SELECTS the best action
    Step 2: Target network EVALUATES that action
    
    Different networks have different noise → less bias!
    """
    with torch.no_grad():
        # ========================================
        # STEP 1: Q-network SELECTS best action
        # ========================================
        q_values = q_net(next_state)
        best_actions = q_values.argmax(dim=1)  # What Q-net thinks is best
        
        # ========================================
        # STEP 2: Target network EVALUATES
        # ========================================
        next_q = target_net(next_state)
        # Get Q-value for the actions Q-net selected
        next_q_values = next_q.gather(1, best_actions.unsqueeze(1)).squeeze()
        
        target = reward + gamma * next_q_values * (1 - done)
    return target


# Demonstrate the difference
print("DOUBLE DQN vs STANDARD DQN")
print("="*60)

state_dim, action_dim = 4, 4
q_net = QNetwork(state_dim, action_dim)
target_net = QNetwork(state_dim, action_dim)

# Sample data
batch_size = 5
rewards = torch.zeros(batch_size)
next_states = torch.randn(batch_size, state_dim)
dones = torch.zeros(batch_size)
gamma = 0.99

dqn_targets = compute_dqn_target(rewards, next_states, dones, target_net, gamma)
ddqn_targets = compute_double_dqn_target(rewards, next_states, dones, q_net, target_net, gamma)

print("\nTarget values for same batch:")
print(f"  DQN targets:    {dqn_targets.numpy().round(3)}")
print(f"  Double DQN:     {ddqn_targets.numpy().round(3)}")
print(f"\n  Difference:     {(dqn_targets - ddqn_targets).numpy().round(3)}")
print("\nDouble DQN typically gives LOWER (more accurate) targets!")
print("="*60)

In [None]:
# Visualize Double DQN architecture

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Standard DQN
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Standard DQN\n(Overestimates)', fontsize=14, fontweight='bold', color='#d32f2f')

# Target network does everything
target_box = FancyBboxPatch((3, 5), 4, 2.5, boxstyle="round,pad=0.1",
                             facecolor='#ffcdd2', edgecolor='#d32f2f', linewidth=3)
ax1.add_patch(target_box)
ax1.text(5, 6.8, 'Target Network (θ⁻)', ha='center', fontsize=11, fontweight='bold')
ax1.text(5, 6.0, 'SELECT: argmax Q(s\',a\';θ⁻)', ha='center', fontsize=9)
ax1.text(5, 5.4, 'EVALUATE: Q(s\',a*;θ⁻)', ha='center', fontsize=9)

# Input/Output arrows
ax1.annotate('', xy=(5, 4.9), xytext=(5, 3.5),
             arrowprops=dict(arrowstyle='->', lw=2, color='#d32f2f'))
ax1.text(5, 3, 'Target: r + γ×max Q', ha='center', fontsize=10, color='#d32f2f')

ax1.annotate('', xy=(5, 7.6), xytext=(5, 8.5),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax1.text(5, 9, "Next state s'", ha='center', fontsize=10)

ax1.text(5, 1.5, '❌ Same noise for selection\n    and evaluation!', 
         ha='center', fontsize=10, color='#d32f2f')

# Right: Double DQN
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Double DQN\n(Accurate)', fontsize=14, fontweight='bold', color='#388e3c')

# Q-network for selection
q_box = FancyBboxPatch((0.5, 5), 4, 2.5, boxstyle="round,pad=0.1",
                        facecolor='#bbdefb', edgecolor='#1976d2', linewidth=3)
ax2.add_patch(q_box)
ax2.text(2.5, 6.8, 'Q-Network (θ)', ha='center', fontsize=11, fontweight='bold')
ax2.text(2.5, 6.0, 'SELECT:', ha='center', fontsize=9)
ax2.text(2.5, 5.4, 'a* = argmax Q(s\',a\';θ)', ha='center', fontsize=9)

# Target network for evaluation
target_box2 = FancyBboxPatch((5.5, 5), 4, 2.5, boxstyle="round,pad=0.1",
                              facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=3)
ax2.add_patch(target_box2)
ax2.text(7.5, 6.8, 'Target Network (θ⁻)', ha='center', fontsize=11, fontweight='bold')
ax2.text(7.5, 6.0, 'EVALUATE:', ha='center', fontsize=9)
ax2.text(7.5, 5.4, 'Q(s\', a*; θ⁻)', ha='center', fontsize=9)

# Arrow between networks
ax2.annotate('', xy=(5.4, 6.25), xytext=(4.6, 6.25),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax2.text(5, 6.8, 'a*', ha='center', fontsize=10, fontweight='bold')

# Output
ax2.annotate('', xy=(5, 4.9), xytext=(5, 3.5),
             arrowprops=dict(arrowstyle='->', lw=2, color='#388e3c'))
ax2.text(5, 3, 'Target: r + γ×Q(s\',a*)', ha='center', fontsize=10, color='#388e3c')

ax2.text(5, 1.5, '✓ Different noise for\n    selection vs evaluation!', 
         ha='center', fontsize=10, color='#388e3c')

plt.tight_layout()
plt.show()

---
## Problem 2: Learning Efficiency

Sometimes the state matters more than the action!

```
    ┌────────────────────────────────────────────────────────────────┐
    │              THE EFFICIENCY PROBLEM                            │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Consider playing Pong near the end of a point:               │
    │                                                                │
    │  SITUATION 1: Ball going into opponent's corner               │
    │    → You're going to WIN no matter what action you take       │
    │    → The STATE is good, actions don't matter much             │
    │                                                                │
    │  SITUATION 2: Ball coming at you very fast                    │
    │    → Your ACTION matters a lot! Move or lose!                 │
    │    → Need to distinguish between good and bad actions         │
    │                                                                │
    │  INSIGHT: Sometimes V(s) is enough, sometimes A(s,a) matters! │
    │                                                                │
    │  SOLUTION: Learn BOTH separately, then combine:               │
    │    Q(s, a) = V(s) + A(s, a)                                   │
    │             ↑        ↑                                        │
    │     "How good is    "How much better                         │
    │      this state?"    is this action?"                        │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

---
## Solution 2: Dueling DQN

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    DUELING ARCHITECTURE                        │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │                        ┌───────────┐                          │
    │                        │   State   │                          │
    │                        │   Input   │                          │
    │                        └─────┬─────┘                          │
    │                              │                                │
    │                        ┌─────▼─────┐                          │
    │                        │  Shared   │                          │
    │                        │  Layers   │                          │
    │                        └─────┬─────┘                          │
    │                      ┌───────┴───────┐                        │
    │                      │               │                        │
    │                ┌─────▼─────┐   ┌─────▼─────┐                  │
    │                │   Value   │   │ Advantage │                  │
    │                │  Stream   │   │  Stream   │                  │
    │                │   V(s)    │   │   A(s,a)  │                  │
    │                └─────┬─────┘   └─────┬─────┘                  │
    │                      │               │                        │
    │                      └───────┬───────┘                        │
    │                              │                                │
    │                   Q(s,a) = V(s) + A(s,a) - mean(A)            │
    │                                                                │
    │  WHY SUBTRACT MEAN(A)?                                        │
    │    To make V(s) uniquely identifiable!                        │
    │    Otherwise V could absorb A or vice versa.                  │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
class DuelingDQN(nn.Module):
    """
    Dueling DQN Architecture.
    
    Separates the Q-value into:
    - V(s): How good is this state?
    - A(s,a): How much better is this action than average?
    
    Q(s, a) = V(s) + A(s, a) - mean(A(s, .))
    
    Benefits:
    - Can learn state value even when actions don't matter
    - More sample efficient
    - Better generalization
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        
        # ========================================
        # SHARED FEATURE EXTRACTION
        # ========================================
        self.features = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # ========================================
        # VALUE STREAM: V(s) - single number
        # "How good is this state overall?"
        # ========================================
        self.value_stream = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)  # Single value output
        )
        
        # ========================================
        # ADVANTAGE STREAM: A(s, a) - one per action
        # "How much better is each action than average?"
        # ========================================
        self.advantage_stream = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, action_dim)  # One per action
        )
    
    def forward(self, state):
        """
        Forward pass: combine value and advantage streams.
        
        Q(s, a) = V(s) + A(s, a) - mean(A(s, .))
        
        Subtracting mean(A) ensures V(s) is unique.
        """
        # Extract shared features
        features = self.features(state)
        
        # Compute value and advantage
        value = self.value_stream(features)           # Shape: (batch, 1)
        advantage = self.advantage_stream(features)   # Shape: (batch, actions)
        
        # ========================================
        # COMBINE: Q = V + A - mean(A)
        # ========================================
        # Subtract mean advantage for identifiability
        q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
        
        return q_values
    
    def get_value_and_advantage(self, state):
        """Helper to see V and A separately (for visualization)."""
        features = self.features(state)
        value = self.value_stream(features)
        advantage = self.advantage_stream(features)
        return value, advantage


# Create and examine Dueling DQN
print("DUELING DQN DEMONSTRATION")
print("="*60)

state_dim = 4
action_dim = 4
dueling = DuelingDQN(state_dim, action_dim)

# Test on sample state
test_state = torch.randn(1, state_dim)

with torch.no_grad():
    q_values = dueling(test_state)
    value, advantage = dueling.get_value_and_advantage(test_state)

print(f"\nFor a sample state:")
print(f"  V(s) = {value.item():.3f} (state value)")
print(f"  A(s, a) = {advantage.numpy()[0].round(3)} (advantage per action)")
print(f"  Q(s, a) = {q_values.numpy()[0].round(3)}")

print(f"\nVerify: Q = V + A - mean(A)")
mean_adv = advantage.mean().item()
reconstructed = value.item() + advantage.numpy()[0] - mean_adv
print(f"  Computed: {reconstructed.round(3)}")
print(f"  Matches!")
print("="*60)

In [None]:
# Visualize the Dueling architecture

fig, axes = plt.subplots(1, 2, figsize=(14, 7))

# Left: Standard DQN architecture
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Standard DQN Architecture', fontsize=14, fontweight='bold')

# Input
input_box = FancyBboxPatch((3.5, 8), 3, 1, boxstyle="round,pad=0.1",
                            facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
ax1.add_patch(input_box)
ax1.text(5, 8.5, 'State Input', ha='center', fontsize=10)

# Hidden layers
for i, y in enumerate([6, 4]):
    box = FancyBboxPatch((3.5, y), 3, 1, boxstyle="round,pad=0.1",
                          facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
    ax1.add_patch(box)
    ax1.text(5, y + 0.5, f'Hidden Layer {i+1}', ha='center', fontsize=10)
    ax1.annotate('', xy=(5, y + 1), xytext=(5, y + 1.9),
                 arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Output
output_box = FancyBboxPatch((3.5, 1.5), 3, 1.5, boxstyle="round,pad=0.1",
                             facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax1.add_patch(output_box)
ax1.text(5, 2.5, 'Q(s,a₁), Q(s,a₂), ...', ha='center', fontsize=10)
ax1.text(5, 2, '(One output per action)', ha='center', fontsize=9, color='#666')
ax1.annotate('', xy=(5, 3.1), xytext=(5, 3.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Right: Dueling DQN architecture
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Dueling DQN Architecture', fontsize=14, fontweight='bold')

# Input
input_box = FancyBboxPatch((3.5, 8), 3, 1, boxstyle="round,pad=0.1",
                            facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
ax2.add_patch(input_box)
ax2.text(5, 8.5, 'State Input', ha='center', fontsize=10)

# Shared layers
shared_box = FancyBboxPatch((3.5, 6), 3, 1, boxstyle="round,pad=0.1",
                             facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
ax2.add_patch(shared_box)
ax2.text(5, 6.5, 'Shared Features', ha='center', fontsize=10)
ax2.annotate('', xy=(5, 7), xytext=(5, 7.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Value stream
value_box = FancyBboxPatch((1, 4), 3, 1.2, boxstyle="round,pad=0.1",
                            facecolor='#e1bee7', edgecolor='#7b1fa2', linewidth=2)
ax2.add_patch(value_box)
ax2.text(2.5, 4.9, 'Value Stream', ha='center', fontsize=10, fontweight='bold')
ax2.text(2.5, 4.3, 'V(s)', ha='center', fontsize=10)

# Advantage stream
adv_box = FancyBboxPatch((6, 4), 3, 1.2, boxstyle="round,pad=0.1",
                          facecolor='#b2dfdb', edgecolor='#00796b', linewidth=2)
ax2.add_patch(adv_box)
ax2.text(7.5, 4.9, 'Advantage Stream', ha='center', fontsize=10, fontweight='bold')
ax2.text(7.5, 4.3, 'A(s,a)', ha='center', fontsize=10)

# Arrows from shared to streams
ax2.annotate('', xy=(2.5, 5.3), xytext=(4, 6),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
ax2.annotate('', xy=(7.5, 5.3), xytext=(6, 6),
             arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

# Combination
combine_box = FancyBboxPatch((3.5, 1.5), 3, 1.5, boxstyle="round,pad=0.1",
                              facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax2.add_patch(combine_box)
ax2.text(5, 2.6, 'Q = V + A - mean(A)', ha='center', fontsize=10, fontweight='bold')
ax2.text(5, 2, 'Q(s,a₁), Q(s,a₂), ...', ha='center', fontsize=9)

# Arrows to combination
ax2.annotate('', xy=(4, 3.1), xytext=(2.5, 3.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#7b1fa2'))
ax2.annotate('', xy=(6, 3.1), xytext=(7.5, 3.9),
             arrowprops=dict(arrowstyle='->', lw=2, color='#00796b'))

plt.tight_layout()
plt.show()

print("\nDUELING ADVANTAGE:")
print("  • Can learn V(s) even when all actions are equally good")
print("  • More efficient learning in states where actions don't matter")
print("  • Better generalization across states")

---
## Other Important Improvements

```
    ┌────────────────────────────────────────────────────────────────┐
    │              OTHER DQN IMPROVEMENTS                            │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  N-STEP RETURNS:                                              │
    │    Instead of: y = r + γ × Q(s')                              │
    │    Use: y = r₀ + γr₁ + γ²r₂ + ... + γⁿQ(sₙ)                  │
    │    → Better credit assignment, less bias                      │
    │                                                                │
    │  NOISY NETWORKS:                                              │
    │    Replace ε-greedy with learned noise in weights             │
    │    w → w + σ × ε where ε ~ N(0,1)                            │
    │    → Smarter exploration, state-dependent uncertainty         │
    │                                                                │
    │  DISTRIBUTIONAL RL (C51):                                     │
    │    Don't predict E[return], predict the full distribution!    │
    │    → Captures uncertainty, more stable learning               │
    │                                                                │
    │  PRIORITIZED REPLAY (covered in notebook 3):                  │
    │    Sample important experiences more often                    │
    │    → More efficient use of experience                         │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Implement N-step returns

def compute_nstep_return(rewards, next_q_value, gamma, n_steps):
    """
    Compute n-step return.
    
    G_t = r_t + γr_{t+1} + γ²r_{t+2} + ... + γⁿQ(s_{t+n})
    
    Benefits:
    - Less bias than 1-step (uses more actual rewards)
    - Less variance than Monte Carlo (still bootstraps)
    
    Args:
        rewards: List of n rewards [r_t, r_{t+1}, ..., r_{t+n-1}]
        next_q_value: Q-value at state n steps ahead
        gamma: Discount factor
        n_steps: Number of steps to look ahead
    """
    n_step_return = 0
    
    # Sum discounted rewards
    for i, r in enumerate(rewards):
        n_step_return += (gamma ** i) * r
    
    # Add bootstrapped value
    n_step_return += (gamma ** n_steps) * next_q_value
    
    return n_step_return


# Demonstrate n-step returns
print("N-STEP RETURNS DEMONSTRATION")
print("="*60)

# Example: 3 steps of experience
rewards = [1, 1, 1]  # Got reward 1 for 3 consecutive steps
next_q = 5.0         # Q-value at state 3 steps ahead
gamma = 0.9

print(f"\nRewards: {rewards}")
print(f"Q(s_3): {next_q}")
print(f"Gamma: {gamma}")

print("\nDifferent n-step returns:")
for n in [1, 2, 3]:
    ret = compute_nstep_return(rewards[:n], next_q, gamma, n)
    formula = ' + '.join([f'{gamma}^{i}×{r}' for i, r in enumerate(rewards[:n])])
    formula += f' + {gamma}^{n}×{next_q}'
    print(f"  n={n}: {ret:.3f}  ({formula})")

print("\n" + "="*60)
print("More steps = less bias (more real rewards) but more variance")
print("Rainbow uses n=3 as a good balance!")
print("="*60)

In [None]:
class NoisyLinear(nn.Module):
    """
    Noisy Linear Layer for exploration.
    
    Instead of ε-greedy (random with probability ε),
    we add learned noise to the weights.
    
    w = μ_w + σ_w * ε  where ε ~ N(0, 1)
    
    Benefits:
    - Exploration is state-dependent (not random)
    - Network learns when to explore
    - No need to tune ε schedule
    """
    
    def __init__(self, in_features, out_features, sigma_init=0.5):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Learnable parameters
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        
        # Initialize
        mu_range = 1 / np.sqrt(in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(sigma_init / np.sqrt(in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(sigma_init / np.sqrt(out_features))
        
        # Register buffer for noise (not a parameter)
        self.register_buffer('weight_epsilon', torch.zeros(out_features, in_features))
        self.register_buffer('bias_epsilon', torch.zeros(out_features))
        
        self.reset_noise()
    
    def reset_noise(self):
        """Sample new noise."""
        self.weight_epsilon.normal_()
        self.bias_epsilon.normal_()
    
    def forward(self, x):
        """Apply noisy linear transformation."""
        if self.training:
            # Use noisy weights during training
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            # Use mean weights during evaluation (greedy)
            weight = self.weight_mu
            bias = self.bias_mu
        
        return F.linear(x, weight, bias)


# Demonstrate Noisy Networks
print("NOISY NETWORKS DEMONSTRATION")
print("="*60)

noisy_layer = NoisyLinear(4, 4)
test_input = torch.randn(1, 4)

print("\nSame input, different noise:")
noisy_layer.train()  # Training mode = use noise
for i in range(3):
    noisy_layer.reset_noise()  # New noise
    output = noisy_layer(test_input)
    print(f"  Sample {i+1}: {output.detach().numpy()[0].round(3)}")

print("\nDuring evaluation (no noise):")
noisy_layer.eval()  # Eval mode = no noise
output_eval = noisy_layer(test_input)
print(f"  Consistent: {output_eval.detach().numpy()[0].round(3)}")

print("\n" + "="*60)
print("NOISY NETS vs ε-GREEDY:")
print("  ε-greedy: Random action with probability ε (state-blind)")
print("  Noisy nets: Uncertainty is STATE-DEPENDENT (smarter!)")
print("="*60)

---
## Rainbow DQN: All Together

Rainbow combines ALL the improvements for maximum performance!

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    RAINBOW DQN (2017)                          │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Rainbow = DQN + ALL improvements:                            │
    │                                                                │
    │    1. ✓ Double DQN - Reduce overestimation                    │
    │    2. ✓ Dueling architecture - V/A separation                 │
    │    3. ✓ Prioritized Experience Replay - Sample efficiently    │
    │    4. ✓ N-step returns - Better credit assignment             │
    │    5. ✓ Distributional RL (C51) - Learn return distribution   │
    │    6. ✓ Noisy Networks - Smarter exploration                  │
    │                                                                │
    │  PERFORMANCE:                                                 │
    │    • 2x median score of any single improvement                │
    │    • State-of-the-art on Atari in 2017                        │
    │                                                                │
    │  FOR PRODUCTION: Use Stable-Baselines3 or CleanRL!            │
    │    Don't implement from scratch unless learning.              │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize Rainbow components and their contributions

fig, ax = plt.subplots(figsize=(12, 7))

components = [
    'DQN (baseline)',
    '+ Double DQN',
    '+ Prioritized Replay',
    '+ Dueling',
    '+ N-step',
    '+ Distributional',
    '+ Noisy Nets',
    'Rainbow (all)'
]

# Approximate relative improvements (illustrative)
scores = [100, 115, 135, 150, 170, 200, 210, 250]

colors = ['#bbdefb', '#c8e6c9', '#e1bee7', '#fff3e0', '#b2dfdb', '#f8bbd9', '#ffe0b2', '#ff8a65']

bars = ax.barh(components, scores, color=colors, edgecolor='black', linewidth=2)

# Add score labels
for bar, score in zip(bars, scores):
    ax.text(bar.get_width() + 5, bar.get_y() + bar.get_height()/2, 
            f'{score}%', va='center', fontsize=11, fontweight='bold')

ax.set_xlabel('Relative Performance (DQN = 100%)', fontsize=12)
ax.set_title('Rainbow DQN: Each Component Adds Performance', fontsize=14, fontweight='bold')
ax.set_xlim(0, 280)
ax.grid(True, alpha=0.3, axis='x')

# Highlight Rainbow
bars[-1].set_edgecolor('#d32f2f')
bars[-1].set_linewidth(3)

plt.tight_layout()
plt.show()

print("\nKEY INSIGHT:")
print("  • Each improvement helps on its own")
print("  • Combined, they're even better (synergy!)")
print("  • Rainbow = 2.5x baseline DQN performance")

---
## Summary: Key Takeaways

### The Problems and Solutions

| Problem | Solution | Key Idea |
|---------|----------|----------|
| Overestimation | Double DQN | Separate selection and evaluation |
| Efficiency | Dueling DQN | Learn V(s) and A(s,a) separately |
| Credit assignment | N-step | Use more real rewards |
| Exploration | Noisy Nets | Learned, state-dependent noise |
| Sample efficiency | PER | Prioritize surprising experiences |

### Double DQN Formula

```
a* = argmax Q(s', a'; θ)      ← Q-network SELECTS
target = r + γ × Q(s', a*; θ⁻)  ← Target EVALUATES
```

### Dueling DQN Formula

```
Q(s, a) = V(s) + A(s, a) - mean(A(s, .))
```

---
## Test Your Understanding

**1. Why does DQN overestimate Q-values?**
<details>
<summary>Click to reveal answer</summary>
The max operator always picks the action with the highest Q-value, but Q-values have estimation noise. Taking the max of noisy estimates gives a biased result - we're likely picking the action that got lucky with positive noise, not necessarily the truly best action. E[max(X + noise)] > max(X).
</details>

**2. How does Double DQN fix overestimation?**
<details>
<summary>Click to reveal answer</summary>
Double DQN decouples action selection from action evaluation:
1. The Q-network selects the best action: a* = argmax Q(s', a'; θ)
2. The target network evaluates that action: Q(s', a*; θ⁻)

Since the two networks have different noise, even if the Q-network picks a lucky action, the target network's evaluation of that action is unbiased.
</details>

**3. What are V(s) and A(s,a) in Dueling DQN?**
<details>
<summary>Click to reveal answer</summary>
- V(s) is the state value: "How good is this state overall, regardless of action?"
- A(s,a) is the advantage: "How much better is this specific action than the average action in this state?"

Q(s,a) = V(s) + A(s,a) - mean(A). This lets the network learn state value separately from action preferences, which is more efficient when actions don't matter much.
</details>

**4. Why subtract mean(A) in Dueling DQN?**
<details>
<summary>Click to reveal answer</summary>
For identifiability! Without subtracting the mean, we can't uniquely determine V(s) and A(s,a) - you could add any constant to V and subtract it from A without changing Q. By forcing A to have mean 0, we ensure V(s) is uniquely identifiable as the expected value over actions.
</details>

**5. What's the advantage of Noisy Networks over ε-greedy?**
<details>
<summary>Click to reveal answer</summary>
Noisy Networks provide state-dependent exploration:
- ε-greedy: Random action with fixed probability (same everywhere)
- Noisy Nets: Uncertainty depends on the state and is learned

The network can learn to be uncertain in new states and confident in familiar ones. No need to manually tune ε schedule!
</details>

---
## What's Next?

You now understand all the major DQN improvements!

In the final notebook of this section, we'll see these techniques in action on Atari games.

**Continue to:** [Notebook 6: Atari Games](06_atari_games.ipynb)

---

*Rainbow: "Why choose one improvement when you can have them all?"*