## Problem: Implement KV Cache for Efficient Autoregressive Generation

### Background

In autoregressive models (like GPT), we generate tokens one at a time. At each step, we compute attention over **all previous tokens**. Without optimization, this means:
- At step 1: Compute K, V for token 0
- At step 2: Compute K, V for tokens 0, 1 (recomputing token 0!)
- At step 3: Compute K, V for tokens 0, 1, 2 (recomputing tokens 0, 1!)

**KV Cache** solves this by storing previously computed K and V tensors:
- At step 1: Compute K₀, V₀, cache them
- At step 2: Compute K₁, V₁, concatenate with cached [K₀, V₀]
- At step 3: Compute K₂, V₂, concatenate with cached [K₀, K₁], [V₀, V₁]

This reduces computation from O(n²) to O(n) for generating n tokens!

### Mathematical Formulation

Without cache:
```
Q_t = W_q @ X[0:t+1]  # Query for all tokens up to t
K_t = W_k @ X[0:t+1]  # Recompute keys for all tokens
V_t = W_v @ X[0:t+1]  # Recompute values for all tokens
```

With cache:
```
Q_t = W_q @ X[t]           # Query only for new token
K_new = W_k @ X[t]         # Key only for new token
V_new = W_v @ X[t]         # Value only for new token
K_cached = concat(K_cache, K_new)  # Append to cache
V_cached = concat(V_cache, V_new)  # Append to cache
```

### Learning Objectives

1. Implement attention with KV caching
2. Understand cache management (initialization, updates)
3. Measure performance improvements
4. Handle edge cases (first token, cache limits)

### References
- [Attention is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- [FlashAttention paper](https://arxiv.org/abs/2205.14135) - discusses memory optimization

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from typing import Optional, Tuple

### Step 1: Baseline Attention (No Cache)

First, let's implement standard scaled dot-product attention without any caching.

In [None]:
def attention_no_cache(q, k, v, mask=None):
    """
    Standard scaled dot-product attention without caching.
    
    Args:
        q: Query [batch, seq_len_q, d_k]
        k: Key [batch, seq_len_k, d_k]
        v: Value [batch, seq_len_k, d_v]
        mask: Optional attention mask
    
    Returns:
        output: [batch, seq_len_q, d_v]
    """
    d_k = q.shape[-1]
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v)
    
    return output

### Step 2: Attention with KV Cache

Now implement attention that maintains a cache of Key and Value tensors.

In [None]:
class KVCache:
    """
    Cache for storing Key and Value tensors during autoregressive generation.
    """
    def __init__(self):
        self.k_cache = None  # [batch, seq_len, d_k]
        self.v_cache = None  # [batch, seq_len, d_v]
    
    def update(self, k_new, v_new):
        """
        Add new key and value tensors to the cache.
        
        Args:
            k_new: New keys [batch, 1, d_k]
            v_new: New values [batch, 1, d_v]
        
        Returns:
            k_cached: All keys including new [batch, seq_len, d_k]
            v_cached: All values including new [batch, seq_len, d_v]
        """
        if self.k_cache is None:
            # First token - initialize cache
            self.k_cache = k_new
            self.v_cache = v_new
        else:
            # Subsequent tokens - concatenate with cache
            self.k_cache = torch.cat([self.k_cache, k_new], dim=1)
            self.v_cache = torch.cat([self.v_cache, v_new], dim=1)
        
        return self.k_cache, self.v_cache
    
    def clear(self):
        """Reset the cache."""
        self.k_cache = None
        self.v_cache = None


def attention_with_cache(
    q, k_new, v_new, 
    cache: Optional[KVCache] = None,
    mask=None
):
    """
    Scaled dot-product attention with KV caching.
    
    Args:
        q: Query for current token [batch, 1, d_k]
        k_new: Key for current token [batch, 1, d_k]
        v_new: Value for current token [batch, 1, d_v]
        cache: KVCache object to store/retrieve cached K,V
        mask: Optional attention mask
    
    Returns:
        output: Attention output [batch, 1, d_v]
    """
    if cache is None:
        # No cache provided - just do normal attention
        return attention_no_cache(q, k_new, v_new, mask)
    
    # Update cache with new K, V
    k_cached, v_cached = cache.update(k_new, v_new)
    
    # Compute attention using all cached keys and values
    d_k = q.shape[-1]
    scores = torch.matmul(q, k_cached.transpose(-2, -1)) / (d_k ** 0.5)
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v_cached)
    
    return output

### Step 3: Test Correctness

Verify that cached attention produces the same results as non-cached.

In [None]:
# Setup test parameters
torch.manual_seed(42)
batch_size = 2
d_model = 64
seq_len = 5

# Simulate autoregressive generation
# We'll generate seq_len tokens one at a time
X = torch.randn(batch_size, seq_len, d_model)

# Projection matrices (shared for both methods)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

print("Testing correctness of KV cache implementation...")
print("=" * 60)

# Method 1: Without cache (recompute everything each step)
outputs_no_cache = []
for t in range(seq_len):
    # At step t, we attend over tokens 0...t
    x_current = X[:, :t+1, :]  # [batch, t+1, d_model]
    
    q = W_q(x_current[:, -1:, :])  # Query for last token only
    k = W_k(x_current)              # Keys for all tokens 0...t
    v = W_v(x_current)              # Values for all tokens 0...t
    
    output = attention_no_cache(q, k, v)
    outputs_no_cache.append(output)

# Method 2: With cache (only compute new K, V each step)
cache = KVCache()
outputs_with_cache = []
for t in range(seq_len):
    x_token = X[:, t:t+1, :]  # Current token [batch, 1, d_model]
    
    q = W_q(x_token)  # Query for current token
    k = W_k(x_token)  # Key for current token ONLY
    v = W_v(x_token)  # Value for current token ONLY
    
    output = attention_with_cache(q, k, v, cache=cache)
    outputs_with_cache.append(output)

# Compare outputs
print("\nComparing outputs at each timestep:\n")
all_match = True
for t in range(seq_len):
    match = torch.allclose(outputs_no_cache[t], outputs_with_cache[t], atol=1e-6, rtol=1e-5)
    status = "✓" if match else "✗"
    print(f"  Step {t}: {status} {'Match' if match else 'Mismatch'}")
    if not match:
        all_match = False
        print(f"    Max diff: {(outputs_no_cache[t] - outputs_with_cache[t]).abs().max():.2e}")

print()
if all_match:
    print("✓ All outputs match! KV cache is working correctly.")
else:
    print("✗ Outputs don't match. Check your implementation.")

### Step 4: Measure Performance Improvement

Compare the computational cost of cached vs non-cached attention.

In [None]:
# Benchmark with longer sequences
torch.manual_seed(42)
batch_size = 4
d_model = 512
seq_len = 100  # Generate 100 tokens

X = torch.randn(batch_size, seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

print("Performance Comparison")
print("=" * 60)
print(f"Batch size: {batch_size}")
print(f"Model dimension: {d_model}")
print(f"Sequence length: {seq_len} tokens\n")

# Benchmark without cache
start = time.time()
for t in range(seq_len):
    x_current = X[:, :t+1, :]
    q = W_q(x_current[:, -1:, :])
    k = W_k(x_current)
    v = W_v(x_current)
    _ = attention_no_cache(q, k, v)
time_no_cache = time.time() - start

# Benchmark with cache
cache = KVCache()
start = time.time()
for t in range(seq_len):
    x_token = X[:, t:t+1, :]
    q = W_q(x_token)
    k = W_k(x_token)
    v = W_v(x_token)
    _ = attention_with_cache(q, k, v, cache=cache)
time_with_cache = time.time() - start

speedup = time_no_cache / time_with_cache

print(f"Time without cache: {time_no_cache:.4f}s")
print(f"Time with cache:    {time_with_cache:.4f}s")
print(f"\nSpeedup: {speedup:.2f}x faster")
print(f"Time saved: {(time_no_cache - time_with_cache):.4f}s ({(1 - time_with_cache/time_no_cache)*100:.1f}% reduction)")

# Theoretical analysis
print("\n" + "=" * 60)
print("Theoretical Complexity Analysis")
print("=" * 60)
print(f"Without cache: O(n²) operations")
print(f"  - At each step t, compute K,V for all t tokens")
print(f"  - Total: 1 + 2 + 3 + ... + {seq_len} = {seq_len*(seq_len+1)//2} computations")
print(f"\nWith cache: O(n) operations")
print(f"  - At each step, compute K,V for 1 new token only")
print(f"  - Total: {seq_len} computations")
print(f"\nTheoretical speedup: ~{seq_len/2:.1f}x for this sequence length")

### Step 5: Multi-Head Attention with KV Cache

Extend KV cache to work with multi-head attention.

In [None]:
class MultiHeadKVCache:
    """
    KV Cache for multi-head attention.
    Stores separate caches for each attention head.
    """
    def __init__(self, num_heads):
        self.num_heads = num_heads
        self.k_cache = None  # [batch, num_heads, seq_len, d_k]
        self.v_cache = None  # [batch, num_heads, seq_len, d_v]
    
    def update(self, k_new, v_new):
        """
        Update cache with new multi-head K, V tensors.
        
        Args:
            k_new: [batch, num_heads, 1, d_k]
            v_new: [batch, num_heads, 1, d_v]
        """
        if self.k_cache is None:
            self.k_cache = k_new
            self.v_cache = v_new
        else:
            self.k_cache = torch.cat([self.k_cache, k_new], dim=2)  # Concat on seq_len dim
            self.v_cache = torch.cat([self.v_cache, v_new], dim=2)
        
        return self.k_cache, self.v_cache
    
    def clear(self):
        self.k_cache = None
        self.v_cache = None


def multi_head_attention_with_cache(
    q, k_new, v_new,
    num_heads,
    cache: Optional[MultiHeadKVCache] = None
):
    """
    Multi-head attention with KV caching.
    
    Args:
        q: Query [batch, 1, d_model]
        k_new: New key [batch, 1, d_model]
        v_new: New value [batch, 1, d_model]
        num_heads: Number of attention heads
        cache: MultiHeadKVCache object
    """
    batch_size = q.shape[0]
    d_model = q.shape[-1]
    d_k = d_model // num_heads
    
    # Reshape to multi-head: [batch, num_heads, seq_len, d_k]
    q = q.view(batch_size, -1, num_heads, d_k).transpose(1, 2)  # [batch, heads, 1, d_k]
    k_new = k_new.view(batch_size, -1, num_heads, d_k).transpose(1, 2)
    v_new = v_new.view(batch_size, -1, num_heads, d_k).transpose(1, 2)
    
    # Update cache
    if cache is not None:
        k_cached, v_cached = cache.update(k_new, v_new)
    else:
        k_cached, v_cached = k_new, v_new
    
    # Scaled dot-product attention
    scores = torch.matmul(q, k_cached.transpose(-2, -1)) / (d_k ** 0.5)
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v_cached)  # [batch, heads, 1, d_k]
    
    # Concatenate heads
    output = output.transpose(1, 2).contiguous().view(batch_size, -1, d_model)
    
    return output


# Quick test
print("Testing Multi-Head Attention with KV Cache")
print("=" * 60)

batch_size = 2
seq_len = 5
d_model = 64
num_heads = 8

X = torch.randn(batch_size, seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

cache = MultiHeadKVCache(num_heads)
for t in range(seq_len):
    x_token = X[:, t:t+1, :]
    q = W_q(x_token)
    k = W_k(x_token)
    v = W_v(x_token)
    output = multi_head_attention_with_cache(q, k, v, num_heads, cache=cache)
    print(f"Step {t}: Output shape = {output.shape}")

print(f"\nFinal cache size: K={cache.k_cache.shape}, V={cache.v_cache.shape}")
print("✓ Multi-head KV cache working correctly!")

## Summary

### What We Learned

1. **KV Cache Concept**: Store previously computed K and V tensors to avoid recomputation
2. **Implementation**: Simple concatenation with cached tensors
3. **Performance**: ~50x speedup for 100-token sequences (O(n²) → O(n))
4. **Multi-Head Extension**: Cache works independently for each attention head

### Key Insights

- **Memory vs Speed Trade-off**: Cache uses O(n) memory but saves O(n²) computation
- **Critical for Inference**: Essential for fast autoregressive generation in production
- **Used Everywhere**: GPT, LLaMA, Claude, etc. all use KV caching

### Interview Tips

Common questions:
- Why is KV cache needed? (Avoid recomputing K,V for previous tokens)
- What's the complexity improvement? (O(n²) → O(n))
- What's the memory cost? (O(batch × seq_len × d_model))
- How does it work with beam search? (Need separate cache per beam)