## Problem: Implement KV Cache for Efficient Autoregressive Generation

### Background

In autoregressive models (like GPT), we generate tokens one at a time. At each step, we compute attention over **all previous tokens**. Without optimization, this means:
- At step 1: Compute K, V for token 0
- At step 2: Compute K, V for tokens 0, 1 (recomputing token 0!)
- At step 3: Compute K, V for tokens 0, 1, 2 (recomputing tokens 0, 1!)

**KV Cache** solves this by storing previously computed K and V tensors:
- At step 1: Compute K‚ÇÄ, V‚ÇÄ, cache them
- At step 2: Compute K‚ÇÅ, V‚ÇÅ, concatenate with cached [K‚ÇÄ, V‚ÇÄ]
- At step 3: Compute K‚ÇÇ, V‚ÇÇ, concatenate with cached [K‚ÇÄ, K‚ÇÅ], [V‚ÇÄ, V‚ÇÅ]

This reduces computation from O(n¬≤) to O(n) for generating n tokens!

### Mathematical Formulation

Without cache:
```
Q_t = W_q @ X[0:t+1]  # Query for all tokens up to t
K_t = W_k @ X[0:t+1]  # Recompute keys for all tokens
V_t = W_v @ X[0:t+1]  # Recompute values for all tokens
```

With cache:
```
Q_t = W_q @ X[t]           # Query only for new token
K_new = W_k @ X[t]         # Key only for new token
V_new = W_v @ X[t]         # Value only for new token
K_cached = concat(K_cache, K_new)  # Append to cache
V_cached = concat(V_cache, V_new)  # Append to cache
```

### Learning Objectives

1. Implement attention with KV caching
2. Understand cache management (initialization, updates)
3. Measure performance improvements
4. Handle edge cases (first token, cache limits)

<details>
  <summary>üí° Hint 1: Cache Initialization</summary>
  For the first token, initialize the cache with the computed K and V. For subsequent tokens, concatenate new K,V with cached values along the sequence dimension.
</details>

<details>
  <summary>üí° Hint 2: Concatenation</summary>
  Use `torch.cat([cached, new], dim=1)` where dim=1 is the sequence length dimension.
</details>

<details>
  <summary>üí° Hint 3: Multi-Head</summary>
  For multi-head attention, cache shape should be [batch, num_heads, seq_len, d_k]. Concatenate along dim=2 (sequence dimension).
</details>

### References
- [Attention is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- [FlashAttention paper](https://arxiv.org/abs/2205.14135) - discusses memory optimization

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from typing import Optional, Tuple

### Step 1: Baseline Attention (No Cache)

First, let's implement standard scaled dot-product attention without any caching.

In [None]:
def attention_no_cache(q, k, v, mask=None):
    """
    Standard scaled dot-product attention without caching.
    
    Args:
        q: Query [batch, seq_len_q, d_k]
        k: Key [batch, seq_len_k, d_k]
        v: Value [batch, seq_len_k, d_v]
        mask: Optional attention mask
    
    Returns:
        output: [batch, seq_len_q, d_v]
    """
    # TODO: Compute scaled dot-product attention
    # Hint: scores = Q @ K^T / sqrt(d_k)
    
    pass

### Step 2: Attention with KV Cache

Now implement attention that maintains a cache of Key and Value tensors.

In [None]:
class KVCache:
    """
    Cache for storing Key and Value tensors during autoregressive generation.
    """
    def __init__(self):
        self.k_cache = None  # [batch, seq_len, d_k]
        self.v_cache = None  # [batch, seq_len, d_v]
    
    def update(self, k_new, v_new):
        """
        Add new key and value tensors to the cache.
        
        Args:
            k_new: New keys [batch, 1, d_k]
            v_new: New values [batch, 1, d_v]
        
        Returns:
            k_cached: All keys including new [batch, seq_len, d_k]
            v_cached: All values including new [batch, seq_len, d_v]
        """
        # TODO: Implement cache update logic
        # If cache is None (first token): initialize with k_new, v_new
        # Otherwise: concatenate k_new, v_new with cached values
        # Hint: torch.cat([self.k_cache, k_new], dim=1)
        
        pass
    
    def clear(self):
        """Reset the cache."""
        self.k_cache = None
        self.v_cache = None


def attention_with_cache(
    q, k_new, v_new, 
    cache: Optional[KVCache] = None,
    mask=None
):
    """
    Scaled dot-product attention with KV caching.
    
    Args:
        q: Query for current token [batch, 1, d_k]
        k_new: Key for current token [batch, 1, d_k]
        v_new: Value for current token [batch, 1, d_v]
        cache: KVCache object to store/retrieve cached K,V
        mask: Optional attention mask
    
    Returns:
        output: Attention output [batch, 1, d_v]
    """
    # TODO: Implement attention with KV cache
    # 1. Update cache with new K, V
    # 2. Compute attention using query and ALL cached keys/values
    
    pass

### Step 3: Test Correctness

Verify that cached attention produces the same results as non-cached.

In [None]:
# Setup test parameters
torch.manual_seed(42)
batch_size = 2
d_model = 64
seq_len = 5

# Simulate autoregressive generation
X = torch.randn(batch_size, seq_len, d_model)

# Projection matrices (shared for both methods)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

print("Testing correctness of KV cache implementation...")
print("=" * 60)

# Method 1: Without cache (recompute everything each step)
outputs_no_cache = []
for t in range(seq_len):
    x_current = X[:, :t+1, :]
    q = W_q(x_current[:, -1:, :])
    k = W_k(x_current)
    v = W_v(x_current)
    output = attention_no_cache(q, k, v)
    outputs_no_cache.append(output)

# Method 2: With cache (only compute new K, V each step)
cache = KVCache()
outputs_with_cache = []
for t in range(seq_len):
    x_token = X[:, t:t+1, :]
    q = W_q(x_token)
    k = W_k(x_token)
    v = W_v(x_token)
    output = attention_with_cache(q, k, v, cache=cache)
    outputs_with_cache.append(output)

# Compare outputs
print("\nComparing outputs at each timestep:\n")
all_match = True
for t in range(seq_len):
    match = torch.allclose(outputs_no_cache[t], outputs_with_cache[t], atol=1e-6, rtol=1e-5)
    status = "‚úì" if match else "‚úó"
    print(f"  Step {t}: {status} {'Match' if match else 'Mismatch'}")
    if not match:
        all_match = False

print()
if all_match:
    print("‚úì All outputs match! KV cache is working correctly.")
else:
    print("‚úó Outputs don't match. Check your implementation.")

### Step 4: Measure Performance Improvement

Compare the computational cost of cached vs non-cached attention.

In [None]:
# Benchmark with longer sequences
torch.manual_seed(42)
batch_size = 4
d_model = 512
seq_len = 100

X = torch.randn(batch_size, seq_len, d_model)
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

print("Performance Comparison")
print("=" * 60)
print(f"Sequence length: {seq_len} tokens\n")

# Benchmark without cache
start = time.time()
for t in range(seq_len):
    x_current = X[:, :t+1, :]
    q = W_q(x_current[:, -1:, :])
    k = W_k(x_current)
    v = W_v(x_current)
    _ = attention_no_cache(q, k, v)
time_no_cache = time.time() - start

# Benchmark with cache
cache = KVCache()
start = time.time()
for t in range(seq_len):
    x_token = X[:, t:t+1, :]
    q = W_q(x_token)
    k = W_k(x_token)
    v = W_v(x_token)
    _ = attention_with_cache(q, k, v, cache=cache)
time_with_cache = time.time() - start

speedup = time_no_cache / time_with_cache

print(f"Time without cache: {time_no_cache:.4f}s")
print(f"Time with cache:    {time_with_cache:.4f}s")
print(f"\nSpeedup: {speedup:.2f}x faster")
print(f"Time saved: {(time_no_cache - time_with_cache):.4f}s")

### Step 5: Multi-Head Attention with KV Cache (Bonus)

Extend KV cache to work with multi-head attention.

In [None]:
class MultiHeadKVCache:
    """
    KV Cache for multi-head attention.
    """
    def __init__(self, num_heads):
        self.num_heads = num_heads
        self.k_cache = None  # [batch, num_heads, seq_len, d_k]
        self.v_cache = None  # [batch, num_heads, seq_len, d_v]
    
    def update(self, k_new, v_new):
        """
        Update cache with new multi-head K, V tensors.
        
        Args:
            k_new: [batch, num_heads, 1, d_k]
            v_new: [batch, num_heads, 1, d_v]
        """
        # TODO: Implement multi-head cache update
        # Hint: Concatenate along dim=2 (sequence length dimension)
        
        pass
    
    def clear(self):
        self.k_cache = None
        self.v_cache = None


def multi_head_attention_with_cache(
    q, k_new, v_new,
    num_heads,
    cache: Optional[MultiHeadKVCache] = None
):
    """
    Multi-head attention with KV caching.
    """
    # TODO: Implement multi-head attention with cache
    # 1. Reshape Q, K, V to [batch, num_heads, seq_len, d_k]
    # 2. Update cache
    # 3. Compute attention
    # 4. Concatenate heads back
    
    pass

## Summary

### Key Concepts

- **KV Cache**: Store K,V from previous tokens to avoid recomputation
- **Performance**: O(n¬≤) ‚Üí O(n) complexity reduction
- **Trade-off**: Uses O(n) memory but saves massive computation
- **Production Critical**: All modern LLMs use KV caching for inference

### Interview Tips

Be ready to answer:
- Why is KV cache needed? (Avoid recomputing K,V)
- What's the complexity improvement? (O(n¬≤) ‚Üí O(n))
- What's the memory cost? (O(batch √ó seq_len √ó d_model))
- How does it work with beam search? (Separate cache per beam)