# Advanced Attention Mechanisms: Practical Optimizations

In the basic attention notebook, we learned the core mechanism. Now let's explore practical improvements that make transformers faster and more efficient in real applications.

## What You'll Learn

1. **KV Caching** - Speed up inference by caching key-value pairs
2. **Sparse Attention** - Reduce complexity with smart attention patterns
3. **Modern Variants** - Multi-Query and Grouped-Query Attention

These optimizations are used in production systems to make transformers practical at scale!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Optional, Tuple, List
import time
from dataclasses import dataclass

# Import our basic attention mechanism
from src.model.attention import MultiHeadAttention

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('default')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Environment setup complete!")

## 1. KV Caching for Efficient Inference

During autoregressive generation, we can cache previously computed key and value vectors to avoid redundant computation. This is crucial for efficient text generation.

**The Problem**: In normal autoregressive generation, we recompute K and V for all previous tokens at every step. This is wasteful!

**The Solution**: Cache the K and V tensors and just append new ones for new tokens.

## 2. Sparse Attention Patterns

To reduce the O(n²) complexity, various sparse attention patterns have been proposed. Let's implement and visualize some common patterns.

**Why Sparse Attention?**
- Standard attention is O(n²) in memory and computation
- Becomes prohibitive for long sequences (>8K tokens)
- Many attention weights are close to zero anyway
- Smart sparsity patterns can maintain model quality

In [None]:
class CachedMultiHeadAttention(nn.Module):
    """Multi-head attention with KV caching for faster inference."""
    
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
        
        # Cache for key and value tensors
        self.kv_cache = {}
    
    def forward(self, query, key, value, mask=None, use_cache=False, cache_key="default"):
        batch_size, seq_len, _ = query.shape
        
        # Linear projections
        Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        if use_cache and cache_key in self.kv_cache:
            # Use cached K, V and append new ones
            cached_K, cached_V = self.kv_cache[cache_key]
            
            new_K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
            new_V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
            
            K = torch.cat([cached_K, new_K], dim=2)  # Concatenate along sequence dimension
            V = torch.cat([cached_V, new_V], dim=2)
            
            # Update cache
            self.kv_cache[cache_key] = (K, V)
        else:
            # Fresh computation
            K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
            V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
            
            if use_cache:
                self.kv_cache[cache_key] = (K, V)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.w_o(attn_output)
        
        return output
    
    def clear_cache(self):
        """Clear the KV cache."""
        self.kv_cache.clear()

# Demonstrate KV caching benefits
print("🚀 KV CACHING DEMONSTRATION")
print("=" * 40)

d_model, n_heads = 256, 8
cached_attn = CachedMultiHeadAttention(d_model, n_heads)

# Simulate autoregressive generation
seq_lens = [1, 2, 3, 4, 5]  # Growing sequence lengths
x = torch.randn(1, 1, d_model)  # Start with one token

print("Simulating autoregressive generation:")
for i, seq_len in enumerate(seq_lens):
    if i == 0:
        # First step - no cache
        output = cached_attn(x, x, x, use_cache=True, cache_key="gen")
        print(f"Step {i+1}: Added token, output shape: {output.shape}")
    else:
        # Subsequent steps - use cache
        new_token = torch.randn(1, 1, d_model)
        output = cached_attn(new_token, new_token, new_token, use_cache=True, cache_key="gen")
        print(f"Step {i+1}: Added token, output shape: {output.shape}")

print("\n✅ KV caching reduces computation in autoregressive generation!")
print("💡 Instead of recomputing all K,V pairs, we reuse cached ones")

In [None]:
class SparseAttentionPatterns:
    """Collection of sparse attention pattern generators."""
    
    @staticmethod
    def create_local_attention_mask(seq_len: int, window_size: int) -> torch.Tensor:
        """Create local attention mask (each token attends to nearby tokens)."""
        mask = torch.zeros(seq_len, seq_len)
        
        for i in range(seq_len):
            start = max(0, i - window_size // 2)
            end = min(seq_len, i + window_size // 2 + 1)
            mask[i, start:end] = 1
        
        return mask
    
    @staticmethod
    def create_strided_attention_mask(seq_len: int, stride: int) -> torch.Tensor:
        """Create strided attention mask (attend to every k-th token)."""
        mask = torch.zeros(seq_len, seq_len)
        
        for i in range(seq_len):
            # Attend to positions at regular intervals
            positions = torch.arange(0, seq_len, stride)
            mask[i, positions] = 1
            # Always attend to self
            mask[i, i] = 1
        
        return mask
    
    @staticmethod
    def create_global_attention_mask(seq_len: int, num_global: int) -> torch.Tensor:
        """Create global attention mask (some tokens attend to all, all attend to globals)."""
        mask = torch.eye(seq_len)  # Self-attention
        
        # First num_global tokens are global
        mask[:num_global, :] = 1  # Global tokens attend to all
        mask[:, :num_global] = 1  # All tokens attend to global tokens
        
        return mask
    
    @staticmethod
    def create_block_sparse_mask(seq_len: int, block_size: int) -> torch.Tensor:
        """Create block sparse attention mask."""
        mask = torch.zeros(seq_len, seq_len)
        
        num_blocks = seq_len // block_size
        
        for i in range(num_blocks):
            for j in range(num_blocks):
                # Attend within block and to adjacent blocks
                if abs(i - j) <= 1:
                    start_i, end_i = i * block_size, (i + 1) * block_size
                    start_j, end_j = j * block_size, (j + 1) * block_size
                    mask[start_i:end_i, start_j:end_j] = 1
        
        return mask

# Visualize different sparse attention patterns
seq_len = 64
patterns = {
    'Full Attention': torch.tril(torch.ones(seq_len, seq_len)),
    'Local (window=8)': SparseAttentionPatterns.create_local_attention_mask(seq_len, 8),
    'Strided (stride=4)': SparseAttentionPatterns.create_strided_attention_mask(seq_len, 4),
    'Global (4 global)': SparseAttentionPatterns.create_global_attention_mask(seq_len, 4),
    'Block Sparse (8x8)': SparseAttentionPatterns.create_block_sparse_mask(seq_len, 8)
}

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for idx, (name, pattern) in enumerate(patterns.items()):
    axes[idx].imshow(pattern.numpy(), cmap='Blues', origin='upper')
    axes[idx].set_title(f'{name}\n{pattern.sum().item():.0f}/{seq_len**2} connections')
    axes[idx].set_xlabel('Key Position')
    if idx == 0:
        axes[idx].set_ylabel('Query Position')
    
    # Add sparsity information
    sparsity = 1 - (pattern.sum() / (seq_len ** 2))
    axes[idx].text(0.02, 0.98, f'Sparsity: {sparsity:.1%}', 
                  transform=axes[idx].transAxes, 
                  bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                  verticalalignment='top')

plt.tight_layout()
plt.show()

# Analyze complexity reduction
print("\nComplexity Analysis of Sparse Patterns:")
print("Pattern\t\t\tConnections\tReduction\tComplexity")
print("-" * 65)

for name, pattern in patterns.items():
    connections = pattern.sum().item()
    reduction = 1 - (connections / (seq_len ** 2))
    if 'Local' in name:
        complexity = "O(n·w)"  # w = window size
    elif 'Strided' in name:
        complexity = "O(n²/s)"  # s = stride
    elif 'Global' in name:
        complexity = "O(n·g + g²)"  # g = global tokens
    elif 'Block' in name:
        complexity = "O(n·b)"  # b = block size
    else:
        complexity = "O(n²)"
    
    print(f"{name:<20}\t{connections:>4.0f}\t{reduction:>6.1%}\t{complexity}")

## 3. Modern Attention Variants

Let's implement some modern attention mechanisms that address efficiency and scaling concerns.

**Key Ideas:**
- **Multi-Query Attention (MQA)**: Share K,V across all heads
- **Grouped-Query Attention (GQA)**: Share K,V within groups of heads
- Both reduce KV cache size and improve inference speed

class MultiQueryAttention(nn.Module):
    """Multi-Query Attention: One key/value head, multiple query heads."""
    
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Multiple query heads, single key/value head
        self.w_q = nn.Linear(d_model, d_model, bias=False)  # n_heads query heads
        self.w_k = nn.Linear(d_model, self.d_k, bias=False)  # 1 key head
        self.w_v = nn.Linear(d_model, self.d_k, bias=False)  # 1 value head
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        # Multiple query heads
        Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Single key and value heads (broadcast to all query heads)
        K = self.w_k(key).view(batch_size, seq_len, 1, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, seq_len, 1, self.d_k).transpose(1, 2)
        
        # Expand K, V to match Q heads
        K = K.expand(-1, self.n_heads, -1, -1)
        V = V.expand(-1, self.n_heads, -1, -1)
        
        # Standard scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.w_o(attn_output)
        
        return output


class GroupedQueryAttention(nn.Module):
    """Grouped-Query Attention: Groups of query heads share key/value heads."""
    
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        super().__init__()
        
        assert d_model % n_heads == 0
        assert n_heads % n_kv_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.d_k = d_model // n_heads
        self.group_size = n_heads // n_kv_heads
        
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.w_v = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.shape
        
        # Query heads (full set)
        Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Fewer key/value heads
        K = self.w_k(key).view(batch_size, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, seq_len, self.n_kv_heads, self.d_k).transpose(1, 2)
        
        # Expand K, V to match Q heads by repeating each K,V head group_size times
        K = K.repeat_interleave(self.group_size, dim=1)
        V = V.repeat_interleave(self.group_size, dim=1)
        
        # Standard scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.w_o(attn_output)
        
        return output


# Compare different attention mechanisms
print("🔄 MODERN ATTENTION VARIANTS COMPARISON")
print("=" * 50)

d_model, seq_len = 256, 32
batch_size = 1

# Create test input
x = torch.randn(batch_size, seq_len, d_model)

# Standard Multi-Head Attention
mha = MultiHeadAttention(d_model, n_heads=8)
mha_params = sum(p.numel() for p in mha.parameters())

# Multi-Query Attention (8 query heads, 1 kv head)
mqa = MultiQueryAttention(d_model, n_heads=8)
mqa_params = sum(p.numel() for p in mqa.parameters())

# Grouped-Query Attention (8 query heads, 2 kv heads)
gqa = GroupedQueryAttention(d_model, n_heads=8, n_kv_heads=2)
gqa_params = sum(p.numel() for p in gqa.parameters())

print(f"Parameter comparison:")
print(f"Multi-Head Attention (MHA):     {mha_params:,} params")
print(f"Multi-Query Attention (MQA):    {mqa_params:,} params ({mha_params/mqa_params:.1f}x reduction)")
print(f"Grouped-Query Attention (GQA):  {gqa_params:,} params ({mha_params/gqa_params:.1f}x reduction)")

# Test forward passes
mha_out = mha(x, x, x)
mqa_out = mqa(x, x, x)
gqa_out = gqa(x, x, x)

print(f"\nOutput shapes (all should be identical):")
print(f"MHA output: {mha_out.shape}")
print(f"MQA output: {mqa_out.shape}")
print(f"GQA output: {gqa_out.shape}")

print(f"\n🎯 Key Benefits:")
print(f"• MQA: Fewer parameters, faster inference")
print(f"• GQA: Balance between efficiency and quality")
print(f"• Both maintain same output dimensions as standard attention")

## Summary

You've learned the essential attention optimizations used in modern transformer systems!

### Key Techniques:

1. **KV Caching** - Cache key-value pairs during autoregressive generation for faster inference
2. **Sparse Attention** - Use smart attention patterns to reduce O(n²) complexity  
3. **Multi-Query Attention (MQA)** - Share key/value heads across multiple query heads
4. **Grouped-Query Attention (GQA)** - Balance between efficiency and quality

### Real-World Impact:
- **KV Caching**: Essential for fast text generation in chatbots and language models
- **Sparse Attention**: Enables processing of longer sequences (Longformer, BigBird)
- **MQA/GQA**: Used in modern models like PaLM, LLaMA for efficient inference

### Key Takeaways:
- Memory optimization often more important than FLOP reduction
- Different sparsity patterns suit different tasks and sequence types
- Modern variants maintain model quality while improving efficiency
- These optimizations are crucial for production transformer systems

These techniques bridge the gap between basic attention and production-ready transformers! 🚀