In [2]:
!pip install torch

Collecting torch
  Downloading torch-2.9.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Downloading filelock-3.20.2-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Using cached networkx-3.6.1-py3-none-any.whl.metadata (6.8 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec>=0.8.5 (from torch)
  Using cached fsspec-2025.12.0-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (fro

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class MultiHeadAttention(nn.Module):
    """Multi-head attention with optional KV caching."""
    
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(
        self, 
        x: torch.Tensor,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            kv_cache: Optional tuple of (cached_k, cached_v) from previous steps
            use_cache: Whether to return updated cache
            
        Returns:
            output: Attention output [batch_size, seq_len, d_model]
            new_cache: Updated (k, v) cache if use_cache=True, else None
        """
        batch_size, seq_len, _ = x.shape
        
        # TODO: Implement query, key, value projections
        # Hint: Use self.W_q, self.W_k, self.W_v
        q = self.W_q(x)  # [batch_size, seq_len, d_model]
        k = self.W_k(x)  # [batch_size, seq_len, d_model]
        v = self.W_v(x)  # [batch_size, seq_len, d_model]
        
        # TODO: Reshape for multi-head attention
        # Hint: Split d_model into n_heads and d_k
        # Target shape: [batch_size, n_heads, seq_len, d_k]
        q = q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # TODO: Implement KV cache logic
        # If kv_cache is provided, concatenate cached k,v with new k,v
        # This allows us to reuse computations from previous tokens
        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            # IMPLEMENT: Concatenate along sequence dimension
            k = torch.cat([cached_k, k], dim=2)
            v = torch.cat([cached_v, v], dim=2)
        
        # TODO: Compute scaled dot-product attention
        # Hint: scores = (Q @ K^T) / sqrt(d_k), then softmax, then @ V
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)
        
        # TODO: Reshape back and apply output projection
        # Hint: Reverse the multi-head split
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_o(attn_output)
        
        # TODO: Prepare cache for return
        # If use_cache=True, return the full k,v tensors for next iteration
        new_cache = (k, v) if use_cache else None
        
        return output, new_cache


class SimpleTransformerBlock(nn.Module):
    """Single transformer block with KV cache support."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        
    def forward(
        self,
        x: torch.Tensor,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        # Self-attention with residual
        attn_out, new_cache = self.attn(self.ln1(x), kv_cache, use_cache)
        x = x + attn_out
        
        # Feedforward with residual
        x = x + self.ff(self.ln2(x))
        
        return x, new_cache


# ============= TEST AND DEMO CODE =============

def test_kv_cache():
    """Test that KV cache produces identical results to full attention."""
    
    print("=" * 60)
    print("Testing KV Cache Implementation")
    print("=" * 60)
    
    # Setup
    batch_size = 2
    d_model = 64
    n_heads = 4
    seq_len = 10
    
    model = SimpleTransformerBlock(d_model, n_heads, d_ff=128)
    model.eval()
    
    # Generate a full sequence
    full_sequence = torch.randn(batch_size, seq_len, d_model)
    
    # Method 1: Process full sequence at once (standard)
    with torch.no_grad():
        output_full, _ = model(full_sequence, use_cache=False)
    
    # Method 2: Process token-by-token with KV cache (autoregressive)
    outputs_cached = []
    cache = None
    
    with torch.no_grad():
        for i in range(seq_len):
            token = full_sequence[:, i:i+1, :]  # [batch_size, 1, d_model]
            output, cache = model(token, kv_cache=cache, use_cache=True)
            outputs_cached.append(output)
    
    output_cached = torch.cat(outputs_cached, dim=1)
    
    # Compare results
    max_diff = (output_full - output_cached).abs().max().item()
    print(f"\n✓ Full sequence shape: {output_full.shape}")
    print(f"✓ Cached sequence shape: {output_cached.shape}")
    print(f"\nMaximum difference: {max_diff:.2e}")
    
    if max_diff < 1e-5:
        print("✓ SUCCESS! KV cache produces identical results.")
    else:
        print("✗ WARNING: Results differ. Check your implementation.")
    
    # Performance comparison
    print("\n" + "=" * 60)
    print("Performance Analysis")
    print("=" * 60)
    
    import time
    
    # Warmup
    for _ in range(5):
        model(full_sequence[:, :1, :])
    
    # Time full attention
    start = time.time()
    with torch.no_grad():
        for _ in range(100):
            model(full_sequence)
    full_time = time.time() - start
    
    # Time cached attention
    start = time.time()
    with torch.no_grad():
        for _ in range(100):
            cache = None
            for i in range(seq_len):
                _, cache = model(full_sequence[:, i:i+1, :], cache, use_cache=True)
    cached_time = time.time() - start
    
    print(f"\nFull attention (100 runs): {full_time:.3f}s")
    print(f"Cached attention (100 runs): {cached_time:.3f}s")
    print(f"Speedup: {full_time/cached_time:.2f}x")
    
    print("\n" + "=" * 60)
    print("Memory Usage (Cached K/V)")
    print("=" * 60)
    
    if cache is not None:
        cached_k, cached_v = cache
        print(f"\nCached K shape: {cached_k.shape}")
        print(f"Cached V shape: {cached_v.shape}")
        total_elements = cached_k.numel() + cached_v.numel()
        memory_mb = total_elements * 4 / (1024**2)  # Assume float32
        print(f"Total cache memory: ~{memory_mb:.2f} MB")


def demonstrate_autoregressive_generation():
    """Simulate autoregressive generation with KV cache."""
    
    print("\n" + "=" * 60)
    print("Simulating Autoregressive Generation")
    print("=" * 60)
    
    d_model = 64
    model = SimpleTransformerBlock(d_model, n_heads=4, d_ff=128)
    model.eval()
    
    # Start with a "prompt" (e.g., 3 tokens)
    prompt = torch.randn(1, 3, d_model)
    
    print("\n1. Processing prompt (3 tokens)...")
    with torch.no_grad():
        _, cache = model(prompt, use_cache=True)
    
    print(f"   Cache initialized. K/V shape: {cache[0].shape}")
    
    # Generate 5 new tokens autoregressively
    print("\n2. Generating tokens one by one...")
    for i in range(5):
        new_token = torch.randn(1, 1, d_model)
        with torch.no_grad():
            output, cache = model(new_token, kv_cache=cache, use_cache=True)
        print(f"   Token {i+1}: Cache now has {cache[0].shape[2]} positions")
    
    print("\n✓ Generation complete!")
    print("   Each new token only processes 1 position but attends to all previous.")


if __name__ == "__main__":
    # Run the tests
    test_kv_cache()
    demonstrate_autoregressive_generation()
    
    print("\n" + "=" * 60)
    print("EXERCISES TO TRY:")
    print("=" * 60)
    print("1. Add causal masking to prevent attending to future tokens")
    print("2. Implement cache size limits (sliding window)")
    print("3. Add cross-attention with separate KV cache")
    print("4. Visualize attention weights with/without cache")
    print("5. Profile memory usage with longer sequences")
    print("=" * 60)

ModuleNotFoundError: No module named 'torch'