# Tutorial 13: Recurrent Neural Networks

Implementing RNNs and LSTMs from scratch, understanding vanishing gradients.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
np.random.seed(42)
torch.manual_seed(42)

## Part 1: Basic RNN from Scratch

In [None]:
class SimpleRNN:
    """
    Vanilla RNN implemented from scratch.
    h_t = tanh(W_hh @ h_{t-1} + W_xh @ x_t + b_h)
    """
    
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        
        # Initialize weights (Xavier)
        scale = np.sqrt(2.0 / (input_size + hidden_size))
        self.W_xh = np.random.randn(hidden_size, input_size) * scale
        self.W_hh = np.random.randn(hidden_size, hidden_size) * scale
        self.b_h = np.zeros((hidden_size, 1))
        
    def forward(self, x, h_prev):
        """
        Single step forward.
        x: (input_size, 1)
        h_prev: (hidden_size, 1)
        """
        z = self.W_hh @ h_prev + self.W_xh @ x + self.b_h
        h = np.tanh(z)
        return h, z  # Return z for backward pass
    
    def forward_sequence(self, xs, h0=None):
        """
        Process entire sequence.
        xs: list of (input_size, 1) arrays
        """
        if h0 is None:
            h0 = np.zeros((self.hidden_size, 1))
        
        hs = [h0]
        zs = []
        
        h = h0
        for x in xs:
            h, z = self.forward(x, h)
            hs.append(h)
            zs.append(z)
        
        return hs, zs

# Test
rnn = SimpleRNN(input_size=4, hidden_size=8)

# Random sequence of length 5
sequence = [np.random.randn(4, 1) for _ in range(5)]
hs, zs = rnn.forward_sequence(sequence)

print(f"Input sequence length: {len(sequence)}")
print(f"Hidden states: {len(hs)} (including h0)")
print(f"Final hidden state shape: {hs[-1].shape}")

## Part 2: Visualize Vanishing Gradients

In [None]:
def compute_gradient_magnitude(W_hh, sequence_length):
    """
    Compute how gradient magnitude changes over timesteps.
    
    For vanilla RNN, gradient at step t w.r.t. step 0 involves:
    ∂h_t/∂h_0 = ∏_{i=1}^{t} W_hh * diag(1 - h_i²)
    
    Simplified: assume tanh'(z) ≈ 1 (near zero), so gradient ∝ W_hh^t
    """
    eigenvalues = np.linalg.eigvals(W_hh)
    max_eigenvalue = np.max(np.abs(eigenvalues))
    
    # Gradient magnitude over time
    magnitudes = [max_eigenvalue ** t for t in range(sequence_length)]
    return magnitudes, max_eigenvalue

# Test with different weight initializations
hidden_size = 100
sequence_length = 50

# Small weights (vanishing)
W_small = np.random.randn(hidden_size, hidden_size) * 0.5 / np.sqrt(hidden_size)
mag_small, eig_small = compute_gradient_magnitude(W_small, sequence_length)

# Identity-like (stable)
W_stable = np.eye(hidden_size) + np.random.randn(hidden_size, hidden_size) * 0.01
mag_stable, eig_stable = compute_gradient_magnitude(W_stable, sequence_length)

# Large weights (exploding)
W_large = np.random.randn(hidden_size, hidden_size) * 1.5 / np.sqrt(hidden_size)
mag_large, eig_large = compute_gradient_magnitude(W_large, sequence_length)

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.semilogy(mag_small, label=f'Small init (λ_max={eig_small:.2f})')
plt.semilogy(mag_stable, label=f'Stable init (λ_max={eig_stable:.2f})')
plt.semilogy(mag_large, label=f'Large init (λ_max={eig_large:.2f})')
plt.xlabel('Timestep')
plt.ylabel('Gradient magnitude (log scale)')
plt.title('Gradient Flow Through Time')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Show eigenvalue distribution
for W, name in [(W_small, 'Small'), (W_stable, 'Stable'), (W_large, 'Large')]:
    eigs = np.linalg.eigvals(W)
    plt.scatter(eigs.real, eigs.imag, alpha=0.5, label=name, s=10)
circle = plt.Circle((0, 0), 1, fill=False, color='red', linestyle='--', label='Unit circle')
plt.gca().add_patch(circle)
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.title('Eigenvalue Distribution of W_hh')
plt.legend()
plt.axis('equal')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Eigenvalues inside unit circle → vanishing gradients")
print("Eigenvalues outside unit circle → exploding gradients")

## Part 3: LSTM from Scratch

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

class SimpleLSTM:
    """
    LSTM implemented from scratch.
    
    f_t = σ(W_f [h_{t-1}, x_t] + b_f)  # Forget gate
    i_t = σ(W_i [h_{t-1}, x_t] + b_i)  # Input gate
    c̃_t = tanh(W_c [h_{t-1}, x_t] + b_c)  # Candidate
    c_t = f_t * c_{t-1} + i_t * c̃_t  # Cell state
    o_t = σ(W_o [h_{t-1}, x_t] + b_o)  # Output gate
    h_t = o_t * tanh(c_t)  # Hidden state
    """
    
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        concat_size = hidden_size + input_size
        
        # Initialize weights for all gates
        scale = np.sqrt(2.0 / concat_size)
        self.W_f = np.random.randn(hidden_size, concat_size) * scale
        self.W_i = np.random.randn(hidden_size, concat_size) * scale
        self.W_c = np.random.randn(hidden_size, concat_size) * scale
        self.W_o = np.random.randn(hidden_size, concat_size) * scale
        
        # Biases (forget gate bias often initialized to 1)
        self.b_f = np.ones((hidden_size, 1))  # Important!
        self.b_i = np.zeros((hidden_size, 1))
        self.b_c = np.zeros((hidden_size, 1))
        self.b_o = np.zeros((hidden_size, 1))
    
    def forward(self, x, h_prev, c_prev):
        """
        Single step forward.
        Returns: h_t, c_t, and all gate values for analysis
        """
        # Concatenate h and x
        concat = np.vstack([h_prev, x])
        
        # Gates
        f = sigmoid(self.W_f @ concat + self.b_f)  # Forget
        i = sigmoid(self.W_i @ concat + self.b_i)  # Input
        c_tilde = np.tanh(self.W_c @ concat + self.b_c)  # Candidate
        o = sigmoid(self.W_o @ concat + self.b_o)  # Output
        
        # Cell and hidden state
        c = f * c_prev + i * c_tilde
        h = o * np.tanh(c)
        
        return h, c, {'f': f, 'i': i, 'o': o, 'c_tilde': c_tilde}
    
    def forward_sequence(self, xs, h0=None, c0=None):
        """Process entire sequence."""
        if h0 is None:
            h0 = np.zeros((self.hidden_size, 1))
        if c0 is None:
            c0 = np.zeros((self.hidden_size, 1))
        
        hs, cs, gates = [h0], [c0], []
        h, c = h0, c0
        
        for x in xs:
            h, c, g = self.forward(x, h, c)
            hs.append(h)
            cs.append(c)
            gates.append(g)
        
        return hs, cs, gates

# Test
lstm = SimpleLSTM(input_size=4, hidden_size=8)
sequence = [np.random.randn(4, 1) for _ in range(10)]
hs, cs, gates = lstm.forward_sequence(sequence)

print(f"Sequence length: {len(sequence)}")
print(f"Final hidden state shape: {hs[-1].shape}")
print(f"Final cell state shape: {cs[-1].shape}")

## Part 4: LSTM Gradient Flow

In [None]:
# Compare gradient flow: RNN vs LSTM
# For LSTM, gradient through cell state: ∂c_t/∂c_{t-1} = f_t
# If f_t ≈ 1, gradients flow unchanged!

# Simulate gradient flow
sequence_length = 50

# RNN: gradient ∝ λ^t where λ is max eigenvalue
rnn_gradient = [0.9 ** t for t in range(sequence_length)]  # Typical vanishing

# LSTM: gradient ∝ ∏ f_t where f_t can be close to 1
forget_gate_values = np.random.uniform(0.9, 1.0, sequence_length)  # High forget gates
lstm_gradient = np.cumprod(forget_gate_values)

plt.figure(figsize=(10, 5))
plt.semilogy(rnn_gradient, 'r-', linewidth=2, label='Vanilla RNN (λ=0.9)')
plt.semilogy(lstm_gradient, 'b-', linewidth=2, label='LSTM (forget gate ≈ 0.95)')
plt.xlabel('Timestep')
plt.ylabel('Gradient magnitude (log scale)')
plt.title('Gradient Flow: RNN vs LSTM')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"RNN gradient at t=50: {rnn_gradient[-1]:.2e}")
print(f"LSTM gradient at t=50: {lstm_gradient[-1]:.2e}")
print("\nLSTM maintains gradient flow through the cell state 'highway'!")

## Part 5: Train on Sequence Task

In [None]:
# Task: Remember the first element of a sequence
# Input: [x1, x2, ..., xT] where xi ∈ {0, 1}
# Output: x1 (requires long-term memory)

def generate_data(n_samples, sequence_length):
    """Generate sequences where we need to remember the first element."""
    X = torch.zeros(n_samples, sequence_length, 1)
    y = torch.zeros(n_samples, dtype=torch.long)
    
    for i in range(n_samples):
        first = np.random.randint(0, 2)
        X[i, 0, 0] = first
        # Fill rest with noise (but target is still the first element)
        X[i, 1:, 0] = torch.rand(sequence_length - 1) * 0.1
        y[i] = first
    
    return X, y

# Models
class RNNClassifier(nn.Module):
    def __init__(self, hidden_size=32):
        super().__init__()
        self.rnn = nn.RNN(1, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)
    
    def forward(self, x):
        _, h_n = self.rnn(x)
        return self.fc(h_n.squeeze(0))

class LSTMClassifier(nn.Module):
    def __init__(self, hidden_size=32):
        super().__init__()
        self.lstm = nn.LSTM(1, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 2)
    
    def forward(self, x):
        _, (h_n, _) = self.lstm(x)
        return self.fc(h_n.squeeze(0))

# Train and compare
def train_model(model, sequence_length, epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    accuracies = []
    
    for epoch in range(epochs):
        X, y = generate_data(100, sequence_length)
        
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        
        acc = (out.argmax(1) == y).float().mean().item()
        accuracies.append(acc)
    
    return accuracies

# Compare on different sequence lengths
sequence_lengths = [10, 25, 50, 100]

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for ax, seq_len in zip(axes.flatten(), sequence_lengths):
    rnn_model = RNNClassifier()
    lstm_model = LSTMClassifier()
    
    rnn_acc = train_model(rnn_model, seq_len)
    lstm_acc = train_model(lstm_model, seq_len)
    
    ax.plot(rnn_acc, 'r-', alpha=0.7, label='RNN')
    ax.plot(lstm_acc, 'b-', alpha=0.7, label='LSTM')
    ax.axhline(0.5, color='gray', linestyle='--', label='Random')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Sequence Length = {seq_len}')
    ax.legend()
    ax.set_ylim(0.4, 1.05)
    ax.grid(True, alpha=0.3)

plt.suptitle('RNN vs LSTM: Remembering First Element', fontsize=14)
plt.tight_layout()
plt.show()

print("As sequence length increases, RNN fails but LSTM succeeds!")
print("This demonstrates LSTM's ability to capture long-range dependencies.")

## Summary

**Key insights:**
1. **RNNs** process sequences with shared weights across time
2. **Vanishing gradients** prevent learning long-range dependencies
3. **LSTM** solves this with a cell state that can carry information unchanged
4. **Gates** (forget, input, output) control information flow
5. For very long sequences, consider **Transformers** (no recurrence)