# Part 4.3: Recurrent Neural Networks (RNNs) â€” The Formula 1 Edition

Sequential data is everywhere: language, music, stock prices, weather, DNA. Unlike images where pixels exist in a grid, sequential data has an inherent **order** and **temporal context**. A word's meaning depends on the words before it. A stock price depends on its history. RNNs are the first architecture designed to handle this challenge by maintaining a **memory** of what came before.

**F1 analogy:** A Grand Prix is the ultimate sequential data problem. Every lap depends on what happened in previous laps -- tire degradation accumulates, fuel burns off making the car lighter, track evolution (rubber laid down) changes grip levels. You cannot understand lap 45 without knowing the history of laps 1-44. An RNN is like the car's accumulated state: the hidden state carries forward everything the network "remembers" about the race so far -- tire wear, fuel load, track conditions -- and updates it with each new lap of data.

---

## Learning Objectives

By the end of this notebook, you should be able to:

- [ ] Explain why feedforward networks fail on sequential data
- [ ] Implement a vanilla RNN from scratch in NumPy
- [ ] Describe backpropagation through time and the vanishing gradient problem
- [ ] Explain the LSTM architecture and the purpose of each gate
- [ ] Compare GRU and LSTM and know when to use each
- [ ] Use PyTorch's nn.RNN, nn.LSTM, and nn.GRU modules correctly
- [ ] Build a character-level language model for text generation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')
torch.manual_seed(42)
np.random.seed(42)

---

## 1. Why Sequences Need Special Treatment

### Intuitive Explanation

Consider predicting the next word in a sentence:

- "The clouds are dark, it will probably **___**" --> "rain"
- "She picked up the phone and **___**" --> "called"

A feedforward network takes a **fixed-size input** and produces a **fixed-size output**. But sequences break this assumption in two fundamental ways:

1. **Variable length**: Sentences can be 5 words or 500 words
2. **Order matters**: "dog bites man" vs "man bites dog" have the same words but opposite meanings
3. **Context accumulates**: Understanding word 50 may require remembering word 3

| Problem | Feedforward Approach | Why It Fails | F1 Parallel |
|---------|---------------------|--------------|-------------|
| Text classification | Flatten all words into one vector | Loses word order | Averaging all lap times ignores the race trajectory |
| Time series prediction | Fixed window of past values | Can't adapt window size | Looking at only the last 5 laps misses a tire change 10 laps ago |
| Speech recognition | Fixed audio chunk | Utterances vary in length | Races vary from 44 to 78 laps depending on the circuit |
| Machine translation | Fixed input/output size | Languages have different sentence lengths | Predicting stint length varies by compound, fuel, and strategy |

**The key insight:** We need an architecture that can process inputs **one step at a time** while maintaining a **running summary** (memory) of everything it has seen so far.

**F1 analogy:** Think of a race engineer monitoring a Grand Prix. They do not wait until the race is over to analyze it -- they process each lap as it happens, updating their running understanding of tire degradation, fuel burn, and relative pace. Their mental model after lap 30 incorporates everything from laps 1-30. That running mental model is the hidden state of an RNN.

### Visualization: Feedforward vs Sequential Processing

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Feedforward approach (all at once)
ax = axes[0]
words = ['The', 'cat', 'sat', 'on', 'mat']
# Draw input words
for i, word in enumerate(words):
    ax.text(i * 0.8 + 0.2, 0.0, word, ha='center', va='center',
            fontsize=12, bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', edgecolor='blue'))
    # Arrow from word to single box
    ax.annotate('', xy=(2.0, 0.4), xytext=(i * 0.8 + 0.2, 0.15),
                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))

# Single processing box
ax.add_patch(plt.Rectangle((1.0, 0.4), 2.0, 0.3, facecolor='lightyellow', edgecolor='orange', lw=2))
ax.text(2.0, 0.55, 'Feedforward\nNetwork', ha='center', va='center', fontsize=11, fontweight='bold')

# Output
ax.annotate('', xy=(2.0, 0.9), xytext=(2.0, 0.7),
            arrowprops=dict(arrowstyle='->', color='orange', lw=2))
ax.text(2.0, 1.0, 'Output', ha='center', va='center', fontsize=12,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='lightyellow', edgecolor='orange'))

ax.set_xlim(-0.3, 4.0)
ax.set_ylim(-0.3, 1.2)
ax.set_title('Feedforward: All inputs at once\n(loses order information)', fontsize=13)
ax.axis('off')

# Right: Sequential approach (one at a time with memory)
ax = axes[1]
colors_seq = ['#a8d8ea', '#aa96da', '#fcbad3', '#ffffd2', '#a8e6cf']
for i, word in enumerate(words):
    x_pos = i * 0.8 + 0.1
    # Word input
    ax.text(x_pos, 0.0, word, ha='center', va='center', fontsize=12,
            bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', edgecolor='blue'))
    # Arrow up to hidden state
    ax.annotate('', xy=(x_pos, 0.35), xytext=(x_pos, 0.15),
                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
    # Hidden state box
    ax.add_patch(plt.Rectangle((x_pos - 0.25, 0.35), 0.5, 0.25,
                                facecolor=colors_seq[i], edgecolor='purple', lw=2, alpha=0.8))
    ax.text(x_pos, 0.475, f'h{i}', ha='center', va='center', fontsize=11, fontweight='bold')
    # Arrow between hidden states
    if i < len(words) - 1:
        ax.annotate('', xy=((i + 1) * 0.8 + 0.1 - 0.25, 0.475),
                    xytext=(x_pos + 0.25, 0.475),
                    arrowprops=dict(arrowstyle='->', color='purple', lw=2))

# Final output
last_x = (len(words) - 1) * 0.8 + 0.1
ax.annotate('', xy=(last_x, 0.8), xytext=(last_x, 0.6),
            arrowprops=dict(arrowstyle='->', color='green', lw=2))
ax.text(last_x, 0.9, 'Output', ha='center', va='center', fontsize=12,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgreen', edgecolor='green'))

ax.text(1.6, 0.75, 'Memory flows forward\nthrough hidden states', ha='center',
        fontsize=10, fontstyle='italic', color='purple')

ax.set_xlim(-0.3, 4.0)
ax.set_ylim(-0.3, 1.15)
ax.set_title('Sequential (RNN): One input at a time\n(preserves order via hidden state)', fontsize=13)
ax.axis('off')

plt.tight_layout()
plt.show()

---

## 2. Vanilla RNN

### Intuitive Explanation

An RNN processes a sequence one element at a time, maintaining a **hidden state** that acts as its memory. At each time step, the RNN:

1. Looks at the current input $x_t$
2. Looks at its previous memory $h_{t-1}$
3. Combines them to create a new memory $h_t$

**Analogy:** Imagine reading a book one word at a time. After each word, your understanding of the story (hidden state) updates. You don't re-read the entire book at every word -- you carry a summary in your head and update it.

**F1 analogy:** The hidden state is the car's accumulated wear state. At each lap $t$, the RNN receives new data (lap time, tire temps, fuel load) and combines it with the car's accumulated state from all previous laps. After lap 30, $h_{30}$ encodes everything the network knows about the car's condition -- how much rubber is left, how the balance has shifted, how the fuel burn is tracking. The same update rule applies at every lap, just as the same physics governs tire degradation whether it is lap 5 or lap 50.

### The Recurrence Relation

$$h_t = \tanh(W_{hh} \cdot h_{t-1} + W_{xh} \cdot x_t + b_h)$$
$$y_t = W_{hy} \cdot h_t + b_y$$

#### Breaking down the formula:

| Component | Shape | Meaning | F1 Parallel |
|-----------|-------|---------|-------------|
| $x_t$ | (input_size,) | Input at time step t | Lap t data: [lap_time, tire_temp, fuel_load, gap_to_leader] |
| $h_{t-1}$ | (hidden_size,) | Previous hidden state (memory) | Car's accumulated state after lap t-1 |
| $W_{xh}$ | (hidden_size, input_size) | How to process new input | How new lap data updates the car state |
| $W_{hh}$ | (hidden_size, hidden_size) | How to process previous memory | How the car's existing state carries forward |
| $b_h$ | (hidden_size,) | Bias term | -- |
| $\tanh$ | -- | Squash to [-1, 1] to prevent explosion | -- |
| $h_t$ | (hidden_size,) | New hidden state (updated memory) | Car's accumulated state after lap t |
| $W_{hy}$ | (output_size, hidden_size) | Transform memory to output | Predict next-lap performance from current state |

**What this means:** The hidden state $h_t$ is a compressed summary of everything the network has seen from time step 0 to t. The same weights $W_{hh}$ and $W_{xh}$ are shared across all time steps -- the RNN applies the same transformation at every step.

### Visualization: Unrolling Through Time

In [None]:
fig, ax = plt.subplots(figsize=(14, 6))

n_steps = 5
labels = ['x0', 'x1', 'x2', 'x3', 'x4']
spacing = 2.5

for t in range(n_steps):
    x = t * spacing
    
    # Input
    ax.text(x, 0, labels[t], ha='center', va='center', fontsize=13,
            bbox=dict(boxstyle='round,pad=0.4', facecolor='lightblue', edgecolor='blue', lw=2))
    
    # Arrow from input to hidden
    ax.annotate('', xy=(x, 1.2), xytext=(x, 0.35),
                arrowprops=dict(arrowstyle='->', color='blue', lw=1.5))
    ax.text(x + 0.3, 0.75, '$W_{xh}$', fontsize=10, color='blue')
    
    # Hidden state
    ax.add_patch(plt.Rectangle((x - 0.6, 1.2), 1.2, 0.8,
                                facecolor='#e8daef', edgecolor='purple', lw=2.5, zorder=3))
    ax.text(x, 1.6, f'$h_{t}$', ha='center', va='center', fontsize=14, fontweight='bold', zorder=4)
    ax.text(x, 1.3, 'tanh', ha='center', va='center', fontsize=9, color='gray', zorder=4)
    
    # Arrow from hidden to output
    ax.annotate('', xy=(x, 2.8), xytext=(x, 2.0),
                arrowprops=dict(arrowstyle='->', color='green', lw=1.5))
    ax.text(x + 0.3, 2.4, '$W_{hy}$', fontsize=10, color='green')
    
    # Output
    ax.text(x, 3.0, f'$y_{t}$', ha='center', va='center', fontsize=13,
            bbox=dict(boxstyle='round,pad=0.4', facecolor='lightgreen', edgecolor='green', lw=2))
    
    # Arrow between hidden states
    if t < n_steps - 1:
        ax.annotate('', xy=((t + 1) * spacing - 0.6, 1.6),
                    xytext=(x + 0.6, 1.6),
                    arrowprops=dict(arrowstyle='->', color='purple', lw=2.5))
        ax.text(x + spacing / 2, 1.85, '$W_{hh}$', ha='center', fontsize=10, color='purple')

# Initial hidden state
ax.annotate('', xy=(-0.6, 1.6), xytext=(-1.5, 1.6),
            arrowprops=dict(arrowstyle='->', color='purple', lw=2))
ax.text(-2.0, 1.6, '$h_0$\n(zeros)', ha='center', va='center', fontsize=11,
        bbox=dict(boxstyle='round,pad=0.3', facecolor='lavender', edgecolor='purple'))

ax.text(n_steps * spacing / 2 - 0.5, 3.7,
        'Same weights (W_xh, W_hh, W_hy) shared across ALL time steps',
        ha='center', fontsize=12, fontstyle='italic', color='darkred',
        bbox=dict(boxstyle='round,pad=0.4', facecolor='#fff3e0', edgecolor='darkred', alpha=0.8))

ax.set_xlim(-2.5, n_steps * spacing)
ax.set_ylim(-0.5, 4.2)
ax.set_title('RNN Unrolled Through Time', fontsize=15, fontweight='bold')
ax.axis('off')
plt.tight_layout()
plt.show()

### Implementing a Vanilla RNN from Scratch

In [None]:
class VanillaRNN:
    """
    Vanilla RNN implemented from scratch in NumPy.
    
    Args:
        input_size: Dimension of input at each time step
        hidden_size: Dimension of hidden state
        output_size: Dimension of output at each time step
    """
    def __init__(self, input_size, hidden_size, output_size):
        # Xavier initialization for weights
        scale_xh = np.sqrt(2.0 / (input_size + hidden_size))
        scale_hh = np.sqrt(2.0 / (hidden_size + hidden_size))
        scale_hy = np.sqrt(2.0 / (hidden_size + output_size))
        
        self.W_xh = np.random.randn(hidden_size, input_size) * scale_xh
        self.W_hh = np.random.randn(hidden_size, hidden_size) * scale_hh
        self.b_h = np.zeros(hidden_size)
        
        self.W_hy = np.random.randn(output_size, hidden_size) * scale_hy
        self.b_y = np.zeros(output_size)
        
        self.hidden_size = hidden_size
    
    def forward(self, inputs, h_prev=None):
        """
        Process a sequence of inputs.
        
        Args:
            inputs: List of input vectors, each shape (input_size,)
            h_prev: Initial hidden state, shape (hidden_size,)
        
        Returns:
            outputs: List of output vectors
            hidden_states: List of hidden states (including initial)
        """
        if h_prev is None:
            h_prev = np.zeros(self.hidden_size)
        
        hidden_states = [h_prev]
        outputs = []
        
        for x_t in inputs:
            # Core RNN computation
            h_t = np.tanh(self.W_hh @ hidden_states[-1] + self.W_xh @ x_t + self.b_h)
            y_t = self.W_hy @ h_t + self.b_y
            
            hidden_states.append(h_t)
            outputs.append(y_t)
        
        return outputs, hidden_states

# Create an RNN and process a simple sequence
np.random.seed(42)
rnn = VanillaRNN(input_size=3, hidden_size=4, output_size=2)

# Create a sequence of 5 time steps, each with 3 features
sequence = [np.random.randn(3) for _ in range(5)]

outputs, hidden_states = rnn.forward(sequence)

print("Vanilla RNN from scratch:")
print(f"  Input size: 3, Hidden size: 4, Output size: 2")
print(f"  Sequence length: {len(sequence)}")
print(f"\nHidden states evolve over time:")
for t, h in enumerate(hidden_states):
    print(f"  h_{t}: {h.round(3)}")
print(f"\nOutputs at each step:")
for t, y in enumerate(outputs):
    print(f"  y_{t}: {y.round(3)}")

### Visualization: Hidden State Evolution

Let's see how the hidden state evolves as the RNN reads a sequence. We will feed in a sine wave and watch each hidden unit track different aspects of the signal.

**F1 analogy:** Think of each hidden unit as a different aspect of the car's state. One might track tire degradation, another fuel burn, another track evolution. As the RNN reads lap-by-lap data, each hidden unit updates independently, capturing a different facet of the car's evolving condition.

In [None]:
# Feed a sine wave through our RNN and watch hidden states
np.random.seed(42)
rnn_vis = VanillaRNN(input_size=1, hidden_size=8, output_size=1)

# Create sine wave input
t = np.linspace(0, 4 * np.pi, 50)
sine_input = [np.array([np.sin(ti)]) for ti in t]

outputs, hidden_states = rnn_vis.forward(sine_input)

# Convert to arrays for plotting
h_array = np.array(hidden_states[1:])  # Skip initial zero state
out_array = np.array(outputs).flatten()

fig, axes = plt.subplots(3, 1, figsize=(12, 10))

# Plot 1: Input signal
ax = axes[0]
ax.plot(t, [s[0] for s in sine_input], 'b-', linewidth=2, label='Input: sin(t)')
ax.set_xlabel('Time')
ax.set_ylabel('Input Value')
ax.set_title('Input Sequence (Sine Wave)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Hidden state evolution (each unit as a different color)
ax = axes[1]
colors = plt.cm.tab10(np.linspace(0, 1, 8))
for i in range(8):
    ax.plot(t, h_array[:, i], linewidth=1.5, alpha=0.8, color=colors[i],
            label=f'h[{i}]')
ax.set_xlabel('Time')
ax.set_ylabel('Hidden Unit Value')
ax.set_title('Hidden State Evolution (Each Unit Tracks Different Features)')
ax.legend(ncol=4, fontsize=9, loc='upper right')
ax.grid(True, alpha=0.3)

# Plot 3: Hidden state as heatmap
ax = axes[2]
im = ax.imshow(h_array.T, aspect='auto', cmap='RdBu', vmin=-1, vmax=1,
               extent=[t[0], t[-1], 7.5, -0.5])
ax.set_xlabel('Time')
ax.set_ylabel('Hidden Unit')
ax.set_title('Hidden State Heatmap (Blue = -1, Red = +1)')
plt.colorbar(im, ax=ax, label='Activation')

plt.tight_layout()
plt.show()

print("Each hidden unit learns to respond differently to the input pattern.")
print("Some track the input directly, others track derivatives or longer patterns.")

---

## 3. Backpropagation Through Time (BPTT)

### Intuitive Explanation

To train an RNN, we need to compute gradients. Since the RNN is unrolled through time, backpropagation must flow **backward through every time step**. This is called **Backpropagation Through Time (BPTT)**.

The gradient of the loss at time step $T$ with respect to earlier hidden states involves a **chain of multiplications** through $W_{hh}$:

$$\frac{\partial L_T}{\partial h_t} = \frac{\partial L_T}{\partial h_T} \cdot \prod_{k=t+1}^{T} \frac{\partial h_k}{\partial h_{k-1}}$$

Each factor $\frac{\partial h_k}{\partial h_{k-1}}$ involves multiplying by $W_{hh}$ and the derivative of tanh.

**The problem:** Multiplying many numbers together causes exponential behavior:
- If each factor is < 1: the product **vanishes** (goes to 0)
- If each factor is > 1: the product **explodes** (goes to infinity)

This is the **vanishing/exploding gradient problem** -- the fundamental limitation of vanilla RNNs.

**F1 analogy:** This is like forgetting what happened in lap 1 by the time you reach lap 50. The vanilla RNN's memory degrades exponentially with distance. If you are trying to predict tire performance on lap 50 and the key information (tire compound choice, formation lap conditions) was set on lap 1, the gradient signal from lap 50 back to lap 1 has to survive 49 multiplications. With vanishing gradients, that signal arrives as essentially zero -- the network cannot learn that lap 1 conditions matter for lap 50 performance.

### Visualization: Gradient Magnitude vs Sequence Length

In [None]:
# Demonstrate vanishing/exploding gradients
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

seq_lengths = np.arange(1, 51)

# Case 1: Vanishing (factor < 1)
ax = axes[0]
factors = [0.9, 0.7, 0.5]
for f in factors:
    grad_magnitudes = f ** seq_lengths
    ax.plot(seq_lengths, grad_magnitudes, linewidth=2, label=f'factor = {f}')
ax.set_xlabel('Steps Back in Time')
ax.set_ylabel('Gradient Magnitude')
ax.set_title('Vanishing Gradients\n(factor < 1)')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)
ax.axhline(y=1e-6, color='red', linestyle='--', alpha=0.5, label='Effective zero')

# Case 2: Exploding (factor > 1)
ax = axes[1]
factors = [1.1, 1.3, 1.5]
for f in factors:
    grad_magnitudes = f ** seq_lengths
    ax.plot(seq_lengths, grad_magnitudes, linewidth=2, label=f'factor = {f}')
ax.set_xlabel('Steps Back in Time')
ax.set_ylabel('Gradient Magnitude')
ax.set_title('Exploding Gradients\n(factor > 1)')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

# Case 3: Simulate actual RNN gradient flow
ax = axes[2]

def simulate_gradient_flow(hidden_size, seq_len, n_trials=50):
    """Simulate gradient magnitude through an RNN."""
    magnitudes = []
    for _ in range(n_trials):
        W_hh = np.random.randn(hidden_size, hidden_size) * (1.0 / np.sqrt(hidden_size))
        grad = np.eye(hidden_size)
        mags = [1.0]
        for t in range(seq_len):
            # Gradient through tanh: diag(1 - tanh^2(h)) * W_hh
            # Approximate tanh derivative as random diagonal in (0, 1)
            tanh_deriv = np.diag(np.random.uniform(0.1, 1.0, hidden_size))
            grad = tanh_deriv @ W_hh @ grad
            mags.append(np.linalg.norm(grad))
        magnitudes.append(mags)
    return np.mean(magnitudes, axis=0)

for hs in [16, 32, 64]:
    mags = simulate_gradient_flow(hs, 50)
    ax.plot(mags, linewidth=2, label=f'hidden_size={hs}')

ax.set_xlabel('Steps Back in Time')
ax.set_ylabel('Gradient Magnitude')
ax.set_title('Simulated RNN Gradient Flow\n(averaged over 50 trials)')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key takeaway: Vanilla RNN gradients vanish exponentially with sequence length.")
print("After ~20 steps, the gradient is effectively zero -- the RNN 'forgets' early inputs.")

### Deep Dive: Vanishing vs Exploding Gradients in RNNs

The vanishing gradient problem is not unique to RNNs -- deep feedforward networks suffer from it too. But RNNs make it **much worse** because the same weight matrix $W_{hh}$ is multiplied at every step. In a 100-layer feedforward network, you have 100 different weight matrices. In an RNN processing a 100-step sequence, you multiply by the **same** $W_{hh}$ 100 times.

#### Key Insight

The eigenvalues of $W_{hh}$ determine what happens:
- If the largest eigenvalue < 1: gradients vanish
- If the largest eigenvalue > 1: gradients explode
- We need eigenvalues close to 1 -- but that is a razor's edge to balance on

**F1 analogy:** Imagine a relay of information along the pit wall. If each station attenuates the signal by 5% (eigenvalue = 0.95), after 50 stations the signal is at $0.95^{50} \approx 0.08$ -- barely audible. If each station amplifies by 5% (eigenvalue = 1.05), after 50 stations it is at $1.05^{50} \approx 11.5$ -- deafening feedback. This is exactly the vanishing/exploding gradient dilemma.

#### Practical Consequences

| Problem | Effect | Solution |
|---------|--------|----------|
| Vanishing gradients | RNN cannot learn long-range dependencies | Use LSTM/GRU |
| Exploding gradients | Training diverges, loss becomes NaN | Gradient clipping |

#### Gradient Clipping

Exploding gradients have a simple fix -- **clip** the gradient norm:

$$\text{if } \|\nabla\| > \text{threshold}: \quad \nabla \leftarrow \frac{\text{threshold}}{\|\nabla\|} \cdot \nabla$$

This rescales the gradient to have a maximum norm, preventing explosions while preserving direction. In PyTorch: `torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)`

Vanishing gradients are harder to fix. We need a fundamentally different architecture -- which leads us to LSTM.

---

## 4. LSTM (Long Short-Term Memory)

### Intuitive Explanation

The LSTM (Hochreiter & Schmidhuber, 1997) solves the vanishing gradient problem with one key idea: **add a separate memory highway**.

**Analogy:** Think of the vanilla RNN as passing a message through a long chain of people (telephone game). By the end, the message is garbled. The LSTM adds a **conveyor belt** running alongside the chain. Important information can be placed on the belt and travel long distances without being distorted.

This conveyor belt is called the **cell state** $C_t$. It runs through the entire sequence with only minor linear interactions, allowing gradients to flow unchanged over many steps.

**F1 analogy:** The LSTM gates are like a race engineer deciding what information to remember and what to forget lap by lap. The **forget gate** says: "the pit crew fumble 20 laps ago? Irrelevant now -- forget it." The **input gate** says: "tire age just crossed 15 laps -- that is critical, store it." The **cell state** is like the strategy whiteboard that carries forward only the information the engineer has decided is important. And the **output gate** decides what to broadcast to the driver right now: "Focus on tire management, ignore the rest."

The LSTM uses **four gates** to control what goes on and off the conveyor belt:

| Gate | Symbol | Purpose | Analogy | F1 Parallel |
|------|--------|---------|---------|-------------|
| **Forget gate** | $f_t$ | What old info to discard | Removing items from the belt | "Forget the formation lap rain -- the track is dry now" |
| **Input gate** | $i_t$ | What new info to add | Placing items on the belt | "Record that we just switched to hard tires" |
| **Cell update** | $\tilde{C}_t$ | Candidate new info | The actual items to place | The specific tire compound and age data |
| **Output gate** | $o_t$ | What to read from memory | Peeking at the belt | "Report current tire life and fuel delta to the driver" |

### LSTM Equations

#### Step 1: Forget Gate -- What to throw away

$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$

| Component | Meaning | Range | F1 Example |
|-----------|---------|-------|------------|
| $\sigma$ | Sigmoid function | 0 to 1 | -- |
| $f_t = 0$ | Completely forget | -- | "Discard last stint's tire data after a pit stop" |
| $f_t = 1$ | Completely remember | -- | "Keep tracking cumulative fuel burn" |

#### Step 2: Input Gate -- What new info to store

$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$
$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$$

The input gate $i_t$ decides **how much** to store, and $\tilde{C}_t$ creates **candidate values** to store.

#### Step 3: Cell State Update

$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$

**What this means:** Forget some of the old cell state ($f_t \odot C_{t-1}$), then add some new info ($i_t \odot \tilde{C}_t$). The $\odot$ symbol means element-wise multiplication.

**F1 analogy:** After a pit stop, the forget gate might zero out tire degradation data (new tires = fresh state), while the input gate stores the new compound type. Meanwhile, fuel data and track evolution data carry forward unchanged -- those are not "reset" by a pit stop.

#### Step 4: Output Gate -- What to output

$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$$
$$h_t = o_t \odot \tanh(C_t)$$

The hidden state $h_t$ is a filtered version of the cell state -- only outputting what is relevant right now.

### Visualization: LSTM Data Flow

In [None]:
fig, ax = plt.subplots(figsize=(14, 8))

# Main cell body
ax.add_patch(plt.Rectangle((1, 1), 10, 6, facecolor='#f5f5f5', edgecolor='black', lw=2, zorder=0))

# === Cell state highway (top) ===
ax.annotate('', xy=(11.5, 6.2), xytext=(-0.5, 6.2),
            arrowprops=dict(arrowstyle='->', color='green', lw=4))
ax.text(6, 6.7, 'Cell State $C_t$ (the conveyor belt)', ha='center', fontsize=13,
        fontweight='bold', color='green')

# === Forget Gate ===
# Gate circle
circle_f = plt.Circle((3, 6.2), 0.4, facecolor='#ffcdd2', edgecolor='red', lw=2, zorder=5)
ax.add_patch(circle_f)
ax.text(3, 6.2, 'x', ha='center', va='center', fontsize=16, fontweight='bold', color='red', zorder=6)

# Forget gate sigmoid
ax.add_patch(plt.Rectangle((2.3, 3.5), 1.4, 0.8, facecolor='#ffcdd2', edgecolor='red',
                            lw=2, zorder=3, alpha=0.9))
ax.text(3, 3.9, '$f_t$\n$\\sigma$', ha='center', va='center', fontsize=11, fontweight='bold', zorder=4)

# Arrow from forget gate to multiply
ax.annotate('', xy=(3, 5.8), xytext=(3, 4.3),
            arrowprops=dict(arrowstyle='->', color='red', lw=2))
ax.text(2.2, 5.0, 'forget', fontsize=9, color='red', fontstyle='italic')

# === Input Gate ===
# Gate circle (add)
circle_i = plt.Circle((6, 6.2), 0.4, facecolor='#c8e6c9', edgecolor='green', lw=2, zorder=5)
ax.add_patch(circle_i)
ax.text(6, 6.2, '+', ha='center', va='center', fontsize=18, fontweight='bold', color='green', zorder=6)

# Input gate sigmoid
ax.add_patch(plt.Rectangle((5.0, 3.5), 1.2, 0.8, facecolor='#c8e6c9', edgecolor='green',
                            lw=2, zorder=3, alpha=0.9))
ax.text(5.6, 3.9, '$i_t$\n$\\sigma$', ha='center', va='center', fontsize=11, fontweight='bold', zorder=4)

# Candidate values
ax.add_patch(plt.Rectangle((6.5, 3.5), 1.2, 0.8, facecolor='#fff9c4', edgecolor='orange',
                            lw=2, zorder=3, alpha=0.9))
ax.text(7.1, 3.9, '$\\tilde{C}_t$\ntanh', ha='center', va='center', fontsize=11, fontweight='bold', zorder=4)

# Multiply circle between input gate and candidate
circle_ic = plt.Circle((6, 5.0), 0.3, facecolor='#c8e6c9', edgecolor='green', lw=2, zorder=5)
ax.add_patch(circle_ic)
ax.text(6, 5.0, 'x', ha='center', va='center', fontsize=14, fontweight='bold', color='green', zorder=6)

# Arrows for input path
ax.annotate('', xy=(5.75, 4.7), xytext=(5.6, 4.3),
            arrowprops=dict(arrowstyle='->', color='green', lw=1.5))
ax.annotate('', xy=(6.25, 4.7), xytext=(7.1, 4.3),
            arrowprops=dict(arrowstyle='->', color='orange', lw=1.5))
ax.annotate('', xy=(6, 5.8), xytext=(6, 5.3),
            arrowprops=dict(arrowstyle='->', color='green', lw=2))

# === Output Gate ===
# tanh on cell state
circle_tanh = plt.Circle((9, 5.2), 0.35, facecolor='#e1bee7', edgecolor='purple', lw=2, zorder=5)
ax.add_patch(circle_tanh)
ax.text(9, 5.2, 'tanh', ha='center', va='center', fontsize=8, fontweight='bold', color='purple', zorder=6)

# Arrow from cell state down to tanh
ax.annotate('', xy=(9, 5.55), xytext=(9, 6.0),
            arrowprops=dict(arrowstyle='->', color='purple', lw=1.5))

# Output gate sigmoid
ax.add_patch(plt.Rectangle((8.4, 3.5), 1.2, 0.8, facecolor='#bbdefb', edgecolor='blue',
                            lw=2, zorder=3, alpha=0.9))
ax.text(9.0, 3.9, '$o_t$\n$\\sigma$', ha='center', va='center', fontsize=11, fontweight='bold', zorder=4)

# Multiply circle for output
circle_o = plt.Circle((9, 4.6), 0.3, facecolor='#bbdefb', edgecolor='blue', lw=2, zorder=5)
ax.add_patch(circle_o)
ax.text(9, 4.6, 'x', ha='center', va='center', fontsize=14, fontweight='bold', color='blue', zorder=6)

# Arrows for output path
ax.annotate('', xy=(9, 4.9), xytext=(9, 4.3),
            arrowprops=dict(arrowstyle='->', color='blue', lw=1.5))
ax.annotate('', xy=(8.75, 4.6), xytext=(9, 4.85),
            arrowprops=dict(arrowstyle='->', color='purple', lw=1.5))

# === Inputs at bottom ===
ax.annotate('', xy=(6, 3.5), xytext=(6, 1.5),
            arrowprops=dict(arrowstyle='->', color='gray', lw=2))
ax.text(6, 1.0, '$[h_{t-1}, x_t]$', ha='center', va='center', fontsize=14,
        bbox=dict(boxstyle='round,pad=0.4', facecolor='lightyellow', edgecolor='gray', lw=2))

# Connect input to all gates
for gate_x in [3.0, 5.6, 7.1, 9.0]:
    ax.plot([6, gate_x], [2.5, 3.5], 'gray', lw=1, alpha=0.5, linestyle='--')

# === Output ===
ax.annotate('', xy=(9, 0.5), xytext=(9, 4.3),
            arrowprops=dict(arrowstyle='->', color='blue', lw=2))
ax.text(9, 0.2, '$h_t$ (output)', ha='center', va='center', fontsize=13,
        bbox=dict(boxstyle='round,pad=0.4', facecolor='lightblue', edgecolor='blue', lw=2))

# === Hidden state arrow (bottom) ===
ax.annotate('', xy=(11.5, 1.6), xytext=(9.5, 1.6),
            arrowprops=dict(arrowstyle='->', color='blue', lw=3))
ax.text(11.5, 1.2, '$h_t$ to\nnext step', fontsize=10, ha='center', color='blue')

# Legend
legend_y = 7.8
legend_items = [
    (1.5, '#ffcdd2', 'red', 'Forget Gate: what to discard from cell state'),
    (4.5, '#c8e6c9', 'green', 'Input Gate: what new info to store'),
    (8.0, '#bbdefb', 'blue', 'Output Gate: what to output as hidden state'),
]
for lx, fc, ec, text in legend_items:
    ax.add_patch(plt.Rectangle((lx, legend_y), 0.4, 0.3, facecolor=fc, edgecolor=ec, lw=2))
    ax.text(lx + 0.6, legend_y + 0.15, text, va='center', fontsize=9)

ax.set_xlim(-1, 13)
ax.set_ylim(-0.3, 8.5)
ax.set_title('LSTM Cell Architecture', fontsize=16, fontweight='bold')
ax.axis('off')
plt.tight_layout()
plt.show()

### Why LSTM Solves Vanishing Gradients

The cell state update is:

$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$

The gradient of $C_t$ with respect to $C_{t-1}$ is simply $f_t$ (element-wise). When the forget gate is close to 1, the gradient flows through **unchanged**. Compare this to the vanilla RNN where the gradient must pass through $W_{hh}$ and tanh at every step.

**F1 analogy:** The cell state is a "gradient highway" -- like the DRS straight of information flow. Information about tire compound choice from lap 1 can travel all the way to lap 50 without being distorted, because the forget gate keeps it at 1.0 (remember everything) for the dimensions that store that information. The vanilla RNN is like taking the twisty sector -- the signal degrades at every corner.

### Visualization: RNN vs LSTM Gradient Flow

In [None]:
# Compare gradient flow: RNN vs LSTM
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

seq_lengths = np.arange(1, 101)
n_trials = 100

# Simulate vanilla RNN gradient flow
rnn_grads = []
for _ in range(n_trials):
    grad = 1.0
    grads_trial = [1.0]
    for t in range(100):
        # Each step: multiply by W_hh eigenvalue * tanh_deriv
        factor = np.random.uniform(0.3, 0.95)  # tanh derivative shrinks things
        grad *= factor
        grads_trial.append(grad)
    rnn_grads.append(grads_trial)
rnn_mean = np.mean(rnn_grads, axis=0)

# Simulate LSTM gradient flow (through cell state)
lstm_grads = []
for _ in range(n_trials):
    grad = 1.0
    grads_trial = [1.0]
    for t in range(100):
        # Cell state gradient: just multiply by forget gate (close to 1)
        forget_gate = np.random.uniform(0.85, 1.0)  # Forget gate typically near 1
        grad *= forget_gate
        grads_trial.append(grad)
    lstm_grads.append(grads_trial)
lstm_mean = np.mean(lstm_grads, axis=0)

# Plot comparison
ax = axes[0]
ax.plot(rnn_mean, 'r-', linewidth=2, label='Vanilla RNN')
ax.plot(lstm_mean, 'b-', linewidth=2, label='LSTM (cell state path)')
ax.set_xlabel('Steps Back in Time', fontsize=12)
ax.set_ylabel('Gradient Magnitude', fontsize=12)
ax.set_title('Gradient Flow: RNN vs LSTM', fontsize=14)
ax.set_yscale('log')
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
ax.axhline(y=1e-6, color='gray', linestyle='--', alpha=0.5)
ax.text(50, 2e-6, 'Effectively zero', fontsize=9, color='gray')

# Show the key difference visually
ax = axes[1]
steps = np.arange(20)

# RNN: repeated matrix multiply
rnn_path = 0.7 ** steps  # typical shrinkage factor per step
# LSTM: forget gate close to 1
lstm_path = 0.95 ** steps  # forget gate near 1

ax.bar(steps - 0.15, rnn_path, width=0.3, color='red', alpha=0.7, label='RNN gradient')
ax.bar(steps + 0.15, lstm_path, width=0.3, color='blue', alpha=0.7, label='LSTM gradient')
ax.set_xlabel('Steps Back in Time', fontsize=12)
ax.set_ylabel('Gradient Magnitude', fontsize=12)
ax.set_title('First 20 Steps: Why LSTM Remembers Better', fontsize=14)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"After 20 steps:")
print(f"  RNN gradient:  {0.7**20:.6f} (effectively zero)")
print(f"  LSTM gradient:  {0.95**20:.6f} (still meaningful!)")
print(f"\nAfter 100 steps:")
print(f"  RNN gradient:  {0.7**100:.2e}")
print(f"  LSTM gradient:  {0.95**100:.4f}")

---

## 5. GRU (Gated Recurrent Unit)

### Intuitive Explanation

The GRU (Cho et al., 2014) is a simplified version of the LSTM. Instead of separate cell state and hidden state with four gates, the GRU:

- Merges the cell state and hidden state into one
- Uses only **two gates** instead of four
- Has fewer parameters, so it trains faster

**Analogy:** If the LSTM is a full-featured word processor, the GRU is a streamlined text editor. It handles most tasks just as well with less complexity.

**F1 analogy:** The LSTM is like a full pit wall setup with separate displays for every metric and dedicated engineers for tires, fuel, aero, and strategy. The GRU is like a single combined dashboard that merges everything into fewer, more efficient displays. For most race scenarios, the streamlined setup works just as well -- and the team can react faster with less information overhead.

### GRU Equations

$$z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \quad \text{(update gate)}$$
$$r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \quad \text{(reset gate)}$$
$$\tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \quad \text{(candidate)}$$
$$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \quad \text{(final state)}$$

#### Breaking down the gates:

| Gate | Purpose | When value is 0 | When value is 1 | F1 Parallel |
|------|---------|-----------------|-----------------|-------------|
| **Update** $z_t$ | How much to update | Keep old state entirely | Replace with new candidate | "Nothing changed this lap" vs "pit stop -- reset everything" |
| **Reset** $r_t$ | How much past to forget | Ignore previous hidden state | Use full previous state | "Start fresh (new stint)" vs "carry forward all history" |

**What this means:** The update gate $z_t$ plays the role of both the forget and input gates in LSTM. When $z_t = 0$, the hidden state is copied forward unchanged (like LSTM's forget gate = 1). When $z_t = 1$, the state is completely replaced.

### GRU vs LSTM Comparison

| Feature | LSTM | GRU |
|---------|------|-----|
| Number of gates | 4 (forget, input, cell, output) | 2 (update, reset) |
| Separate cell state | Yes ($C_t$ and $h_t$) | No (only $h_t$) |
| Parameters | More (~4x hidden_size^2) | Fewer (~3x hidden_size^2) |
| Training speed | Slower | Faster |
| Performance on long sequences | Often better | Comparable |
| When to use | Long sequences, complex dependencies | Smaller datasets, faster training |

#### Key Insight

In practice, GRU and LSTM perform similarly on most tasks. The GRU is often preferred when computational resources are limited or the dataset is small. Try both and pick what works best for your specific problem.

### Visualization: RNN vs LSTM vs GRU Architecture Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Helper function to draw a simplified cell
def draw_cell(ax, title, gates, state_lines, color_scheme):
    """Draw a simplified RNN cell diagram."""
    ax.set_xlim(-0.5, 4.5)
    ax.set_ylim(-0.5, 4.5)
    
    # Cell body
    ax.add_patch(plt.Rectangle((0.5, 0.5), 3.5, 3.5, facecolor='#f9f9f9',
                                edgecolor='black', lw=2))
    
    # Input
    ax.text(2.25, -0.2, '$x_t, h_{t-1}$', ha='center', fontsize=10,
            bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='gray'))
    ax.annotate('', xy=(2.25, 0.5), xytext=(2.25, 0.1),
                arrowprops=dict(arrowstyle='->', color='gray', lw=2))
    
    # Gates
    gate_y = 1.5
    gate_width = 0.8
    for i, (name, gcolor) in enumerate(gates):
        gx = 1.0 + i * 1.2
        ax.add_patch(plt.Rectangle((gx - gate_width/2, gate_y - 0.3),
                                    gate_width, 0.6, facecolor=gcolor, edgecolor='black', lw=1.5))
        ax.text(gx, gate_y, name, ha='center', va='center', fontsize=8, fontweight='bold')
    
    # State lines
    for line_y, label, lcolor in state_lines:
        ax.annotate('', xy=(4.3, line_y), xytext=(0.2, line_y),
                    arrowprops=dict(arrowstyle='->', color=lcolor, lw=3))
        ax.text(2.25, line_y + 0.25, label, ha='center', fontsize=9,
                color=lcolor, fontweight='bold')
    
    # Output
    ax.text(3.5, 4.3, '$h_t$', ha='center', fontsize=11,
            bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='blue'))
    
    ax.set_title(title, fontsize=13, fontweight='bold', color=color_scheme)
    ax.axis('off')

# Vanilla RNN
draw_cell(axes[0], 'Vanilla RNN\n(1 operation)', 
          [('tanh', '#e8daef')],
          [(3.0, '$h_t$', 'purple')],
          'purple')
axes[0].text(2.25, 2.5, 'Simple but\nforgets quickly', ha='center', fontsize=10,
             fontstyle='italic', color='gray')

# LSTM
draw_cell(axes[1], 'LSTM\n(4 gates)',
          [('$f_t$', '#ffcdd2'), ('$i_t$', '#c8e6c9'), ('$o_t$', '#bbdefb')],
          [(3.5, '$C_t$ (cell state)', 'green'), (2.8, '$h_t$ (hidden)', 'blue')],
          'darkblue')
axes[1].text(2.25, 2.3, '+ candidate $\\tilde{C}_t$', ha='center', fontsize=9, color='gray')

# GRU
draw_cell(axes[2], 'GRU\n(2 gates)',
          [('$z_t$', '#c8e6c9'), ('$r_t$', '#ffcdd2')],
          [(3.0, '$h_t$', 'purple')],
          'darkgreen')
axes[2].text(2.25, 2.5, 'Simpler than LSTM\nsimilar performance', ha='center', fontsize=10,
             fontstyle='italic', color='gray')

plt.suptitle('Architecture Comparison', fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Parameter count comparison
hidden_size = 256
input_size = 100
print(f"\nParameter count comparison (input={input_size}, hidden={hidden_size}):")
print(f"  Vanilla RNN: {(input_size * hidden_size + hidden_size * hidden_size + hidden_size):,} params")
print(f"  LSTM:         {4 * (input_size * hidden_size + hidden_size * hidden_size + hidden_size):,} params")
print(f"  GRU:          {3 * (input_size * hidden_size + hidden_size * hidden_size + hidden_size):,} params")
print(f"\n  LSTM has {4/1:.0f}x more params than RNN, GRU has {3/1:.0f}x more")

---

## 6. RNNs in PyTorch

### Intuitive Explanation

PyTorch provides optimized implementations of all three architectures: `nn.RNN`, `nn.LSTM`, and `nn.GRU`. The API is consistent across all three, so once you learn one, you know them all.

The key challenge is understanding the **shape conventions**:

| Dimension | Meaning | Example | F1 Parallel |
|-----------|---------|---------|-------------|
| `seq_len` | Length of the sequence | 20 words in a sentence | 56 laps in a race |
| `batch` | Number of sequences in parallel | 32 sentences at once | 20 cars' race data simultaneously |
| `input_size` | Features per time step | 100-dim word embedding | [speed, throttle, brake, tire_temp, fuel] per lap |
| `hidden_size` | Size of hidden state | 256 units | How much "memory" the model carries |
| `num_layers` | Stacked RNN layers | 2 layers deep | Hierarchical analysis (raw -> events -> strategy) |
| `num_directions` | 1 (forward) or 2 (bidirectional) | -- | -- |

**Input shape:** `(seq_len, batch, input_size)` -- note that sequence length comes FIRST by default.

**Output shape:** `(seq_len, batch, hidden_size * num_directions)` -- hidden state at every time step.

**Hidden state shape:** `(num_layers * num_directions, batch, hidden_size)` -- final hidden state.

In [None]:
# Basic PyTorch RNN usage
print("=" * 60)
print("nn.RNN - Basic Usage")
print("=" * 60)

# Create an RNN
rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=1, batch_first=False)

# Input: (seq_len=5, batch=3, input_size=10)
x = torch.randn(5, 3, 10)

# Forward pass
output, h_n = rnn(x)

print(f"Input shape:       {x.shape}  (seq_len, batch, input_size)")
print(f"Output shape:      {output.shape}  (seq_len, batch, hidden_size)")
print(f"Final hidden shape: {h_n.shape}  (num_layers, batch, hidden_size)")
print()

# Output contains hidden states at ALL time steps
# h_n contains ONLY the final hidden state
print("Verify: output[-1] == h_n[0] (last output equals final hidden state)")
print(f"  Match: {torch.allclose(output[-1], h_n[0])}")

print("\n" + "=" * 60)
print("nn.LSTM - Returns (output, (h_n, c_n))")
print("=" * 60)

lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=1)
output, (h_n, c_n) = lstm(x)

print(f"Input shape:       {x.shape}")
print(f"Output shape:      {output.shape}")
print(f"Final hidden (h):  {h_n.shape}")
print(f"Final cell (c):    {c_n.shape}  <-- LSTM also returns cell state!")

print("\n" + "=" * 60)
print("nn.GRU - Same API as RNN")
print("=" * 60)

gru = nn.GRU(input_size=10, hidden_size=20, num_layers=1)
output, h_n = gru(x)

print(f"Input shape:       {x.shape}")
print(f"Output shape:      {output.shape}")
print(f"Final hidden (h):  {h_n.shape}")

### Bidirectional RNNs

A standard RNN only reads the sequence left-to-right. But sometimes context from the **future** matters too. For example, in "I saw the bank by the river," the word "river" helps disambiguate "bank."

A **bidirectional** RNN runs two separate RNNs: one forward and one backward. Their outputs are concatenated.

**F1 analogy:** A forward-only RNN is like watching the race in real time -- you can only see what has already happened. A bidirectional RNN is like analyzing the race after it is over, where you can look both backward (what led to this lap) and forward (what happened next). Post-race analysis is naturally bidirectional: knowing that a driver pitted on lap 35 helps you understand why their pace dropped on lap 33.

### Stacking RNN Layers

Like feedforward networks, we can stack multiple RNN layers. The output of one layer becomes the input of the next. This lets the network learn hierarchical representations of sequences.

In [None]:
# Bidirectional LSTM
print("=" * 60)
print("Bidirectional LSTM")
print("=" * 60)

bi_lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=1, bidirectional=True)
x = torch.randn(5, 3, 10)  # (seq_len=5, batch=3, features=10)
output, (h_n, c_n) = bi_lstm(x)

print(f"Input:           {x.shape}")
print(f"Output:          {output.shape}  <-- hidden_size * 2 = {20*2} (forward + backward concatenated)")
print(f"Hidden states:   {h_n.shape}  <-- 2 directions * 1 layer = 2")
print(f"  h_n[0] = final forward hidden state")
print(f"  h_n[1] = final backward hidden state")

print("\n" + "=" * 60)
print("Stacked (Multi-layer) LSTM")
print("=" * 60)

stacked_lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=3, dropout=0.2)
output, (h_n, c_n) = stacked_lstm(x)

print(f"Input:           {x.shape}")
print(f"Output:          {output.shape}  <-- output from LAST layer only")
print(f"Hidden states:   {h_n.shape}  <-- 3 layers, each with its own hidden state")
print(f"  h_n[0] = layer 1 final hidden state")
print(f"  h_n[1] = layer 2 final hidden state")
print(f"  h_n[2] = layer 3 final hidden state")

print("\n" + "=" * 60)
print("Stacked + Bidirectional")
print("=" * 60)

combo = nn.LSTM(input_size=10, hidden_size=20, num_layers=2, bidirectional=True)
output, (h_n, c_n) = combo(x)

print(f"Input:           {x.shape}")
print(f"Output:          {output.shape}  <-- hidden * 2 directions = {20*2}")
print(f"Hidden states:   {h_n.shape}  <-- 2 layers * 2 directions = 4")

print("\n" + "=" * 60)
print("batch_first=True (more intuitive ordering)")
print("=" * 60)

lstm_bf = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
x_bf = torch.randn(3, 5, 10)  # (batch=3, seq_len=5, features=10)
output, (h_n, c_n) = lstm_bf(x_bf)

print(f"Input:           {x_bf.shape}  (batch, seq_len, features)")
print(f"Output:          {output.shape}  (batch, seq_len, hidden_size)")

### Interactive Exploration: How Hidden Size Affects Capacity

Let's see how different hidden sizes affect an RNN's ability to learn a simple sequence pattern.

**F1 analogy:** The hidden size is like the bandwidth of the race engineer's mental model. A small hidden size (say, 8) might only track a couple of metrics -- tire age and fuel. A large hidden size (256) can simultaneously track degradation curves, gaps to every competitor, weather forecasts, and optimal pit windows. More capacity means richer race understanding, but also more data needed to train.

In [None]:
# Train simple RNNs with different hidden sizes to predict a sine wave
def train_sine_predictor(hidden_size, rnn_type='LSTM', epochs=200):
    """
    Train an RNN to predict the next value in a sine wave.
    
    Args:
        hidden_size: Number of hidden units
        rnn_type: 'RNN', 'LSTM', or 'GRU'
        epochs: Number of training epochs
    
    Returns:
        List of losses per epoch
    """
    torch.manual_seed(42)
    
    # Generate sine wave data
    t = torch.linspace(0, 8 * np.pi, 200)
    data = torch.sin(t)
    
    # Create sequences: use 20 steps to predict the next
    seq_len = 20
    X, y = [], []
    for i in range(len(data) - seq_len):
        X.append(data[i:i+seq_len])
        y.append(data[i+seq_len])
    X = torch.stack(X).unsqueeze(-1)  # (samples, seq_len, 1)
    y = torch.stack(y).unsqueeze(-1)  # (samples, 1)
    
    # Build model
    RNNClass = {'RNN': nn.RNN, 'LSTM': nn.LSTM, 'GRU': nn.GRU}[rnn_type]
    rnn = RNNClass(input_size=1, hidden_size=hidden_size, batch_first=True)
    fc = nn.Linear(hidden_size, 1)
    
    params = list(rnn.parameters()) + list(fc.parameters())
    optimizer = optim.Adam(params, lr=0.01)
    criterion = nn.MSELoss()
    
    losses = []
    for epoch in range(epochs):
        optimizer.zero_grad()
        output, _ = rnn(X)
        pred = fc(output[:, -1, :])  # Use last hidden state
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    
    return losses

# Compare different hidden sizes
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Vary hidden size
ax = axes[0]
for hs in [2, 8, 32, 128]:
    losses = train_sine_predictor(hs, 'LSTM', epochs=200)
    ax.plot(losses, linewidth=2, label=f'hidden_size={hs}')

ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('Effect of Hidden Size (LSTM)')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

# Compare architectures
ax = axes[1]
for rnn_type, color in [('RNN', 'red'), ('LSTM', 'blue'), ('GRU', 'green')]:
    losses = train_sine_predictor(32, rnn_type, epochs=200)
    ax.plot(losses, linewidth=2, color=color, label=rnn_type)

ax.set_xlabel('Epoch')
ax.set_ylabel('MSE Loss')
ax.set_title('RNN vs LSTM vs GRU (hidden_size=32)')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Why This Matters in Machine Learning

| Application | Architecture | Key Details | F1 Parallel |
|-------------|-------------|-------------|-------------|
| Sentiment analysis | Bidirectional LSTM | Use final hidden state for classification | Post-race analysis of radio messages for driver mood |
| Machine translation | Encoder-decoder LSTM | Encoder summarizes input, decoder generates output | Translating one team's strategy into predicted race outcome |
| Speech recognition | Stacked bidirectional LSTM | Multiple layers capture different time scales | Decoding team radio through engine noise at multiple frequencies |
| Time series forecasting | LSTM or GRU | Use last output to predict next value | Predicting next-lap tire degradation from race history |
| Music generation | LSTM with sampling | Generate one note at a time | Generating synthetic telemetry for simulation testing |
| Named entity recognition | Bidirectional LSTM | Output at every time step | Classifying each lap as "push," "conserve," or "pit window" |

### Packed Sequences (Brief Note)

When processing batches of sequences with **different lengths**, PyTorch provides `pack_padded_sequence` and `pad_packed_sequence`. These ensure the RNN does not waste computation on padding tokens:

```python
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# lengths = actual length of each sequence in the batch
packed = pack_padded_sequence(padded_input, lengths, batch_first=True, enforce_sorted=False)
output, hidden = lstm(packed)
output, lengths = pad_packed_sequence(output, batch_first=True)
```

---

## 7. Text Generation Project

### Character-Level Language Model

Now let's put everything together and build a **character-level language model**. This model:

1. Reads text one character at a time
2. Learns patterns in the text (spelling, common words, structure)
3. Generates new text by predicting one character at a time

We will train on a small text corpus and watch the generated text improve from random gibberish to recognizable patterns.

**F1 analogy:** This is analogous to a model that reads race telemetry one sample at a time and learns to generate synthetic telemetry. After training on enough real laps, the model would produce realistic-looking speed traces, brake pressure curves, and steering inputs -- capturing the statistical patterns of actual driving without ever having been on track.

In [None]:
# Our training text -- a collection of short passages
text = """To be or not to be that is the question
Whether tis nobler in the mind to suffer
The slings and arrows of outrageous fortune
Or to take arms against a sea of troubles
And by opposing end them To die to sleep
No more and by a sleep to say we end
The heartache and the thousand natural shocks
That flesh is heir to Tis a consummation
Devoutly to be wished To die to sleep
To sleep perchance to dream ay there is the rub
For in that sleep of death what dreams may come
When we have shuffled off this mortal coil
Must give us pause There is the respect
That makes calamity of so long life
For who would bear the whips and scorns of time
The oppressor wrong the proud man contumely
The pangs of despised love the law delay
The insolence of office and the spurns
That patient merit of the unworthy takes"""

# Build character vocabulary
chars = sorted(list(set(text)))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for ch, i in char_to_idx.items()}
vocab_size = len(chars)

print(f"Text length: {len(text)} characters")
print(f"Vocabulary size: {vocab_size} unique characters")
print(f"Characters: {''.join(chars)}")

# Encode the text as integers
encoded = torch.tensor([char_to_idx[ch] for ch in text], dtype=torch.long)
print(f"\nFirst 50 chars: '{text[:50]}'")
print(f"Encoded:        {encoded[:50].tolist()}")

In [None]:
# Create training sequences
def create_sequences(encoded_text, seq_len=40):
    """
    Create input-target pairs for training.
    
    For each position, the input is a sequence of characters and
    the target is the same sequence shifted by one character.
    """
    inputs, targets = [], []
    for i in range(0, len(encoded_text) - seq_len):
        inputs.append(encoded_text[i:i+seq_len])
        targets.append(encoded_text[i+1:i+seq_len+1])
    return torch.stack(inputs), torch.stack(targets)

seq_len = 40
X, y = create_sequences(encoded, seq_len)
print(f"Training sequences: {X.shape[0]}")
print(f"Sequence length: {seq_len}")
print(f"\nExample input:  '{''.join([idx_to_char[i.item()] for i in X[0]])}'")
print(f"Example target: '{''.join([idx_to_char[i.item()] for i in y[0]])}'")
print("\nNotice: target is shifted by 1 character (predict the next character)")

In [None]:
class CharLSTM(nn.Module):
    """
    Character-level language model using LSTM.
    
    Architecture:
        Embedding -> LSTM -> Linear -> Softmax (over characters)
    
    Args:
        vocab_size: Number of unique characters
        embed_size: Dimension of character embeddings
        hidden_size: LSTM hidden state size
        num_layers: Number of stacked LSTM layers
    """
    def __init__(self, vocab_size, embed_size=32, hidden_size=128, num_layers=2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Character embedding: maps integer indices to dense vectors
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        # LSTM layers
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers,
                           batch_first=True, dropout=0.2 if num_layers > 1 else 0)
        
        # Output projection: hidden state -> character probabilities
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x, hidden=None):
        """
        Forward pass.
        
        Args:
            x: Input character indices, shape (batch, seq_len)
            hidden: Optional initial hidden state
        
        Returns:
            output: Character logits at each position, shape (batch, seq_len, vocab_size)
            hidden: Final hidden state tuple (h_n, c_n)
        """
        # Embed characters: (batch, seq_len) -> (batch, seq_len, embed_size)
        embedded = self.embed(x)
        
        # LSTM: (batch, seq_len, embed_size) -> (batch, seq_len, hidden_size)
        lstm_out, hidden = self.lstm(embedded, hidden)
        
        # Project to vocabulary: (batch, seq_len, hidden_size) -> (batch, seq_len, vocab_size)
        output = self.fc(lstm_out)
        
        return output, hidden
    
    def generate(self, start_char, char_to_idx, idx_to_char, length=200, temperature=1.0):
        """
        Generate text one character at a time.
        
        Args:
            start_char: Starting character
            char_to_idx: Character to index mapping
            idx_to_char: Index to character mapping
            length: Number of characters to generate
            temperature: Controls randomness (lower = more deterministic)
        
        Returns:
            Generated text string
        """
        self.eval()
        generated = [start_char]
        x = torch.tensor([[char_to_idx[start_char]]])
        hidden = None
        
        with torch.no_grad():
            for _ in range(length - 1):
                output, hidden = self(x, hidden)
                
                # Apply temperature scaling
                logits = output[0, -1, :] / temperature
                probs = torch.softmax(logits, dim=0)
                
                # Sample from the distribution
                next_idx = torch.multinomial(probs, 1).item()
                generated.append(idx_to_char[next_idx])
                
                x = torch.tensor([[next_idx]])
        
        self.train()
        return ''.join(generated)

model = CharLSTM(vocab_size, embed_size=32, hidden_size=128, num_layers=2)
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
torch.manual_seed(42)
model = CharLSTM(vocab_size, embed_size=32, hidden_size=128, num_layers=2)
optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()

# Training
epochs = 300
losses = []
samples = {}

print("Training character-level language model...")
print("=" * 60)

# Generate sample before training
sample = model.generate('T', char_to_idx, idx_to_char, length=100, temperature=0.8)
samples[0] = sample
print(f"\nEpoch 0 (before training):")
print(f"  '{sample[:80]}...'")

for epoch in range(1, epochs + 1):
    model.train()
    
    # Shuffle training data
    perm = torch.randperm(X.shape[0])
    X_shuffled = X[perm]
    y_shuffled = y[perm]
    
    epoch_loss = 0
    batch_size = 64
    n_batches = 0
    
    for i in range(0, X.shape[0], batch_size):
        X_batch = X_shuffled[i:i+batch_size]
        y_batch = y_shuffled[i:i+batch_size]
        
        optimizer.zero_grad()
        output, _ = model(X_batch)
        
        # Reshape for cross entropy: (batch * seq_len, vocab_size) vs (batch * seq_len,)
        loss = criterion(output.reshape(-1, vocab_size), y_batch.reshape(-1))
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        
        optimizer.step()
        epoch_loss += loss.item()
        n_batches += 1
    
    avg_loss = epoch_loss / n_batches
    losses.append(avg_loss)
    
    # Print progress and generate samples at milestones
    if epoch in [50, 100, 150, 200, 300]:
        sample = model.generate('T', char_to_idx, idx_to_char, length=100, temperature=0.8)
        samples[epoch] = sample
        print(f"\nEpoch {epoch} (loss={avg_loss:.4f}):")
        print(f"  '{sample[:80]}...'")

print("\n" + "=" * 60)
print("Training complete!")

In [None]:
# Plot training loss
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(losses, 'b-', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax.set_title('Character Language Model Training Loss', fontsize=14)
ax.grid(True, alpha=0.3)

# Mark sample points
for epoch in samples:
    if epoch > 0:
        ax.axvline(x=epoch-1, color='red', linestyle='--', alpha=0.3)
        ax.text(epoch-1, ax.get_ylim()[1] * 0.95, f'Epoch {epoch}',
                fontsize=8, rotation=90, va='top', color='red')

plt.tight_layout()
plt.show()

### Temperature Sampling

The **temperature** parameter controls how "creative" vs "conservative" the model is when generating text:

| Temperature | Effect | Use Case | F1 Parallel |
|-------------|--------|----------|-------------|
| 0.2 (low) | Very predictable, repetitive | When accuracy matters | Conservative strategy: follow the data, minimize risk |
| 0.8 (medium) | Balanced creativity | General text generation | Balanced strategy: react to conditions but stay within model |
| 1.0 (default) | True model distribution | Standard sampling | Neutral: let the probabilities speak for themselves |
| 1.5+ (high) | Very random, creative | Brainstorming, art | Aggressive strategy: gamble on an unusual call for a podium |

Mathematically, temperature divides the logits before softmax:

$$P(c) = \frac{e^{z_c / T}}{\sum_j e^{z_j / T}}$$

Low temperature makes the distribution **sharper** (peak character dominates). High temperature makes it **flatter** (more uniform).

In [None]:
# Demonstrate temperature effect
print("Generated text at different temperatures:")
print("=" * 60)

for temp in [0.2, 0.5, 0.8, 1.0, 1.5]:
    sample = model.generate('T', char_to_idx, idx_to_char, length=120, temperature=temp)
    print(f"\nTemperature = {temp}:")
    print(f"  '{sample[:100]}...'")

# Visualize temperature effect on probability distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Create sample logits
logits = torch.tensor([2.0, 1.5, 0.5, 0.2, -0.5, -1.0, -1.5, -2.0])
labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']

for ax, temp in zip(axes, [0.2, 1.0, 2.0]):
    probs = torch.softmax(logits / temp, dim=0).numpy()
    bars = ax.bar(labels, probs, color='steelblue', edgecolor='black', alpha=0.8)
    bars[0].set_color('orange')  # Highlight most likely
    ax.set_xlabel('Character')
    ax.set_ylabel('Probability')
    ax.set_title(f'Temperature = {temp}')
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3)

plt.suptitle('How Temperature Affects the Probability Distribution', fontsize=14)
plt.tight_layout()
plt.show()

---

## Exercises

### Exercise 1: Implement a GRU Cell from Scratch

Implement the GRU equations in NumPy, following the same pattern as the VanillaRNN class.

**F1 scenario:** You are building a lightweight race-state tracker that updates the car's condition lap by lap using only two gates (update and reset) instead of four. This is the GRU approach -- simpler than the full LSTM pit-wall setup, but often just as effective for predicting tire degradation curves.

In [None]:
# EXERCISE 1: Implement GRU from scratch
def sigmoid(x):
    """Numerically stable sigmoid."""
    return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))

class GRUFromScratch:
    """
    GRU implemented from scratch in NumPy.
    
    Equations:
        z_t = sigmoid(W_z @ [h_{t-1}, x_t] + b_z)    # update gate
        r_t = sigmoid(W_r @ [h_{t-1}, x_t] + b_r)    # reset gate
        h_tilde = tanh(W_h @ [r_t * h_{t-1}, x_t] + b_h)  # candidate
        h_t = (1 - z_t) * h_{t-1} + z_t * h_tilde    # final state
    """
    def __init__(self, input_size, hidden_size):
        scale = np.sqrt(2.0 / (input_size + hidden_size))
        combined = input_size + hidden_size
        
        # TODO: Initialize weights for update gate (W_z, b_z)
        # TODO: Initialize weights for reset gate (W_r, b_r)
        # TODO: Initialize weights for candidate (W_h, b_h)
        # Hint: Each W should be shape (hidden_size, input_size + hidden_size)
        # Hint: Each b should be shape (hidden_size,)
        
        self.W_z = np.random.randn(hidden_size, combined) * scale
        self.b_z = np.zeros(hidden_size)
        self.W_r = np.random.randn(hidden_size, combined) * scale
        self.b_r = np.zeros(hidden_size)
        self.W_h = np.random.randn(hidden_size, combined) * scale
        self.b_h = np.zeros(hidden_size)
        
        self.hidden_size = hidden_size
    
    def forward(self, inputs, h_prev=None):
        """
        Process a sequence through the GRU.
        
        Args:
            inputs: List of input vectors
            h_prev: Initial hidden state
        
        Returns:
            hidden_states: List of hidden states
        """
        if h_prev is None:
            h_prev = np.zeros(self.hidden_size)
        
        hidden_states = [h_prev]
        
        for x_t in inputs:
            h = hidden_states[-1]
            
            # TODO: Implement the GRU equations
            # Step 1: Concatenate h and x_t
            # Step 2: Compute update gate z_t
            # Step 3: Compute reset gate r_t
            # Step 4: Compute candidate h_tilde (using r_t * h concatenated with x_t)
            # Step 5: Compute final state h_t
            
            # Hint: np.concatenate([h, x_t])
            
            pass  # Replace with your implementation
        
        return hidden_states

# Test your implementation
np.random.seed(42)
gru_scratch = GRUFromScratch(input_size=3, hidden_size=4)
test_seq = [np.random.randn(3) for _ in range(5)]

# If implemented correctly, this should produce 6 hidden states (initial + 5 steps)
# hidden_states = gru_scratch.forward(test_seq)
# print(f"Number of hidden states: {len(hidden_states)}")
# for t, h in enumerate(hidden_states):
#     print(f"  h_{t}: {h.round(3)}")

# Verify against PyTorch GRU
print("Verify: PyTorch GRU output for reference:")
torch.manual_seed(42)
gru_pt = nn.GRU(input_size=3, hidden_size=4, batch_first=True)
x_pt = torch.tensor(np.stack(test_seq)).unsqueeze(0).float()
out_pt, h_pt = gru_pt(x_pt)
print(f"  PyTorch final hidden: {h_pt.squeeze().detach().numpy().round(3)}")
print(f"  (Your implementation should produce similar-magnitude values)")

### Exercise 2: Sequence Classification with LSTM

Build an LSTM that classifies whether a sequence of numbers is trending up or down.

**F1 scenario:** Think of this as classifying whether a driver's pace over a stint is improving (trending down in lap times as they warm up the tires and the car gets lighter) or degrading (trending up as tires wear out). The LSTM reads the sequence of lap times and outputs a single classification: "improving" or "degrading."

In [None]:
# EXERCISE 2: Sequence Classification

# Generate data: sequences that trend up (label=1) or down (label=0)
def generate_trend_data(n_samples=500, seq_len=20):
    """Generate sequences with upward or downward trends."""
    X = []
    y = []
    for _ in range(n_samples):
        if np.random.random() > 0.5:
            # Upward trend
            slope = np.random.uniform(0.05, 0.2)
            seq = slope * np.arange(seq_len) + np.random.randn(seq_len) * 0.3
            label = 1
        else:
            # Downward trend
            slope = np.random.uniform(-0.2, -0.05)
            seq = slope * np.arange(seq_len) + np.random.randn(seq_len) * 0.3
            label = 0
        X.append(seq)
        y.append(label)
    
    X = torch.tensor(np.array(X), dtype=torch.float32).unsqueeze(-1)  # (N, seq_len, 1)
    y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)  # (N, 1)
    return X, y

X_train, y_train = generate_trend_data(500, 20)
X_test, y_test = generate_trend_data(100, 20)

print(f"Training data: {X_train.shape}, Labels: {y_train.shape}")
print(f"Test data: {X_test.shape}, Labels: {y_test.shape}")

class TrendClassifier(nn.Module):
    """
    LSTM-based sequence classifier.
    
    TODO: Implement this model with:
    1. An LSTM layer (input_size=1, hidden_size=32)
    2. A Linear layer that maps hidden_size -> 1
    
    The forward method should:
    1. Pass input through LSTM
    2. Take the LAST hidden state output[:, -1, :]
    3. Pass through linear layer
    4. Return logits (no sigmoid -- use BCEWithLogitsLoss)
    """
    def __init__(self, hidden_size=32):
        super().__init__()
        # TODO: Define layers
        # Hint: self.lstm = nn.LSTM(input_size=1, hidden_size=hidden_size, batch_first=True)
        # Hint: self.fc = nn.Linear(hidden_size, 1)
        pass
    
    def forward(self, x):
        # TODO: Implement forward pass
        # Hint: output, (h_n, c_n) = self.lstm(x)
        # Hint: return self.fc(output[:, -1, :])
        pass

# Train and evaluate
# model = TrendClassifier(hidden_size=32)
# optimizer = optim.Adam(model.parameters(), lr=0.01)
# criterion = nn.BCEWithLogitsLoss()
# 
# for epoch in range(50):
#     optimizer.zero_grad()
#     output = model(X_train)
#     loss = criterion(output, y_train)
#     loss.backward()
#     optimizer.step()
#     
#     if (epoch + 1) % 10 == 0:
#         with torch.no_grad():
#             pred = (torch.sigmoid(model(X_test)) > 0.5).float()
#             acc = (pred == y_test).float().mean()
#             print(f"Epoch {epoch+1}: loss={loss.item():.4f}, test_acc={acc.item():.4f}")

# Expected: Test accuracy > 0.90 after 50 epochs
print("\n(Uncomment the training code above after implementing TrendClassifier)")

### Exercise 3: Exploring LSTM Gates

Use PyTorch hooks to visualize what the LSTM gates are doing during text generation. This exercise explores how the forget and input gates behave on real data.

**F1 scenario:** Imagine peering inside the LSTM during a race simulation. When does the forget gate fire (what information is being discarded)? When does the input gate spike (what new information is being stored)? Visualizing gate activations is like watching the strategy engineer's decision-making in real time -- seeing exactly when they decide to forget old tire data and store new pit stop information.

In [None]:
# EXERCISE 3: Visualize LSTM gate activations

def visualize_lstm_gates(model, text_input, char_to_idx, idx_to_char):
    """
    Feed a string through the trained CharLSTM and visualize gate activations.
    
    TODO: 
    1. Encode the text as indices
    2. Pass through the model's embedding layer
    3. Manually step through the LSTM to record gate activations
    
    Hint: PyTorch LSTM stores weights as [W_ii, W_if, W_ig, W_io] concatenated
    where i=input, f=forget, g=cell candidate, o=output
    
    For a simpler approach, just feed the text through and visualize
    the hidden state heatmap (similar to what we did with the vanilla RNN).
    """
    model.eval()
    
    # Encode text
    indices = torch.tensor([[char_to_idx[ch] for ch in text_input]])
    
    # Get hidden states at each position
    with torch.no_grad():
        embedded = model.embed(indices)
        output, _ = model.lstm(embedded)
        # output shape: (1, seq_len, hidden_size)
        hidden_states = output.squeeze(0).numpy()
    
    # Visualize
    fig, axes = plt.subplots(2, 1, figsize=(14, 6))
    
    # Hidden state heatmap
    ax = axes[0]
    im = ax.imshow(hidden_states.T[:32], aspect='auto', cmap='RdBu', vmin=-1, vmax=1)
    ax.set_ylabel('Hidden Unit (first 32)')
    ax.set_title('LSTM Hidden State Activations')
    plt.colorbar(im, ax=ax)
    
    # Set x-axis to show characters
    ax.set_xticks(range(len(text_input)))
    ax.set_xticklabels(list(text_input), fontsize=8)
    
    # Mean activation magnitude per character
    ax = axes[1]
    mean_act = np.abs(hidden_states).mean(axis=1)
    ax.bar(range(len(text_input)), mean_act, color='steelblue', alpha=0.8)
    ax.set_xticks(range(len(text_input)))
    ax.set_xticklabels(list(text_input), fontsize=8)
    ax.set_ylabel('Mean |Activation|')
    ax.set_title('Activation Magnitude Per Character')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Visualize with a sample from the training text
sample_text = "To be or not to be"
visualize_lstm_gates(model, sample_text, char_to_idx, idx_to_char)
print("Notice how the hidden state changes as the model processes each character.")
print("Spaces and word boundaries often cause distinct activation patterns.")

---

## Summary

### Key Concepts

**Why Sequences Need RNNs:**
- Feedforward networks require fixed-size inputs and ignore order
- Sequential data (text, audio, time series) needs memory of past context
- RNNs process one step at a time, maintaining a hidden state as memory
- **F1 parallel:** A Grand Prix is inherently sequential -- each lap depends on cumulative tire wear, fuel burn, and track evolution from all previous laps

**Vanilla RNN:**
- Core equation: $h_t = \tanh(W_{hh} \cdot h_{t-1} + W_{xh} \cdot x_t + b)$
- Same weights shared across all time steps (parameter efficiency)
- Hidden state is a compressed summary of the entire sequence so far
- Limited by vanishing gradients for long sequences
- **F1 parallel:** The hidden state is the car's accumulated wear state -- but vanilla RNNs "forget" what happened on lap 1 by the time they reach lap 50

**Backpropagation Through Time (BPTT):**
- Gradients flow backward through unrolled time steps
- Repeated multiplication causes exponential vanishing or exploding
- Gradient clipping fixes explosions; gated architectures fix vanishing
- **F1 parallel:** Like a radio relay chain -- the signal degrades at every station unless you build a direct link

**LSTM:**
- Cell state provides a "gradient highway" for long-range dependencies
- Four gates: forget (what to discard), input (what to store), cell update (new candidates), output (what to reveal)
- Key: cell state gradient is just $f_t$ -- near 1 means gradient flows unchanged
- **F1 parallel:** Gates decide what race info to remember vs forget. Tire age matters; a pit crew fumble 20 laps ago does not. The cell state is the strategy whiteboard.

**GRU:**
- Simplified LSTM with two gates: update and reset
- Fewer parameters, faster training, similar performance
- Merges cell state and hidden state into one
- **F1 parallel:** A streamlined dashboard vs the full pit wall setup -- less overhead, similar results

**PyTorch RNN Modules:**
- Consistent API: `nn.RNN`, `nn.LSTM`, `nn.GRU`
- Shape convention: `(seq_len, batch, features)` or `batch_first=True`
- Support for bidirectional and stacked layers

### Connection to Deep Learning

| Concept | Application | F1 Parallel |
|---------|------------|-------------|
| Vanilla RNN | Simple sequence tasks, educational baseline | Basic lap-by-lap state tracking |
| LSTM | Machine translation, speech recognition, text generation | Full race-state modeling with selective memory |
| GRU | Smaller models, mobile deployment, quick prototyping | Lightweight real-time tire degradation predictor |
| Bidirectional RNN | NER, sentiment analysis, any task needing full context | Post-race analysis looking forward and backward |
| Character-level LM | Text generation, spelling correction, data augmentation | Synthetic telemetry generation for simulation |
| Temperature sampling | Controlling creativity in generation | Conservative vs aggressive strategy selection |
| Gradient clipping | Essential for stable RNN training | Keeping the signal from exploding in long races |

### Checklist

- [ ] I can explain why feedforward networks fail on sequences
- [ ] I can implement a vanilla RNN from scratch and describe each component
- [ ] I understand why gradients vanish/explode and how BPTT causes it
- [ ] I can draw the LSTM architecture and explain each gate's purpose
- [ ] I can compare GRU and LSTM and choose appropriately
- [ ] I can use PyTorch RNN modules with correct input/output shapes
- [ ] I can build and train a character-level language model
- [ ] I understand temperature sampling and its effect on generation

---

## Next Steps

Now that you understand recurrent architectures, you are ready for the next major breakthrough in sequence modeling:

1. **Attention Mechanisms**: Instead of compressing the entire sequence into a fixed-size hidden state, attention lets the model "look back" at any position. This solves the information bottleneck of RNNs. In F1 terms: instead of relying on a compressed race summary, the model can look back at any specific lap and ask "what happened there?"

2. **Transformers**: Built entirely on attention (no recurrence), transformers process all positions in parallel. They power GPT, BERT, and virtually all modern language models.

3. **Sequence-to-Sequence Models**: Encoder-decoder architectures for translation, summarization, and other sequence transformation tasks.

**Historical context:** RNNs dominated NLP from 2013-2017. The Transformer (Vaswani et al., 2017) largely replaced them for most tasks. However, RNNs remain valuable for:
- Understanding the foundations of sequence modeling
- Resource-constrained environments
- Streaming data where you process one element at a time
- State-space models (Mamba, etc.) which borrow RNN-like ideas

**Practical next steps:**
- Try training the character-level model on a larger text corpus
- Build a sentiment classifier using bidirectional LSTM
- Experiment with different sequence lengths to observe vanishing gradients firsthand
- Compare RNN and Transformer performance on the same sequence task