# Module 1.6: KV Cache & Attention Optimization

**Goal**: Understand autoregressive generation optimization

**Time**: 50 minutes

**Concepts Covered**:
- Naive generation (O(n²) problem)
- KV cache implementation
- Grouped Query Attention (GQA)
- Multi-Head Latent Attention (MLA)
- Benchmark: MHA vs GQA vs MLA

## Setup

In [None]:
!pip install torch numpy matplotlib seaborn transformers -q

In [None]:
import torch
import torch.nn as nn
import time

torch.manual_seed(42)

## Naive Generation (O(n²))

In [None]:
# Simulate naive generation
def naive_generate(model, prompt, max_len=10):
    sequence = prompt
    for _ in range(max_len):
        # Recompute attention for entire sequence each time
        output = model(sequence)
        next_token = output[:, -1:].argmax(dim=-1)
        sequence = torch.cat([sequence, next_token], dim=1)
    return sequence

print("Naive generation recomputes all previous tokens each step!")

## KV Cache Implementation

In [None]:
class CachedAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, cache=None):
        batch, seq, _ = x.shape
        Q = self.W_q(x).view(batch, seq, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch, seq, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch, seq, self.n_heads, self.d_k).transpose(1, 2)
        
        if cache is not None:
            # Concatenate with cache
            K = torch.cat([cache['k'], K], dim=2)
            V = torch.cat([cache['v'], V], dim=2)
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        
        # Update cache
        new_cache = {'k': K, 'v': V}
        
        out = out.transpose(1, 2).contiguous().view(batch, seq, self.d_model)
        return self.W_o(out), new_cache

print("KV cache stores K and V from previous tokens!")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.