# Experience Replay: The Memory Bank That Saves DQN

Experience replay is one of the two key innovations that made DQN actually work! This notebook dives deep into why it's essential.

## What You'll Learn

By the end of this notebook, you'll understand:
- Why correlated data is deadly for neural networks (with a studying analogy!)
- How experience replay breaks these correlations
- Implementing a replay buffer from scratch
- Prioritized experience replay: learning from mistakes
- Comparing DQN with and without replay

**Prerequisites:** Notebook 2 (DQN From Scratch)

**Time:** ~25 minutes

---
## The Big Picture: The Study Habits Analogy

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE STUDY HABITS ANALOGY                              │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Imagine you're studying for a big exam covering 10 chapters. │
    │                                                                │
    │  BAD STRATEGY (Online Learning):                              │
    │    Day 1: Study Ch 1, Ch 1, Ch 1, Ch 1, Ch 1, ...            │
    │    Day 2: Study Ch 2, Ch 2, Ch 2, Ch 2, Ch 2, ...            │
    │    Day 3: Study Ch 3, Ch 3, Ch 3, Ch 3, Ch 3, ...            │
    │                                                                │
    │    PROBLEM: By Day 10, you've forgotten Chapter 1!           │
    │    You only remember what you studied recently.               │
    │                                                                │
    │  GOOD STRATEGY (Experience Replay):                           │
    │    Take notes in a notebook (replay buffer)                   │
    │    Each study session: randomly review notes from ALL chapters│
    │    "Chapter 5 note, Chapter 2 note, Chapter 8 note, ..."     │
    │                                                                │
    │    RESULT: You remember everything evenly!                    │
    │    No correlation between what you study each session.        │
    │                                                                │
    │  Experience Replay = Taking Notes + Random Review             │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
import numpy as np
import torch
import torch.nn as nn
from collections import deque
import random
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle, Circle

# Visualize the correlation problem
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Online learning (correlated)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('ONLINE LEARNING\n(Sequential Data)', fontsize=14, fontweight='bold', color='#d32f2f')

# Show sequential samples
for i in range(5):
    x = 1 + i * 1.7
    # All boxes same color to show correlation
    color = f'#{(int(100 + i*30)):02x}8080'
    box = FancyBboxPatch((x, 5), 1.2, 1.2, boxstyle="round,pad=0.05",
                          facecolor=color, edgecolor='black', linewidth=2)
    ax1.add_patch(box)
    ax1.text(x + 0.6, 5.6, f's{i+1}', ha='center', fontsize=11, fontweight='bold')
    if i < 4:
        ax1.annotate('', xy=(x + 1.3, 5.6), xytext=(x + 1.2, 5.6),
                    arrowprops=dict(arrowstyle='->', lw=2, color='#666'))

ax1.text(5, 4, 'Consecutive samples are HIGHLY SIMILAR', ha='center', fontsize=11, color='#d32f2f')
ax1.text(5, 3.2, '❌ Violates i.i.d. assumption', ha='center', fontsize=10)
ax1.text(5, 2.5, '❌ Network overfits to recent data', ha='center', fontsize=10)
ax1.text(5, 1.8, '❌ Forgets older experiences quickly', ha='center', fontsize=10)

# Right: Experience replay (uncorrelated)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('EXPERIENCE REPLAY\n(Random Sampling)', fontsize=14, fontweight='bold', color='#388e3c')

# Memory buffer
buffer_box = FancyBboxPatch((1, 6), 8, 2.5, boxstyle="round,pad=0.1",
                             facecolor='#e3f2fd', edgecolor='#1976d2', linewidth=2)
ax2.add_patch(buffer_box)
ax2.text(5, 8.7, 'REPLAY BUFFER (Memory)', ha='center', fontsize=11, fontweight='bold', color='#1976d2')

# Experiences in buffer (diverse colors)
colors = ['#e57373', '#81c784', '#64b5f6', '#ffb74d', '#ba68c8', '#4db6ac']
for i, color in enumerate(colors):
    x = 1.5 + i * 1.2
    box = FancyBboxPatch((x, 6.5), 0.9, 0.9, boxstyle="round,pad=0.02",
                          facecolor=color, edgecolor='black', linewidth=1)
    ax2.add_patch(box)
    ax2.text(x + 0.45, 7, f'e{i+1}', ha='center', fontsize=9, fontweight='bold')

# Random sampling
ax2.text(5, 5.3, '↓ Random Sample ↓', ha='center', fontsize=10, color='#388e3c')

# Sampled batch (random order/colors)
sampled_colors = [colors[2], colors[5], colors[0], colors[4]]
ax2.text(5, 4.5, 'Training Batch:', ha='center', fontsize=10, fontweight='bold')
for i, color in enumerate(sampled_colors):
    x = 2 + i * 1.5
    box = FancyBboxPatch((x, 3.5), 0.9, 0.9, boxstyle="round,pad=0.02",
                          facecolor=color, edgecolor='black', linewidth=1)
    ax2.add_patch(box)

ax2.text(5, 2.5, '✓ Diverse, uncorrelated samples', ha='center', fontsize=10, color='#388e3c')
ax2.text(5, 1.8, '✓ Reuses past experiences', ha='center', fontsize=10, color='#388e3c')
ax2.text(5, 1.1, '✓ Stable training!', ha='center', fontsize=10, color='#388e3c', fontweight='bold')

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("THE CORRELATION PROBLEM")
print("="*70)
print("""
Neural networks expect i.i.d. (independent and identically distributed) data.

In RL, consecutive experiences are HIGHLY CORRELATED:
  - State at time t is very similar to state at time t+1
  - All experiences in one episode come from the same situation

This causes the network to:
  1. Overfit to recent experiences
  2. Forget how to handle older situations ("catastrophic forgetting")
  3. Oscillate wildly during training

SOLUTION: Store experiences and randomly sample from them!
""")
print("="*70)

---
## Implementing a Replay Buffer

A replay buffer is like a library with limited shelf space:

```
    ┌────────────────────────────────────────────────────────────────┐
    │                    REPLAY BUFFER                               │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  STORAGE:                                                     │
    │    Each experience = (state, action, reward, next_state, done)│
    │    Capacity = max number of experiences to store              │
    │    When full: remove oldest experience (FIFO)                 │
    │                                                                │
    │  OPERATIONS:                                                  │
    │    push(s, a, r, s', done) → Store new experience            │
    │    sample(batch_size) → Random sample of experiences         │
    │                                                                │
    │  TYPICAL SIZES:                                               │
    │    CartPole: 10,000 - 100,000                                 │
    │    Atari: 1,000,000                                           │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
class ReplayBuffer:
    """
    Standard Uniform Replay Buffer.
    
    Stores experiences and allows random sampling for training.
    Uses a deque (double-ended queue) for efficient O(1) operations.
    
    The key insight: Random sampling breaks temporal correlations!
    """
    
    def __init__(self, capacity):
        """
        Args:
            capacity: Maximum number of experiences to store.
                      Older experiences are automatically removed.
        """
        self.buffer = deque(maxlen=capacity)
        self.capacity = capacity
    
    def push(self, state, action, reward, next_state, done):
        """
        Store a new experience.
        
        If buffer is full, oldest experience is automatically removed.
        """
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        """
        Randomly sample a batch of experiences.
        
        This is where the magic happens - random sampling means:
        - Experiences from different episodes in same batch
        - Experiences from different times in same batch
        - Approximately i.i.d. training data!
        
        Returns:
            Tuple of (states, actions, rewards, next_states, dones) tensors
        """
        # Random sample (this breaks correlations!)
        batch = random.sample(self.buffer, batch_size)
        
        # Unpack and convert to tensors
        states, actions, rewards, next_states, dones = zip(*batch)
        
        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones)
        )
    
    def __len__(self):
        return len(self.buffer)


# Demonstrate the buffer
print("REPLAY BUFFER DEMONSTRATION")
print("="*60)

buffer = ReplayBuffer(capacity=1000)

# Simulate collecting experiences
print("\nSimulating 100 experiences...")
for i in range(100):
    state = np.array([i * 0.01, np.sin(i * 0.1), np.cos(i * 0.1), i * 0.001])
    action = i % 2
    reward = 1.0 if i % 10 == 0 else 0.0
    next_state = np.array([(i+1) * 0.01, np.sin((i+1) * 0.1), np.cos((i+1) * 0.1), (i+1) * 0.001])
    done = (i == 99)
    buffer.push(state, action, reward, next_state, done)

print(f"Buffer size: {len(buffer)} / {buffer.capacity}")

# Sample a batch
print("\nSampling a batch of 5 experiences...")
states, actions, rewards, next_states, dones = buffer.sample(5)

print(f"\nSampled States (first column shows diversity):")
for i, s in enumerate(states):
    print(f"  Sample {i+1}: state[0] = {s[0].item():.3f}")

print("\nNotice: The state values are NOT sequential!")
print("This is the key benefit - random sampling breaks correlations.")
print("="*60)

---
## Why Random Sampling Matters: Visualization

Let's see the difference in what the network "sees" during training:

In [None]:
# Visualize the difference in training data
np.random.seed(42)

# Simulate 100 timesteps of experience
timesteps = np.arange(100)
state_values = np.sin(timesteps * 0.1) + 0.1 * np.random.randn(100)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Left: All experiences
ax1 = axes[0]
ax1.plot(timesteps, state_values, 'b-', alpha=0.7, linewidth=1)
ax1.scatter(timesteps, state_values, c='blue', s=20, alpha=0.7)
ax1.set_xlabel('Time Step', fontsize=11)
ax1.set_ylabel('State Value', fontsize=11)
ax1.set_title('All Collected Experiences', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Middle: Online learning (sequential batch)
ax2 = axes[1]
# Take consecutive samples (last 8)
online_indices = np.arange(92, 100)
ax2.plot(timesteps, state_values, 'gray', alpha=0.2, linewidth=1)
ax2.scatter(timesteps, state_values, c='gray', s=10, alpha=0.2)
ax2.scatter(online_indices, state_values[online_indices], c='red', s=80, 
           edgecolors='black', linewidth=2, zorder=5, label='Training Batch')
ax2.axvspan(92, 99, alpha=0.2, color='red')
ax2.set_xlabel('Time Step', fontsize=11)
ax2.set_title('Online Learning\n(Consecutive Samples)', fontsize=12, fontweight='bold', color='#d32f2f')
ax2.grid(True, alpha=0.3)
ax2.legend(loc='upper right')

# Right: Replay (random batch)
ax3 = axes[2]
# Random samples
replay_indices = np.random.choice(100, 8, replace=False)
ax3.plot(timesteps, state_values, 'gray', alpha=0.2, linewidth=1)
ax3.scatter(timesteps, state_values, c='gray', s=10, alpha=0.2)
ax3.scatter(replay_indices, state_values[replay_indices], c='green', s=80,
           edgecolors='black', linewidth=2, zorder=5, label='Training Batch')
for idx in replay_indices:
    ax3.axvline(x=idx, alpha=0.1, color='green')
ax3.set_xlabel('Time Step', fontsize=11)
ax3.set_title('Experience Replay\n(Random Samples)', fontsize=12, fontweight='bold', color='#388e3c')
ax3.grid(True, alpha=0.3)
ax3.legend(loc='upper right')

plt.tight_layout()
plt.show()

# Show the actual values
print("\nSAMPLED VALUES COMPARISON:")
print("-"*50)
print(f"Online (consecutive): {state_values[online_indices].round(2)}")
print(f"  → Very similar values! High correlation.")
print(f"\nReplay (random): {state_values[replay_indices].round(2)}")
print(f"  → Diverse values! Low correlation.")

---
## Prioritized Experience Replay: Learning From Mistakes

Not all experiences are equally valuable! Some teach us more than others.

```
    ┌────────────────────────────────────────────────────────────────┐
    │            PRIORITIZED EXPERIENCE REPLAY                       │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  IDEA: Sample experiences that are more "surprising" more often│
    │                                                                │
    │  What makes an experience surprising?                         │
    │    High TD Error = |target - prediction|                      │
    │                                                                │
    │    If TD error is high → We predicted poorly                  │
    │    → This experience has more to teach us!                    │
    │                                                                │
    │  PRIORITY FORMULA:                                            │
    │    priority = |TD error| + ε   (ε is small constant)          │
    │                                                                │
    │  SAMPLING:                                                    │
    │    P(sample experience i) ∝ priority_i^α                      │
    │    α controls how much to prioritize (0 = uniform, 1 = full)  │
    │                                                                │
    │  ANALOGY:                                                     │
    │    Studying for an exam: spend more time on problems you      │
    │    got WRONG, not the ones you already know!                  │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
class PrioritizedReplayBuffer:
    """
    Prioritized Experience Replay Buffer.
    
    Samples experiences with high TD error more frequently.
    This focuses learning on "surprising" experiences.
    
    Key insight: Experiences we predict poorly are more informative!
    """
    
    def __init__(self, capacity, alpha=0.6, epsilon=1e-6):
        """
        Args:
            capacity: Maximum buffer size
            alpha: How much to prioritize (0=uniform, 1=fully prioritized)
            epsilon: Small constant added to priorities (ensures non-zero)
        """
        self.capacity = capacity
        self.alpha = alpha
        self.epsilon = epsilon
        
        self.buffer = []
        self.priorities = []
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        """
        Store experience with maximum priority.
        
        New experiences get max priority so they're sampled at least once.
        """
        max_priority = max(self.priorities) if self.priorities else 1.0
        
        experience = (state, action, reward, next_state, done)
        
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
            self.priorities.append(max_priority)
        else:
            self.buffer[self.position] = experience
            self.priorities[self.position] = max_priority
        
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        """
        Sample experiences with probability proportional to priority.
        
        P(i) = priority_i^alpha / sum(priority_j^alpha)
        """
        # Compute sampling probabilities
        priorities = np.array(self.priorities) ** self.alpha
        probabilities = priorities / priorities.sum()
        
        # Sample indices according to priorities
        indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
        
        # Get experiences
        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states, dones = zip(*batch)
        
        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones),
            indices,
            torch.FloatTensor(probabilities[indices])  # For importance sampling
        )
    
    def update_priorities(self, indices, td_errors):
        """
        Update priorities based on new TD errors.
        
        Called after training step with the computed TD errors.
        """
        for idx, td_error in zip(indices, td_errors):
            self.priorities[idx] = abs(td_error) + self.epsilon
    
    def __len__(self):
        return len(self.buffer)


# Demonstrate prioritized replay
print("PRIORITIZED REPLAY DEMONSTRATION")
print("="*60)

per_buffer = PrioritizedReplayBuffer(capacity=100, alpha=0.6)

# Add experiences
for i in range(100):
    state = np.random.randn(4)
    per_buffer.push(state, i % 2, np.random.randn(), np.random.randn(4), False)

# Simulate updating priorities (some experiences are "surprising")
for i in range(0, 100, 10):
    per_buffer.priorities[i] = 5.0  # High priority (big TD error)

print("\nAfter setting high priority for indices 0, 10, 20, 30, ...")
print(f"\nSampling 20 experiences:")

# Sample and count how often high-priority experiences appear
high_priority_count = 0
for _ in range(100):
    _, _, _, _, _, indices, _ = per_buffer.sample(20)
    high_priority_count += sum(1 for i in indices if i % 10 == 0)

print(f"High-priority experiences sampled: {high_priority_count} / 2000")
print(f"Expected if uniform: ~200 (10% of experiences)")
print(f"\nWith prioritization, surprising experiences are sampled MORE OFTEN!")
print("="*60)

In [None]:
# Visualize prioritized vs uniform sampling
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Create sample priorities
np.random.seed(42)
n_experiences = 50
priorities = np.random.exponential(1, n_experiences)
# Make a few experiences very high priority
priorities[5] = 10
priorities[22] = 8
priorities[38] = 12

# Left: Priorities
ax1 = axes[0]
colors = ['red' if p > 5 else 'steelblue' for p in priorities]
ax1.bar(range(n_experiences), priorities, color=colors, edgecolor='black', linewidth=0.5)
ax1.set_xlabel('Experience Index', fontsize=11)
ax1.set_ylabel('Priority (|TD Error|)', fontsize=11)
ax1.set_title('Experience Priorities\n(Red = High TD Error)', fontsize=12, fontweight='bold')
ax1.axhline(y=np.mean(priorities), color='gray', linestyle='--', label=f'Mean: {np.mean(priorities):.1f}')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Right: Sampling probabilities comparison
ax2 = axes[1]

# Uniform probabilities
uniform_probs = np.ones(n_experiences) / n_experiences

# Prioritized probabilities (alpha=0.6)
alpha = 0.6
prioritized_probs = (priorities ** alpha) / (priorities ** alpha).sum()

x = np.arange(n_experiences)
width = 0.35

ax2.bar(x - width/2, uniform_probs * 100, width, label='Uniform', color='#90caf9', edgecolor='black')
ax2.bar(x + width/2, prioritized_probs * 100, width, label='Prioritized', color='#ef5350', edgecolor='black')

ax2.set_xlabel('Experience Index', fontsize=11)
ax2.set_ylabel('Sampling Probability (%)', fontsize=11)
ax2.set_title('Uniform vs Prioritized Sampling', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nHigh-priority experiences (5, 22, 38) are sampled much more often!")
print(f"Uniform probability: {uniform_probs[5]*100:.1f}%")
print(f"Prioritized probability: {prioritized_probs[5]*100:.1f}% (experience 5)")

---
## Summary: Key Takeaways

### The Problem

| Issue | Description |
|-------|-------------|
| **Correlation** | Consecutive experiences are nearly identical |
| **Forgetting** | Network overfits to recent data, forgets past |
| **Instability** | Training oscillates wildly |

### The Solution

| Component | How It Helps |
|-----------|-------------|
| **Replay Buffer** | Stores past experiences for later use |
| **Random Sampling** | Breaks temporal correlations |
| **Data Reuse** | Same experience can train multiple times |

### Uniform vs Prioritized

| Type | Sampling | When to Use |
|------|----------|-------------|
| **Uniform** | All equal probability | Simple, stable |
| **Prioritized** | High TD-error more likely | Better sample efficiency |

---
## Test Your Understanding

**1. Why is correlated training data bad for neural networks?**
<details>
<summary>Click to reveal answer</summary>
Neural networks assume i.i.d. (independent and identically distributed) training data. Correlated data causes the network to overfit to recent patterns while forgetting older ones ("catastrophic forgetting"). The gradient updates become biased toward the current situation rather than the overall task.
</details>

**2. How does experience replay break correlations?**
<details>
<summary>Click to reveal answer</summary>
By storing experiences in a buffer and randomly sampling from it, we mix experiences from different times and episodes in each training batch. A batch might contain experiences from the beginning of training alongside recent ones, making the data approximately i.i.d.
</details>

**3. What does "priority" mean in prioritized experience replay?**
<details>
<summary>Click to reveal answer</summary>
Priority is based on TD error (|target - prediction|). Experiences with high TD error were poorly predicted, meaning they have more to teach the network. By sampling these more often, we focus learning on the most informative experiences.
</details>

**4. Why do new experiences get max priority in PER?**
<details>
<summary>Click to reveal answer</summary>
New experiences haven't been trained on yet, so we don't know their TD error. Giving them max priority ensures they'll be sampled at least once, after which their priority will be updated based on actual TD error.
</details>

---
## What's Next?

Experience replay is half the story. The other key innovation is **Target Networks**!

**Continue to:** [Notebook 4: Target Networks](04_target_networks.ipynb)

---

*Experience replay: "Learn from the past, not just the present."*