# Advanced Attention Mechanisms in Transformers

This notebook explores advanced attention mechanisms beyond the basic multi-head attention covered in earlier notebooks. We'll dive deep into the computational complexity, optimization techniques, and modern variants that make transformers more efficient and powerful.

## Learning Objectives

By the end of this notebook, you will understand:
1. **Computational Complexity**: Why attention is O(n²) and what this means for scaling
2. **KV Caching**: How to optimize inference through key-value caching
3. **Sparse Attention**: Techniques to reduce attention complexity
4. **Attention Patterns**: How to visualize and interpret attention weights
5. **Modern Variants**: Flash Attention, Multi-Query Attention, and Grouped-Query Attention

## Prerequisites

- Understanding of basic attention mechanism (notebook 02)
- Familiarity with PyTorch tensors and operations
- Basic knowledge of computational complexity

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Optional, Tuple, List
import time
from dataclasses import dataclass

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Attention Complexity Analysis

### The O(n²) Problem

Standard attention computes:
```
Attention(Q, K, V) = softmax(QK^T / √d_k)V
```

Let's analyze why this is quadratic in sequence length and what it means for memory and computation.

In [None]:
class AttentionComplexityAnalyzer:
    """Analyzes computational complexity of attention mechanisms."""
    
    @staticmethod
    def compute_attention_flops(seq_len: int, d_model: int, n_heads: int) -> dict:
        """Compute FLOPs for attention computation."""
        d_k = d_model // n_heads
        
        # QK^T computation: (seq_len, d_k) @ (d_k, seq_len) = (seq_len, seq_len)
        qk_flops = seq_len * seq_len * d_k * n_heads
        
        # Softmax: roughly 3 operations per element (exp, sum, divide)
        softmax_flops = 3 * seq_len * seq_len * n_heads
        
        # Attention @ V: (seq_len, seq_len) @ (seq_len, d_k) = (seq_len, d_k)
        av_flops = seq_len * seq_len * d_k * n_heads
        
        # Linear projections: 4 projections (Q, K, V, output)
        linear_flops = 4 * seq_len * d_model * d_model
        
        total_flops = qk_flops + softmax_flops + av_flops + linear_flops
        
        return {
            'qk_computation': qk_flops,
            'softmax': softmax_flops,
            'attention_values': av_flops,
            'linear_projections': linear_flops,
            'total': total_flops,
            'quadratic_component': qk_flops + softmax_flops + av_flops,
            'linear_component': linear_flops
        }
    
    @staticmethod
    def compute_attention_memory(seq_len: int, d_model: int, n_heads: int, 
                               batch_size: int = 1) -> dict:
        """Compute memory requirements for attention."""
        d_k = d_model // n_heads
        
        # Input embeddings
        input_memory = batch_size * seq_len * d_model * 4  # 4 bytes per float32
        
        # Q, K, V matrices
        qkv_memory = 3 * batch_size * n_heads * seq_len * d_k * 4
        
        # Attention weights matrix (the big one!)
        attention_weights_memory = batch_size * n_heads * seq_len * seq_len * 4
        
        # Output
        output_memory = batch_size * seq_len * d_model * 4
        
        total_memory = input_memory + qkv_memory + attention_weights_memory + output_memory
        
        return {
            'input': input_memory,
            'qkv_matrices': qkv_memory,
            'attention_weights': attention_weights_memory,  # This is O(n²)!
            'output': output_memory,
            'total_bytes': total_memory,
            'total_mb': total_memory / (1024 * 1024),
            'total_gb': total_memory / (1024 * 1024 * 1024)
        }

# Analyze complexity for different sequence lengths
seq_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
d_model = 512
n_heads = 8
batch_size = 4

complexity_results = []
memory_results = []

for seq_len in seq_lengths:
    flops = AttentionComplexityAnalyzer.compute_attention_flops(seq_len, d_model, n_heads)
    memory = AttentionComplexityAnalyzer.compute_attention_memory(seq_len, d_model, n_heads, batch_size)
    
    complexity_results.append({
        'seq_len': seq_len,
        'total_gflops': flops['total'] / 1e9,
        'quadratic_gflops': flops['quadratic_component'] / 1e9,
        'linear_gflops': flops['linear_component'] / 1e9
    })
    
    memory_results.append({
        'seq_len': seq_len,
        'total_gb': memory['total_gb'],
        'attention_weights_gb': memory['attention_weights'] / (1024**3),
        'other_gb': (memory['total_bytes'] - memory['attention_weights']) / (1024**3)
    })

print("Computational Complexity Analysis:")
print("Seq Len | Total GFLOPs | Quadratic | Linear")
print("-" * 45)
for result in complexity_results:
    print(f"{result['seq_len']:7d} | {result['total_gflops']:11.2f} | {result['quadratic_gflops']:9.2f} | {result['linear_gflops']:6.2f}")

print("\nMemory Requirements Analysis:")
print("Seq Len | Total GB | Attn Weights GB | Other GB")
print("-" * 50)
for result in memory_results:
    print(f"{result['seq_len']:7d} | {result['total_gb']:8.2f} | {result['attention_weights_gb']:14.2f} | {result['other_gb']:8.2f}")

In [None]:
# Visualize the scaling behavior
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Computational complexity
seq_lens = [r['seq_len'] for r in complexity_results]
total_gflops = [r['total_gflops'] for r in complexity_results]
quad_gflops = [r['quadratic_gflops'] for r in complexity_results]
linear_gflops = [r['linear_gflops'] for r in complexity_results]

ax1.loglog(seq_lens, total_gflops, 'o-', label='Total', linewidth=2, markersize=8)
ax1.loglog(seq_lens, quad_gflops, 's--', label='Quadratic Component', linewidth=2, markersize=6)
ax1.loglog(seq_lens, linear_gflops, '^:', label='Linear Component', linewidth=2, markersize=6)

# Add theoretical O(n²) reference line
reference_n2 = [seq_lens[0]**2 * total_gflops[0] / seq_lens[0]**2 * (n/seq_lens[0])**2 for n in seq_lens]
ax1.loglog(seq_lens, reference_n2, 'k--', alpha=0.5, label='O(n²) reference')

ax1.set_xlabel('Sequence Length')
ax1.set_ylabel('GFLOPs')
ax1.set_title('Attention Computational Complexity')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Memory requirements
total_gb = [r['total_gb'] for r in memory_results]
attn_gb = [r['attention_weights_gb'] for r in memory_results]
other_gb = [r['other_gb'] for r in memory_results]

ax2.loglog(seq_lens, total_gb, 'o-', label='Total Memory', linewidth=2, markersize=8)
ax2.loglog(seq_lens, attn_gb, 's--', label='Attention Weights', linewidth=2, markersize=6)
ax2.loglog(seq_lens, other_gb, '^:', label='Other Components', linewidth=2, markersize=6)

ax2.set_xlabel('Sequence Length')
ax2.set_ylabel('Memory (GB)')
ax2.set_title('Attention Memory Requirements')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n📊 Key Observations:")
print(f"• At seq_len=8192, attention weights alone require {attn_gb[-1]:.1f} GB of memory!")
print(f"• Quadratic component dominates total FLOPs for long sequences")
print(f"• Memory grows as O(n²), making long sequences prohibitively expensive")

## 2. KV Caching for Efficient Inference

During autoregressive generation, we can cache previously computed key and value vectors to avoid redundant computation. This is crucial for efficient text generation.

In [None]:
class CachedMultiHeadAttention(nn.Module):
    """Multi-head attention with KV caching for efficient inference."""
    
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        
        # Cache for keys and values
        self.k_cache = None
        self.v_cache = None
        self.cache_len = 0
    
    def clear_cache(self):
        """Clear the KV cache."""
        self.k_cache = None
        self.v_cache = None
        self.cache_len = 0
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, 
                use_cache: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass with optional KV caching.
        
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Attention mask
            use_cache: Whether to use/update cache
        """
        batch_size, seq_len, d_model = x.shape
        
        # Compute Q, K, V
        Q = self.w_q(x)  # [batch_size, seq_len, d_model]
        K = self.w_k(x)  # [batch_size, seq_len, d_model]
        V = self.w_v(x)  # [batch_size, seq_len, d_model]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        if use_cache:
            if self.k_cache is not None:
                # Concatenate with cache
                K = torch.cat([self.k_cache, K], dim=2)
                V = torch.cat([self.v_cache, V], dim=2)
            
            # Update cache
            self.k_cache = K
            self.v_cache = V
            self.cache_len = K.size(2)
        
        # Attention computation
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        out = torch.matmul(attn_weights, V)
        
        # Reshape and project output
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out, attn_weights

# Demonstrate KV caching efficiency
def benchmark_kv_caching():
    """Benchmark the efficiency gain from KV caching."""
    d_model = 512
    n_heads = 8
    batch_size = 1
    max_len = 100
    
    # Create attention layer
    attention = CachedMultiHeadAttention(d_model, n_heads).to(device)
    
    # Simulate autoregressive generation
    input_seq = torch.randn(batch_size, 1, d_model).to(device)
    
    # Without caching - recompute everything each step
    print("Benchmarking without KV caching...")
    attention.clear_cache()
    
    start_time = time.time()
    current_seq = input_seq
    
    for step in range(max_len):
        # Create causal mask
        seq_len = current_seq.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len)).to(device)
        mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, n_heads, -1, -1)
        
        # Forward pass (recomputing everything)
        out, _ = attention(current_seq, mask, use_cache=False)
        
        # Add new token (random for demo)
        new_token = torch.randn(batch_size, 1, d_model).to(device)
        current_seq = torch.cat([current_seq, new_token], dim=1)
    
    time_without_cache = time.time() - start_time
    
    # With caching - only compute new K,V each step
    print("Benchmarking with KV caching...")
    attention.clear_cache()
    
    start_time = time.time()
    current_seq = input_seq
    
    for step in range(max_len):
        # Only need to pass the new token
        if step == 0:
            # First step - pass initial sequence
            seq_len = current_seq.size(1)
            mask = torch.tril(torch.ones(seq_len, seq_len)).to(device)
            mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, n_heads, -1, -1)
            out, _ = attention(current_seq, mask, use_cache=True)
        else:
            # Subsequent steps - only pass new token
            new_token = torch.randn(batch_size, 1, d_model).to(device)
            # Create mask for the new position
            cache_len = attention.cache_len
            mask = torch.tril(torch.ones(1, cache_len + 1)).to(device)
            mask = mask.unsqueeze(0).unsqueeze(0).expand(batch_size, n_heads, -1, -1)
            out, _ = attention(new_token, mask, use_cache=True)
            current_seq = torch.cat([current_seq, new_token], dim=1)
    
    time_with_cache = time.time() - start_time
    
    speedup = time_without_cache / time_with_cache
    
    print(f"\n⏱️  Performance Results:")
    print(f"Without KV caching: {time_without_cache:.3f}s")
    print(f"With KV caching:    {time_with_cache:.3f}s")
    print(f"Speedup:            {speedup:.1f}x")
    
    return speedup

speedup = benchmark_kv_caching()

## 3. Sparse Attention Patterns

To reduce the O(n²) complexity, various sparse attention patterns have been proposed. Let's implement and visualize some common patterns.

In [None]:
class SparseAttentionPatterns:
    """Collection of sparse attention pattern generators."""
    
    @staticmethod
    def create_local_attention_mask(seq_len: int, window_size: int) -> torch.Tensor:
        """Create local attention mask (each token attends to nearby tokens)."""
        mask = torch.zeros(seq_len, seq_len)
        
        for i in range(seq_len):
            start = max(0, i - window_size // 2)
            end = min(seq_len, i + window_size // 2 + 1)
            mask[i, start:end] = 1
        
        return mask
    
    @staticmethod
    def create_strided_attention_mask(seq_len: int, stride: int) -> torch.Tensor:
        """Create strided attention mask (attend to every k-th token)."""
        mask = torch.zeros(seq_len, seq_len)
        
        for i in range(seq_len):
            # Attend to positions at regular intervals
            positions = torch.arange(0, seq_len, stride)
            mask[i, positions] = 1
            # Always attend to self
            mask[i, i] = 1
        
        return mask
    
    @staticmethod
    def create_global_attention_mask(seq_len: int, num_global: int) -> torch.Tensor:
        """Create global attention mask (some tokens attend to all, all attend to globals)."""
        mask = torch.eye(seq_len)  # Self-attention
        
        # First num_global tokens are global
        mask[:num_global, :] = 1  # Global tokens attend to all
        mask[:, :num_global] = 1  # All tokens attend to global tokens
        
        return mask
    
    @staticmethod
    def create_block_sparse_mask(seq_len: int, block_size: int) -> torch.Tensor:
        """Create block sparse attention mask."""
        mask = torch.zeros(seq_len, seq_len)
        
        num_blocks = seq_len // block_size
        
        for i in range(num_blocks):
            for j in range(num_blocks):
                # Attend within block and to adjacent blocks
                if abs(i - j) <= 1:
                    start_i, end_i = i * block_size, (i + 1) * block_size
                    start_j, end_j = j * block_size, (j + 1) * block_size
                    mask[start_i:end_i, start_j:end_j] = 1
        
        return mask

# Visualize different sparse attention patterns
seq_len = 64
patterns = {
    'Full Attention': torch.tril(torch.ones(seq_len, seq_len)),
    'Local (window=8)': SparseAttentionPatterns.create_local_attention_mask(seq_len, 8),
    'Strided (stride=4)': SparseAttentionPatterns.create_strided_attention_mask(seq_len, 4),
    'Global (4 global)': SparseAttentionPatterns.create_global_attention_mask(seq_len, 4),
    'Block Sparse (8x8)': SparseAttentionPatterns.create_block_sparse_mask(seq_len, 8)
}

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for idx, (name, pattern) in enumerate(patterns.items()):
    axes[idx].imshow(pattern.numpy(), cmap='Blues', origin='upper')
    axes[idx].set_title(f'{name}\n{pattern.sum().item():.0f}/{seq_len**2} connections')
    axes[idx].set_xlabel('Key Position')
    if idx == 0:
        axes[idx].set_ylabel('Query Position')
    
    # Add sparsity information
    sparsity = 1 - (pattern.sum() / (seq_len ** 2))
    axes[idx].text(0.02, 0.98, f'Sparsity: {sparsity:.1%}', 
                  transform=axes[idx].transAxes, 
                  bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                  verticalalignment='top')

plt.tight_layout()
plt.show()

# Analyze complexity reduction
print("\nComplexity Analysis of Sparse Patterns:")
print("Pattern\t\t\tConnections\tReduction\tComplexity")
print("-" * 65)

for name, pattern in patterns.items():
    connections = pattern.sum().item()
    reduction = 1 - (connections / (seq_len ** 2))
    if 'Local' in name:
        complexity = "O(n·w)"  # w = window size
    elif 'Strided' in name:
        complexity = "O(n²/s)"  # s = stride
    elif 'Global' in name:
        complexity = "O(n·g + g²)"  # g = global tokens
    elif 'Block' in name:
        complexity = "O(n·b)"  # b = block size
    else:
        complexity = "O(n²)"
    
    print(f"{name:<20}\t{connections:>4.0f}\t{reduction:>6.1%}\t{complexity}")

## 4. Attention Pattern Visualization and Analysis

Understanding what attention patterns emerge during training is crucial for interpreting transformer behavior.

In [None]:
class AttentionAnalyzer:
    """Tools for analyzing and visualizing attention patterns."""
    
    @staticmethod
    def compute_attention_statistics(attention_weights: torch.Tensor) -> dict:
        """
        Compute statistics about attention patterns.
        
        Args:
            attention_weights: [batch, heads, seq_len, seq_len]
        """
        batch_size, n_heads, seq_len, _ = attention_weights.shape
        
        # Average across batch
        attn = attention_weights.mean(dim=0)  # [heads, seq_len, seq_len]
        
        # Compute entropy (measure of attention spread)
        entropy = -torch.sum(attn * torch.log(attn + 1e-8), dim=-1)  # [heads, seq_len]
        
        # Compute attention distance (how far attention looks)
        positions = torch.arange(seq_len, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        query_pos = positions.expand(n_heads, seq_len, 1)
        key_pos = positions.transpose(-1, -2).expand(n_heads, 1, seq_len)
        
        distances = torch.abs(query_pos - key_pos)
        avg_distance = torch.sum(attn * distances, dim=-1)  # [heads, seq_len]
        
        # Compute attention to self vs others
        self_attention = torch.diagonal(attn, dim1=-2, dim2=-1)  # [heads, seq_len]
        
        return {
            'entropy': entropy,
            'average_distance': avg_distance,
            'self_attention': self_attention,
            'max_attention': attn.max(dim=-1)[0],
            'attention_concentration': (attn**2).sum(dim=-1)  # Gini coefficient approximation
        }
    
    @staticmethod
    def visualize_attention_heads(attention_weights: torch.Tensor, 
                                 layer_name: str = "Layer",
                                 max_heads: int = 8):
        """
        Visualize attention patterns for multiple heads.
        
        Args:
            attention_weights: [batch, heads, seq_len, seq_len]
            layer_name: Name of the layer for the title
            max_heads: Maximum number of heads to visualize
        """
        # Take first batch, limit heads
        attn = attention_weights[0][:max_heads]  # [heads, seq_len, seq_len]
        n_heads_to_show = min(max_heads, attn.size(0))
        
        # Create subplots
        cols = 4
        rows = (n_heads_to_show + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3*rows))
        if rows == 1:
            axes = axes.reshape(1, -1)
        
        for head_idx in range(n_heads_to_show):
            row = head_idx // cols
            col = head_idx % cols
            
            # Plot attention pattern
            im = axes[row, col].imshow(attn[head_idx].detach().cpu().numpy(), 
                                     cmap='Blues', origin='upper')
            axes[row, col].set_title(f'Head {head_idx + 1}')
            axes[row, col].set_xlabel('Key Position')
            axes[row, col].set_ylabel('Query Position')
            
            # Add colorbar
            plt.colorbar(im, ax=axes[row, col], fraction=0.046, pad=0.04)
        
        # Hide unused subplots
        for head_idx in range(n_heads_to_show, rows * cols):
            row = head_idx // cols
            col = head_idx % cols
            axes[row, col].set_visible(False)
        
        plt.suptitle(f'{layer_name} - Attention Patterns', fontsize=16, y=1.02)
        plt.tight_layout()
        plt.show()
    
    @staticmethod
    def plot_attention_statistics(stats: dict, layer_name: str = "Layer"):
        """Plot various attention statistics."""
        n_heads, seq_len = stats['entropy'].shape
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        
        # 1. Entropy across positions
        for head in range(n_heads):
            axes[0, 0].plot(stats['entropy'][head].cpu(), 
                          label=f'Head {head+1}', alpha=0.7)
        axes[0, 0].set_title('Attention Entropy by Position')
        axes[0, 0].set_xlabel('Position')
        axes[0, 0].set_ylabel('Entropy')
        axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # 2. Average attention distance
        for head in range(n_heads):
            axes[0, 1].plot(stats['average_distance'][head].cpu(), 
                          label=f'Head {head+1}', alpha=0.7)
        axes[0, 1].set_title('Average Attention Distance')
        axes[0, 1].set_xlabel('Query Position')
        axes[0, 1].set_ylabel('Average Distance')
        
        # 3. Self-attention strength
        for head in range(n_heads):
            axes[0, 2].plot(stats['self_attention'][head].cpu(), 
                          label=f'Head {head+1}', alpha=0.7)
        axes[0, 2].set_title('Self-Attention Strength')
        axes[0, 2].set_xlabel('Position')
        axes[0, 2].set_ylabel('Self-Attention Weight')
        
        # 4. Head-wise statistics (boxplots)
        head_entropy_means = stats['entropy'].mean(dim=1).cpu()
        head_distance_means = stats['average_distance'].mean(dim=1).cpu()
        head_self_means = stats['self_attention'].mean(dim=1).cpu()
        
        head_labels = [f'H{i+1}' for i in range(n_heads)]
        
        axes[1, 0].bar(head_labels, head_entropy_means)
        axes[1, 0].set_title('Mean Entropy by Head')
        axes[1, 0].set_ylabel('Mean Entropy')
        
        axes[1, 1].bar(head_labels, head_distance_means)
        axes[1, 1].set_title('Mean Distance by Head')
        axes[1, 1].set_ylabel('Mean Distance')
        
        axes[1, 2].bar(head_labels, head_self_means)
        axes[1, 2].set_title('Mean Self-Attention by Head')
        axes[1, 2].set_ylabel('Mean Self-Attention')
        
        plt.suptitle(f'{layer_name} - Attention Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()

# Create a sample attention pattern to analyze
def create_synthetic_attention_patterns(batch_size=2, n_heads=6, seq_len=32):
    """Create synthetic attention patterns that mimic real transformer behavior."""
    attention_weights = torch.zeros(batch_size, n_heads, seq_len, seq_len)
    
    for b in range(batch_size):
        for h in range(n_heads):
            if h == 0:  # Local attention head
                for i in range(seq_len):
                    start = max(0, i-3)
                    end = min(seq_len, i+4)
                    attention_weights[b, h, i, start:end] = torch.softmax(
                        torch.randn(end-start), dim=0)
            
            elif h == 1:  # Global attention head (attends to beginning)
                for i in range(seq_len):
                    weights = torch.zeros(seq_len)
                    weights[:5] = torch.rand(5) * 2  # Higher weight on first tokens
                    weights[i] = torch.rand(1) * 2   # And self
                    attention_weights[b, h, i, :] = torch.softmax(weights, dim=0)
            
            elif h == 2:  # Self-attention focused
                for i in range(seq_len):
                    weights = torch.zeros(seq_len)
                    weights[i] = 3.0  # Strong self-attention
                    weights += torch.randn(seq_len) * 0.5  # Weak noise
                    attention_weights[b, h, i, :] = torch.softmax(weights, dim=0)
            
            else:  # Random patterns for other heads
                for i in range(seq_len):
                    # Create causal mask
                    weights = torch.randn(i+1) if i < seq_len-1 else torch.randn(seq_len)
                    full_weights = torch.full((seq_len,), -float('inf'))
                    full_weights[:i+1] = weights
                    attention_weights[b, h, i, :] = torch.softmax(full_weights, dim=0)
    
    return attention_weights

# Generate and analyze synthetic attention patterns
synthetic_attention = create_synthetic_attention_patterns()

# Visualize attention heads
AttentionAnalyzer.visualize_attention_heads(synthetic_attention, "Synthetic Layer")

# Compute and plot statistics
stats = AttentionAnalyzer.compute_attention_statistics(synthetic_attention)
AttentionAnalyzer.plot_attention_statistics(stats, "Synthetic Layer")

## 5. Modern Attention Variants

Let's implement some modern attention mechanisms that address efficiency and scaling concerns.

In [None]:
class MultiQueryAttention(nn.Module):
    """Multi-Query Attention (MQA) - shares K,V across heads."""
    
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # MQA: multiple Q heads, single K,V
        self.w_q = nn.Linear(d_model, d_model, bias=False)  # Full Q projection
        self.w_k = nn.Linear(d_model, self.d_k, bias=False)  # Single K projection
        self.w_v = nn.Linear(d_model, self.d_k, bias=False)  # Single V projection
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        batch_size, seq_len, d_model = x.shape
        
        # Compute Q (multiple heads), K, V (single head each)
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).unsqueeze(1).expand(-1, self.n_heads, -1, -1)  # Broadcast to all heads
        V = self.w_v(x).unsqueeze(1).expand(-1, self.n_heads, -1, -1)  # Broadcast to all heads
        
        # Attention computation (same as standard)
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        out = torch.matmul(attn_weights, V)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out, attn_weights


class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention (GQA) - groups heads for K,V sharing."""
    
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        assert n_heads % n_kv_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads  # How many Q heads per KV head
        self.d_k = d_model // n_heads
        
        # GQA: full Q, reduced K,V
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.w_v = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        batch_size, seq_len, d_model = x.shape
        
        # Compute Q, K, V
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        
        # Repeat K,V for each group
        K = K.repeat_interleave(self.n_rep, dim=1)  # [batch, n_heads, seq_len, d_k]
        V = V.repeat_interleave(self.n_rep, dim=1)  # [batch, n_heads, seq_len, d_k]
        
        # Standard attention computation
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        out = torch.matmul(attn_weights, V)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out, attn_weights


# Compare parameter counts and memory usage
def compare_attention_variants():
    """Compare different attention mechanisms."""
    d_model = 512
    n_heads = 8
    n_kv_heads = 2  # For GQA
    batch_size = 4
    seq_len = 1024
    
    # Create different attention mechanisms
    standard_attn = CachedMultiHeadAttention(d_model, n_heads)
    mqa_attn = MultiQueryAttention(d_model, n_heads)
    gqa_attn = GroupedQueryAttention(d_model, n_heads, n_kv_heads)
    
    mechanisms = {
        'Standard MHA': standard_attn,
        'Multi-Query (MQA)': mqa_attn,
        'Grouped-Query (GQA)': gqa_attn
    }
    
    results = []
    
    for name, mechanism in mechanisms.items():
        # Count parameters
        n_params = sum(p.numel() for p in mechanism.parameters())
        
        # Estimate KV cache memory (during inference)
        if name == 'Standard MHA':
            kv_cache_size = 2 * batch_size * n_heads * seq_len * (d_model // n_heads) * 4  # 4 bytes per float
        elif name == 'Multi-Query (MQA)':
            kv_cache_size = 2 * batch_size * 1 * seq_len * (d_model // n_heads) * 4  # Single KV
        else:  # GQA
            kv_cache_size = 2 * batch_size * n_kv_heads * seq_len * (d_model // n_heads) * 4
        
        results.append({
            'mechanism': name,
            'parameters': n_params,
            'kv_cache_mb': kv_cache_size / (1024 * 1024),
            'param_reduction': 0,  # Will calculate relative to standard
            'cache_reduction': 0   # Will calculate relative to standard
        })
    
    # Calculate relative reductions
    standard_params = results[0]['parameters']
    standard_cache = results[0]['kv_cache_mb']
    
    for result in results:
        result['param_reduction'] = 1 - (result['parameters'] / standard_params)
        result['cache_reduction'] = 1 - (result['kv_cache_mb'] / standard_cache)
    
    # Display results
    print("Attention Mechanism Comparison:")
    print("Mechanism\t\tParameters\tKV Cache (MB)\tParam Reduction\tCache Reduction")
    print("-" * 85)
    
    for result in results:
        print(f"{result['mechanism']:<20}\t{result['parameters']:>8,}\t{result['kv_cache_mb']:>10.1f}\t"
              f"{result['param_reduction']:>11.1%}\t{result['cache_reduction']:>11.1%}")
    
    return results

comparison_results = compare_attention_variants()

# Visualize the comparison
mechanisms = [r['mechanism'] for r in comparison_results]
param_counts = [r['parameters'] / 1000 for r in comparison_results]  # In thousands
cache_sizes = [r['kv_cache_mb'] for r in comparison_results]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Parameter comparison
bars1 = ax1.bar(mechanisms, param_counts, color=['skyblue', 'lightcoral', 'lightgreen'])
ax1.set_title('Parameter Count Comparison')
ax1.set_ylabel('Parameters (thousands)')
ax1.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, count in zip(bars1, param_counts):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
             f'{count:.0f}K', ha='center', va='bottom')

# KV Cache comparison
bars2 = ax2.bar(mechanisms, cache_sizes, color=['skyblue', 'lightcoral', 'lightgreen'])
ax2.set_title('KV Cache Memory Comparison')
ax2.set_ylabel('Memory (MB)')
ax2.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, size in zip(bars2, cache_sizes):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
             f'{size:.1f}MB', ha='center', va='bottom')

plt.tight_layout()
plt.show()

print(f"\n🎯 Key Insights:")
print(f"• MQA reduces KV cache by {comparison_results[1]['cache_reduction']:.0%} (critical for long sequences)")
print(f"• GQA balances between MHA and MQA, reducing cache by {comparison_results[2]['cache_reduction']:.0%}")
print(f"• Parameter reduction is modest but KV cache reduction is substantial")
print(f"• These optimizations become critical for inference at scale")

## 6. Practical Implementation Tips

Let's look at some practical considerations when implementing advanced attention mechanisms.

In [None]:
class AttentionOptimizationUtils:
    """Utilities for optimizing attention computation."""
    
    @staticmethod
    def fused_scaled_dot_product_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                                         mask: Optional[torch.Tensor] = None,
                                         dropout_p: float = 0.0,
                                         is_causal: bool = False) -> torch.Tensor:
        """
        Fused scaled dot-product attention (simulates Flash Attention concept).
        In practice, you'd use torch.nn.functional.scaled_dot_product_attention in PyTorch 2.0+
        """
        # This is a simplified version - real Flash Attention uses tiling and recomputation
        scale = 1.0 / math.sqrt(Q.size(-1))
        
        if is_causal:
            # Create causal mask
            seq_len = Q.size(-2)
            causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=Q.device))
            if mask is not None:
                mask = mask * causal_mask
            else:
                mask = causal_mask
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -float('inf'))
        
        # Apply softmax
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply dropout
        if dropout_p > 0.0:
            attn_weights = F.dropout(attn_weights, p=dropout_p)
        
        # Apply attention to values
        output = torch.matmul(attn_weights, V)
        
        return output
    
    @staticmethod
    def benchmark_attention_variants(seq_lengths: List[int], 
                                   d_model: int = 512, 
                                   n_heads: int = 8,
                                   batch_size: int = 4):
        """Benchmark different attention implementations."""
        results = {}
        
        for seq_len in seq_lengths:
            print(f"\nBenchmarking seq_len = {seq_len}...")
            
            # Create random inputs
            x = torch.randn(batch_size, seq_len, d_model, device=device)
            
            d_k = d_model // n_heads
            Q = torch.randn(batch_size, n_heads, seq_len, d_k, device=device)
            K = torch.randn(batch_size, n_heads, seq_len, d_k, device=device)
            V = torch.randn(batch_size, n_heads, seq_len, d_k, device=device)
            
            # Standard attention
            start_time = time.time()
            for _ in range(10):  # Multiple runs for stable timing
                scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
                attn_weights = F.softmax(scores, dim=-1)
                output = torch.matmul(attn_weights, V)
                torch.cuda.synchronize() if device.type == 'cuda' else None
            standard_time = (time.time() - start_time) / 10
            
            # Fused attention (simulated)
            start_time = time.time()
            for _ in range(10):
                output = AttentionOptimizationUtils.fused_scaled_dot_product_attention(Q, K, V)
                torch.cuda.synchronize() if device.type == 'cuda' else None
            fused_time = (time.time() - start_time) / 10
            
            results[seq_len] = {
                'standard_time': standard_time,
                'fused_time': fused_time,
                'speedup': standard_time / fused_time
            }
            
            print(f"  Standard: {standard_time:.4f}s")
            print(f"  Fused:    {fused_time:.4f}s")
            print(f"  Speedup:  {standard_time/fused_time:.2f}x")
        
        return results

# Run benchmarks (skip if no GPU available)
if device.type == 'cuda':
    print("Running attention benchmarks on GPU...")
    benchmark_results = AttentionOptimizationUtils.benchmark_attention_variants(
        seq_lengths=[128, 256, 512, 1024],
        d_model=512,
        n_heads=8,
        batch_size=4
    )
else:
    print("Skipping GPU benchmarks (no CUDA device available)")
    # Create dummy results for visualization
    benchmark_results = {
        128: {'standard_time': 0.005, 'fused_time': 0.003, 'speedup': 1.67},
        256: {'standard_time': 0.015, 'fused_time': 0.008, 'speedup': 1.88},
        512: {'standard_time': 0.045, 'fused_time': 0.020, 'speedup': 2.25},
        1024: {'standard_time': 0.150, 'fused_time': 0.055, 'speedup': 2.73}
    }

# Visualize benchmark results
if benchmark_results:
    seq_lens = list(benchmark_results.keys())
    standard_times = [benchmark_results[s]['standard_time'] * 1000 for s in seq_lens]  # Convert to ms
    fused_times = [benchmark_results[s]['fused_time'] * 1000 for s in seq_lens]
    speedups = [benchmark_results[s]['speedup'] for s in seq_lens]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Timing comparison
    x = np.arange(len(seq_lens))
    width = 0.35
    
    ax1.bar(x - width/2, standard_times, width, label='Standard Attention', alpha=0.8)
    ax1.bar(x + width/2, fused_times, width, label='Fused Attention', alpha=0.8)
    
    ax1.set_xlabel('Sequence Length')
    ax1.set_ylabel('Time (ms)')
    ax1.set_title('Attention Implementation Timing')
    ax1.set_xticks(x)
    ax1.set_xticklabels(seq_lens)
    ax1.legend()
    ax1.set_yscale('log')
    
    # Speedup
    ax2.plot(seq_lens, speedups, 'o-', linewidth=2, markersize=8)
    ax2.set_xlabel('Sequence Length')
    ax2.set_ylabel('Speedup (x)')
    ax2.set_title('Fused Attention Speedup')
    ax2.grid(True, alpha=0.3)
    
    # Add horizontal line at 1x
    ax2.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.show()

print(f"\n🚀 Optimization Summary:")
print(f"• Fused implementations reduce memory bandwidth requirements")
print(f"• Speedup generally increases with sequence length")
print(f"• Modern frameworks provide optimized attention kernels")
print(f"• Flash Attention achieves O(n) memory complexity for attention")

## Summary and Key Takeaways

### 🎯 What We've Learned

1. **Computational Complexity**:
   - Standard attention is O(n²) in both time and memory
   - The attention weight matrix dominates memory usage for long sequences
   - Quadratic scaling becomes prohibitive beyond ~8K tokens

2. **KV Caching**:
   - Essential optimization for autoregressive generation
   - Provides substantial speedup by avoiding redundant computation
   - Critical for real-time inference applications

3. **Sparse Attention Patterns**:
   - Local, strided, global, and block-sparse patterns reduce complexity
   - Trade-off between efficiency and modeling capability
   - Different patterns suit different tasks and sequence types

4. **Modern Variants**:
   - **Multi-Query Attention (MQA)**: Shares K,V across heads, reduces KV cache
   - **Grouped-Query Attention (GQA)**: Balances between MHA and MQA
   - **Flash Attention**: Achieves O(n) memory complexity through tiling

5. **Practical Considerations**:
   - Memory optimization is often more important than FLOP reduction
   - Hardware-aware implementations provide significant speedups
   - Modern frameworks offer optimized attention kernels

### 🔄 Next Steps

In the next notebook (08_modern_architecture_improvements), we'll explore:
- RMSNorm vs LayerNorm
- SwiGLU activation functions
- Rotary Position Embedding (RoPE)
- Pre-norm vs Post-norm architectures

### 📚 Further Reading

- **Flash Attention**: Dao et al. (2022) - "FlashAttention: Fast and Memory-Efficient Exact Attention"
- **Multi-Query Attention**: Shazeer (2019) - "Fast Transformer Decoding"
- **Sparse Attention**: Child et al. (2019) - "Generating Long Sequences with Sparse Transformers"
- **Grouped-Query Attention**: Ainslie et al. (2023) - "GQA: Training Generalized Multi-Query Transformer"