# Attention Mechanisms - Interactive Notebook

This notebook provides hands-on experience with attention mechanisms - the key innovation behind transformers.

## Setup

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import HTML, display
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Reproducibility
np.random.seed(42)
torch.manual_seed(42)

## 1. Understanding Attention Intuitively

Let's start with a simple example to build intuition.

In [None]:
# Simple attention example: Finding relevant words
def simple_attention_demo():
    # Sentence and word embeddings (simplified)
    words = ["The", "cat", "sat", "on", "the", "mat"]
    
    # Simple embeddings (2D for visualization)
    embeddings = {
        "The": np.array([0.1, 0.9]),
        "cat": np.array([0.9, 0.2]),
        "sat": np.array([0.5, 0.5]),
        "on": np.array([0.3, 0.7]),
        "mat": np.array([0.8, 0.3])
    }
    
    # Query: "What did the cat do?"
    query = np.array([0.7, 0.4])  # Similar to "cat" and "sat"
    
    # Compute attention scores (dot product)
    scores = []
    for word in words:
        score = np.dot(query, embeddings[word])
        scores.append(score)
    
    # Convert to probabilities with softmax
    scores = np.array(scores)
    attention_weights = np.exp(scores) / np.sum(np.exp(scores))
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot embeddings
    for word in set(words):
        emb = embeddings[word]
        ax1.scatter(emb[0], emb[1], s=100)
        ax1.annotate(word, (emb[0], emb[1]), xytext=(5, 5), textcoords='offset points')
    
    ax1.scatter(query[0], query[1], s=200, c='red', marker='*')
    ax1.annotate('Query', (query[0], query[1]), xytext=(5, 5), textcoords='offset points', color='red')
    ax1.set_xlabel('Dimension 1')
    ax1.set_ylabel('Dimension 2')
    ax1.set_title('Word Embeddings and Query')
    ax1.grid(True, alpha=0.3)
    
    # Plot attention weights
    bars = ax2.bar(range(len(words)), attention_weights)
    ax2.set_xticks(range(len(words)))
    ax2.set_xticklabels(words)
    ax2.set_ylabel('Attention Weight')
    ax2.set_title('Attention Distribution')
    
    # Color bars by weight
    for bar, weight in zip(bars, attention_weights):
        bar.set_color(plt.cm.Blues(weight * 2))
    
    # Add values on bars
    for i, (word, weight) in enumerate(zip(words, attention_weights)):
        ax2.text(i, weight + 0.01, f'{weight:.3f}', ha='center')
    
    plt.tight_layout()
    plt.show()
    
    # Print interpretation
    print("Attention weights show which words are most relevant to the query:")
    sorted_idx = np.argsort(attention_weights)[::-1]
    for idx in sorted_idx[:3]:
        print(f"  '{words[idx]}': {attention_weights[idx]:.3f}")

simple_attention_demo()

## 2. Implementing Attention Step by Step

Let's build attention from scratch to understand each component.

In [None]:
def attention_step_by_step():
    print("=== Scaled Dot-Product Attention ===\n")
    
    # Example dimensions
    seq_len = 4
    d_k = 3
    
    # Random queries, keys, values
    Q = np.random.randn(seq_len, d_k)
    K = np.random.randn(seq_len, d_k)
    V = np.random.randn(seq_len, d_k)
    
    print(f"Q shape: {Q.shape} (seq_len × d_k)")
    print(f"K shape: {K.shape} (seq_len × d_k)")
    print(f"V shape: {V.shape} (seq_len × d_k)\n")
    
    # Step 1: Compute scores
    scores = Q @ K.T
    print("Step 1: Compute scores = Q @ K^T")
    print(f"Scores shape: {scores.shape}")
    print("Scores matrix:")
    print(scores)
    print()
    
    # Step 2: Scale
    scaled_scores = scores / np.sqrt(d_k)
    print(f"Step 2: Scale by √d_k = √{d_k} = {np.sqrt(d_k):.3f}")
    print("Scaled scores:")
    print(scaled_scores)
    print()
    
    # Step 3: Apply softmax
    attention_weights = np.exp(scaled_scores) / np.sum(np.exp(scaled_scores), axis=-1, keepdims=True)
    print("Step 3: Apply softmax (row-wise)")
    print("Attention weights:")
    print(attention_weights)
    print(f"\nRow sums: {attention_weights.sum(axis=1)} (should all be 1.0)")
    print()
    
    # Step 4: Apply to values
    output = attention_weights @ V
    print("Step 4: Multiply by values")
    print(f"Output shape: {output.shape}")
    print("Output:")
    print(output)
    
    # Visualize the process
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    # Q, K, V matrices
    im1 = axes[0, 0].imshow(Q, cmap='RdBu', aspect='auto')
    axes[0, 0].set_title('Q (Queries)')
    axes[0, 0].set_ylabel('Position')
    axes[0, 0].set_xlabel('Dimension')
    plt.colorbar(im1, ax=axes[0, 0])
    
    im2 = axes[0, 1].imshow(K, cmap='RdBu', aspect='auto')
    axes[0, 1].set_title('K (Keys)')
    axes[0, 1].set_xlabel('Dimension')
    plt.colorbar(im2, ax=axes[0, 1])
    
    im3 = axes[0, 2].imshow(V, cmap='RdBu', aspect='auto')
    axes[0, 2].set_title('V (Values)')
    axes[0, 2].set_xlabel('Dimension')
    plt.colorbar(im3, ax=axes[0, 2])
    
    # Scores and attention
    im4 = axes[1, 0].imshow(scores, cmap='RdBu', aspect='auto')
    axes[1, 0].set_title('Scores (Q @ K^T)')
    axes[1, 0].set_ylabel('Query position')
    axes[1, 0].set_xlabel('Key position')
    plt.colorbar(im4, ax=axes[1, 0])
    
    im5 = axes[1, 1].imshow(attention_weights, cmap='Blues', aspect='auto', vmin=0, vmax=1)
    axes[1, 1].set_title('Attention Weights')
    axes[1, 1].set_ylabel('Query position')
    axes[1, 1].set_xlabel('Key position')
    plt.colorbar(im5, ax=axes[1, 1])
    
    im6 = axes[1, 2].imshow(output, cmap='RdBu', aspect='auto')
    axes[1, 2].set_title('Output')
    axes[1, 2].set_ylabel('Position')
    axes[1, 2].set_xlabel('Dimension')
    plt.colorbar(im6, ax=axes[1, 2])
    
    plt.suptitle('Attention Computation Visualization', fontsize=16)
    plt.tight_layout()
    plt.show()

attention_step_by_step()

## 3. Self-Attention in Action

Let's see how self-attention helps with understanding context.

In [None]:
class SimpleSelfAttention:
    def __init__(self, d_model, d_k=None):
        self.d_model = d_model
        self.d_k = d_k or d_model
        
        # Initialize projection matrices
        self.W_q = np.random.randn(d_model, self.d_k) * 0.1
        self.W_k = np.random.randn(d_model, self.d_k) * 0.1
        self.W_v = np.random.randn(d_model, self.d_k) * 0.1
    
    def forward(self, x):
        # Project to Q, K, V
        Q = x @ self.W_q
        K = x @ self.W_k
        V = x @ self.W_v
        
        # Compute attention
        scores = Q @ K.T / np.sqrt(self.d_k)
        weights = self.softmax(scores)
        output = weights @ V
        
        return output, weights
    
    def softmax(self, x):
        exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Demonstrate with a sentence
def self_attention_sentence_demo():
    # Simple sentence
    words = ["The", "cat", "sat", "on", "the", "mat"]
    
    # Create simple embeddings
    vocab = list(set(words))
    word_to_idx = {w: i for i, w in enumerate(vocab)}
    
    # One-hot encode and project to embeddings
    d_model = 8
    embedding_matrix = np.random.randn(len(vocab), d_model)
    
    # Get embeddings for our sentence
    embeddings = np.array([embedding_matrix[word_to_idx[w]] for w in words])
    
    # Apply self-attention
    attention = SimpleSelfAttention(d_model, d_k=4)
    output, weights = attention.forward(embeddings)
    
    # Visualize attention patterns
    plt.figure(figsize=(8, 6))
    sns.heatmap(weights, 
                xticklabels=words,
                yticklabels=words,
                cmap='Blues',
                cbar_kws={'label': 'Attention Weight'},
                square=True,
                vmin=0,
                vmax=1)
    plt.title('Self-Attention Pattern')
    plt.xlabel('Attending to (Keys)')
    plt.ylabel('Position (Queries)')
    plt.tight_layout()
    plt.show()
    
    # Analyze specific positions
    print("Attention Analysis:")
    for i, word in enumerate(words):
        top_attention_idx = np.argsort(weights[i])[::-1][:3]
        print(f"\n'{word}' (position {i}) mainly attends to:")
        for idx in top_attention_idx:
            print(f"  - '{words[idx]}' (position {idx}): {weights[i, idx]:.3f}")

self_attention_sentence_demo()

## 4. Comparing Attention Types

Let's compare different attention mechanisms and see their differences.

In [None]:
def compare_attention_types():
    # Setup
    seq_len = 6
    d_model = 8
    
    # Create input
    x = np.random.randn(seq_len, d_model)
    
    # 1. Full Self-Attention
    full_attention = SimpleSelfAttention(d_model)
    _, weights_full = full_attention.forward(x)
    
    # 2. Causal (Autoregressive) Attention
    weights_causal = weights_full.copy()
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)
    weights_causal[mask == 1] = 0
    # Renormalize
    weights_causal = weights_causal / weights_causal.sum(axis=-1, keepdims=True)
    
    # 3. Local Attention (window size 3)
    weights_local = np.zeros((seq_len, seq_len))
    window = 1
    for i in range(seq_len):
        start = max(0, i - window)
        end = min(seq_len, i + window + 1)
        weights_local[i, start:end] = weights_full[i, start:end]
        weights_local[i] = weights_local[i] / weights_local[i].sum()
    
    # 4. Strided Attention
    weights_strided = np.zeros((seq_len, seq_len))
    stride = 2
    for i in range(seq_len):
        for j in range(0, seq_len, stride):
            if j < seq_len:
                weights_strided[i, j] = weights_full[i, j]
        weights_strided[i] = weights_strided[i] / (weights_strided[i].sum() + 1e-9)
    
    # Visualize all patterns
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    patterns = [
        (weights_full, 'Full Self-Attention', axes[0, 0]),
        (weights_causal, 'Causal Attention', axes[0, 1]),
        (weights_local, 'Local Attention (window=1)', axes[1, 0]),
        (weights_strided, 'Strided Attention (stride=2)', axes[1, 1])
    ]
    
    for weights, title, ax in patterns:
        im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1, aspect='auto')
        ax.set_title(title)
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
        
        # Add grid
        for i in range(seq_len + 1):
            ax.axhline(i - 0.5, color='gray', linewidth=0.5)
            ax.axvline(i - 0.5, color='gray', linewidth=0.5)
        
        plt.colorbar(im, ax=ax)
    
    plt.suptitle('Different Attention Patterns', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Print complexity analysis
    print("Complexity Analysis:")
    print(f"Full Attention: O(n²) = O({seq_len}²) = {seq_len**2} operations")
    print(f"Causal Attention: O(n²/2) ≈ {seq_len**2 // 2} operations")
    print(f"Local Attention: O(n×w) = O({seq_len}×{2*window+1}) = {seq_len*(2*window+1)} operations")
    print(f"Strided Attention: O(n²/s) = O({seq_len}²/{stride}) = {seq_len**2 // stride} operations")

compare_attention_types()

## 5. The Importance of Scaling

Let's see why we scale by √d_k in attention.

In [None]:
def demonstrate_scaling_importance():
    print("=== Why Scale by √d_k? ===\n")
    
    # Different dimensions
    dimensions = [4, 16, 64, 256]
    seq_len = 8
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    for idx, d_k in enumerate(dimensions):
        # Generate random Q and K
        Q = np.random.randn(seq_len, d_k)
        K = np.random.randn(seq_len, d_k)
        
        # Compute scores without scaling
        scores_unscaled = Q @ K.T
        
        # Compute scores with scaling
        scores_scaled = scores_unscaled / np.sqrt(d_k)
        
        # Apply softmax
        def softmax(x):
            exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
            return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
        
        weights_unscaled = softmax(scores_unscaled)
        weights_scaled = softmax(scores_scaled)
        
        # Visualize
        im1 = axes[0, idx].imshow(weights_unscaled, cmap='Blues', vmin=0, vmax=1)
        axes[0, idx].set_title(f'd_k={d_k}\nUnscaled')
        
        im2 = axes[1, idx].imshow(weights_scaled, cmap='Blues', vmin=0, vmax=1)
        axes[1, idx].set_title(f'Scaled by √{d_k}')
        
        # Calculate entropy (measure of concentration)
        entropy_unscaled = -np.sum(weights_unscaled * np.log(weights_unscaled + 1e-9)) / seq_len
        entropy_scaled = -np.sum(weights_scaled * np.log(weights_scaled + 1e-9)) / seq_len
        
        axes[0, idx].set_xlabel(f'Entropy: {entropy_unscaled:.2f}')
        axes[1, idx].set_xlabel(f'Entropy: {entropy_scaled:.2f}')
    
    plt.suptitle('Effect of Scaling on Attention Distribution', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Show the mathematical reason
    print("Mathematical Explanation:")
    print("- Dot product variance: Var(q·k) = d_k × Var(q_i) × Var(k_i)")
    print("- As d_k increases, dot products grow larger")
    print("- Large values → softmax becomes peaked (near one-hot)")
    print("- Peaked softmax → vanishing gradients")
    print("- Scaling by √d_k keeps variance constant")
    
    # Demonstrate gradient issue
    print("\nGradient magnitude through softmax:")
    x = np.linspace(-10, 10, 100)
    
    plt.figure(figsize=(10, 6))
    for scale in [0.5, 1.0, 2.0, 5.0]:
        y = np.exp(x * scale) / np.sum(np.exp(x * scale))
        gradient = y * (1 - y)  # Gradient of softmax
        plt.plot(x, gradient, label=f'Scale={scale}')
    
    plt.xlabel('Input value')
    plt.ylabel('Gradient magnitude')
    plt.title('Softmax Gradient for Different Scales')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

demonstrate_scaling_importance()

## 6. Attention as Information Retrieval

Let's build an intuitive example showing attention as a soft database lookup.

In [None]:
def attention_as_retrieval():
    print("=== Attention as Information Retrieval ===\n")
    
    # Create a "database" of facts
    facts = [
        "Paris is the capital of France",
        "London is the capital of England",
        "Cats are animals",
        "Dogs are animals",
        "The sun is hot",
        "Ice is cold"
    ]
    
    # Simple encoding: average word embeddings
    word_embeddings = {
        "Paris": np.array([0.9, 0.1, 0.0, 0.0]),
        "London": np.array([0.8, 0.2, 0.0, 0.0]),
        "capital": np.array([0.7, 0.7, 0.0, 0.0]),
        "France": np.array([0.9, 0.0, 0.0, 0.0]),
        "England": np.array([0.8, 0.0, 0.0, 0.0]),
        "cats": np.array([0.0, 0.0, 0.9, 0.1]),
        "dogs": np.array([0.0, 0.0, 0.8, 0.2]),
        "animals": np.array([0.0, 0.0, 0.7, 0.7]),
        "sun": np.array([0.0, 0.9, 0.0, 0.9]),
        "hot": np.array([0.0, 0.7, 0.0, 0.9]),
        "ice": np.array([0.0, 0.1, 0.0, 0.1]),
        "cold": np.array([0.0, 0.3, 0.0, 0.1]),
        # Default for other words
        "default": np.array([0.1, 0.1, 0.1, 0.1])
    }
    
    # Encode facts (keys and values)
    fact_embeddings = []
    for fact in facts:
        words = fact.lower().split()
        embeddings = [word_embeddings.get(w, word_embeddings["default"]) for w in words]
        fact_embedding = np.mean(embeddings, axis=0)
        fact_embeddings.append(fact_embedding)
    
    fact_embeddings = np.array(fact_embeddings)
    
    # Queries
    queries = [
        "What is the capital of France?",
        "Tell me about animals",
        "Temperature information"
    ]
    
    # Process each query
    fig, axes = plt.subplots(1, len(queries), figsize=(15, 5))
    
    for idx, (query, ax) in enumerate(zip(queries, axes)):
        # Encode query
        query_words = query.lower().replace("?", "").split()
        query_embeddings = [word_embeddings.get(w, word_embeddings["default"]) for w in query_words]
        query_vector = np.mean(query_embeddings, axis=0)
        
        # Compute attention scores
        scores = fact_embeddings @ query_vector
        attention_weights = np.exp(scores) / np.sum(np.exp(scores))
        
        # Visualize
        bars = ax.barh(range(len(facts)), attention_weights)
        ax.set_yticks(range(len(facts)))
        ax.set_yticklabels([f[:20] + "..." if len(f) > 20 else f for f in facts])
        ax.set_xlabel('Attention Weight')
        ax.set_title(f'Query: "{query}"')
        
        # Color by weight
        for bar, weight in zip(bars, attention_weights):
            bar.set_color(plt.cm.Blues(weight * 2))
        
        # Add values
        for i, weight in enumerate(attention_weights):
            ax.text(weight + 0.01, i, f'{weight:.3f}', va='center')
    
    plt.suptitle('Attention as Soft Database Lookup', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("Key Insights:")
    print("- Attention finds relevant information based on similarity")
    print("- Unlike hard lookup, it can combine multiple sources")
    print("- Weights show the relevance of each piece of information")

attention_as_retrieval()

## 7. Multi-Head Attention Preview

Let's get a preview of why we use multiple attention heads.

In [None]:
def multi_head_preview():
    print("=== Why Multiple Attention Heads? ===\n")
    
    # Sentence for analysis
    words = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
    seq_len = len(words)
    
    # Simulate different attention patterns that different heads might learn
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Head 1: Attending to previous word
    head1 = np.zeros((seq_len, seq_len))
    for i in range(1, seq_len):
        head1[i, i-1] = 0.8
        head1[i, i] = 0.2
    head1[0, 0] = 1.0
    
    # Head 2: Attending to next word
    head2 = np.zeros((seq_len, seq_len))
    for i in range(seq_len-1):
        head2[i, i+1] = 0.8
        head2[i, i] = 0.2
    head2[-1, -1] = 1.0
    
    # Head 3: Attending to determiners and their nouns
    head3 = np.eye(seq_len) * 0.3
    # "The" -> "fox", "the" -> "dog"
    head3[0, 3] = 0.7  # The -> fox
    head3[6, 8] = 0.7  # the -> dog
    # Adjectives to nouns
    head3[1, 3] = 0.5  # quick -> fox
    head3[2, 3] = 0.5  # brown -> fox
    head3[7, 8] = 0.5  # lazy -> dog
    
    # Head 4: Global attention (attending to all positions)
    head4 = np.ones((seq_len, seq_len)) / seq_len
    
    heads = [
        (head1, "Head 1: Previous Word", axes[0, 0]),
        (head2, "Head 2: Next Word", axes[0, 1]),
        (head3, "Head 3: Syntactic Relations", axes[1, 0]),
        (head4, "Head 4: Global Context", axes[1, 1])
    ]
    
    for weights, title, ax in heads:
        im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)
        ax.set_title(title)
        ax.set_xticks(range(seq_len))
        ax.set_yticks(range(seq_len))
        ax.set_xticklabels(words, rotation=45, ha='right')
        ax.set_yticklabels(words)
        ax.set_xlabel('Attending to')
        ax.set_ylabel('From position')
        plt.colorbar(im, ax=ax)
    
    plt.suptitle('Different Attention Heads Learn Different Patterns', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("Benefits of Multiple Heads:")
    print("1. Different heads can capture different types of relationships")
    print("2. Parallel attention to multiple aspects of the input")
    print("3. More expressive power than single attention")
    print("4. Robustness - if one head fails, others compensate")

multi_head_preview()

## 8. Implementing Efficient Attention

Let's implement attention efficiently using PyTorch.

In [None]:
class EfficientAttention(nn.Module):
    def __init__(self, d_model, d_k=None):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k or d_model
        
        # Linear projections
        self.W_q = nn.Linear(d_model, self.d_k, bias=False)
        self.W_k = nn.Linear(d_model, self.d_k, bias=False)
        self.W_v = nn.Linear(d_model, self.d_k, bias=False)
        
    def forward(self, x, mask=None, return_attention=True):
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        if return_attention:
            return output, attention_weights
        return output

# Benchmark and visualize
def benchmark_attention():
    print("=== Efficient Attention Implementation ===\n")
    
    # Test different sequence lengths
    seq_lengths = [10, 50, 100, 200, 500]
    d_model = 64
    batch_size = 32
    
    times = []
    memory_usage = []
    
    for seq_len in seq_lengths:
        # Create model and input
        model = EfficientAttention(d_model)
        x = torch.randn(batch_size, seq_len, d_model)
        
        # Warm up
        for _ in range(5):
            _ = model(x, return_attention=False)
        
        # Time
        import time
        start = time.time()
        for _ in range(20):
            _ = model(x, return_attention=False)
        elapsed = (time.time() - start) / 20
        times.append(elapsed * 1000)  # Convert to ms
        
        # Memory (attention matrix size)
        memory = batch_size * seq_len * seq_len * 4 / (1024 * 1024)  # MB
        memory_usage.append(memory)
    
    # Plot results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    ax1.plot(seq_lengths, times, 'b-o', linewidth=2, markersize=8)
    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Time (ms)')
    ax1.set_title('Attention Computation Time')
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(seq_lengths, memory_usage, 'r-o', linewidth=2, markersize=8)
    ax2.set_xlabel('Sequence Length')
    ax2.set_ylabel('Memory (MB)')
    ax2.set_title('Attention Matrix Memory Usage')
    ax2.grid(True, alpha=0.3)
    
    plt.suptitle('Attention Scalability Analysis', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print(f"Complexity: O(n²d) where n=sequence length, d=dimension")
    print(f"Memory: O(n²) for storing attention weights")
    print(f"\nFor seq_len=1000: {1000**2:,} attention values to compute!")

benchmark_attention()

## 9. Attention Patterns in Practice

Let's visualize some real attention patterns that emerge.

In [None]:
def analyze_learned_patterns():
    print("=== Attention Patterns in Practice ===\n")
    
    # Create a more complex example
    sentences = [
        "The cat sat on the mat .",
        "She opened the door carefully .",
        "Time flies like an arrow ."
    ]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for idx, (sentence, ax) in enumerate(zip(sentences, axes)):
        words = sentence.split()
        seq_len = len(words)
        
        # Create mock attention pattern based on linguistic principles
        attention = np.random.rand(seq_len, seq_len) * 0.1
        
        # Add some realistic patterns
        for i in range(seq_len):
            # Self-attention
            attention[i, i] += 0.3
            
            # Adjacent words
            if i > 0:
                attention[i, i-1] += 0.2
            if i < seq_len - 1:
                attention[i, i+1] += 0.2
        
        # Specific patterns for this sentence
        if idx == 0:  # "The cat sat on the mat"
            attention[2, 1] += 0.3  # sat -> cat
            attention[5, 1] += 0.2  # mat -> cat
        elif idx == 1:  # "She opened the door carefully"
            attention[1, 0] += 0.3  # opened -> She
            attention[4, 1] += 0.3  # carefully -> opened
        elif idx == 2:  # "Time flies like an arrow"
            attention[1, 0] += 0.4  # flies -> Time
            attention[4, 1] += 0.2  # arrow -> flies
        
        # Normalize
        attention = attention / attention.sum(axis=-1, keepdims=True)
        
        # Visualize
        im = ax.imshow(attention, cmap='Blues', vmin=0, vmax=0.5)
        ax.set_xticks(range(seq_len))
        ax.set_yticks(range(seq_len))
        ax.set_xticklabels(words, rotation=45, ha='right')
        ax.set_yticklabels(words)
        ax.set_title(f'Sentence {idx+1}')
        ax.set_xlabel('Attending to')
        if idx == 0:
            ax.set_ylabel('From position')
        
        # Add colorbar
        plt.colorbar(im, ax=ax)
    
    plt.suptitle('Typical Attention Patterns in Language', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    print("Common patterns observed:")
    print("1. Strong self-attention (diagonal)")
    print("2. Local attention to nearby words")
    print("3. Syntactic dependencies (subject-verb, verb-object)")
    print("4. Long-range semantic connections")

analyze_learned_patterns()

## 10. Summary and Key Takeaways

Let's summarize what we've learned about attention mechanisms.

In [None]:
def create_summary():
    print("=== Attention Mechanisms: Summary ===\n")
    
    # Create a visual summary
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Key concepts
    concepts = [
        "Query-Key-Value\nFramework",
        "Scaled Dot-Product\nAttention",
        "Parallel\nComputation",
        "No Information\nBottleneck",
        "Content-Based\nAddressing",
        "Interpretable\nWeights"
    ]
    
    # Position concepts in a circle
    n_concepts = len(concepts)
    angles = np.linspace(0, 2*np.pi, n_concepts, endpoint=False)
    radius = 3
    
    # Draw concepts
    for i, (concept, angle) in enumerate(zip(concepts, angles)):
        x = radius * np.cos(angle)
        y = radius * np.sin(angle)
        
        # Draw box
        box = plt.Rectangle((x-0.8, y-0.3), 1.6, 0.6, 
                           facecolor='lightblue', 
                           edgecolor='darkblue',
                           linewidth=2)
        ax.add_patch(box)
        
        # Add text
        ax.text(x, y, concept, ha='center', va='center', 
               fontsize=10, fontweight='bold')
    
    # Draw center
    center_circle = plt.Circle((0, 0), 1.2, facecolor='gold', 
                              edgecolor='darkorange', linewidth=3)
    ax.add_patch(center_circle)
    ax.text(0, 0, 'ATTENTION\nMECHANISM', ha='center', va='center',
           fontsize=14, fontweight='bold')
    
    # Draw connections
    for angle in angles:
        x1 = 1.2 * np.cos(angle)
        y1 = 1.2 * np.sin(angle)
        x2 = (radius - 0.8) * np.cos(angle)
        y2 = (radius - 0.8) * np.sin(angle)
        ax.plot([x1, x2], [y1, y2], 'k-', alpha=0.3, linewidth=2)
    
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.set_aspect('equal')
    ax.axis('off')
    
    plt.title('Attention Mechanisms: Core Concepts', fontsize=16, pad=20)
    plt.tight_layout()
    plt.show()
    
    # Print key formulas
    print("Key Formulas:\n")
    print("1. Attention(Q,K,V) = softmax(QK^T/√d_k)V")
    print("2. Q = XW_Q, K = XW_K, V = XW_V")
    print("3. MultiHead = Concat(head_1, ..., head_h)W_O")
    print("\nComplexity: O(n²d) time, O(n²) space")
    
    print("\n" + "="*50)
    print("🎯 You now understand attention mechanisms!")
    print("🚀 Next: See how transformers use attention as their core building block")
    print("="*50)

create_summary()

## Exercises

Try these exercises to deepen your understanding:

In [None]:
print("=== Exercises ===\n")

print("1. Implement Masked Attention")
print("   Create a function that applies different mask types (causal, padding)\n")

print("2. Temperature Scaling")
print("   Modify attention to include a temperature parameter")
print("   Observe how it affects the attention distribution\n")

print("3. Relative Position Encoding")
print("   Instead of absolute positions, implement attention with relative positions\n")

print("4. Sparse Attention")
print("   Implement a pattern where each position only attends to a subset\n")

print("5. Cross-Attention")
print("   Implement attention between two different sequences")
print("   (e.g., for translation or question-answering)\n")

# Exercise starter code
def exercise_masked_attention():
    """Exercise 1: Implement different mask types"""
    seq_len = 8
    
    # TODO: Create causal mask (lower triangular)
    causal_mask = None
    
    # TODO: Create padding mask (mask out positions 6,7)
    padding_mask = None
    
    # TODO: Apply masks to attention computation
    pass

print("\n💡 These exercises will prepare you for understanding transformers!")