# Introduction to Sequence Modeling - Interactive Notebook

This notebook provides hands-on experience with RNNs and demonstrates why we need attention mechanisms.

## Setup

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
import time
import warnings
warnings.filterwarnings('ignore')

# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Helper for pretty printing
def print_section(title):
    print("\n" + "="*50)
    print(f"  {title}")
    print("="*50)

## 1. Understanding Sequential Data

Let's start by seeing why order matters in sequences.

In [None]:
print_section("Why Order Matters")

# Example 1: Word order changes meaning
sentences = [
    "The cat chased the mouse",
    "The mouse chased the cat",
    "Chased the mouse the cat"  # Scrambled
]

print("Sentences with different word orders:")
for i, sent in enumerate(sentences, 1):
    print(f"{i}. {sent}")

# Example 2: Bag of Words loses order
print("\nBag of Words representation:")
for sent in sentences[:2]:
    words = sent.lower().split()
    bow = {word: words.count(word) for word in set(words)}
    print(f"'{sent}' → {bow}")

print("\n⚠️  Both sentences have the same BoW representation!")

## 2. Building a Simple RNN from Scratch

Let's implement a basic RNN to understand how it processes sequences.

In [None]:
class SimpleRNN:
    """A minimal RNN implementation for educational purposes."""
    
    def __init__(self, input_size, hidden_size, output_size):
        # Initialize weights with small random values
        self.Wxh = np.random.randn(hidden_size, input_size) * 0.01
        self.Whh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.Why = np.random.randn(output_size, hidden_size) * 0.01
        self.bh = np.zeros((hidden_size, 1))
        self.by = np.zeros((output_size, 1))
        
        self.hidden_size = hidden_size
        
    def step(self, x, h_prev):
        """Single RNN step."""
        # Update hidden state
        h = np.tanh(self.Wxh @ x + self.Whh @ h_prev + self.bh)
        # Compute output
        y = self.Why @ h + self.by
        return h, y
    
    def forward(self, inputs):
        """Process entire sequence."""
        h = np.zeros((self.hidden_size, 1))
        outputs = []
        hiddens = [h]
        
        for x in inputs:
            h, y = self.step(x, h)
            outputs.append(y)
            hiddens.append(h)
            
        return outputs, hiddens

# Create a small RNN
rnn = SimpleRNN(input_size=3, hidden_size=4, output_size=2)

# Process a sequence
sequence = [np.random.randn(3, 1) for _ in range(5)]
outputs, hiddens = rnn.forward(sequence)

print_section("RNN Processing")
print(f"Sequence length: {len(sequence)}")
print(f"Hidden state shape: {hiddens[0].shape}")
print(f"Output shape: {outputs[0].shape}")
print(f"\nNumber of parameters: {rnn.Wxh.size + rnn.Whh.size + rnn.Why.size + rnn.bh.size + rnn.by.size}")

## 3. Visualizing RNN Processing

Let's visualize how information flows through an RNN step by step.

In [None]:
def visualize_rnn_processing():
    """Visualize how RNN processes a sequence step by step."""
    
    # Create figure
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    fig.suptitle('RNN Processing Steps', fontsize=16)
    
    # Sample sequence
    words = ["The", "cat", "sat", "on", "mat"]
    hidden_size = 4
    
    # Process each word
    hidden_states = []
    h = np.zeros(hidden_size)
    
    for t, word in enumerate(words):
        # Simulate processing (random for visualization)
        h = np.tanh(np.random.randn(hidden_size) + 0.5 * h)
        hidden_states.append(h)
        
        if t < 6:  # We have 6 subplot slots
            ax = axes[t // 3, t % 3]
            
            # Visualize hidden state
            ax.bar(range(hidden_size), h, color='blue', alpha=0.7)
            ax.set_ylim(-1, 1)
            ax.set_title(f"Step {t+1}: '{word}'")
            ax.set_xlabel('Hidden units')
            ax.set_ylabel('Activation')
            ax.grid(True, alpha=0.3)
            
            # Add text showing what's encoded
            encoded = " ".join(words[:t+1])
            ax.text(0.5, -1.3, f"Encoding: '{encoded}'", 
                   transform=ax.transAxes, ha='center', fontsize=10)
    
    # Hide unused subplot
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Show hidden state evolution
    plt.figure(figsize=(10, 6))
    hidden_matrix = np.array(hidden_states).T
    
    plt.imshow(hidden_matrix, aspect='auto', cmap='RdBu', vmin=-1, vmax=1)
    plt.colorbar(label='Activation')
    plt.xlabel('Time step')
    plt.ylabel('Hidden unit')
    plt.title('Hidden State Evolution Over Time')
    plt.xticks(range(len(words)), words)
    plt.show()

visualize_rnn_processing()

## 4. The Vanishing Gradient Problem

Let's demonstrate why RNNs struggle with long sequences.

In [None]:
def demonstrate_gradient_flow():
    """Show how gradients vanish or explode in RNNs."""
    
    print_section("Gradient Flow in RNNs")
    
    # Simulate gradient backpropagation
    sequence_lengths = [5, 10, 20, 50]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    for idx, seq_len in enumerate(sequence_lengths):
        ax = axes[idx // 2, idx % 2]
        
        # Case 1: Vanishing gradients (factor < 1)
        gradient_vanish = []
        grad = 1.0
        for t in range(seq_len):
            grad *= 0.8  # Gradient shrinks
            gradient_vanish.append(grad)
        
        # Case 2: Exploding gradients (factor > 1)
        gradient_explode = []
        grad = 1.0
        for t in range(seq_len):
            grad *= 1.2  # Gradient grows
            gradient_explode.append(grad)
        
        # Plot
        ax.semilogy(gradient_vanish, 'b-', linewidth=2, label='Vanishing (×0.8)')
        ax.semilogy(gradient_explode, 'r-', linewidth=2, label='Exploding (×1.2)')
        ax.axhline(y=0.01, color='gray', linestyle='--', alpha=0.5)
        ax.axhline(y=100, color='gray', linestyle='--', alpha=0.5)
        
        ax.set_xlabel('Backprop steps')
        ax.set_ylabel('Gradient magnitude')
        ax.set_title(f'Sequence length: {seq_len}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add annotations
        ax.fill_between(range(seq_len), 0.01, 0.0001, alpha=0.2, color='blue', 
                       label='Too small to learn')
        ax.fill_between(range(seq_len), 100, 1000, alpha=0.2, color='red',
                       label='Numerical instability')
    
    plt.suptitle('Gradient Flow Through Time', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Show the mathematical reason
    print("\nMathematical explanation:")
    print("Gradient after t steps: ∂L/∂h₀ = ∂L/∂hₜ × ∏(∂hᵢ/∂hᵢ₋₁)")
    print("If |∂hᵢ/∂hᵢ₋₁| < 1: gradient → 0 (vanishing)")
    print("If |∂hᵢ/∂hᵢ₋₁| > 1: gradient → ∞ (exploding)")

demonstrate_gradient_flow()

## 5. LSTM: Attempting to Solve the Problem

Let's see how LSTMs try to address gradient issues with gating mechanisms.

In [None]:
class LSTMCell:
    """Simple LSTM cell implementation."""
    
    def __init__(self, input_size, hidden_size):
        # Combined weights for all gates
        self.W = np.random.randn(4 * hidden_size, input_size + hidden_size) * 0.01
        self.b = np.zeros((4 * hidden_size, 1))
        self.hidden_size = hidden_size
        
    def forward(self, x, h_prev, c_prev):
        # Concatenate input and previous hidden
        combined = np.vstack([x, h_prev])
        
        # Compute all gates at once
        gates = self.W @ combined + self.b
        
        # Split into individual gates
        i_gate = self._sigmoid(gates[:self.hidden_size])  # Input gate
        f_gate = self._sigmoid(gates[self.hidden_size:2*self.hidden_size])  # Forget gate
        g_gate = np.tanh(gates[2*self.hidden_size:3*self.hidden_size])  # Candidate
        o_gate = self._sigmoid(gates[3*self.hidden_size:])  # Output gate
        
        # Update cell state (highway for gradients)
        c = f_gate * c_prev + i_gate * g_gate
        
        # Update hidden state
        h = o_gate * np.tanh(c)
        
        return h, c, (i_gate, f_gate, g_gate, o_gate)
    
    def _sigmoid(self, x):
        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

# Visualize LSTM gates
def visualize_lstm_gates():
    print_section("LSTM Gate Visualization")
    
    lstm = LSTMCell(input_size=5, hidden_size=4)
    
    # Sample input
    x = np.random.randn(5, 1)
    h_prev = np.zeros((4, 1))
    c_prev = np.zeros((4, 1))
    
    # Forward pass
    h, c, gates = lstm.forward(x, h_prev, c_prev)
    i_gate, f_gate, g_gate, o_gate = gates
    
    # Plot gates
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    gates_data = [
        (i_gate, 'Input Gate', 'Controls new information'),
        (f_gate, 'Forget Gate', 'Controls what to forget'),
        (g_gate, 'Candidate Values', 'New information to add'),
        (o_gate, 'Output Gate', 'Controls what to output')
    ]
    
    for idx, (gate, title, desc) in enumerate(gates_data):
        ax = axes[idx // 2, idx % 2]
        
        bars = ax.bar(range(len(gate)), gate.flatten())
        
        # Color bars based on value
        for bar, val in zip(bars, gate.flatten()):
            if title == 'Candidate Values':
                bar.set_color('green' if val > 0 else 'red')
            else:
                bar.set_color(plt.cm.Blues(val))
        
        ax.set_ylim(-1.5 if title == 'Candidate Values' else 0, 1.5 if title == 'Candidate Values' else 1)
        ax.set_xlabel('Hidden units')
        ax.set_ylabel('Gate value')
        ax.set_title(f'{title}\n{desc}')
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('LSTM Gates in Action', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("\nKey insights:")
    print("- Forget gate ≈ 0.5: Keeping about half of previous information")
    print("- Input gate controls how much new information to add")
    print("- Cell state provides a highway for gradient flow")
    print("- But still processes sequentially!")

visualize_lstm_gates()

## 6. The Information Bottleneck Problem

Let's visualize how RNNs compress all information into a fixed-size vector.

In [None]:
def demonstrate_bottleneck():
    print_section("Information Bottleneck in Seq2Seq")
    
    # Simulate encoding sentences of different lengths
    sentences = [
        "Hello",
        "Hello world",
        "The quick brown fox jumps",
        "The quick brown fox jumps over the lazy dog near the river"
    ]
    
    hidden_size = 8  # Fixed size context vector
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    for idx, sentence in enumerate(sentences):
        ax = axes[idx // 2, idx % 2]
        words = sentence.split()
        
        # Simulate encoding process
        context_vector = np.random.randn(hidden_size)
        
        # Visualize compression
        word_info = len(words) * 50  # Arbitrary "information units" per word
        context_capacity = hidden_size * 10  # Capacity of context vector
        
        # Bar chart showing information
        bars = ax.bar(['Input Information', 'Context Capacity'], 
                      [word_info, context_capacity])
        
        # Color based on whether information fits
        if word_info > context_capacity:
            bars[0].set_color('red')
            bars[1].set_color('orange')
            ax.text(0.5, word_info + 10, 'Information Lost!', 
                   ha='center', color='red', fontweight='bold')
        else:
            bars[0].set_color('green')
            bars[1].set_color('green')
        
        ax.set_ylabel('Information (arbitrary units)')
        ax.set_title(f'Sentence: "{sentence}"\n({len(words)} words)')
        
        # Add text showing the compression ratio
        compression_ratio = word_info / context_capacity
        ax.text(0.5, -0.15, f'Compression ratio: {compression_ratio:.1f}:1', 
               transform=ax.transAxes, ha='center')
    
    plt.suptitle('Information Bottleneck: Longer Sequences Lose More Information', 
                fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("\n🔴 Problem: All sentences compressed to same size vector!")
    print("🟢 Solution: Attention lets decoder look at all encoder states")

demonstrate_bottleneck()

## 7. Sequential vs Parallel Processing

Let's compare how RNNs and Attention process sequences.

In [None]:
def compare_processing_methods():
    print_section("Sequential (RNN) vs Parallel (Attention) Processing")
    
    sequence_length = 8
    
    # Simulate RNN processing
    print("RNN Processing (Sequential):")
    print("Each step must wait for the previous one...\n")
    
    rnn_times = []
    for t in range(sequence_length):
        print(f"Step {t+1}: ", end='')
        for _ in range(t+1):
            print("→", end='')
            time.sleep(0.1)
        print(f" Done (depends on steps 1-{t})")
        rnn_times.append((t+1) * 0.1)
    
    total_rnn_time = sum(rnn_times)
    print(f"\nTotal RNN time: {total_rnn_time:.1f}s")
    
    print("\n" + "="*50 + "\n")
    
    # Simulate Attention processing  
    print("Attention Processing (Parallel):")
    print("All positions processed simultaneously!\n")
    
    print("All steps: ", end='')
    for t in range(sequence_length):
        print(f"[{t+1}]", end=' ')
    
    time.sleep(0.2)  # Simulate parallel processing
    print("→ Done!")
    print(f"\nTotal Attention time: 0.2s (with {sequence_length} parallel processors)")
    
    # Visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # RNN timeline
    for t in range(sequence_length):
        ax1.barh(t, 1, left=t, height=0.8, alpha=0.7, 
                color=plt.cm.Blues(0.5 + t * 0.05))
        ax1.text(t + 0.5, t, f'Step {t+1}', ha='center', va='center')
    
    ax1.set_xlim(0, sequence_length)
    ax1.set_ylim(-0.5, sequence_length - 0.5)
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Processing Step')
    ax1.set_title('RNN: Sequential Processing')
    ax1.grid(True, alpha=0.3)
    
    # Attention timeline
    ax2.barh(range(sequence_length), [1] * sequence_length, 
            height=0.8, alpha=0.7, color='green')
    for t in range(sequence_length):
        ax2.text(0.5, t, f'Pos {t+1}', ha='center', va='center')
    
    ax2.set_xlim(0, 2)
    ax2.set_ylim(-0.5, sequence_length - 0.5)
    ax2.set_xlabel('Time')
    ax2.set_ylabel('Position')
    ax2.set_title('Attention: Parallel Processing')
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle('Processing Time Comparison', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print(f"\n⚡ Speedup with parallelization: {total_rnn_time/0.2:.0f}x faster!")

compare_processing_methods()

## 8. Training a Simple RNN Language Model

Let's train a character-level RNN to see its limitations in practice.

In [None]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size=128):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.output = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, hidden=None):
        embed = self.embedding(x)
        out, hidden = self.rnn(embed, hidden)
        out = self.output(out)
        return out, hidden
    
    def generate(self, start_char, char_to_idx, idx_to_char, length=100):
        self.eval()
        with torch.no_grad():
            # Start with the given character
            current = torch.tensor([[char_to_idx[start_char]]])
            hidden = None
            result = start_char
            
            for _ in range(length - 1):
                output, hidden = self.forward(current, hidden)
                probs = torch.softmax(output[0, -1], dim=0)
                next_char_idx = torch.multinomial(probs, 1).item()
                result += idx_to_char[next_char_idx]
                current = torch.tensor([[next_char_idx]])
                
        return result

# Train on a simple pattern
def train_char_rnn():
    print_section("Training Character-Level RNN")
    
    # Simple repetitive text to learn
    text = "hello world! " * 20
    chars = list(set(text))
    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for ch, i in char_to_idx.items()}
    
    print(f"Training text: '{text[:50]}...'")
    print(f"Vocabulary size: {len(chars)}")
    print(f"Characters: {chars}")
    
    # Prepare data
    data = torch.tensor([char_to_idx[ch] for ch in text])
    
    # Create model
    model = CharRNN(len(chars), hidden_size=32)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Training
    losses = []
    model.train()
    
    for epoch in range(100):
        hidden = None
        total_loss = 0
        
        # Process sequence in chunks
        chunk_size = 10
        for i in range(0, len(data) - chunk_size, chunk_size):
            # Get input and target
            inputs = data[i:i+chunk_size].unsqueeze(0)
            targets = data[i+1:i+chunk_size+1].unsqueeze(0)
            
            # Forward pass
            outputs, hidden = model(inputs, hidden)
            loss = criterion(outputs.view(-1, len(chars)), targets.view(-1))
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Detach hidden state
            hidden = hidden.detach()
            total_loss += loss.item()
        
        losses.append(total_loss)
        
        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss:.4f}")
    
    # Generate text
    print("\nGenerated text samples:")
    for start in ['h', 'w', ' ']:
        generated = model.generate(start, char_to_idx, idx_to_char, length=50)
        print(f"Starting with '{start}': {generated}")
    
    # Plot loss
    plt.figure(figsize=(10, 6))
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('RNN Training Loss')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print("\n📝 Note: RNN learned the pattern but would struggle with longer dependencies!")

train_char_rnn()

## 9. Demonstrating Long-Range Dependencies

Let's see why RNNs fail at long-range dependencies.

In [None]:
def long_range_dependency_test():
    print_section("Long-Range Dependency Challenge")
    
    # Create sentences with dependencies at different distances
    test_sentences = [
        ("The cat [MASK].", "sat", 1),
        ("The cat that was black [MASK].", "sat", 4),
        ("The cat that chased the mouse that ate the cheese [MASK].", "sat", 10),
        ("The cat that lived in the house that Jack built which stood on the hill [MASK].", "sat", 15)
    ]
    
    # Simulate RNN performance degradation
    plt.figure(figsize=(12, 8))
    
    distances = []
    rnn_accuracy = []
    attention_accuracy = []
    
    for sent, answer, distance in test_sentences:
        distances.append(distance)
        
        # Simulate accuracy (RNN degrades, Attention maintains)
        rnn_acc = 0.95 * (0.8 ** (distance / 5))  # Exponential decay
        att_acc = 0.95 - 0.01 * (distance / 10)   # Slight linear decay
        
        rnn_accuracy.append(rnn_acc)
        attention_accuracy.append(att_acc)
        
        print(f"\nSentence: {sent}")
        print(f"Correct answer: '{answer}'")
        print(f"Distance to dependency: {distance} words")
        print(f"RNN accuracy: {rnn_acc:.2%}")
        print(f"Attention accuracy: {att_acc:.2%}")
    
    # Plot comparison
    x = np.arange(len(distances))
    width = 0.35
    
    plt.bar(x - width/2, rnn_accuracy, width, label='RNN', color='red', alpha=0.7)
    plt.bar(x + width/2, attention_accuracy, width, label='Attention', color='green', alpha=0.7)
    
    plt.xlabel('Test Case')
    plt.ylabel('Accuracy')
    plt.title('Performance on Long-Range Dependencies')
    plt.xticks(x, [f'Distance: {d}' for d in distances])
    plt.legend()
    plt.ylim(0, 1)
    
    # Add annotations
    for i, (rnn, att) in enumerate(zip(rnn_accuracy, attention_accuracy)):
        plt.text(i - width/2, rnn + 0.02, f'{rnn:.0%}', ha='center')
        plt.text(i + width/2, att + 0.02, f'{att:.0%}', ha='center')
    
    plt.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()
    
    print("\n🔴 RNN: Performance drops significantly with distance")
    print("🟢 Attention: Maintains performance regardless of distance")

long_range_dependency_test()

## 10. Summary: Why We Need Attention

Let's summarize all the problems with RNNs and preview the solution.

In [None]:
def create_summary_visualization():
    print_section("RNN Limitations vs Attention Solutions")
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Sequential vs Parallel
    ax = axes[0, 0]
    ax.text(0.5, 0.8, 'Processing', ha='center', fontsize=14, fontweight='bold')
    ax.text(0.25, 0.6, 'RNN', ha='center', fontsize=12, color='red')
    ax.text(0.25, 0.4, '→→→→→', ha='center', fontsize=10)
    ax.text(0.25, 0.2, 'Sequential\n(Slow)', ha='center', fontsize=10)
    
    ax.text(0.75, 0.6, 'Attention', ha='center', fontsize=12, color='green')
    ax.text(0.75, 0.4, '⇉⇉⇉⇉⇉', ha='center', fontsize=10)
    ax.text(0.75, 0.2, 'Parallel\n(Fast)', ha='center', fontsize=10)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    
    # 2. Information flow
    ax = axes[0, 1]
    ax.text(0.5, 0.8, 'Information Flow', ha='center', fontsize=14, fontweight='bold')
    
    # RNN bottleneck
    ax.arrow(0.1, 0.5, 0.15, 0, head_width=0.05, color='red')
    ax.text(0.3, 0.5, '🍾', fontsize=20)
    ax.arrow(0.35, 0.5, 0.15, 0, head_width=0.05, color='red')
    ax.text(0.25, 0.3, 'Bottleneck', ha='center', color='red')
    
    # Attention direct connections
    for i in range(5):
        y = 0.5 + (i - 2) * 0.05
        ax.arrow(0.6, y, 0.3, 0, head_width=0.02, color='green', alpha=0.7)
    ax.text(0.75, 0.3, 'Direct paths', ha='center', color='green')
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    
    # 3. Gradient flow
    ax = axes[1, 0]
    steps = np.arange(10)
    rnn_gradient = 0.8 ** steps
    attention_gradient = np.ones_like(steps) * 0.9
    
    ax.plot(steps, rnn_gradient, 'r-', linewidth=2, label='RNN')
    ax.plot(steps, attention_gradient, 'g-', linewidth=2, label='Attention')
    ax.set_xlabel('Backprop Steps')
    ax.set_ylabel('Gradient Magnitude')
    ax.set_title('Gradient Flow')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Memory access pattern
    ax = axes[1, 1]
    positions = np.arange(8)
    
    # RNN: can only access through chain
    ax.text(0.5, 0.9, 'Memory Access', ha='center', fontsize=12, fontweight='bold')
    
    # Draw RNN chain
    for i in range(7):
        ax.plot([i, i+1], [0.7, 0.7], 'r-', linewidth=2)
        ax.plot(i, 0.7, 'ro', markersize=8)
    ax.plot(7, 0.7, 'ro', markersize=8)
    ax.text(3.5, 0.6, 'RNN: Sequential access only', ha='center', color='red')
    
    # Draw Attention full connectivity
    for i in range(8):
        for j in range(8):
            if i != j:
                ax.plot([i, j], [0.3, 0.3], 'g-', alpha=0.1, linewidth=1)
        ax.plot(i, 0.3, 'go', markersize=8)
    ax.text(3.5, 0.2, 'Attention: Random access to any position', ha='center', color='green')
    
    ax.set_xlim(-0.5, 7.5)
    ax.set_ylim(0, 1)
    ax.axis('off')
    
    plt.suptitle('Why Attention Solves RNN Problems', fontsize=16)
    plt.tight_layout()
    plt.show()

create_summary_visualization()

print("\n" + "="*60)
print("🎯 KEY INSIGHTS:")
print("="*60)
print("\n1. RNNs must process sequences step by step (slow)")
print("   → Attention processes all positions in parallel (fast)")
print("\n2. RNNs compress everything into fixed-size vectors")
print("   → Attention maintains all information")
print("\n3. RNNs suffer from vanishing/exploding gradients")
print("   → Attention has direct gradient paths")
print("\n4. RNNs can only access memory sequentially")
print("   → Attention can access any position directly")
print("\n🚀 Next: Learn how attention mechanisms work!")

## Exercises

Try these exercises to deepen your understanding:

In [None]:
print_section("Exercises")

print("1. Gradient Clipping")
print("   Implement gradient clipping in the RNN backward pass.")
print("   Why is this necessary?\n")

print("2. Bidirectional RNN")
print("   Modify the SimpleRNN to process sequences in both directions.")
print("   What are the benefits?\n")

print("3. GRU Implementation")
print("   Implement a GRU cell (simpler than LSTM).")
print("   Compare its gates with LSTM.\n")

print("4. Attention Preview")
print("   Given a sequence, compute pairwise similarity scores.")
print("   This is the foundation of attention!\n")

# Skeleton for Exercise 4
def attention_preview():
    # Word embeddings (random for now)
    words = ["The", "cat", "sat", "on", "mat"]
    embeddings = np.random.randn(5, 4)  # 5 words, 4 dimensions
    
    # TODO: Compute similarity matrix using dot products
    # similarity[i,j] = embedding[i] · embedding[j]
    
    # TODO: Apply softmax to get attention weights
    # For each word, what other words does it "attend to"?
    
    pass

print("💡 Hint for Exercise 4: This is exactly what self-attention does!")