# Part 6.2: Q-Learning and Deep Q-Networks — The Formula 1 Edition

In Notebook 21, we learned the RL framework and solved small MDPs with dynamic programming — but those methods required knowing the environment's transition dynamics. Real-world agents don't have that luxury.

**Q-learning** changed everything in 1989: it learns optimal behavior purely from experience, without a model of the environment. Then in 2013, DeepMind's **DQN** scaled this idea with neural networks, learning to play Atari games from raw pixels — a landmark achievement that reignited interest in deep RL.

**The F1 Connection:** Q-learning is how a race strategist would learn optimal pit stop timing *without* a lap-time simulator — purely from racing experience. Each race provides data: "We pitted on lap 22 from P3 with worn mediums, and the outcome was a P2 finish." Over hundreds of races, the Q-table converges: Q(P3_worn_mediums_lap22, pit_now) = 0.85. DQN scales this to high-dimensional states — a deep network that takes in the full telemetry snapshot (position, gaps, tire temps, fuel load, weather forecast) and outputs the value of every possible strategic action. Modern F1 teams use exactly this kind of neural network-based strategy tool, trained on thousands of simulated races.

## Learning Objectives

- [ ] Implement tabular Q-learning from scratch and understand its convergence guarantees
- [ ] Distinguish between on-policy (SARSA) and off-policy (Q-learning) methods
- [ ] Understand why naive function approximation with neural networks fails in RL
- [ ] Implement experience replay and understand why it's critical
- [ ] Implement target networks and understand how they stabilize training
- [ ] Build a complete DQN from scratch in PyTorch
- [ ] Train a DQN agent on a control task
- [ ] Understand Double DQN and why vanilla DQN overestimates Q-values
- [ ] Recognize the limitations that motivate policy gradient methods

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from collections import defaultdict, deque, namedtuple
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# For reproducibility
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)

print("Part 6.2: Q-Learning and Deep Q-Networks")
print("=" * 50)

---

## 1. From TD Learning to Q-Learning

Recall from Notebook 21 that TD(0) learns the value function $V(s)$. But to act optimally, we need $Q(s,a)$ — the value of taking action $a$ in state $s$.

### SARSA: On-Policy TD Control

**SARSA** updates Q-values using the action the agent *actually takes* next:

$$Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[r_t + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t)\right]$$

The name comes from the quintuple: $(S_t, A_t, R_t, S_{t+1}, A_{t+1})$.

### Q-Learning: Off-Policy TD Control

**Q-learning** updates using the *best possible* next action, regardless of what the agent actually does:

$$Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[r_t + \gamma \max_{a'} Q(s_{t+1}, a') - Q(s_t, a_t)\right]$$

This is **off-policy** because it learns about the optimal policy while following an exploratory policy.

**F1 analogy:** SARSA is like a cautious strategist who evaluates strategies based on what the driver *actually does* — including mistakes. If the driver sometimes misses the pit entry (exploration), SARSA learns cautious values that account for that. Q-learning is like an idealist strategist who evaluates every decision assuming *perfect execution going forward*. "If we pit now and the driver nails every remaining lap, we finish P2." Q-learning learns the optimal strategy regardless of how sloppy the current execution is.

| Property | SARSA (On-Policy) | Q-Learning (Off-Policy) | F1 Analogy |
|----------|------------------|------------------------|------------|
| **Update target** | $r + \gamma Q(s', a')$ | $r + \gamma \max_{a'} Q(s', a')$ | Actual outcome vs. best possible outcome |
| **Learns about** | The policy being followed | The optimal policy | Strategy with driver errors vs. ideal strategy |
| **Exploration impact** | Exploration affects learned values | Learns optimal regardless of exploration | Conservative near danger vs. optimistic everywhere |
| **Safety** | More conservative (accounts for exploration) | Can be overoptimistic | SARSA avoids risky pit entries; Q-learning assumes perfect execution |

### Implementing the Gridworld (from Notebook 21)

In [None]:
class GridWorld:
    """Gridworld environment from Notebook 21."""
    EMPTY, WALL, GOAL, TRAP = 0, 1, 2, 3
    ACTIONS = ['up', 'down', 'left', 'right']
    ACTION_DELTAS = {'up': (-1,0), 'down': (1,0), 'left': (0,-1), 'right': (0,1)}
    
    def __init__(self, grid_size=4, slip_prob=0.1):
        self.grid_size = grid_size
        self.slip_prob = slip_prob
        self.grid = np.zeros((grid_size, grid_size), dtype=int)
        self.grid[0, 3] = self.GOAL
        self.grid[1, 1] = self.WALL
        self.grid[2, 3] = self.TRAP
        self.states = [(i,j) for i in range(grid_size) for j in range(grid_size)
                       if self.grid[i,j] != self.WALL]
        self.terminal_states = [(i,j) for i in range(grid_size) for j in range(grid_size)
                                if self.grid[i,j] in [self.GOAL, self.TRAP]]
        self.start = (3, 0)
        self.agent_pos = self.start
    
    def reset(self):
        self.agent_pos = self.start
        return self.agent_pos
    
    def _is_valid(self, pos):
        r, c = pos
        return (0 <= r < self.grid_size and 0 <= c < self.grid_size
                and self.grid[r,c] != self.WALL)
    
    def step(self, action):
        if self.agent_pos in self.terminal_states:
            return self.agent_pos, 0.0, True
        
        # Stochastic transitions
        if np.random.random() < self.slip_prob:
            perp = ['left','right'] if action in ['up','down'] else ['up','down']
            action = np.random.choice(perp)
        
        dr, dc = self.ACTION_DELTAS[action]
        new_pos = (self.agent_pos[0] + dr, self.agent_pos[1] + dc)
        if self._is_valid(new_pos):
            self.agent_pos = new_pos
        
        # Rewards
        cell = self.grid[self.agent_pos[0], self.agent_pos[1]]
        if cell == self.GOAL:
            return self.agent_pos, 1.0, True
        elif cell == self.TRAP:
            return self.agent_pos, -1.0, True
        else:
            return self.agent_pos, -0.04, False


env = GridWorld()
print(f"GridWorld: {env.grid_size}×{env.grid_size}, {len(env.states)} states, {len(env.ACTIONS)} actions")

---

## 2. Tabular Q-Learning

Let's implement Q-learning with a lookup table — one entry for every (state, action) pair.

In [None]:
def q_learning(env, n_episodes=10000, alpha=0.1, gamma=0.9, 
               epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
    """Tabular Q-learning with epsilon-greedy exploration."""
    Q = defaultdict(float)  # Q[(state, action)] -> value
    
    episode_rewards = []
    episode_lengths = []
    epsilon = epsilon_start
    
    for episode in range(n_episodes):
        state = env.reset()
        total_reward = 0
        steps = 0
        
        for _ in range(200):  # Max steps per episode
            # Epsilon-greedy action selection
            if np.random.random() < epsilon:
                action = np.random.choice(env.ACTIONS)
            else:
                q_values = [Q[(state, a)] for a in env.ACTIONS]
                action = env.ACTIONS[np.argmax(q_values)]
            
            next_state, reward, done = env.step(action)
            
            # Q-learning update: use MAX over next actions (off-policy)
            best_next_q = max(Q[(next_state, a)] for a in env.ACTIONS)
            td_target = reward + gamma * best_next_q * (1 - done)
            td_error = td_target - Q[(state, action)]
            Q[(state, action)] += alpha * td_error
            
            total_reward += reward
            steps += 1
            state = next_state
            
            if done:
                break
        
        episode_rewards.append(total_reward)
        episode_lengths.append(steps)
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
    
    return dict(Q), episode_rewards, episode_lengths


def sarsa(env, n_episodes=10000, alpha=0.1, gamma=0.9,
          epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
    """SARSA: on-policy TD control."""
    Q = defaultdict(float)
    episode_rewards = []
    epsilon = epsilon_start
    
    for episode in range(n_episodes):
        state = env.reset()
        # Choose initial action
        if np.random.random() < epsilon:
            action = np.random.choice(env.ACTIONS)
        else:
            q_values = [Q[(state, a)] for a in env.ACTIONS]
            action = env.ACTIONS[np.argmax(q_values)]
        
        total_reward = 0
        
        for _ in range(200):
            next_state, reward, done = env.step(action)
            
            # Choose next action (for SARSA update)
            if np.random.random() < epsilon:
                next_action = np.random.choice(env.ACTIONS)
            else:
                q_values = [Q[(next_state, a)] for a in env.ACTIONS]
                next_action = env.ACTIONS[np.argmax(q_values)]
            
            # SARSA update: use the ACTUAL next action (on-policy)
            td_target = reward + gamma * Q[(next_state, next_action)] * (1 - done)
            td_error = td_target - Q[(state, action)]
            Q[(state, action)] += alpha * td_error
            
            total_reward += reward
            state = next_state
            action = next_action
            
            if done:
                break
        
        episode_rewards.append(total_reward)
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
    
    return dict(Q), episode_rewards


# Train both
Q_ql, rewards_ql, lengths_ql = q_learning(env, n_episodes=10000)
Q_sarsa, rewards_sarsa = sarsa(env, n_episodes=10000)

print("Training complete!")
print(f"Q-learning: avg reward (last 100) = {np.mean(rewards_ql[-100:]):.3f}")
print(f"SARSA:      avg reward (last 100) = {np.mean(rewards_sarsa[-100:]):.3f}")

### Visualization: Q-Learning vs. SARSA

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Learning curves
window = 100
ql_smooth = np.convolve(rewards_ql, np.ones(window)/window, mode='valid')
sarsa_smooth = np.convolve(rewards_sarsa, np.ones(window)/window, mode='valid')

axes[0].plot(ql_smooth, label='Q-Learning (off-policy)', color='#3498db', linewidth=2)
axes[0].plot(sarsa_smooth, label='SARSA (on-policy)', color='#e74c3c', linewidth=2)
axes[0].set_xlabel('Episode', fontsize=12)
axes[0].set_ylabel('Average Reward (100-ep window)', fontsize=12)
axes[0].set_title('Q-Learning vs. SARSA', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Right: Learned Q-values heatmap for Q-learning
n = env.grid_size
max_q_grid = np.full((n, n), np.nan)
best_action_grid = {}

for s in env.states:
    if s not in env.terminal_states:
        q_vals = [Q_ql.get((s, a), 0) for a in env.ACTIONS]
        max_q_grid[s[0], s[1]] = max(q_vals)
        best_action_grid[s] = env.ACTIONS[np.argmax(q_vals)]
    else:
        if env.grid[s[0], s[1]] == GridWorld.GOAL:
            max_q_grid[s[0], s[1]] = 1.0
        else:
            max_q_grid[s[0], s[1]] = -1.0

im = axes[1].imshow(max_q_grid, cmap='RdYlGn', vmin=-1, vmax=1)
for i in range(n):
    for j in range(n):
        if not np.isnan(max_q_grid[i, j]):
            axes[1].text(j, i, f'{max_q_grid[i,j]:.2f}', ha='center', va='center',
                        fontsize=10, fontweight='bold')
            if (i, j) in best_action_grid:
                arrow_map = {'up': '↑', 'down': '↓', 'left': '←', 'right': '→'}
                axes[1].text(j, i + 0.3, arrow_map[best_action_grid[(i,j)]],
                           ha='center', va='center', fontsize=14, color='#2c3e50')
        elif env.grid[i,j] == GridWorld.WALL:
            axes[1].text(j, i, 'W', ha='center', va='center', fontsize=12, fontweight='bold')

axes[1].set_title('Learned Q-Values (max over actions)', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=axes[1], shrink=0.8)

plt.tight_layout()
plt.show()

### Deep Dive: On-Policy vs. Off-Policy

The difference matters in practice:

- **SARSA** learns values that account for its own exploration. Near the trap, SARSA learns cautious values because it knows *it might accidentally step into the trap* during exploration.
- **Q-learning** learns the value of the *optimal* policy, assuming perfect action selection afterward. It's more optimistic near dangerous states.

This is why SARSA is sometimes called "safer" — it produces policies that account for the agent's own imperfections.

**F1 analogy:** Near the end of a wet race at Spa, SARSA is the strategist who says "Stay on inters — our driver sometimes brakes too late on slicks in the wet." Q-learning is the strategist who says "Switch to slicks — the optimal driver would gain 2 seconds per lap." Both are correct in their own frame. If your driver is Max Verstappen, Q-learning's optimism is justified. If it's a rookie, SARSA's caution keeps you in the points.

---

## 3. The Challenge of Scaling: Why Tables Aren't Enough

Tabular Q-learning works beautifully for small problems. But consider:

| Problem | State Space Size | F1 Equivalent |
|---------|------------------|---------------|
| 4x4 Gridworld | 15 states | 15 race scenarios on a napkin |
| Tic-Tac-Toe | ~5,000 states | Simple pit window calculator |
| Chess | ~10^47 states | — |
| Atari (pixel input) | 256^(210x160x3) states | — |
| F1 Race Strategy | Continuous: position x gaps x tire_age x compound x fuel x weather x ... | The real problem — infinite states |

We need **function approximation** — use a neural network to estimate $Q(s, a; \theta)$ instead of storing a table. This is exactly what F1 teams do: they can't store a Q-value for every possible race situation, so they train neural networks on simulated race data to generalize across states.

### The Deadly Triad

Simply plugging a neural network into Q-learning doesn't work. Three factors combine to cause instability:

1. **Function approximation**: Neural networks generalize across states (a change in one Q-value affects others)
2. **Bootstrapping**: TD updates use the network's own predictions as targets
3. **Off-policy learning**: Training on data from a different policy than we're evaluating

Any two of these are fine. All three together cause divergence. DQN's key insight was solving this with two techniques: **experience replay** and **target networks**.

### Visualization: Why Naive Deep Q-Learning Fails

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))

# Problem 1: Correlated data
ax = axes[0]
ax.set_title('Problem 1: Correlated Data', fontsize=12, fontweight='bold')
t = np.arange(100)
# Simulate sequential experiences from one trajectory
trajectory = np.cumsum(np.random.randn(100) * 0.3)
ax.plot(t, trajectory, 'b-', linewidth=2, label='Sequential experience')
# Random samples from different trajectories
random_samples = np.random.randn(100) * 2
ax.scatter(t[::5], random_samples[::5], c='red', s=30, zorder=5, label='Random replay')
ax.set_xlabel('Time step')
ax.set_ylabel('Experience')
ax.legend(fontsize=9)
ax.text(50, -4, 'Sequential data\nis highly correlated', ha='center', fontsize=9,
        style='italic', color='gray')
ax.grid(True, alpha=0.3)

# Problem 2: Moving target
ax = axes[1]
ax.set_title('Problem 2: Moving Target', fontsize=12, fontweight='bold')
steps = np.arange(50)
# Target keeps moving as network updates
predictions = np.zeros(50)
targets = np.zeros(50)
pred, target = 0.0, 1.0
for i in range(50):
    predictions[i] = pred
    targets[i] = target
    pred += 0.15 * (target - pred)  # Chase the target
    target += 0.1 * np.sin(i * 0.3)  # Target also moves

ax.plot(steps, targets, 'r--', linewidth=2, label='Moving target')
ax.plot(steps, predictions, 'b-', linewidth=2, label='Network prediction')
ax.fill_between(steps, predictions, targets, alpha=0.15, color='purple')
ax.set_xlabel('Training step')
ax.set_ylabel('Q-value')
ax.legend(fontsize=9)
ax.text(25, 0.3, 'Target shifts as\nnetwork updates', ha='center', fontsize=9,
        style='italic', color='gray')
ax.grid(True, alpha=0.3)

# Problem 3: Catastrophic forgetting
ax = axes[2]
ax.set_title('Problem 3: Catastrophic Forgetting', fontsize=12, fontweight='bold')
regions = ['Region A\n(visited early)', 'Region B\n(visited now)', 'Region C\n(visited later)']
accuracy_before = [0.9, 0.3, 0.1]
accuracy_after = [0.3, 0.9, 0.2]
x = np.arange(3)
ax.bar(x - 0.15, accuracy_before, 0.3, label='Before training on B', color='#3498db', alpha=0.8)
ax.bar(x + 0.15, accuracy_after, 0.3, label='After training on B', color='#e74c3c', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(regions, fontsize=9)
ax.set_ylabel('Q-value accuracy')
ax.legend(fontsize=9)
ax.text(1, 0.05, 'Learning B destroys A', ha='center', fontsize=9,
        style='italic', color='gray')
ax.grid(True, alpha=0.3, axis='y')

plt.suptitle('Three Problems with Naive Deep Q-Learning', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

---

## 4. Experience Replay

**Experience replay** stores transitions $(s, a, r, s', \text{done})$ in a buffer and samples random mini-batches for training. This solves two problems:

1. **Breaks correlations**: Random sampling from the buffer produces i.i.d.-like training data
2. **Data efficiency**: Each experience can be reused many times
3. **Prevents forgetting**: Old experiences are revisited during training

**F1 analogy:** Without experience replay, the strategy model would only learn from the *most recent race* — and consecutive laps within that race are highly correlated (similar tire state, similar gaps). That's like training your strategy only on Monza data and forgetting everything about Monaco. Experience replay is like the team's historical race database: during training, you randomly sample from Silverstone 2019, Spa 2022, Suzuka 2023. The randomness breaks the correlations, and revisiting old races prevents catastrophic forgetting of track-specific knowledge.

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))


class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""
    
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        """Add a transition to the buffer."""
        self.buffer.append(Transition(state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        """Sample a random batch of transitions."""
        transitions = random.sample(self.buffer, batch_size)
        batch = Transition(*zip(*transitions))
        return batch
    
    def __len__(self):
        return len(self.buffer)


# Demonstrate replay buffer
buffer = ReplayBuffer(capacity=1000)

# Fill with some random experience
state = env.reset()
for _ in range(500):
    action = np.random.choice(env.ACTIONS)
    next_state, reward, done = env.step(action)
    buffer.push(state, env.ACTIONS.index(action), reward, next_state, done)
    state = next_state if not done else env.reset()

# Sample a batch
batch = buffer.sample(4)
print(f"Buffer size: {len(buffer)}")
print(f"\nSample batch of 4 transitions:")
for i in range(4):
    print(f"  s={batch.state[i]}, a={env.ACTIONS[batch.action[i]]}, "
          f"r={batch.reward[i]:.2f}, s'={batch.next_state[i]}, done={batch.done[i]}")

---

## 5. Target Networks

The second DQN innovation: use a **separate, slowly-updated copy** of the Q-network to compute TD targets.

Instead of:
$$\text{target} = r + \gamma \max_{a'} Q(s', a'; \theta) \quad \text{(same network — moving target!)}$$

We use:
$$\text{target} = r + \gamma \max_{a'} Q(s', a'; \theta^{-}) \quad \text{(frozen target network)}$$

The target network $\theta^{-}$ is updated periodically by copying the main network's weights. This stabilizes training by keeping the target fixed for multiple updates.

**Intuition**: Imagine trying to hit a target that moves every time you adjust your aim. By freezing the target periodically, you can make progress before it shifts again.

**F1 analogy:** Without a target network, it's like a strategist who changes their "benchmark lap time" every time they recalculate — the goal keeps moving. With a target network, the benchmark stays fixed for, say, 5 race weekends. The strategist optimizes pit stop timing against that stable benchmark, then updates the benchmark. It's the difference between chasing a moving goalposts and methodically improving against a fixed standard.

---

## 6. Building DQN from Scratch

Now let's put it all together. We'll build a DQN to solve a custom control task — a cart that needs to balance a pole (a simplified version of the classic CartPole problem).

In [None]:
class CartPoleSimple:
    """Simplified CartPole environment (no external dependencies needed).
    
    State: [cart_position, cart_velocity, pole_angle, pole_angular_velocity]
    Actions: 0 (push left) or 1 (push right)
    """
    
    def __init__(self):
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masscart + self.masspole
        self.length = 0.5  # Half the pole length
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # Time step
        
        # Failure thresholds
        self.x_threshold = 2.4
        self.theta_threshold = 12 * np.pi / 180  # 12 degrees
        
        self.state_dim = 4
        self.n_actions = 2
        self.state = None
    
    def reset(self):
        """Reset to random state near center."""
        self.state = np.random.uniform(-0.05, 0.05, size=4)
        return self.state.copy()
    
    def step(self, action):
        """Simulate one step of physics."""
        x, x_dot, theta, theta_dot = self.state
        force = self.force_mag if action == 1 else -self.force_mag
        
        cos_theta = np.cos(theta)
        sin_theta = np.sin(theta)
        
        # Physics equations
        temp = (force + self.polemass_length * theta_dot**2 * sin_theta) / self.total_mass
        theta_acc = (self.gravity * sin_theta - cos_theta * temp) / (
            self.length * (4.0/3.0 - self.masspole * cos_theta**2 / self.total_mass))
        x_acc = temp - self.polemass_length * theta_acc * cos_theta / self.total_mass
        
        # Euler integration
        x += self.tau * x_dot
        x_dot += self.tau * x_acc
        theta += self.tau * theta_dot
        theta_dot += self.tau * theta_acc
        
        self.state = np.array([x, x_dot, theta, theta_dot])
        
        # Check termination
        done = (abs(x) > self.x_threshold or abs(theta) > self.theta_threshold)
        reward = 1.0 if not done else 0.0
        
        return self.state.copy(), reward, done


# Test the environment
cart_env = CartPoleSimple()
state = cart_env.reset()
print(f"CartPole state dim: {cart_env.state_dim}")
print(f"Actions: {cart_env.n_actions} (left, right)")
print(f"Initial state: {state}")

# Random agent baseline
episode_lengths = []
for _ in range(100):
    state = cart_env.reset()
    length = 0
    for _ in range(500):
        action = np.random.randint(2)
        state, reward, done = cart_env.step(action)
        length += 1
        if done:
            break
    episode_lengths.append(length)

print(f"\nRandom agent: avg episode length = {np.mean(episode_lengths):.1f} steps")
print("(Goal: balance for 200+ steps)")

In [None]:
class QNetwork(nn.Module):
    """Neural network that estimates Q(s,a) for all actions."""
    
    def __init__(self, state_dim, n_actions, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
    
    def forward(self, x):
        return self.net(x)


class DQNAgent:
    """Deep Q-Network agent with experience replay and target network."""
    
    def __init__(self, state_dim, n_actions, hidden_dim=64, lr=1e-3,
                 gamma=0.99, buffer_size=10000, batch_size=64,
                 target_update_freq=100, epsilon_start=1.0,
                 epsilon_end=0.01, epsilon_decay=0.995):
        
        self.n_actions = n_actions
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update_freq = target_update_freq
        
        # Exploration
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        
        # Networks
        self.q_network = QNetwork(state_dim, n_actions, hidden_dim)
        self.target_network = QNetwork(state_dim, n_actions, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())  # Copy weights
        
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(buffer_size)
        self.steps_done = 0
    
    def select_action(self, state):
        """Epsilon-greedy action selection."""
        if np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_network(state_tensor)
            return q_values.argmax(dim=1).item()
    
    def update(self):
        """Perform one step of DQN training."""
        if len(self.replay_buffer) < self.batch_size:
            return None
        
        # Sample batch from replay buffer
        batch = self.replay_buffer.sample(self.batch_size)
        
        states = torch.FloatTensor(np.array(batch.state))
        actions = torch.LongTensor(batch.action)
        rewards = torch.FloatTensor(batch.reward)
        next_states = torch.FloatTensor(np.array(batch.next_state))
        dones = torch.FloatTensor(batch.done)
        
        # Current Q-values: Q(s, a) for the actions we actually took
        current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # Target Q-values: r + γ max_a' Q_target(s', a')
        with torch.no_grad():
            next_q = self.target_network(next_states).max(dim=1)[0]
            target_q = rewards + self.gamma * next_q * (1 - dones)
        
        # Loss: MSE between current and target Q-values
        loss = F.mse_loss(current_q, target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()
        
        self.steps_done += 1
        
        # Periodically update target network
        if self.steps_done % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        
        return loss.item()
    
    def decay_epsilon(self):
        """Decay exploration rate."""
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)


print("DQN agent architecture:")
agent = DQNAgent(state_dim=4, n_actions=2)
print(agent.q_network)
total_params = sum(p.numel() for p in agent.q_network.parameters())
print(f"\nTotal parameters: {total_params:,}")

### Training the DQN Agent

In [None]:
def train_dqn(env, agent, n_episodes=500, max_steps=500):
    """Train DQN agent on the environment."""
    episode_rewards = []
    episode_lengths = []
    losses = []
    epsilons = []
    
    for episode in range(n_episodes):
        state = env.reset()
        total_reward = 0
        episode_loss = []
        
        for step in range(max_steps):
            action = agent.select_action(state)
            next_state, reward, done = env.step(action)
            
            # Store transition
            agent.replay_buffer.push(state, action, reward, next_state, float(done))
            
            # Train
            loss = agent.update()
            if loss is not None:
                episode_loss.append(loss)
            
            total_reward += reward
            state = next_state
            
            if done:
                break
        
        agent.decay_epsilon()
        episode_rewards.append(total_reward)
        episode_lengths.append(step + 1)
        epsilons.append(agent.epsilon)
        if episode_loss:
            losses.append(np.mean(episode_loss))
        
        if (episode + 1) % 50 == 0:
            avg_reward = np.mean(episode_rewards[-50:])
            avg_length = np.mean(episode_lengths[-50:])
            print(f"Episode {episode+1:4d} | Avg Reward: {avg_reward:6.1f} | "
                  f"Avg Length: {avg_length:5.1f} | ε: {agent.epsilon:.3f}")
    
    return episode_rewards, episode_lengths, losses, epsilons


# Train!
cart_env = CartPoleSimple()
agent = DQNAgent(state_dim=4, n_actions=2, lr=1e-3, gamma=0.99,
                 buffer_size=10000, batch_size=64, target_update_freq=100)

rewards, lengths, losses, epsilons = train_dqn(cart_env, agent, n_episodes=500)

### Visualization: DQN Training Progress

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Episode lengths (performance)
ax = axes[0, 0]
ax.plot(lengths, alpha=0.3, color='#3498db')
window = 20
if len(lengths) >= window:
    smoothed = np.convolve(lengths, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(lengths)), smoothed, color='#2c3e50', linewidth=2,
            label=f'{window}-ep moving avg')
ax.axhline(y=200, color='red', linestyle='--', label='Goal (200 steps)')
ax.set_xlabel('Episode')
ax.set_ylabel('Episode Length')
ax.set_title('DQN Training: Episode Length', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Episode rewards
ax = axes[0, 1]
ax.plot(rewards, alpha=0.3, color='#2ecc71')
if len(rewards) >= window:
    smoothed_r = np.convolve(rewards, np.ones(window)/window, mode='valid')
    ax.plot(range(window-1, len(rewards)), smoothed_r, color='#27ae60', linewidth=2,
            label=f'{window}-ep moving avg')
ax.set_xlabel('Episode')
ax.set_ylabel('Total Reward')
ax.set_title('DQN Training: Reward', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Training loss
ax = axes[1, 0]
if losses:
    ax.plot(losses, alpha=0.5, color='#e74c3c')
    if len(losses) >= window:
        smoothed_l = np.convolve(losses, np.ones(window)/window, mode='valid')
        ax.plot(range(window-1, len(losses)), smoothed_l, color='#c0392b', linewidth=2)
ax.set_xlabel('Episode')
ax.set_ylabel('Loss')
ax.set_title('DQN Training: Loss', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# Epsilon schedule
ax = axes[1, 1]
ax.plot(epsilons, color='#9b59b6', linewidth=2)
ax.set_xlabel('Episode')
ax.set_ylabel('Epsilon')
ax.set_title('Exploration Rate (ε) Decay', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.suptitle('DQN Agent Training on CartPole', fontsize=15, fontweight='bold', y=1.01)
plt.tight_layout()
plt.show()

# Final evaluation
agent.epsilon = 0  # No exploration
eval_lengths = []
for _ in range(20):
    state = cart_env.reset()
    length = 0
    for _ in range(500):
        action = agent.select_action(state)
        state, _, done = cart_env.step(action)
        length += 1
        if done:
            break
    eval_lengths.append(length)

print(f"\nFinal evaluation (20 episodes, no exploration):")
print(f"  Average length: {np.mean(eval_lengths):.1f} ± {np.std(eval_lengths):.1f}")
print(f"  Min: {min(eval_lengths)}, Max: {max(eval_lengths)}")

---

## 7. Ablation Study: Why Each Component Matters

Let's prove that experience replay and target networks are both necessary by removing them one at a time.

In [None]:
class DQNNoReplay(DQNAgent):
    """DQN without experience replay — trains on most recent transition only."""
    def update(self):
        if len(self.replay_buffer) < 1:
            return None
        
        # Use only the LAST transition (no replay)
        last = self.replay_buffer.buffer[-1]
        
        state = torch.FloatTensor(np.array(last.state)).unsqueeze(0)
        action = torch.LongTensor([last.action])
        reward = torch.FloatTensor([last.reward])
        next_state = torch.FloatTensor(np.array(last.next_state)).unsqueeze(0)
        done = torch.FloatTensor([last.done])
        
        current_q = self.q_network(state).gather(1, action.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_q = self.target_network(next_state).max(dim=1)[0]
            target_q = reward + self.gamma * next_q * (1 - done)
        
        loss = F.mse_loss(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.steps_done += 1
        if self.steps_done % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        return loss.item()


class DQNNoTarget(DQNAgent):
    """DQN without target network — uses same network for targets."""
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return None
        
        batch = self.replay_buffer.sample(self.batch_size)
        states = torch.FloatTensor(np.array(batch.state))
        actions = torch.LongTensor(batch.action)
        rewards = torch.FloatTensor(batch.reward)
        next_states = torch.FloatTensor(np.array(batch.next_state))
        dones = torch.FloatTensor(batch.done)
        
        current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        # Uses SAME network for target (no target network)
        with torch.no_grad():
            next_q = self.q_network(next_states).max(dim=1)[0]
            target_q = rewards + self.gamma * next_q * (1 - dones)
        
        loss = F.mse_loss(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()
        self.steps_done += 1
        return loss.item()


# Run ablations (shorter training for speed)
n_ep = 300
print("Training full DQN...")
agent_full = DQNAgent(4, 2)
r_full, l_full, _, _ = train_dqn(CartPoleSimple(), agent_full, n_ep)

print("\nTraining DQN without replay...")
agent_no_replay = DQNNoReplay(4, 2)
r_no_replay, l_no_replay, _, _ = train_dqn(CartPoleSimple(), agent_no_replay, n_ep)

print("\nTraining DQN without target network...")
agent_no_target = DQNNoTarget(4, 2)
r_no_target, l_no_target, _, _ = train_dqn(CartPoleSimple(), agent_no_target, n_ep)

In [None]:
# Visualize ablation results
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

window = 20
configs = [
    ('Full DQN', l_full, '#2ecc71'),
    ('No Experience Replay', l_no_replay, '#e74c3c'),
    ('No Target Network', l_no_target, '#f39c12'),
]

for name, lengths, color in configs:
    if len(lengths) >= window:
        smoothed = np.convolve(lengths, np.ones(window)/window, mode='valid')
        ax.plot(smoothed, label=name, color=color, linewidth=2.5)

ax.axhline(y=200, color='gray', linestyle='--', alpha=0.5, label='Goal')
ax.set_xlabel('Episode', fontsize=12)
ax.set_ylabel('Episode Length (smoothed)', fontsize=12)
ax.set_title('DQN Ablation Study: Both Components Are Critical', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Ablation results (avg length, last 50 episodes):")
for name, lengths, _ in configs:
    print(f"  {name}: {np.mean(lengths[-50:]):.1f}")

---

## 8. Double DQN: Fixing Overestimation

Vanilla DQN tends to **overestimate** Q-values. The max operator in the target:

$$\text{target} = r + \gamma \max_{a'} Q(s', a'; \theta^{-})$$

systematically selects overestimated values (because noise in Q-estimates tends to be positive when you take the max).

**Double DQN** decouples action selection from evaluation:

$$\text{target} = r + \gamma Q\left(s', \arg\max_{a'} Q(s', a'; \theta); \theta^{-}\right)$$

- Use the **online network** to select the best action
- Use the **target network** to evaluate that action

This simple change significantly reduces overestimation.

**F1 analogy:** Regular DQN is like an overconfident strategist who always assumes the *best possible* outcome for each action — "If we pit now, the undercut will definitely work AND we'll get a perfect pit stop AND the safety car will come out." By always picking the most optimistic estimate, they systematically overvalue risky strategies. Double DQN separates the question: one model says "pitting now looks best" (action selection), and a different model says "here's what pitting now is actually worth" (evaluation). The decoupling grounds the optimism in more realistic assessments.

In [None]:
class DoubleDQNAgent(DQNAgent):
    """Double DQN: decouple action selection from evaluation."""
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return None
        
        batch = self.replay_buffer.sample(self.batch_size)
        states = torch.FloatTensor(np.array(batch.state))
        actions = torch.LongTensor(batch.action)
        rewards = torch.FloatTensor(batch.reward)
        next_states = torch.FloatTensor(np.array(batch.next_state))
        dones = torch.FloatTensor(batch.done)
        
        current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        with torch.no_grad():
            # Double DQN: select action with ONLINE network
            best_actions = self.q_network(next_states).argmax(dim=1)
            # Evaluate with TARGET network
            next_q = self.target_network(next_states).gather(1, best_actions.unsqueeze(1)).squeeze(1)
            target_q = rewards + self.gamma * next_q * (1 - dones)
        
        loss = F.mse_loss(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()
        self.steps_done += 1
        if self.steps_done % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
        return loss.item()


# Demonstrate Q-value overestimation
def measure_q_overestimation(AgentClass, n_episodes=300):
    """Track max Q-value predictions during training."""
    env = CartPoleSimple()
    agent = AgentClass(4, 2)
    
    max_q_history = []
    lengths = []
    
    for ep in range(n_episodes):
        state = env.reset()
        ep_max_q = []
        length = 0
        
        for _ in range(500):
            with torch.no_grad():
                q_vals = agent.q_network(torch.FloatTensor(state).unsqueeze(0))
                ep_max_q.append(q_vals.max().item())
            
            action = agent.select_action(state)
            next_state, reward, done = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, float(done))
            agent.update()
            length += 1
            state = next_state
            if done:
                break
        
        agent.decay_epsilon()
        max_q_history.append(np.mean(ep_max_q))
        lengths.append(length)
    
    return max_q_history, lengths


print("Measuring Q-value overestimation...")
q_dqn, l_dqn = measure_q_overestimation(DQNAgent)
q_ddqn, l_ddqn = measure_q_overestimation(DoubleDQNAgent)
print("Done!")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

window = 20

# Q-value estimates
ax = axes[0]
for data, label, color in [(q_dqn, 'DQN', '#e74c3c'), (q_ddqn, 'Double DQN', '#3498db')]:
    smoothed = np.convolve(data, np.ones(window)/window, mode='valid')
    ax.plot(smoothed, label=label, color=color, linewidth=2)
ax.set_xlabel('Episode', fontsize=12)
ax.set_ylabel('Average Max Q-Value', fontsize=12)
ax.set_title('Q-Value Estimates: DQN Overestimates', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Performance comparison
ax = axes[1]
for data, label, color in [(l_dqn, 'DQN', '#e74c3c'), (l_ddqn, 'Double DQN', '#3498db')]:
    smoothed = np.convolve(data, np.ones(window)/window, mode='valid')
    ax.plot(smoothed, label=label, color=color, linewidth=2)
ax.axhline(y=200, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Episode', fontsize=12)
ax.set_ylabel('Episode Length', fontsize=12)
ax.set_title('Performance: DQN vs Double DQN', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 9. Limitations of Value-Based Methods

DQN is powerful, but it has fundamental limitations:

| Limitation | Why It Matters | F1 Parallel |
|-----------|----------------|-------------|
| **Discrete actions only** | Can't directly handle continuous control (robotics, steering angles) | Can't output "pit on lap 22.7" — only "pit" or "don't pit" |
| **Deterministic policy** | Always outputs one action per state (can't learn stochastic strategies) | Can't say "70% chance we should pit" — only yes or no |
| **Overestimation** | Even Double DQN doesn't fully solve this | Still too optimistic about undercut success rates |
| **No policy gradient** | Can't optimize policy directly for objectives like RLHF | Can't directly optimize for "smooth strategic decisions" |

These limitations motivate **policy gradient methods** (Notebook 23), which:
- Learn a policy $\pi_\theta(a|s)$ directly as a neural network
- Handle continuous action spaces naturally — perfect for steering angle, throttle, brake pressure
- Can optimize any differentiable objective
- Are the backbone of PPO and RLHF

In [None]:
# Summary comparison
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.axis('off')
ax.set_title('DQN Architecture Summary', fontsize=16, fontweight='bold')

# Components
components = [
    (1, 6, 3, 1.2, 'Experience\nReplay Buffer', '#3498db',
     'Stores transitions\nBreaks correlations'),
    (1, 3.5, 3, 1.2, 'Q-Network\n(Online)', '#2ecc71',
     'Predicts Q(s,a)\nUpdated every step'),
    (6, 3.5, 3, 1.2, 'Target Network\n(Frozen)', '#e74c3c',
     'Computes TD targets\nPeriodically synced'),
    (3.5, 1, 3, 1.0, 'ε-Greedy\nExploration', '#f39c12',
     'Balances explore/exploit\nDecays over training'),
]

for x, y, w, h, label, color, desc in components:
    box = mpatches.FancyBboxPatch((x, y), w, h, boxstyle="round,pad=0.2",
                                   facecolor=color, edgecolor='black', linewidth=2, alpha=0.9)
    ax.add_patch(box)
    ax.text(x + w/2, y + h/2 + 0.15, label, ha='center', va='center',
            fontsize=10, fontweight='bold', color='white')
    ax.text(x + w/2, y - 0.3, desc, ha='center', va='center',
            fontsize=8, color='gray', style='italic')

# Arrows
ax.annotate('', xy=(2.5, 6), xytext=(2.5, 4.7),
            arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
ax.annotate('', xy=(6, 4.1), xytext=(4, 4.1),
            arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
ax.text(5, 4.5, 'periodic\ncopy', ha='center', fontsize=8, color='gray')
ax.annotate('', xy=(3.5, 3.5), xytext=(3.5, 2),
            arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

plt.tight_layout()
plt.show()

---

## Exercises

### Exercise 1: Dueling DQN Architecture — Separating Track Position Value from Action Advantage

Implement the **Dueling DQN** architecture, which separates the Q-network into two streams:
- A **value stream** $V(s)$: how good is this state? ("Being P3 with fresh tires is inherently valuable")
- An **advantage stream** $A(s,a)$: how much better is this action than average? ("Pitting now is 0.3 better than average from P3")

$$Q(s,a) = V(s) + A(s,a) - \frac{1}{|\mathcal{A}|}\sum_{a'} A(s,a')$$

In F1 terms, some race positions are good regardless of what you do next (P2 with a 5-second gap). The value stream captures that. The advantage stream captures whether pitting NOW specifically is better than the alternatives. Separating these makes learning more efficient — the network can quickly learn which positions are valuable and independently learn which actions improve things.

In [None]:
# Exercise 1: Your code here
# Hint: Modify QNetwork to have shared layers, then split into
# a value head (outputs 1 value) and advantage head (outputs n_actions values)


### Exercise 2: Prioritized Experience Replay — Learning More from Surprising Races

Not all experiences are equally useful. Implement **prioritized replay** where transitions with higher TD error are sampled more frequently. Use importance sampling weights to correct the bias.

In F1 terms, a race where your strategy prediction was wildly wrong (huge TD error) — like an unexpected safety car changing the pit window — is far more informative than a processional race where everything went as expected. Prioritized replay ensures the model trains more on those surprising, information-rich races.

In [None]:
# Exercise 2: Your code here
# Hint: Store |TD error| + small epsilon as priority for each transition
# Sample proportional to priority, apply importance sampling weights to loss


### Exercise 3: Hyperparameter Sensitivity — Tuning the Strategy Engine

DQN has many hyperparameters, much like an F1 car has many setup parameters. Run experiments varying:
- Learning rate: [1e-4, 1e-3, 1e-2] — how fast the model adapts (like setup change aggressiveness)
- Target update frequency: [10, 100, 500] — how often the benchmark refreshes (like strategy model update cadence)
- Buffer size: [1000, 10000, 50000] — how much historical data to retain (like seasons of race data in memory)

Which hyperparameter has the biggest impact on performance? In F1, teams agonize over which setup parameters matter most for each track — this exercise gives you intuition for the same question in DQN.

In [None]:
# Exercise 3: Your code here
# Hint: Create a function that trains a DQN with given hyperparameters
# and returns the average episode length over the last 50 episodes


---

## Summary

### Key Concepts

| Concept | What It Means | F1 Parallel |
|---------|--------------|-------------|
| **Q-learning** | Learn Q(s,a) from experience using off-policy TD updates | Learn pit stop value from race experience, assuming optimal future execution |
| **SARSA** | On-policy variant — learns about the exploration policy | Learn strategy value accounting for driver imperfections |
| **Deadly triad** | Function approx + bootstrapping + off-policy = instability | Strategy model that generalizes, self-references, and learns from old data = chaos |
| **Experience replay** | Break data correlations, enable data reuse | Historical race database with random sampling |
| **Target networks** | Fixed TD targets for stability | Stable benchmark lap time, updated periodically |
| **Double DQN** | Decouple action selection from evaluation | One model picks strategy, another estimates its true value |
| **DQN limitations** | Discrete actions, deterministic policy only | Can't do continuous throttle control or probabilistic strategies |

### Fundamental Insight

DQN showed that the combination of deep learning and RL can solve problems previously thought intractable. The key wasn't a new algorithm — it was engineering: experience replay and target networks turned an unstable process into a robust learning system. Much of deep RL research is about making the learning process stable enough for neural networks to work. In F1 terms, the raw strategy math was always there — the breakthrough was building the engineering infrastructure (data pipelines, simulation tools, real-time telemetry) that made it practical at race speed.

---

## Next Steps

DQN learns *which actions are good* (value function) and derives a policy from that. But what if we could learn the policy **directly**? In **Notebook 23: Policy Gradient Methods**, we'll:

- Derive the policy gradient theorem — the mathematical foundation for directly optimizing policies
- Implement REINFORCE, the simplest policy gradient algorithm — directly learning race strategy by reinforcing good decisions
- Understand variance reduction with baselines
- Build an actor-critic architecture that combines the best of both worlds — like having a driver (actor) and strategist (critic) working together
- See why policy gradients are the path to PPO and RLHF