# Part 4, Lab 5: KV Cache Optimization

**Time:** ~45 minutes

The KV cache is often the memory bottleneck in LLM inference. This lab explores KV cache sizing, quantization, and management strategies.

## Learning Objectives

1. Calculate KV cache memory requirements
2. Implement KV cache quantization
3. Understand PagedAttention concepts
4. Optimize for long-context inference

In [None]:
import numpy as np
import torch
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

---
## 1. KV Cache Memory Calculation

KV cache grows linearly with sequence length and batch size:

```
Memory = 2 × layers × heads × head_dim × seq_len × batch × bytes_per_element
```

In [None]:
def calculate_kv_cache_size(model_config, seq_len, batch_size, dtype_bytes=2):
    """
    Calculate KV cache memory in bytes.
    
    Args:
        model_config: dict with 'layers', 'heads', 'head_dim'
        seq_len: sequence length
        batch_size: batch size
        dtype_bytes: bytes per element (2 for FP16, 1 for INT8/FP8)
    """
    return (2 *  # K and V
            model_config['layers'] *
            model_config['heads'] *
            model_config['head_dim'] *
            seq_len *
            batch_size *
            dtype_bytes)

# Common model configurations
models = {
    'Llama-2-7B': {'layers': 32, 'heads': 32, 'head_dim': 128},
    'Llama-2-13B': {'layers': 40, 'heads': 40, 'head_dim': 128},
    'Llama-2-70B': {'layers': 80, 'heads': 64, 'head_dim': 128},
    'Llama-3-8B': {'layers': 32, 'heads': 32, 'head_dim': 128},
    'Llama-3-70B': {'layers': 80, 'heads': 64, 'head_dim': 128},
}

print("KV Cache Memory Requirements (FP16, batch=1):")
print("=" * 60)
print(f"{'Model':<15} {'4K ctx':>10} {'16K ctx':>10} {'32K ctx':>10} {'128K ctx':>10}")
print("-" * 60)

for name, config in models.items():
    sizes = []
    for ctx in [4096, 16384, 32768, 131072]:
        size_gb = calculate_kv_cache_size(config, ctx, 1, 2) / (1024**3)
        sizes.append(f"{size_gb:.1f} GB")
    print(f"{name:<15} {sizes[0]:>10} {sizes[1]:>10} {sizes[2]:>10} {sizes[3]:>10}")

---
## 2. KV Cache Quantization

Quantizing KV cache to INT8 or FP8 halves memory with minimal quality loss.

In [None]:
class QuantizedKVCache:
    """KV cache with INT8 quantization."""
    
    def __init__(self, num_layers, num_heads, head_dim, max_seq_len, device='cuda'):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.device = device
        
        # Quantized storage
        self.k_cache = torch.zeros(
            num_layers, max_seq_len, num_heads, head_dim,
            dtype=torch.int8, device=device
        )
        self.v_cache = torch.zeros(
            num_layers, max_seq_len, num_heads, head_dim,
            dtype=torch.int8, device=device
        )
        
        # Per-head scale factors
        self.k_scales = torch.ones(
            num_layers, max_seq_len, num_heads,
            dtype=torch.float16, device=device
        )
        self.v_scales = torch.ones(
            num_layers, max_seq_len, num_heads,
            dtype=torch.float16, device=device
        )
        
        self.seq_len = 0
    
    def update(self, layer_idx, k, v):
        """
        Update cache with new K, V tensors.
        k, v: [batch=1, num_heads, seq_len, head_dim]
        """
        new_tokens = k.shape[2]
        start_pos = self.seq_len
        
        # Quantize K
        k_reshaped = k.squeeze(0).transpose(0, 1)  # [seq, heads, dim]
        k_abs_max = k_reshaped.abs().max(dim=-1, keepdim=True)[0]
        k_scale = k_abs_max / 127.0
        k_scale = k_scale.clamp(min=1e-8)
        k_int8 = (k_reshaped / k_scale).round().clamp(-128, 127).to(torch.int8)
        
        # Quantize V
        v_reshaped = v.squeeze(0).transpose(0, 1)  # [seq, heads, dim]
        v_abs_max = v_reshaped.abs().max(dim=-1, keepdim=True)[0]
        v_scale = v_abs_max / 127.0
        v_scale = v_scale.clamp(min=1e-8)
        v_int8 = (v_reshaped / v_scale).round().clamp(-128, 127).to(torch.int8)
        
        # Store
        self.k_cache[layer_idx, start_pos:start_pos+new_tokens] = k_int8
        self.v_cache[layer_idx, start_pos:start_pos+new_tokens] = v_int8
        self.k_scales[layer_idx, start_pos:start_pos+new_tokens] = k_scale.squeeze(-1).half()
        self.v_scales[layer_idx, start_pos:start_pos+new_tokens] = v_scale.squeeze(-1).half()
        
        if layer_idx == self.num_layers - 1:
            self.seq_len += new_tokens
    
    def get(self, layer_idx):
        """Get dequantized K, V for attention."""
        k_int8 = self.k_cache[layer_idx, :self.seq_len]  # [seq, heads, dim]
        v_int8 = self.v_cache[layer_idx, :self.seq_len]
        k_scale = self.k_scales[layer_idx, :self.seq_len].unsqueeze(-1)
        v_scale = self.v_scales[layer_idx, :self.seq_len].unsqueeze(-1)
        
        k = k_int8.float() * k_scale
        v = v_int8.float() * v_scale
        
        # Reshape back: [seq, heads, dim] -> [1, heads, seq, dim]
        k = k.transpose(0, 1).unsqueeze(0)
        v = v.transpose(0, 1).unsqueeze(0)
        
        return k, v

# Test quantized KV cache
cache = QuantizedKVCache(
    num_layers=4, num_heads=8, head_dim=64,
    max_seq_len=1024, device=device
)

# Simulate adding tokens
for step in range(10):
    k = torch.randn(1, 8, 1, 64, device=device)  # 1 new token
    v = torch.randn(1, 8, 1, 64, device=device)
    for layer in range(4):
        cache.update(layer, k, v)

# Retrieve
k_retrieved, v_retrieved = cache.get(0)
print(f"Cache sequence length: {cache.seq_len}")
print(f"Retrieved K shape: {k_retrieved.shape}")
print(f"Memory: {cache.k_cache.numel() + cache.v_cache.numel()} bytes (INT8)")

---
## 3. PagedAttention Concepts

PagedAttention (vLLM) manages KV cache like virtual memory with fixed-size blocks.

In [None]:
class PagedKVCache:
    """Simplified PagedAttention-style KV cache."""
    
    def __init__(self, num_blocks, block_size, num_layers, num_heads, head_dim, device='cuda'):
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device
        
        # Block pool
        self.k_blocks = torch.zeros(
            num_blocks, num_layers, block_size, num_heads, head_dim,
            dtype=torch.float16, device=device
        )
        self.v_blocks = torch.zeros(
            num_blocks, num_layers, block_size, num_heads, head_dim,
            dtype=torch.float16, device=device
        )
        
        # Block allocation tracking
        self.free_blocks = list(range(num_blocks))
        self.sequence_blocks = {}  # seq_id -> [block_ids]
        self.sequence_lengths = {}  # seq_id -> length
    
    def allocate_sequence(self, seq_id):
        """Start a new sequence."""
        self.sequence_blocks[seq_id] = []
        self.sequence_lengths[seq_id] = 0
    
    def append_token(self, seq_id, layer_idx, k, v):
        """Append a token to a sequence."""
        seq_len = self.sequence_lengths[seq_id]
        block_idx_in_seq = seq_len // self.block_size
        pos_in_block = seq_len % self.block_size
        
        # Allocate new block if needed
        if pos_in_block == 0:
            if not self.free_blocks:
                raise RuntimeError("Out of KV cache blocks!")
            new_block = self.free_blocks.pop()
            self.sequence_blocks[seq_id].append(new_block)
        
        # Get physical block
        physical_block = self.sequence_blocks[seq_id][block_idx_in_seq]
        
        # Store KV
        self.k_blocks[physical_block, layer_idx, pos_in_block] = k.squeeze().half()
        self.v_blocks[physical_block, layer_idx, pos_in_block] = v.squeeze().half()
        
        if layer_idx == self.num_layers - 1:
            self.sequence_lengths[seq_id] += 1
    
    def free_sequence(self, seq_id):
        """Free all blocks for a sequence."""
        blocks = self.sequence_blocks.pop(seq_id)
        self.free_blocks.extend(blocks)
        del self.sequence_lengths[seq_id]
    
    def memory_utilization(self):
        """Calculate memory utilization percentage."""
        used_blocks = self.num_blocks - len(self.free_blocks)
        return used_blocks / self.num_blocks * 100

# Demo PagedAttention
paged_cache = PagedKVCache(
    num_blocks=100, block_size=16,
    num_layers=4, num_heads=8, head_dim=64,
    device=device
)

# Simulate multiple sequences
for seq_id in range(5):
    paged_cache.allocate_sequence(seq_id)
    # Each sequence has different length
    for _ in range(np.random.randint(10, 50)):
        k = torch.randn(1, 8, 64, device=device)
        v = torch.randn(1, 8, 64, device=device)
        for layer in range(4):
            paged_cache.append_token(seq_id, layer, k, v)

print("PagedAttention Demo:")
for seq_id in paged_cache.sequence_lengths:
    blocks_used = len(paged_cache.sequence_blocks[seq_id])
    print(f"  Seq {seq_id}: {paged_cache.sequence_lengths[seq_id]} tokens, {blocks_used} blocks")
print(f"Memory utilization: {paged_cache.memory_utilization():.1f}%")
print(f"Free blocks: {len(paged_cache.free_blocks)}")

---
## 4. Memory Comparison

Compare different KV cache strategies.

In [None]:
model_config = {'layers': 32, 'heads': 32, 'head_dim': 128}  # Llama-2-7B
batch_size = 32
context_length = 4096

print(f"KV Cache Memory Comparison (Llama-2-7B, batch={batch_size}, ctx={context_length}):")
print("=" * 60)

strategies = [
    ("FP32", 4),
    ("FP16", 2),
    ("FP8", 1),
    ("INT8", 1),
    ("INT4 (experimental)", 0.5),
]

baseline = calculate_kv_cache_size(model_config, context_length, batch_size, 2)

for name, bytes_per_elem in strategies:
    size = calculate_kv_cache_size(model_config, context_length, batch_size, bytes_per_elem)
    reduction = baseline / size if size < baseline else 1
    print(f"  {name:<20}: {size / (1024**3):.2f} GB ({reduction:.1f}x vs FP16)")

---
## Exercises

1. **Prefix Caching**: Implement copy-on-write for shared prefixes between sequences
2. **Sliding Window**: Implement sliding window attention with fixed KV cache size
3. **Speculative Decoding**: Implement KV cache management for speculative decoding

## Key Takeaways

- KV cache dominates memory for long-context LLM inference
- INT8/FP8 quantization halves KV cache memory with minimal quality loss
- PagedAttention enables ~95% memory utilization vs ~50-60% traditional
- Block-based management enables efficient multi-sequence batching