### KV Cache - Version 1 (Simple)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self, n_heads, d_model):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = d_model // n_heads

        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None, use_cache=False):
        batch_size, seq_len, _ = x.shape

        qkv = self.qkv_proj(x).view(batch_size, seq_len, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        past_len = 0
        if kv_cache:
            k_prev, v_prev = kv_cache
            past_len = k_prev.shape[2]
            k = torch.cat([k_prev, k], dim=2)
            v = torch.cat([v_prev, v], dim=2)

        attn = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)

        if seq_len > 1 or past_len > 0:
            past_mask = torch.ones(seq_len, past_len)
            current_mask = torch.tril(torch.ones(seq_len, seq_len))

            mask = torch.cat([past_mask, current_mask], dim=1)
            mask = mask.unsqueeze(0).unsqueeze(0)
            mask = mask == 0

            attn = torch.masked_fill(attn, mask, value=float('-inf'))

        attn = nn.functional.softmax(attn, dim=-1)
        scores = attn @ v

        scores = scores.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.out_proj(scores)

        if use_cache:
            return output, (k, v)
        else:
            return output, None
        

In [3]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.head_dim = d_model // n_heads

        self.attention_layer = MultiHeadAttentionWithCache(n_heads, d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, kv_cache=None, use_cache=False):
        residual = x
        out = self.norm1(x)
        out, new_kv_cache = self.attention_layer(out, kv_cache, use_cache)
        out = residual + self.dropout1(out)

        residual = out
        out = self.norm2(out)
        out = self.mlp(out)
        out = residual + self.dropout2(out)

        return out, new_kv_cache

In [4]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, n_heads, d_model, n_layers, d_ff, max_len=512):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_heads = n_heads
        self.d_model = d_model
        self.n_layers = n_layers
        self.d_ff = d_ff
        self.max_len = max_len
        self.head_dim = d_model // n_heads

        self.positional_encoding = nn.Embedding(max_len, d_model)
        self.embedding = nn.Embedding(vocab_size, d_model)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(n_heads, d_model, d_ff)
            for _ in range(n_layers)
        ])

        self.ln_norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, past_kv_caches=None, use_cache=False):
        batch_size, seq_len = input_ids.shape

        if past_kv_caches is not None and past_kv_caches[0] is not None:
            past_len = past_kv_caches[0][0].shape[2]
        else:
            past_len = 0

        positions = torch.arange(past_len, past_len + seq_len).unsqueeze(0)

        x = self.embedding(input_ids) + self.positional_encoding(positions)

        new_kv_caches = []
        for i, layer in enumerate(self.layers):
            layer_cache = past_kv_caches[i] if past_kv_caches is not None else None

            x, new_cache = layer(x, layer_cache, use_cache)

            if use_cache:
                new_kv_caches.append(new_cache)

        x = self.lm_head(self.ln_norm(x))

        return x, new_kv_caches if use_cache else None

In [5]:
@torch.no_grad()
def generate_with_cache(model, prompt_ids, max_new_tokens, eos_token_id, temperature=1.0):
    """
    Autoregressive generation using KV cache.
    
    Args:
        model: SimpleTransformer model
        prompt_ids: [batch, prompt_len] - input token IDs
        max_new_tokens: number of tokens to generate
        temperature: sampling temperature
        
    Returns:
        generated_ids: [batch, prompt_len + max_new_tokens]
    """
    model.eval()
    batch_size = prompt_ids.shape[0]
    
    generated = prompt_ids.clone()
    
    past_kv_caches = None
    
    for step in range(max_new_tokens):
        if step == 0:
            input_ids = prompt_ids
        else:
            input_ids = generated[:, -1:] 
        
        logits, past_kv_caches = model(
            input_ids, 
            past_kv_caches=past_kv_caches,
            use_cache=True
        )
        
        next_token_logits = logits[:, -1, :] / temperature
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        generated = torch.cat([generated, next_token], dim=1)

        if (next_token == eos_token_id).all():
            break
    
    return generated

### Multi-Layer & Batched KV Cache

In [6]:
from typing import Optional, Tuple, List

In [7]:
class KVCache:
    def __init__(self, key_cache: torch.Tensor, val_cache: torch.Tensor):
        self.key_cache = key_cache
        self.val_cache = val_cache

    @property
    def seq_len(self) -> int:
        """Current sequence length in cache"""
        return self.key_cache.shape[2]
    
    @property
    def batch_size(self) -> int:
        return self.key_cache.shape[0]
    
    @property
    def num_heads(self) -> int:
        return self.key_cache.shape[1]
    
    @property
    def head_dim(self) -> int:
        return self.key_cache.shape[3]

    def update(self, new_key: torch.Tensor, new_val: torch.Tensor):
        """
        Append new K,V to cache.
        
        Args:
            new_key: [batch, num_heads, new_len, head_dim]
            new_value: [batch, num_heads, new_len, head_dim]
        
        Returns:
            Updated KVCache (new object)
        """
        self.key_cache = torch.cat([self.key_cache, new_key], dim=2)
        self.val_cache = torch.cat([self.val_cache, new_val], dim=2)

        return self.key_cache, self.val_cache

    def get(self):
        return self.key_cache, self.val_cache

    def get_memory_usage(self) -> dict:
        """Calculate memory usage in MB"""
        k_bytes = self.key_cache.numel() * self.key_cache.element_size()
        v_bytes = self.value_cache.numel() * self.value_cache.element_size()
        
        return {
            'key_mb': k_bytes / (1024 ** 2),
            'value_mb': v_bytes / (1024 ** 2),
            'total_mb': (k_bytes + v_bytes) / (1024 ** 2)
        }


In [8]:
class MultiLayerKVCache:
    def __init__(self, num_layers: int):
        """
        Args:
            num_layers: number of transformer layers
        """
        self.num_layers = num_layers
        self.caches = [None] * num_layers

    def update(self, layer_idx: int, new_key: torch.Tensor, new_val: torch.Tensor):
        """
        Update cache for a specific layer.
        
        Args:
            layer_idx: which layer to update (0 to num_layers-1)
            new_key: [batch, num_heads, seq_len, head_dim]
            new_value: [batch, num_heads, seq_len, head_dim]
        """
        if self.caches[layer_idx] is None:
            self.caches[layer_idx] = KVCache(new_key, new_val)
        else:
            new_key, new_val = self.caches[layer_idx].update(new_key, new_val)

        return new_key, new_val

    def get(self, layer_idx: int):
        """
        Get cache for a specific layer.
        
        Args:
            layer_idx: which layer's cache to retrieve
        
        Returns:
            KVCache object or None if not initialized
        """
        return self.caches[layer_idx]

    def get_seq_len(self) -> int:
        """
        Get current sequence length (same across all layers).
        
        Returns:
            sequence length, or 0 if no cache exists
        """
        if self.caches[0] is None:
            return 0
        return self.caches[0].seq_len
    
    def get_total_memory_usage(self) -> dict:
        """Calculate total memory across all layers"""
        total_mb = 0
        num_active_layers = 0
        
        for cache in self.caches:
            if cache is not None:
                mem = cache.get_memory_usage()
                total_mb += mem['total_mb']
                num_active_layers += 1
        
        avg_per_layer = total_mb / num_active_layers if num_active_layers > 0 else 0
        
        return {
            'total_mb': total_mb,
            'total_gb': total_mb / 1024,
            'per_layer_mb': avg_per_layer,
            'num_active_layers': num_active_layers
        }
    

In [9]:
class MultiHeadAttentionBatched(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        batch_size, seq_len, d_model = x.shape
        
        QKV = self.qkv_proj(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        QKV = QKV.permute(2, 0, 3, 1, 4)
        Q, K, V = QKV[0], QKV[1], QKV[2]
        
        if kv_cache is not None:
            K, V = kv_cache.update(K, V)
        
        scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = attn_weights @ V
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        output = self.out_proj(attn_output)

        new_cache = None
        if not kv_cache:
            new_cache = KVCache(K, V)
        
        return output, new_cache

In [10]:
def create_causal_mask(seq_len, past_len):
    """
    Create causal attention mask for autoregressive generation.
    
    Args:
        seq_len: length of current sequence
        past_len: length of cached sequence
        device: torch device
        
    Returns:
        mask: [1, 1, seq_len, past_len + seq_len]
              1 = attend, 0 = mask
    """
    past_mask = torch.ones(seq_len, past_len)
    current_mask = torch.tril(torch.ones(seq_len, seq_len))

    mask = torch.cat([past_mask, current_mask], dim=1)
    mask = mask.unsqueeze(0).unsqueeze(0)

    return mask

In [11]:
def create_padding_mask(input_length: torch.Tensor, max_len: int, past_len: int = 0):
    """
    Create padding mask for batched sequences with different lengths.
    
    Args:
        input_lengths: [batch] - actual length of each sequence (without padding)
        max_len: maximum sequence length in batch (without padding)
        past_len: length of cached sequence
        
    Returns:
        mask: [batch, 1, 1, past_len + max_len]
              1 = attend (real token), 0 = mask (padding)
    """
    batch_size = input_length.shape[0]
    total_len = past_len + max_len

    positions = torch.arange(total_len).unsqueeze(0).expand(batch_size, -1)

    pad_amounts = total_len - (input_length + past_len)

    mask = positions >= pad_amounts.unsqueeze(1)

    return mask.unsqueeze(1).unsqueeze(1)

In [12]:
def combine_masks(causal_mask: torch.Tensor, padding_mask: torch.Tensor) -> torch.Tensor:
    """
    Combine causal and padding masks.
    Both must be 1 for attention, 0 for masking.
    """
    # Broadcasting: causal_mask [1, 1, seq_len, total_len]
    #               padding_mask [batch, 1, 1, total_len]
    # Result: [batch, 1, seq_len, total_len]
    return causal_mask & padding_mask

In [13]:
class TransformerLayerWithCache(nn.Module):
    """Transformer layer using new cache structure"""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        self.self_attn = MultiHeadAttentionBatched(d_model, num_heads)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False
    ):
        attn_out, new_cache = self.self_attn(
            self.norm1(x),
            attention_mask=attention_mask,
            kv_cache=kv_cache,
            use_cache=use_cache
        )
        x = x + self.dropout1(attn_out)
        x = x + self.dropout2(self.ffn(self.norm2(x)))
        
        return x, new_cache

In [14]:
class TransformerWithCache(nn.Module):
    """
    Complete Transformer with proper cache management.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len=2048):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        
        self.layers = nn.ModuleList([
            TransformerLayerWithCache(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[MultiLayerKVCache] = None,
        use_cache: bool = False
    ):
        """
        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, 1, seq_len, total_len] or None
            kv_cache: MultiLayerKVCache or None
            use_cache: bool
            
        Returns:
            logits: [batch, seq_len, vocab_size]
            new_cache: MultiLayerKVCache or None
        """
        batch_size, seq_len = input_ids.shape
        
        past_len = kv_cache.get_seq_len() if kv_cache is not None else 0
        
        position_ids = torch.arange(
            past_len, past_len + seq_len,
            device=input_ids.device
        ).unsqueeze(0).expand(batch_size, -1)
        
        x = self.token_embedding(input_ids) + self.pos_embedding(position_ids)
        new_cache = MultiLayerKVCache(self.num_layers) if use_cache else None
        
        for layer_idx, layer in enumerate(self.layers):
            layer_cache = kv_cache.get(layer_idx) if kv_cache is not None else None
            
            x, layer_new_cache = layer(
                x,
                attention_mask=attention_mask,
                kv_cache=layer_cache,
                use_cache=use_cache
            )
            
            if use_cache:
                new_cache.caches[layer_idx] = layer_new_cache
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits, new_cache

In [15]:
@torch.no_grad()
def generate_batched(
    model: TransformerWithCache,
    prompt_ids_list: List[torch.Tensor],
    max_new_tokens: int,
    temperature: float = 1.0,
    pad_token_id: int = 0
) -> List[torch.Tensor]:
    """
    Batched generation with variable-length prompts.
    
    Args:
        model: TransformerWithCache
        prompt_ids_list: list of [1, prompt_len] tensors (different lengths OK!)
        max_new_tokens: how many tokens to generate
        temperature: sampling temperature
        pad_token_id: ID for padding token
        
    Returns:
        list of generated sequences [1, total_len]
    """
    model.eval()
    device = next(model.parameters()).device
    
    batch_size = len(prompt_ids_list)
    
    prompt_lengths = torch.tensor([p.shape[1] for p in prompt_ids_list], device=device)
    max_prompt_len = prompt_lengths.max().item()
    
    print(f"\n{'='*60}")
    print(f"Batched Generation: {batch_size} sequences")
    print(f"Prompt lengths: {prompt_lengths.tolist()}")
    print(f"Max prompt: {max_prompt_len}, Generating: {max_new_tokens}")
    print(f"{'='*60}\n")
    
    padded_prompts = []
    for prompt in prompt_ids_list:
        prompt = prompt.to(device)
        pad_len = max_prompt_len - prompt.shape[1]
        if pad_len > 0:
            padding = torch.full((1, pad_len), pad_token_id, device=device, dtype=prompt.dtype)
            prompt = torch.cat([padding, prompt], dim=1)
        padded_prompts.append(prompt)
    
    input_ids = torch.cat(padded_prompts, dim=0)
    
    finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
    
    kv_cache = None
    
    all_generated = input_ids.clone()
    
    for step in range(max_new_tokens):
        past_len = kv_cache.get_seq_len() if kv_cache is not None else 0
        
        if step == 0:
            current_input = input_ids
            seq_len = max_prompt_len
            
            padding_mask = create_padding_mask(prompt_lengths, max_prompt_len, past_len=0)
            causal_mask = create_causal_mask(seq_len, past_len=0, device=device)
            attention_mask = combine_masks(causal_mask, padding_mask)
            
            print(f"PREFILL: Processing {seq_len} tokens (batch={batch_size})")
        else:
            current_input = all_generated[:, -1:]
            seq_len = 1
            
            total_lengths = prompt_lengths + step 
            padding_mask = create_padding_mask(prompt_lengths, seq_len, past_len=past_len)
            causal_mask = create_causal_mask(seq_len, past_len=past_len, device=device)
            attention_mask = combine_masks(causal_mask, padding_mask)
            
            print(f" DECODE {step}: cache_len={past_len}")
        
        logits, kv_cache = model(
            current_input,
            attention_mask=attention_mask,
            kv_cache=kv_cache,
            use_cache=True
        )
        
        next_token_logits = logits[:, -1, :] / temperature
        
        probs = torch.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1)
        
        next_tokens = torch.where(
            finished.unsqueeze(1),
            torch.tensor(pad_token_id, device=device),
            next_tokens
        )
        
        all_generated = torch.cat([all_generated, next_tokens], dim=1)
        
    print(f"\n Generation complete!")
    print(f"Final cache: {kv_cache}")
    
    return generated_list

#### Memory-Efficient Cache Updates

In [16]:
class SlidingWindowCache(MultiLayerKVCache):
    """
    KV Cache with sliding window - keeps only last N tokens.
    """
    
    def __init__(self, num_layers: int, window_size: int, num_sink_tokens: int = 4):
        super().__init__(num_layers)
        self.window_size = window_size
        self.num_sink_tokens = num_sink_tokens 
    
    def update(self, layer_idx: int, new_key: torch.Tensor, new_value: torch.Tensor):
        """Update with sliding window logic"""
        if self.caches[layer_idx] is None:
            self.caches[layer_idx] = KVCache(key_cache=new_key, value_cache=new_value)
        else:
            cache = self.caches[layer_idx]
            updated_k = torch.cat([cache.key_cache, new_key], dim=2)
            updated_v = torch.cat([cache.value_cache, new_value], dim=2)
            
            current_len = updated_k.shape[2]
            if current_len > self.window_size:
                keep_len = self.window_size - self.num_sink_tokens
                
                sink_k = updated_k[:, :, :self.num_sink_tokens, :]
                sink_v = updated_v[:, :, :self.num_sink_tokens, :]
                
                window_k = updated_k[:, :, -keep_len:, :]
                window_v = updated_v[:, :, -keep_len:, :]
                
                updated_k = torch.cat([sink_k, window_k], dim=2)
                updated_v = torch.cat([sink_v, window_v], dim=2)
            
            self.caches[layer_idx] = KVCache(key_cache=updated_k, value_cache=updated_v)
    
    def __repr__(self):
        base_repr = super().__repr__()
        return f"SlidingWindow{base_repr}, window={self.window_size}, sinks={self.num_sink_tokens}"


#### KV Cache Quantization (INT8/FP8)

In [17]:
class QuantizedKVCache:
    """
    INT8 quantized KV cache with adaptive re-quantization.
    Re-quantizes only when new tokens exceed current scale.
    """
    def __init__(self, key_cache_fp: torch.Tensor, val_cache_fp: torch.Tensor, quantize: bool = True):
        """
        Args:
            key_cache_fp: [batch, num_heads, seq_len, head_dim] in FP32
            val_cache_fp: [batch, num_heads, seq_len, head_dim] in FP32
            quantize: whether to quantize (False for testing)
        """
        self.device = key_cache_fp.device
        self.dtype_original = key_cache_fp.dtype

        if quantize:
            self.key_cache, self.key_scale = self._quantize_to_int8(key_cache_fp)
            self.val_cache, self.val_scale = self._quantize_to_int8(val_cache_fp)
        else:
            self.key_cache = key_cache_fp
            self.val_cache = val_cache_fp
            self.key_scale = None
            self.val_scale = None

        self.is_quantized = quantize
        self.num_requantizations = 0 

    def _quantize_to_int8(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Symmetric per-tensor quantization to INT8.
        
        Args:
            tensor: FP32 tensor to quantize
            
        Returns:
            quantized: INT8 tensor
            scale: FP32 scalar scale factor
        """
        q_max = 127.0
        
        max_abs = tensor.abs().max()
        
        if max_abs == 0:
            scale = torch.tensor(1.0, device=tensor.device, dtype=torch.float32)
        else:
            scale = max_abs / q_max
        
        quantized = torch.clamp(
            torch.round(tensor / scale), 
            -127, 127
        ).to(torch.int8)
        
        return quantized, scale

    def _dequantize_from_int8(self, quantized: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
        """Dequantize INT8 tensor back to FP32."""
        return quantized.to(torch.float32) * scale

    def update(self, new_key: torch.Tensor, new_val: torch.Tensor):
        """
        Append new K,V to cache with adaptive re-quantization.
        
        Re-quantizes entire cache only if new tokens have larger scale.
        """
        if self.is_quantized:
            # Compute what scales the new tokens would need
            _, new_k_scale = self._quantize_to_int8(new_key)
            _, new_v_scale = self._quantize_to_int8(new_val)
            
            # Check if we need to re-quantize
            needs_requant_k = new_k_scale > self.key_scale
            needs_requant_v = new_v_scale > self.val_scale
            
            if needs_requant_k or needs_requant_v:
                # Re-quantize entire cache with new data
                print(f"Re-quantizing cache (K: {needs_requant_k}, V: {needs_requant_v})")
                self.num_requantizations += 1
                
                # Dequantize existing cache
                K_fp = self._dequantize_from_int8(self.key_cache, self.key_scale)
                V_fp = self._dequantize_from_int8(self.val_cache, self.val_scale)
                
                # Concatenate in FP32
                K_fp = torch.cat([K_fp, new_key], dim=2)
                V_fp = torch.cat([V_fp, new_val], dim=2)
                
                # Re-quantize entire cache with new unified scale
                self.key_cache, self.key_scale = self._quantize_to_int8(K_fp)
                self.val_cache, self.val_scale = self._quantize_to_int8(V_fp)
            else:
                # Use existing scale (safe - won't clip)
                q_key = torch.clamp(
                    torch.round(new_key / self.key_scale), 
                    -127, 127
                ).to(torch.int8)
                q_val = torch.clamp(
                    torch.round(new_val / self.val_scale), 
                    -127, 127
                ).to(torch.int8)
                
                self.key_cache = torch.cat([self.key_cache, q_key], dim=2)
                self.val_cache = torch.cat([self.val_cache, q_val], dim=2)
        else:
            self.key_cache = torch.cat([self.key_cache, new_key], dim=2)
            self.val_cache = torch.cat([self.val_cache, new_val], dim=2)
    
    def get_kv_fp(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get K,V in original floating point precision."""
        if self.is_quantized:
            K = self._dequantize_from_int8(self.key_cache, self.key_scale)
            V = self._dequantize_from_int8(self.val_cache, self.val_scale)
        else:
            K = self.key_cache
            V = self.val_cache
        
        return K, V

    def get_memory_usage(self) -> dict:
        """Calculate memory usage"""
        if self.is_quantized:
            k_bytes = self.key_cache.numel() * 1
            v_bytes = self.val_cache.numel() * 1
            scale_bytes = (self.key_scale.numel() * 4 + self.val_scale.numel() * 4)
        else:
            k_bytes = self.key_cache.numel() * self.key_cache.element_size()
            v_bytes = self.val_cache.numel() * self.val_cache.element_size()
            scale_bytes = 0
        
        total_mb = (k_bytes + v_bytes + scale_bytes) / (1024 ** 2)
        
        return {
            'key_mb': k_bytes / (1024 ** 2),
            'value_mb': v_bytes / (1024 ** 2),
            'scale_mb': scale_bytes / (1024 ** 2),
            'total_mb': total_mb,
            'num_requantizations': self.num_requantizations
        }
    
    @property
    def seq_len(self) -> int:
        """Current sequence length in cache"""
        return self.key_cache.shape[2]

In [18]:
class QuantizedKVCachePerToken:
    """
    INT8 quantized KV cache with per-token scales.
    Stores separate scale for each sequence position for maximum accuracy.
    """
    def __init__(self, key_cache_fp: torch.Tensor, val_cache_fp: torch.Tensor, quantize: bool = True):
        """
        Args:
            key_cache_fp: [batch, num_heads, seq_len, head_dim] in FP32
            val_cache_fp: [batch, num_heads, seq_len, head_dim] in FP32
            quantize: whether to quantize (False for testing)
        """
        self.device = key_cache_fp.device
        self.dtype_original = key_cache_fp.dtype
        self.is_quantized = quantize
        
        if quantize:
            # Quantize each token separately and store its scale
            batch, num_heads, seq_len, head_dim = key_cache_fp.shape
            
            self.key_cache = torch.zeros(
                (batch, num_heads, seq_len, head_dim), 
                dtype=torch.int8, 
                device=self.device
            )
            self.val_cache = torch.zeros(
                (batch, num_heads, seq_len, head_dim), 
                dtype=torch.int8, 
                device=self.device
            )
            
            # Store scales per sequence position
            self.key_scales = []  # List of [batch, num_heads, 1, 1] tensors
            self.val_scales = []
            
            # Quantize each position
            for i in range(seq_len):
                k_token = key_cache_fp[:, :, i:i+1, :]
                v_token = val_cache_fp[:, :, i:i+1, :]
                
                q_k, scale_k = self._quantize_to_int8_per_token(k_token)
                q_v, scale_v = self._quantize_to_int8_per_token(v_token)
                
                self.key_cache[:, :, i:i+1, :] = q_k
                self.val_cache[:, :, i:i+1, :] = q_v
                
                self.key_scales.append(scale_k)
                self.val_scales.append(scale_v)
        else:
            self.key_cache = key_cache_fp
            self.val_cache = val_cache_fp
            self.key_scales = None
            self.val_scales = None

    def _quantize_to_int8_per_token(self, token: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Quantize a single token position.
        
        Args:
            token: [batch, num_heads, 1, head_dim]
            
        Returns:
            quantized: INT8 tensor [batch, num_heads, 1, head_dim]
            scale: FP32 tensor [batch, num_heads, 1, 1] - per head scale
        """
        q_max = 127.0
        
        # Compute scale per head: [batch, num_heads, 1, 1]
        max_abs = token.abs().amax(dim=-1, keepdim=True)  # [batch, num_heads, 1, 1]
        
        scale = torch.where(
            max_abs == 0, 
            torch.tensor(1.0, device=token.device, dtype=torch.float32),
            max_abs / q_max
        )
        
        quantized = torch.clamp(
            torch.round(token / scale), 
            -127, 127
        ).to(torch.int8)
        
        return quantized, scale

    def update(self, new_key: torch.Tensor, new_val: torch.Tensor):
        """
        Append new K,V to cache, quantizing each token separately.
        
        Args:
            new_key: [batch, num_heads, seq_len_new, head_dim]
            new_val: [batch, num_heads, seq_len_new, head_dim]
        """
        if self.is_quantized:
            seq_len_new = new_key.shape[2]
            
            # Pre-allocate space for new tokens
            new_k_cache = torch.zeros(
                (new_key.shape[0], new_key.shape[1], seq_len_new, new_key.shape[3]),
                dtype=torch.int8,
                device=self.device
            )
            new_v_cache = torch.zeros_like(new_k_cache)
            
            # Quantize each new token position
            for i in range(seq_len_new):
                k_token = new_key[:, :, i:i+1, :]
                v_token = new_val[:, :, i:i+1, :]
                
                q_k, scale_k = self._quantize_to_int8_per_token(k_token)
                q_v, scale_v = self._quantize_to_int8_per_token(v_token)
                
                new_k_cache[:, :, i:i+1, :] = q_k
                new_v_cache[:, :, i:i+1, :] = q_v
                
                self.key_scales.append(scale_k)
                self.val_scales.append(scale_v)
            
            # Concatenate to existing cache
            self.key_cache = torch.cat([self.key_cache, new_k_cache], dim=2)
            self.val_cache = torch.cat([self.val_cache, new_v_cache], dim=2)
        else:
            self.key_cache = torch.cat([self.key_cache, new_key], dim=2)
            self.val_cache = torch.cat([self.val_cache, new_val], dim=2)
    
    def get_kv_fp(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get K,V in original floating point precision."""
        if self.is_quantized:
            seq_len = self.key_cache.shape[2]
            
            # Dequantize each position with its own scale
            K_parts = []
            V_parts = []
            
            for i in range(seq_len):
                k_token_int8 = self.key_cache[:, :, i:i+1, :]
                v_token_int8 = self.val_cache[:, :, i:i+1, :]
                
                k_scale = self.key_scales[i]
                v_scale = self.val_scales[i]
                
                k_fp = k_token_int8.to(torch.float32) * k_scale
                v_fp = v_token_int8.to(torch.float32) * v_scale
                
                K_parts.append(k_fp)
                V_parts.append(v_fp)
            
            K = torch.cat(K_parts, dim=2)
            V = torch.cat(V_parts, dim=2)
        else:
            K = self.key_cache
            V = self.val_cache
        
        return K, V

    def get_memory_usage(self) -> dict:
        """Calculate memory usage"""
        if self.is_quantized:
            k_bytes = self.key_cache.numel() * 1
            v_bytes = self.val_cache.numel() * 1
            
            # Scales: num_tokens * (batch * num_heads * 1 * 1) * 4 bytes each
            num_scales = len(self.key_scales)
            scale_bytes = num_scales * (self.key_scales[0].numel() * 4 * 2)  # K and V
        else:
            k_bytes = self.key_cache.numel() * self.key_cache.element_size()
            v_bytes = self.val_cache.numel() * self.val_cache.element_size()
            scale_bytes = 0
        
        total_mb = (k_bytes + v_bytes + scale_bytes) / (1024 ** 2)
        
        return {
            'key_mb': k_bytes / (1024 ** 2),
            'value_mb': v_bytes / (1024 ** 2),
            'scale_mb': scale_bytes / (1024 ** 2),
            'total_mb': total_mb,
            'num_scales': len(self.key_scales) if self.is_quantized else 0
        }
    
    @property
    def seq_len(self) -> int:
        """Current sequence length in cache"""
        return self.key_cache.shape[2]

### PagedAttention Concepts (vLLM)

PagedAttention Key Ideas:

1. **Problem with Continuous Cache:**
   - Traditional: KV cache is contiguous tensor
   - Issue: Memory fragmentation, can't easily share cache between sequences
   
2. **PagedAttention Solution:**
   - Split KV cache into fixed-size "pages" (blocks)
   - Like virtual memory in OS!
   - Pages can be non-contiguous in physical memory
   
3. **Benefits:**
   - Near-zero memory fragmentation
   - Easy cache sharing (beam search, parallel sampling)
   - Dynamic memory allocation

Visual: 
1. Traditional:
- [Seq1: KKKKKK...] [Seq2: KKKKKK...] [Wasted Space] [Seq3: KKK...]

2. PagedAttention:
- Physical Memory: [Block0][Block1][Block2][Block3][Block4][Block5]
- Seq1 mapping: Block0 â†’ Block2 â†’ Block4
- Seq2 mapping: Block1 â†’ Block3
- Seq3 mapping: Block5
(No fragmentation!)

In [19]:
class PagedKVCache:
    def __init__(
        self, 
        num_blocks: int,
        block_size: int,
        num_heads: int,
        head_dim: int,
        dtype=torch.float32
    ):
        """
        Args:
            num_blocks: Total number of blocks in memory pool
            block_size: Number of tokens per block
            num_heads: Number of KV heads
            head_dim: Dimension per head
        """
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        self.blocks = torch.zeros(
            num_blocks, 2, block_size, num_heads, head_dim, dtype=dtype
        )

        self.free_blocks = set(range(num_blocks))

        self.seq_to_blocks = {}
        self.seq_lengths = {}

    def allocate_blocks(self, seq_id: int, num_blocks_needed: int) -> bool:
        """
        Allocate blocks for a sequence.
        
        Returns:
            True if successful, False if not enough free blocks
        """
        if len(self.free_blocks) < num_blocks_needed:
            return False

        allocated = []
        for _ in range(num_blocks_needed):
            block_id = self.free_blocks.pop()
            allocated.append(block_id)

        self.seq_to_blocks[seq_id] = allocated
        self.seq_lengths[seq_id] = 0

        return True

    def append_tokens(self, seq_id: int, new_k: torch.Tensor, new_v: torch.Tensor):
        """
        Append new K,V to sequence's cache.
        
        Args:
            seq_id: Sequence ID
            new_k: [num_tokens, num_heads, head_dim]
            new_v: [num_tokens, num_heads, head_dim]
        
        Returns:
            True if successful, False if need more blocks
        """
        if seq_id not in self.seq_to_blocks or seq_id not in self.seq_lengths:
            return True

        num_new_tokens = new_k.shape[0]
        current_len = self.seq_lengths[seq_id]
        new_len = current_len + num_new_tokens

        blocks_needed = (new_len + self.block_size - 1) // self.block_size
        blocks_allocated = len(self.seq_to_blocks[seq_id])

        if blocks_needed > blocks_allocated:
            additional_blocks = blocks_needed - blocks_allocated
            if len(self.free_blocks) < additional_blocks:
                return True

            for _ in range(additional_blocks):
                block_id = self.free_blocks.pop()
                self.seq_to_blocks[seq_id].append(block_id)

        
        for i, (k_token, v_token) in enumerate(zip(new_k, new_v)):
            token_idx = current_len + i
            block_idx = token_idx // self.block_size
            offset = token_idx % self.block_size

            block_id = self.seq_to_blocks[seq_id][block_idx]

            self.blocks[block_id, 0, offset] = k_token
            self.blocks[block_id, 1, offset] = v_token

        self.seq_lengths[seq_id] = new_len
        return True

    def get_kv(self, seq_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Gather K,V for a sequence from its blocks.
        
        Returns:
            K: [seq_len, num_heads, head_dim]
            V: [seq_len, num_heads, head_dim]
        """
        if seq_id not in self.seq_to_blocks:
            raise ValueError(f"Sequence {seq_id} not found")

        seq_len = self.seq_lengths[seq_id]
        block_ids = self.seq_to_blocks[seq_id]

        K_list = []
        V_list = []

        for block_idx, block_id in enumerate(blocks_ids):
            start_token = block_idx * self.block_size
            end_token = min(start_token + self.block_size, seq_len)

            num_tokens = end_token - start_token

            if num_tokens > 0:
                K_list.append(self.blocks[block_id, 0, :num_tokens])
                V_list.append(self.blocks[block_id, 1, :num_tokens])

        K = torch.cat(K_list, dim=0)
        V = torch.cat(V_list, dim=0)

        return K, V

    def free_sequence(self, seq_id: int):
        """Free all blocks associated with a sequence"""
        if seq_id in self.seq_to_blocks:
            for block_id in self.seq_to_blocks[seq_id]:
                self.free_blocks.add(block_id)
            
            del self.seq_to_blocks[seq_id]
            del self.seq_lengths[seq_id]


#### Grouped Query Attention

In [20]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention (GQA) - used in LLaMA-2, Mistral, etc.
    
    Reduces KV cache by sharing K,V across groups of query heads.
    """
    
    def __init__(
        self,
        d_model: int,
        num_query_heads: int,
        num_kv_heads: int,
        dropout: float = 0.0
    ):
        super().__init__()
        
        assert d_model % num_query_heads == 0, "d_model must be divisible by num_query_heads"
        assert num_query_heads % num_kv_heads == 0, "num_query_heads must be divisible by num_kv_heads"
        
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_model // num_query_heads
        
        # Number of query heads per KV head
        self.num_queries_per_kv = num_query_heads // num_kv_heads
        
        # Projections
        self.W_q = nn.Linear(d_model, num_query_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)  # Smaller!
        self.W_v = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)  # Smaller!
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: [batch, seq_len, d_model]
            attention_mask: [batch, 1, seq_len, total_len]
            kv_cache: (K, V) where K,V are [batch, num_kv_heads, past_len, head_dim]
            use_cache: bool
            
        Returns:
            output: [batch, seq_len, d_model]
            new_cache: (K, V) or None
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)  # [batch, seq_len, num_query_heads * head_dim]
        K = self.W_k(x)  # [batch, seq_len, num_kv_heads * head_dim]
        V = self.W_v(x)  # [batch, seq_len, num_kv_heads * head_dim]
        
        # Reshape Q: [batch, num_query_heads, seq_len, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        
        # Reshape K,V: [batch, num_kv_heads, seq_len, head_dim]
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # Handle cache
        if kv_cache is not None:
            K_cache, V_cache = kv_cache
            K = torch.cat([K_cache, K], dim=2)
            V = torch.cat([V_cache, V], dim=2)
        
        # CRITICAL: Repeat K,V to match number of query heads
        # Each KV head is shared across num_queries_per_kv query heads
        # [batch, num_kv_heads, seq_len, head_dim] â†’ [batch, num_query_heads, seq_len, head_dim]
        K_repeated = K.repeat_interleave(self.num_queries_per_kv, dim=1)
        V_repeated = V.repeat_interleave(self.num_queries_per_kv, dim=1)
        
        # Now K_repeated and V_repeated have same num_heads as Q
        # Compute attention
        scores = Q @ K_repeated.transpose(-2, -1) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = attn_weights @ V_repeated
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        output = self.W_o(attn_output)
        
        # Return cache (store original K,V, NOT repeated versions!)
        new_cache = (K, V) if use_cache else None
        
        return output, new_cache

In [21]:
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention (MQA) - extreme version where num_kv_heads = 1.
    Used in PaLM, StarCoder, Falcon.
    """
    
    def __init__(self, d_model: int, num_query_heads: int, dropout: float = 0.0):
        super().__init__()
        
        assert d_model % num_query_heads == 0
        
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.head_dim = d_model // num_query_heads
        
        # Q has full heads, K,V have only 1 head
        self.W_q = nn.Linear(d_model, num_query_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(d_model, self.head_dim, bias=False)  # Single head!
        self.W_v = nn.Linear(d_model, self.head_dim, bias=False)  # Single head!
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            kv_cache: (K, V) where K,V are [batch, 1, past_len, head_dim]
        """
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Q: [batch, num_query_heads, seq_len, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        
        # K,V: [batch, 1, seq_len, head_dim]
        K = K.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        
        # Handle cache
        if kv_cache is not None:
            K_cache, V_cache = kv_cache
            K = torch.cat([K_cache, K], dim=2)
            V = torch.cat([V_cache, V], dim=2)
        
        # Repeat K,V to match Q heads
        K_repeated = K.expand(-1, self.num_query_heads, -1, -1)
        V_repeated = V.expand(-1, self.num_query_heads, -1, -1)
        
        # Attention
        scores = Q @ K_repeated.transpose(-2, -1) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = attn_weights @ V_repeated
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        output = self.W_o(attn_output)
        
        new_cache = (K, V) if use_cache else None
        
        return output, new_cache

### Flash Attention

Flash Attention Key Ideas:

Standard Attention Problem:
1. Compute scores: Q @ K^T (materialize NxN matrix)
2. Softmax: softmax(scores) (materialize NxN matrix)
3. Output: attn @ V

Memory: O(N^2) - stores full attention matrix

Flash Attention Solution:
- Tile Q, K, V
- Compute attention in blocks
- Never materialize full NxN matrix
- Memory: O(N) instead of O(N^2)

With KV Cache:
- Flash Attention + KV cache is THE standard for production
- vLLM, TensorRT-LLM all use this combination

Flash Attention + KV Cache Best Practices:

1. **Always use Flash Attention for long sequences (>512 tokens)**
   - 2-4x speedup on attention computation
   - Enables longer contexts (up to 100k+ tokens)

2. **Combine with GQA for maximum efficiency**
   - GQA reduces cache size
   - Flash Attention reduces compute
   - Together: massive throughput gains

3. **Production Stack:**
   - vLLM: PagedAttention + Flash Attention + GQA + INT8 cache
   - TensorRT-LLM: Flash Attention + GQA + FP8 cache
   - SGLang: RadixAttention (prefix caching) + Flash Attention

4. **When to use what:**
   - Batch=1, latency-critical: Flash Attention + FP16 cache
   - High throughput serving: Flash + PagedAttention + INT8 cache
   - Long context (>8k): Flash + GQA + sliding window

#### Continuous Batching Considerations

In [58]:
"""
Continuous Batching (used in vLLM, TensorRT-LLM):

Traditional Batching:
- Wait for N requests
- Process all together
- All finish at different times â†’ wasted compute

Continuous Batching:
- As soon as one request finishes, add new one
- Dynamically change batch composition
- Maximum GPU utilization

KV Cache Challenges:
1. Variable sequence lengths in batch
2. Sequences finishing at different times
3. Need to efficiently add/remove from batch

Solutions:
1. PagedAttention: Easy to swap sequences in/out
2. Padding + masking: Handle variable lengths
3. Separate prefill and decode batches
"""

class ContinuousBatchManager:
    """
    Manages continuous batching with KV cache.
    Simplified version of what vLLM does.
    """
    
    def __init__(self, model, max_batch_size: int, paged_cache: PagedKVCache):
        self.model = model
        self.max_batch_size = max_batch_size
        self.cache = paged_cache
        
        # Active sequences
        self.active_sequences = {}  # seq_id -> sequence info
        
    def add_sequence(self, seq_id: int, prompt: torch.Tensor):
        """Add new sequence to batch"""
        if len(self.active_sequences) >= self.max_batch_size:
            return False  # Batch full
        
        # Allocate cache blocks
        prompt_len = prompt.shape[0]
        blocks_needed = (prompt_len + self.cache.block_size - 1) // self.cache.block_size
        
        if not self.cache.allocate_blocks(seq_id, blocks_needed):
            return False  # Not enough cache memory
        
        self.active_sequences[seq_id] = {
            'prompt': prompt,
            'generated': [],
            'finished': False
        }
        
        return True
    
    def remove_sequence(self, seq_id: int):
        """Remove finished sequence from batch"""
        if seq_id in self.active_sequences:
            self.cache.free_sequence(seq_id)
            del self.active_sequences[seq_id]
    
    def step(self):
        """
        One generation step for all active sequences.
        
        In practice, this would:
        1. Separate prefill vs decode
        2. Batch compatible sequences together
        3. Use PagedAttention to gather K,V
        """
        # Pseudocode for continuous batching step
        pass

print("""
ðŸ’¡ Continuous Batching + KV Cache:

Key Insight: PagedAttention makes continuous batching efficient!

Without PagedAttention:
- Hard to dynamically resize cache tensors
- Memory fragmentation when swapping sequences
- Copying overhead

With PagedAttention:
- Just update block mappings (O(1) operation)
- No memory copying
- Perfect for continuous batching

This is why vLLM achieves 10-20x higher throughput than HuggingFace!
""")


ðŸ’¡ Continuous Batching + KV Cache:

Key Insight: PagedAttention makes continuous batching efficient!

Without PagedAttention:
- Hard to dynamically resize cache tensors
- Memory fragmentation when swapping sequences
- Copying overhead

With PagedAttention:
- Just update block mappings (O(1) operation)
- No memory copying
- Perfect for continuous batching

This is why vLLM achieves 10-20x higher throughput than HuggingFace!

