# Notebook 04: GPU-Accelerated Attention

## Scaling Attention with CUDA

Now that you understand attention conceptually, let's implement it efficiently on GPU! In this notebook:

1. **PyTorch GPU Implementation** - Converting CPU code to GPU
2. **Performance Benchmarking** - CPU vs GPU speedup
3. **Memory Optimization** - Managing large attention matrices
4. **Batched Processing** - Handling multiple sequences efficiently

This is where transformers truly shine - parallel attention across massive datasets!

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
from typing import Tuple, Optional

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è CUDA not available. Using CPU (will be slow).")

## Part 1: GPU Attention Implementation

### PyTorch Implementation

In [None]:
def scaled_dot_product_attention_gpu(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    GPU-accelerated scaled dot-product attention.
    
    Args:
        Q: Queries (batch, n_heads, seq_len, d_k)
        K: Keys (batch, n_heads, seq_len, d_k)
        V: Values (batch, n_heads, seq_len, d_v)
        mask: Optional mask (batch, n_heads, seq_len, seq_len)
    
    Returns:
        output: Attention output (batch, n_heads, seq_len, d_v)
        attention_weights: (batch, n_heads, seq_len, seq_len)
    """
    d_k = Q.size(-1)
    
    # Compute attention scores: Q @ K^T / sqrt(d_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32, device=Q.device))
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test the implementation
batch_size = 4
n_heads = 8
seq_len = 16
d_k = 64

Q = torch.randn(batch_size, n_heads, seq_len, d_k, device=device)
K = torch.randn(batch_size, n_heads, seq_len, d_k, device=device)
V = torch.randn(batch_size, n_heads, seq_len, d_k, device=device)

output, attn_weights = scaled_dot_product_attention_gpu(Q, K, V)

print(f"‚úÖ GPU Attention Implementation")
print(f"Input Q shape: {Q.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights sum check: {attn_weights.sum(dim=-1)[0, 0, 0]:.6f} (should be 1.0)")

## Part 2: Performance Comparison - CPU vs GPU

Let's benchmark the performance difference!

In [None]:
def benchmark_attention(seq_lengths: list, d_k: int = 64, batch_size: int = 32) -> dict:
    """Benchmark attention performance across different sequence lengths."""
    cpu_times = []
    gpu_times = []
    
    for seq_len in seq_lengths:
        print(f"Benchmarking seq_len={seq_len}...")
        
        # Create test data
        Q_cpu = torch.randn(batch_size, 1, seq_len, d_k)
        K_cpu = torch.randn(batch_size, 1, seq_len, d_k)
        V_cpu = torch.randn(batch_size, 1, seq_len, d_k)
        
        # CPU benchmark
        start = time.time()
        _, _ = scaled_dot_product_attention_gpu(Q_cpu, K_cpu, V_cpu)
        cpu_times.append((time.time() - start) * 1000)
        
        # GPU benchmark
        if torch.cuda.is_available():
            Q_gpu = Q_cpu.to(device)
            K_gpu = K_cpu.to(device)
            V_gpu = V_cpu.to(device)
            
            # Warmup
            _, _ = scaled_dot_product_attention_gpu(Q_gpu, K_gpu, V_gpu)
            torch.cuda.synchronize()
            
            # Actual benchmark
            start = time.time()
            _, _ = scaled_dot_product_attention_gpu(Q_gpu, K_gpu, V_gpu)
            torch.cuda.synchronize()
            gpu_times.append((time.time() - start) * 1000)
        else:
            gpu_times.append(0)
    
    return {'cpu': cpu_times, 'gpu': gpu_times}

seq_lengths = [32, 64, 128, 256, 512]
results = benchmark_attention(seq_lengths)

# Display results
print("\n" + "="*70)
print("Attention Performance Benchmark")
print("="*70)
print(f"{'Seq Length':<12} {'CPU (ms)':<15} {'GPU (ms)':<15} {'Speedup':<12}")
print("="*70)

for seq_len, cpu_t, gpu_t in zip(seq_lengths, results['cpu'], results['gpu']):
    speedup = cpu_t / gpu_t if gpu_t > 0 else 0
    print(f"{seq_len:<12} {cpu_t:<15.2f} {gpu_t:<15.4f} {speedup:<12.1f}x")

In [None]:
# Visualize performance
if torch.cuda.is_available():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Time comparison
    ax1.semilogy(seq_lengths, results['cpu'], 'o-', label='CPU', linewidth=2, markersize=8)
    ax1.semilogy(seq_lengths, results['gpu'], 's-', label='GPU', linewidth=2, markersize=8)
    ax1.set_xlabel('Sequence Length', fontsize=12)
    ax1.set_ylabel('Time (ms, log scale)', fontsize=12)
    ax1.set_title('Attention Performance: CPU vs GPU', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Speedup
    speedups = [cpu / gpu for cpu, gpu in zip(results['cpu'], results['gpu'])]
    ax2.plot(seq_lengths, speedups, 'o-', linewidth=2, markersize=8, color='green')
    ax2.set_xlabel('Sequence Length', fontsize=12)
    ax2.set_ylabel('Speedup (√ó)', fontsize=12)
    ax2.set_title('GPU Speedup over CPU', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.axhline(y=1, color='r', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìä Maximum speedup: {max(speedups):.1f}x at seq_len={seq_lengths[speedups.index(max(speedups))]}")

## Part 3: Memory Considerations

### The Memory Challenge

Attention requires $O(n^2)$ memory for the attention matrix, where $n$ is sequence length.

**Example:** For seq_len=1024, batch=32, heads=8:
- Attention matrix: `32 √ó 8 √ó 1024 √ó 1024 √ó 4 bytes = 1 GB!`

In [None]:
def estimate_attention_memory(batch_size: int, n_heads: int, seq_len: int, d_k: int) -> dict:
    """Estimate memory requirements for attention."""
    # All values in bytes (float32 = 4 bytes)
    bytes_per_element = 4
    
    # Q, K, V storage
    qkv_memory = 3 * batch_size * n_heads * seq_len * d_k * bytes_per_element
    
    # Attention scores matrix
    scores_memory = batch_size * n_heads * seq_len * seq_len * bytes_per_element
    
    # Output
    output_memory = batch_size * n_heads * seq_len * d_k * bytes_per_element
    
    total_memory = qkv_memory + scores_memory + output_memory
    
    return {
        'QKV (MB)': qkv_memory / 1e6,
        'Scores (MB)': scores_memory / 1e6,
        'Output (MB)': output_memory / 1e6,
        'Total (MB)': total_memory / 1e6,
        'Total (GB)': total_memory / 1e9
    }

# Test different configurations
configs = [
    (32, 8, 512, 64),
    (32, 8, 1024, 64),
    (32, 8, 2048, 64),
    (64, 16, 1024, 64)
]

print("Memory Requirements for Different Configurations:")
print("="*80)
print(f"{'Config':<25} {'QKV (MB)':<12} {'Scores (MB)':<15} {'Total (MB)':<12}")
print("="*80)

for batch, heads, seq, d_k in configs:
    mem = estimate_attention_memory(batch, heads, seq, d_k)
    config_str = f"B={batch}, H={heads}, L={seq}"
    print(f"{config_str:<25} {mem['QKV (MB)']:<12.1f} {mem['Scores (MB)']:<15.1f} {mem['Total (MB)']:<12.1f}")

print("\nüí° Key Insight: Attention scores dominate memory for large sequences!")

## Part 4: Multi-Head Attention

### Why Multiple Heads?

Multi-head attention allows the model to attend to different aspects simultaneously:
- Head 1: Syntactic relationships
- Head 2: Semantic meaning
- Head 3: Long-range dependencies
- etc.

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.W_Q = torch.nn.Linear(d_model, d_model)
        self.W_K = torch.nn.Linear(d_model, d_model)
        self.W_V = torch.nn.Linear(d_model, d_model)
        self.W_O = torch.nn.Linear(d_model, d_model)
        
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, 
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size = Q.size(0)
        
        # Linear projections and reshape to (batch, n_heads, seq_len, d_k)
        Q = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        attn_output, _ = scaled_dot_product_attention_gpu(Q, K, V, mask)
        
        # Concatenate heads: (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Final linear projection
        output = self.W_O(attn_output)
        
        return output

# Test multi-head attention
d_model = 512
n_heads = 8
batch_size = 16
seq_len = 32

mha = MultiHeadAttention(d_model, n_heads).to(device)
x = torch.randn(batch_size, seq_len, d_model, device=device)

output = mha(x, x, x)

print(f"‚úÖ Multi-Head Attention")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in mha.parameters()):,}")

### Visualizing Attention Heads

In [None]:
# Get attention weights from individual heads
with torch.no_grad():
    x_small = torch.randn(1, 10, d_model, device=device)
    
    Q = mha.W_Q(x_small).view(1, 10, n_heads, mha.d_k).transpose(1, 2)
    K = mha.W_K(x_small).view(1, 10, n_heads, mha.d_k).transpose(1, 2)
    V = mha.W_V(x_small).view(1, 10, n_heads, mha.d_k).transpose(1, 2)
    
    _, attn_weights = scaled_dot_product_attention_gpu(Q, K, V)
    attn_weights = attn_weights.cpu().numpy()[0]  # (n_heads, seq_len, seq_len)

# Plot first 4 attention heads
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for i in range(4):
    sns.heatmap(attn_weights[i], annot=True, fmt='.2f', cmap='viridis',
                ax=axes[i], cbar_kws={'label': 'Attention Weight'})
    axes[i].set_title(f'Attention Head {i+1}', fontsize=12, fontweight='bold')
    axes[i].set_xlabel('Key Position')
    axes[i].set_ylabel('Query Position')

plt.tight_layout()
plt.show()

print("\nüí° Notice: Different heads learn different attention patterns!")

## Part 5: Optimizations and Best Practices

### 1. Gradient Checkpointing (Trading Compute for Memory)

In [None]:
# Compare memory usage with/without gradient checkpointing
from torch.utils.checkpoint import checkpoint

def measure_memory_usage(model, x, use_checkpoint=False):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    if use_checkpoint:
        output = checkpoint(model, x, x, x)
    else:
        output = model(x, x, x)
    
    loss = output.sum()
    loss.backward()
    
    peak_memory = torch.cuda.max_memory_allocated() / 1e6
    return peak_memory

if torch.cuda.is_available():
    mha = MultiHeadAttention(512, 8).to(device)
    x = torch.randn(8, 64, 512, device=device, requires_grad=True)
    
    mem_normal = measure_memory_usage(mha, x, use_checkpoint=False)
    mem_checkpoint = measure_memory_usage(mha, x, use_checkpoint=True)
    
    print(f"Memory Usage Comparison:")
    print(f"  Normal: {mem_normal:.1f} MB")
    print(f"  With Checkpointing: {mem_checkpoint:.1f} MB")
    print(f"  Savings: {(1 - mem_checkpoint/mem_normal) * 100:.1f}%")

### 2. Flash Attention (Modern Optimization)

PyTorch 2.0+ includes optimized attention implementations:

In [None]:
# Use PyTorch's optimized scaled_dot_product_attention (if available)
if hasattr(F, 'scaled_dot_product_attention'):
    def fast_attention(Q, K, V, mask=None):
        """Use PyTorch's optimized implementation."""
        return F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
    
    print("‚úÖ PyTorch optimized attention available!")
    print("   This includes Flash Attention optimizations")
else:
    print("‚ö†Ô∏è Using manual implementation (PyTorch < 2.0)")
    print("   Consider upgrading for better performance")

## Exercise Section

### Exercise 1: Attention Dropout
Add dropout to attention weights:
- Apply after softmax
- Compare training stability with/without

In [None]:
# TODO: Implement attention with dropout

### Exercise 2: Cross-Attention GPU
Implement cross-attention on GPU:
- Different Q vs K, V sequences
- Benchmark performance

In [None]:
# TODO: Implement cross-attention

### Exercise 3: Memory Profiling
Profile memory usage for different configurations:
- Vary batch size, sequence length, number of heads
- Find maximum feasible configuration for your GPU

In [None]:
# TODO: Profile memory usage

## Summary

### Key Takeaways

‚úÖ **GPU Acceleration:**
- PyTorch handles CUDA operations automatically
- Massive speedup for large sequences (50-100x+)
- Critical for training large transformers

‚úÖ **Memory Management:**
- Attention has $O(n^2)$ memory complexity
- Gradient checkpointing trades compute for memory
- Monitor GPU memory usage carefully

‚úÖ **Multi-Head Attention:**
- Parallel attention over different representations
- Each head learns unique patterns
- Concatenate and project back to d_model

‚úÖ **Optimization:**
- Use PyTorch's optimized implementations when available
- Flash Attention for better memory/speed
- Batch operations for efficiency

### Next Steps

In **Notebook 05**, we'll build:
- Complete transformer encoder block
- Feed-forward networks
- Layer normalization and residual connections
- Full encoder stack

## Further reading (Archive.org)

To connect attention mechanisms with GPU implementation details, search Archive.org for:

- "GPU deep learning"
- "high performance deep learning"
- "CUDA deep learning kernels"

Look for discussions of memory access patterns, kernel fusion, and batching strategies, which will help you reason about how an attention kernel can be optimized beyond the straightforward PyTorch implementation used here.