### KV Cache - Version 1 (Simple)

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [12]:
class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self, num_heads, d_model):
        super().__init__()

        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.d_model = d_model

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, kv_cache=None, use_cache=None):
        """
        Args:
            x: [batch, seq_len, d_model] - input tokens
            kv_cache: tuple of (K_cache, V_cache) or None
                K_cache: [batch, num_heads, past_len, head_dim]
                V_cache: [batch, num_heads, past_len, head_dim]
            use_cache: whether to return updated cache
            
        Returns:
            output: [batch, seq_len, d_model]
            new_cache: tuple of (K, V) if use_cache else None
        """

        batch_size, seq_len, d_model = x.shape

        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dimension).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dimension).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dimension).transpose(1, 2)

        if use_cahe:
            K_all, V_all = kv_cache
            K = torch.cat([K_all, K], dim=2)
            V = torch.cat([V_all, V], dim=2)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.head_dim)

        if seq_len > 1:
            causal_mask = torch.tril(
                torch.ones(seq_len, K.shape[2])
            ).unsqueeze(0).unsqueeze(0)
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_output = attn_weights @ V

        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )

        output = self.W_o(attn_output)

        if use_cache:
            return output, (K, V)
        else:
            return output, None

In [13]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, num_heads, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.attn_layer = MultiHeadAttentionWithCache(num_heads, d_model)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GeLU(),
            nn.Linear(d_ff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, kv_cache=None, use_cache=False):
        """
        Args:
            x: [batch, seq_len, d_model]
            kv_cache: tuple of (K_cache, V_cache) or None
            use_cache: bool
            
        Returns:
            output: [batch, seq_len, d_model]
            new_cache: tuple or None
        """
        x_skip = x
        x = self.norm1(x)
        x, new_kv_cache = self.attn_layer(x)
        x = x_skip + self.dropout1(x)


        x_skip = x
        x = self.norm2(x)
        x = self.ffn(x)
        x = x_skip + self.dropout2(x)

        return x, new_kv_cache

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, num_heads, d_model, num_layers, d_ff, max_seq_len=512):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.d_model = d_model
        self.num_layers = num_layers

        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Embedding(max_seq_len, d_model)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(num_heads, d_model, d_ff)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, past_kv_caches=None, use_cache=False):
        """
        Args:
            input_ids: [batch, seq_len] - token indices
            past_kv_caches: list of (K_cache, V_cache) tuples, one per layer
            use_cache: bool
            
        Returns:
            logits: [batch, seq_len, vocab_size]
            new_kv_caches: list of (K, V) tuples if use_cache else None
        """
        batch_size, seq_len = input_ids.shape

        if past_kv_caches is not None and past_kv_caches[0] is not None:
            past_len = past_kv_caches[0][0].shape[2]
        else:
            past_len = 0

        positions = torch.arange(
            past_len, 
            past_len + seq_len
        ).unsqueeze(0)

        x = self.token_embeddings(x) + self.positional_encoding(x)

        new_kv_caches = []
        for i, layer in enumerate(self.layers):
            layer_cache = past_kv_caches[i] if past_kv_caches is not None else None
            
            x, new_cache = layer(x, kv_cache=layer_caches, use_cache=use_cache)

            if use_cache:
                new_kv_caches.append(new_cache)

        x = self.ln_norm(x)
        logits = self.lm_head(x)

        return logits, new_kv_caches if use else None

In [14]:
@torch.no_grad()
def generate_with_cache(model, prompt_ids, max_new_tokens, temperature=1.0):
    """
    Autoregressive generation using KV cache.
    
    Args:
        model: SimpleTransformer model
        prompt_ids: [batch, prompt_len] - input token IDs
        max_new_tokens: number of tokens to generate
        temperature: sampling temperature
        
    Returns:
        generated_ids: [batch, prompt_len + max_new_tokens]
    """
    model.eval()
    batch_size = prompt_ids.shape[0]
    
    generated = prompt_ids.clone()
    
    past_kv_caches = None
    
    for step in range(max_new_tokens):
        if step == 0:
            input_ids = prompt_ids
        else:
            input_ids = generated[:, -1:] 
        
        logits, past_kv_caches = model(
            input_ids, 
            past_kv_caches=past_kv_caches,
            use_cache=True
        )
        
        next_token_logits = logits[:, -1, :] / temperature
        
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        generated = torch.cat([generated, next_token], dim=1)
    
    return generated

### Multi-Layer & Batched KV Cache

In [15]:
from typing import Optional, Tuple, List

In [18]:
class KVCache:
    def __init__(self, key_cache: torch.Tensor, val_cache: torch.Tensor):
        assert key_cache.shape == value_cache.shape, \
            f"K and V cache shapes must match: {key_cache.shape} vs {value_cache.shape}"
        
        self.key_cache = key_cache
        self.value_cache = value_cache

    @property
    def seq_len(self) -> int:
        """Current sequence length in cache"""
        return self.key_cache.shape[2]
    
    @property
    def batch_size(self) -> int:
        return self.key_cache.shape[0]
    
    @property
    def num_heads(self) -> int:
        return self.key_cache.shape[1]
    
    @property
    def head_dim(self) -> int:
        return self.key_cache.shape[3]

    def update(self, new_key: torch.Tensor, new_val: torch.Tensor):
        """
        Append new K,V to cache.
        
        Args:
            new_key: [batch, num_heads, new_len, head_dim]
            new_value: [batch, num_heads, new_len, head_dim]
        
        Returns:
            Updated KVCache (new object)
        """
        self.key_cache = torch.cat([self.key_cache, new_key], dim=2)
        self.val_cache = torch.cat([self.val_cache, val_key], dim=2)

        return self.key_cache, self.val_cache

    def get(self):
        return self.key_cache, self.val_cache

    def get_memory_usage(self) -> dict:
        """Calculate memory usage in MB"""
        k_bytes = self.key_cache.numel() * self.key_cache.element_size()
        v_bytes = self.value_cache.numel() * self.value_cache.element_size()
        
        return {
            'key_mb': k_bytes / (1024 ** 2),
            'value_mb': v_bytes / (1024 ** 2),
            'total_mb': (k_bytes + v_bytes) / (1024 ** 2)
        }


In [19]:
class MultiLayerKVCache:
    def __init__(self, num_layers: int):
        """
        Args:
            num_layers: number of transformer layers
        """
        self.num_layers = num_layers
        self.caches = [None] * num_layers

    def update(self, layer_idx: int, new_key: torch.Tensor, new_val: torch.Tensor):
        """
        Update cache for a specific layer.
        
        Args:
            layer_idx: which layer to update (0 to num_layers-1)
            new_key: [batch, num_heads, seq_len, head_dim]
            new_value: [batch, num_heads, seq_len, head_dim]
        """
        if self.caches[layer_idx] is None:
            self.caches[layer_idx] = KVCache(new_key, new_cache)
        else:
            new_key, new_val = self.caches[layer_idx].update(new_key, new_val)

        return new_key, new_val

    def get(self, layer_idx: int):
        """
        Get cache for a specific layer.
        
        Args:
            layer_idx: which layer's cache to retrieve
        
        Returns:
            KVCache object or None if not initialized
        """
        return self.caches[layer_idx]

    def get_seq_len(self) -> int:
        """
        Get current sequence length (same across all layers).
        
        Returns:
            sequence length, or 0 if no cache exists
        """
        if self.caches[0] is None:
            return 0
        return self.caches[0].seq_len
    
    def get_total_memory_usage(self) -> dict:
        """Calculate total memory across all layers"""
        total_mb = 0
        num_active_layers = 0
        
        for cache in self.caches:
            if cache is not None:
                mem = cache.get_memory_usage()
                total_mb += mem['total_mb']
                num_active_layers += 1
        
        avg_per_layer = total_mb / num_active_layers if num_active_layers > 0 else 0
        
        return {
            'total_mb': total_mb,
            'total_gb': total_mb / 1024,
            'per_layer_mb': avg_per_layer,
            'num_active_layers': num_active_layers
        }
    

In [20]:
class MultiHeadAttentionBatched(nn.Module):
    """
    Multi-Head Attention with proper batching support.
    Handles variable-length sequences with padding.
    """
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        """
        Args:
            x: [batch, seq_len, d_model]
            attention_mask: [batch, 1, seq_len, total_len] 
                - 1 for positions to attend to, 0 for masked positions
                - total_len = past_len + seq_len
            kv_cache: KVCache object or None
            use_cache: whether to return updated cache
            
        Returns:
            output: [batch, seq_len, d_model]
            new_cache: KVCache or None
        """
        batch_size, seq_len, d_model = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        if kv_cache is not None:
            K, V = kv_cache.update(K, V)
        
        scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = attn_weights @ V
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        output = self.W_o(attn_output)

        new_cache = None
        if not kv_cache:
            new_cache = KVCache(K, V)
        
        return output, new_cache

In [None]:
def create_causal_mask(self, seq_len, past_len):
    """
    Create causal attention mask for autoregressive generation.
    
    Args:
        seq_len: length of current sequence
        past_len: length of cached sequence
        device: torch device
        
    Returns:
        mask: [1, 1, seq_len, past_len + seq_len]
              1 = attend, 0 = mask
    """
    total_len = past_len + seq_len

    mask = torch.ones(seq_len, total_len)

    if seq_len > 1:
        casual_mask = torch.tril(
            torch.ones(seq_len, seq_len)
        )
        mask[:,past_len:] = causal_mask

    return mask.unsqueeze(0).unsqueeze(0)

In [37]:
def create_padding_mask(input_length: torch.Tensor, max_len: int, past_len: int = 0):
    """
    Create padding mask for batched sequences with different lengths.
    
    Args:
        input_lengths: [batch] - actual length of each sequence (without padding)
        max_len: maximum sequence length in batch (without padding)
        past_len: length of cached sequence
        
    Returns:
        mask: [batch, 1, 1, past_len + max_len]
              1 = attend (real token), 0 = mask (padding)
    """

    batch_size = input_length.shape[0]
    total_length = past_len + max_len

    positions = torch.arange(total_length).unsqueenze(0).expand(batch_size, -1)

    # left padding
    pad_amounts = total_length - (input_length + past_len)
    mask = positions >= pad_amounts.unsqueeze(1)

    return mask.unsqueeze(1).unsqueeze(1)

In [38]:
def combine_masks(causal_mask: torch.Tensor, padding_mask: torch.Tensor) -> torch.Tensor:
    """
    Combine causal and padding masks.
    Both must be 1 for attention, 0 for masking.
    """
    # Broadcasting: causal_mask [1, 1, seq_len, total_len]
    #               padding_mask [batch, 1, 1, total_len]
    # Result: [batch, 1, seq_len, total_len]
    return causal_mask & padding_mask

In [39]:
class TransformerLayerWithCache(nn.Module):
    """Transformer layer using new cache structure"""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        self.self_attn = MultiHeadAttentionBatched(d_model, num_heads)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False
    ):
        # Self-attention
        attn_out, new_cache = self.self_attn(
            self.norm1(x),
            attention_mask=attention_mask,
            kv_cache=kv_cache,
            use_cache=use_cache
        )
        x = x + self.dropout1(attn_out)
        
        # FFN
        x = x + self.dropout2(self.ffn(self.norm2(x)))
        
        return x, new_cache

In [40]:
class TransformerWithCache(nn.Module):
    """
    Complete Transformer with proper cache management.
    """
    
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len=2048):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        
        self.layers = nn.ModuleList([
            TransformerLayerWithCache(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[MultiLayerKVCache] = None,
        use_cache: bool = False
    ):
        """
        Args:
            input_ids: [batch, seq_len]
            attention_mask: [batch, 1, seq_len, total_len] or None
            kv_cache: MultiLayerKVCache or None
            use_cache: bool
            
        Returns:
            logits: [batch, seq_len, vocab_size]
            new_cache: MultiLayerKVCache or None
        """
        batch_size, seq_len = input_ids.shape
        
        # Get past length from cache
        past_len = kv_cache.get_seq_len() if kv_cache is not None else 0
        
        # Position IDs
        position_ids = torch.arange(
            past_len, past_len + seq_len,
            device=input_ids.device
        ).unsqueeze(0).expand(batch_size, -1)
        
        # Embeddings
        x = self.token_embedding(input_ids) + self.pos_embedding(position_ids)
        
        # Initialize new cache if needed
        new_cache = MultiLayerKVCache(self.num_layers) if use_cache else None
        
        # Pass through layers
        for layer_idx, layer in enumerate(self.layers):
            layer_cache = kv_cache.get(layer_idx) if kv_cache is not None else None
            
            x, layer_new_cache = layer(
                x,
                attention_mask=attention_mask,
                kv_cache=layer_cache,
                use_cache=use_cache
            )
            
            if use_cache:
                new_cache.caches[layer_idx] = layer_new_cache
        
        # Output
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        return logits, new_cache

In [None]:
@torch.no_grad()
def generate_batched(
    model: TransformerWithCache,
    prompt_ids_list: List[torch.Tensor],
    max_new_tokens: int,
    temperature: float = 1.0,
    pad_token_id: int = 0
) -> List[torch.Tensor]:
    """
    Batched generation with variable-length prompts.
    
    Args:
        model: TransformerWithCache
        prompt_ids_list: list of [1, prompt_len] tensors (different lengths OK!)
        max_new_tokens: how many tokens to generate
        temperature: sampling temperature
        pad_token_id: ID for padding token
        
    Returns:
        list of generated sequences [1, total_len]
    """
    model.eval()
    device = next(model.parameters()).device
    
    batch_size = len(prompt_ids_list)
    
    prompt_lengths = torch.tensor([p.shape[1] for p in prompt_ids_list], device=device)
    max_prompt_len = prompt_lengths.max().item()
    
    print(f"\n{'='*60}")
    print(f"Batched Generation: {batch_size} sequences")
    print(f"Prompt lengths: {prompt_lengths.tolist()}")
    print(f"Max prompt: {max_prompt_len}, Generating: {max_new_tokens}")
    print(f"{'='*60}\n")
    
    padded_prompts = []
    for prompt in prompt_ids_list:
        prompt = prompt.to(device)
        pad_len = max_prompt_len - prompt.shape[1]
        if pad_len > 0:
            padding = torch.full((1, pad_len), pad_token_id, device=device, dtype=prompt.dtype)
            prompt = torch.cat([padding, prompt], dim=1)
        padded_prompts.append(prompt)
    
    input_ids = torch.cat(padded_prompts, dim=0)
    
    finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
    
    kv_cache = None
    
    all_generated = input_ids.clone()
    
    for step in range(max_new_tokens):
        past_len = kv_cache.get_seq_len() if kv_cache is not None else 0
        
        if step == 0:
            current_input = input_ids
            seq_len = max_prompt_len
            
            padding_mask = create_padding_mask(prompt_lengths, max_prompt_len, past_len=0)
            causal_mask = create_causal_mask(seq_len, past_len=0, device=device)
            attention_mask = combine_masks(causal_mask, padding_mask)
            
            print(f"PREFILL: Processing {seq_len} tokens (batch={batch_size})")
        else:
            current_input = all_generated[:, -1:]
            seq_len = 1
            
            total_lengths = prompt_lengths + step 
            padding_mask = create_padding_mask(prompt_lengths, seq_len, past_len=past_len)
            causal_mask = create_causal_mask(seq_len, past_len=past_len, device=device)
            attention_mask = combine_masks(causal_mask, padding_mask)
            
            print(f" DECODE {step}: cache_len={past_len}")
        
        logits, kv_cache = model(
            current_input,
            attention_mask=attention_mask,
            kv_cache=kv_cache,
            use_cache=True
        )
        
        next_token_logits = logits[:, -1, :] / temperature
        
        probs = torch.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1)
        
        next_tokens = torch.where(
            finished.unsqueeze(1),
            torch.tensor(pad_token_id, device=device),
            next_tokens
        )
        
        all_generated = torch.cat([all_generated, next_tokens], dim=1)
        
    print(f"\n Generation complete!")
    print(f"Final cache: {kv_cache}")
    
    return generated_list

#### Memory-Efficient Cache Updates

In [41]:
class SlidingWindowCache(MultiLayerKVCache):
    """
    KV Cache with sliding window - keeps only last N tokens.
    """
    
    def __init__(self, num_layers: int, window_size: int, num_sink_tokens: int = 4):
        super().__init__(num_layers)
        self.window_size = window_size
        self.num_sink_tokens = num_sink_tokens  # Keep first N tokens (attention sinks)
    
    def update(self, layer_idx: int, new_key: torch.Tensor, new_value: torch.Tensor):
        """Update with sliding window logic"""
        if self.caches[layer_idx] is None:
            # First update
            self.caches[layer_idx] = KVCache(key_cache=new_key, value_cache=new_value)
        else:
            # Concatenate new K,V
            cache = self.caches[layer_idx]
            updated_k = torch.cat([cache.key_cache, new_key], dim=2)
            updated_v = torch.cat([cache.value_cache, new_value], dim=2)
            
            # Apply sliding window if exceeded
            current_len = updated_k.shape[2]
            if current_len > self.window_size:
                # Keep: [first num_sink_tokens] + [last (window_size - num_sink_tokens)]
                keep_len = self.window_size - self.num_sink_tokens
                
                sink_k = updated_k[:, :, :self.num_sink_tokens, :]
                sink_v = updated_v[:, :, :self.num_sink_tokens, :]
                
                window_k = updated_k[:, :, -keep_len:, :]
                window_v = updated_v[:, :, -keep_len:, :]
                
                updated_k = torch.cat([sink_k, window_k], dim=2)
                updated_v = torch.cat([sink_v, window_v], dim=2)
            
            self.caches[layer_idx] = KVCache(key_cache=updated_k, value_cache=updated_v)
    
    def __repr__(self):
        base_repr = super().__repr__()
        return f"SlidingWindow{base_repr}, window={self.window_size}, sinks={self.num_sink_tokens}"


#### KV Cache Quantization (INT8/FP8)

In [54]:
class QuantizedKVCache:
    """
    INT8 quantized KV cache for memory efficiency.
    
    Stores K,V in INT8, maintains FP32 scales for dequantization.
    """
    def __init__(self, key_cache_fp: torch.Tensor, val_cache_fp: torch.Tensor, quantize: bool = True):
        """ Args:
            key_cache_fp: [batch, num_heads, seq_len, head_dim] in FP32
            value_cache_fp: [batch, num_heads, seq_len, head_dim] in FP32
            quantize: whether to quantize (False for testing)
        """

        self.shape = key_cache_fp.shape
        self.dtype_original = key_cache_fp.device

        if quantize:
            self.key_cache, self.key_scale, self.key_zero_point = self._quantize_to_int8(key_cache_fp)
            self.val_cache, self.val_scale, self.val_zero_point = self._quantize_to_int8(val_cache_fp)
        else:
            self.key_cache = key_cache_fp
            self.val_cache = val_cache_fp
            self.key_scale, self.val_scale = None, None
            self.key_zero_point, self.val_zero_point = None, None

        self.is_quantized = quantize

    def _quantize_to_int8(self, tensor):
        q_max = 2 ** (8 - 1) - 1
        q_min = - q_max

        max_val = tensor.abs().max()
        min_val = tensor.abs().min()

        scale = (max_val - min_val) / (q_max - q_min)
        if scale == 0.0:
            scale = 1.0

        zero_point = q_min - torch.round(min_val / scale)

        quantized = torch.clamp(
            torch.round(tensor / scale) + zero_point, 
            q_min, q_max
        ).to(torch.int8)

        return quantized, scale, zero_point

    def _dequantize_from_int8(self, quantized, scale, zero_point):
        dq = (quantized - zero_point) * scale
        return dq.to(torch.float32)

    def update(self, new_key: torch.Tensor, new_val: torch.Tensor):
        if self.is_quantized:
            q_key, q_val = self._quantize_to_int8(new_key), self._quantize_to_int8(new_val)
        else:
            q_key, q_val = new_key, new_val

        self.key_cache = torch.cat([self.key_cache, q_key], dim=2)
        self.val_cache = torch.cat([self.val_cache, v_key], dim=2)

    
    def get_kv_fp(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get K,V in original floating point precision.
        
        Returns:
            K: [batch, num_heads, seq_len, head_dim]
            V: [batch, num_heads, seq_len, head_dim]
        """
        if self.is_quantized:
            K = self._dequantize_from_int8(self.key_cache, self.key_scale, self.key_zero_point)
            V = self._dequantize_from_int8(self.val_cache, self.val_scale, self.val_zero_point)
        else:
            K = self.key_cache
            V = self.val_cache
        
        return K, V

    def get_memory_usage(self) -> dict:
        """Calculate memory usage"""
        if self.is_quantized:
            # INT8 cache
            k_bytes = self.key_cache.numel() * 1  # 1 byte per INT8
            v_bytes = self.val_cache.numel() * 1
            scale_bytes = self.key_scale.numel() * 4 + self.val_scale.numel() * 4  # FP32
        else:
            # Original precision
            k_bytes = self.key_cache.numel() * self.key_cache.element_size()
            v_bytes = self.val_cache.numel() * self.val_cache.element_size()
            scale_bytes = 0
        
        total_mb = (k_bytes + v_bytes + scale_bytes) / (1024 ** 2)
        
        return {
            'key_mb': k_bytes / (1024 ** 2),
            'value_mb': v_bytes / (1024 ** 2),
            'scale_mb': scale_bytes / (1024 ** 2),
            'total_mb': total_mb
        }
    
    @property
    def seq_len(self) -> int:
        return self.shape[2]

In [55]:
def test_kv_quantization():
    print("\n" + "="*70)
    print("Testing KV Cache Quantization (INT8)")
    print("="*70 + "\n")
    
    batch, num_heads, seq_len, head_dim = 1, 32, 2048, 128
    
    # Create random K,V in FP16
    K_fp16 = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16)
    V_fp16 = torch.randn(batch, num_heads, seq_len, head_dim, dtype=torch.float16)
    
    # Create caches
    cache_fp16 = QuantizedKVCache(K_fp16, V_fp16, quantize=False)
    cache_int8 = QuantizedKVCache(K_fp16, V_fp16, quantize=True)
    
    # Compare memory
    mem_fp16 = cache_fp16.get_memory_usage()
    mem_int8 = cache_int8.get_memory_usage()
    
    print(f"FP16 Cache: {mem_fp16['total_mb']:.2f} MB")
    print(f"INT8 Cache: {mem_int8['total_mb']:.2f} MB")
    print(f"Memory Savings: {mem_fp16['total_mb'] / mem_int8['total_mb']:.2f}x\n")
    
    # Check quality: dequantize and compare
    K_dequant, V_dequant = cache_int8.get_kv_fp()
    
    k_error = (K_fp16 - K_dequant).abs().mean()
    v_error = (V_fp16 - V_dequant).abs().mean()
    
    print(f"Quantization Error:")
    print(f"  K mean absolute error: {k_error:.6f}")
    print(f"  V mean absolute error: {v_error:.6f}")
    
    # Relative error
    k_rel_error = ((K_fp16 - K_dequant).abs() / (K_fp16.abs() + 1e-6)).mean()
    v_rel_error = ((V_fp16 - V_dequant).abs() / (V_fp16.abs() + 1e-6)).mean()
    
    print(f"  K relative error: {k_rel_error:.4f} ({k_rel_error*100:.2f}%)")
    print(f"  V relative error: {v_rel_error:.4f} ({v_rel_error*100:.2f}%)")
    
    print(f"\nðŸ’¡ Typical perplexity impact: <1% increase")
    print(f"   Memory savings enable 2x larger batch size!\n")

test_kv_quantization()


Testing KV Cache Quantization (INT8)

FP16 Cache: 32.00 MB
INT8 Cache: 16.00 MB
Memory Savings: 2.00x

Quantization Error:
  K mean absolute error: 0.400711
  V mean absolute error: 0.400150
  K relative error: 0.5102 (51.02%)
  V relative error: 0.5100 (51.00%)

ðŸ’¡ Typical perplexity impact: <1% increase
   Memory savings enable 2x larger batch size!



### PagedAttention Concepts (vLLM)

PagedAttention Key Ideas:

1. **Problem with Continuous Cache:**
   - Traditional: KV cache is contiguous tensor
   - Issue: Memory fragmentation, can't easily share cache between sequences
   
2. **PagedAttention Solution:**
   - Split KV cache into fixed-size "pages" (blocks)
   - Like virtual memory in OS!
   - Pages can be non-contiguous in physical memory
   
3. **Benefits:**
   - Near-zero memory fragmentation
   - Easy cache sharing (beam search, parallel sampling)
   - Dynamic memory allocation

Visual: 
1. Traditional:
- [Seq1: KKKKKK...] [Seq2: KKKKKK...] [Wasted Space] [Seq3: KKK...]

2. PagedAttention:
- Physical Memory: [Block0][Block1][Block2][Block3][Block4][Block5]
- Seq1 mapping: Block0 â†’ Block2 â†’ Block4
- Seq2 mapping: Block1 â†’ Block3
- Seq3 mapping: Block5
(No fragmentation!)

In [56]:
class PagedKVCache:
    def __init__(
        self, 
        num_blocks: int,
        block_size: int,
        num_heads: int,
        head_dim: int,
        dtype: torch.float32
    ):
        """
        Args:
            num_blocks: Total number of blocks in memory pool
            block_size: Number of tokens per block
            num_heads: Number of KV heads
            head_dim: Dimension per head
        """
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        self.blocks = torch.zeros(
            num_blocks, 2, block_size, num_heads, head_dim, dtype=dtype
        )

        self.free_blocks = set(range(num_blocks))

        self.seq_to_blocks = {}
        self.seq_lengths = {}

    def allocate_blocks(self, seq_id: int, num_blocks_needed: int) -> bool:
        """
        Allocate blocks for a sequence.
        
        Returns:
            True if successful, False if not enough free blocks
        """
        if len(self.free_blocks) < num_blocks_needed:
            return False

        allocated = []
        for _ in range(num_blocks_needed):
            block_id = self.free_blocks.pop()
            allocated.append(block_id)

        self.seq_to_blocks[seq_id] = allocated
        self.seq_lengths[seq_id] = 0

        return True

    def append_tokens(self, seq_id: int, new_k: torch.Tensor, new_v: torch.Tensor):
        """
        Append new K,V to sequence's cache.
        
        Args:
            seq_id: Sequence ID
            new_k: [num_tokens, num_heads, head_dim]
            new_v: [num_tokens, num_heads, head_dim]
        
        Returns:
            True if successful, False if need more blocks
        """
        if seq_id not in self.seq_to_blocks or seq_id not in self.seq_lengths:
            return True

        num_new_tokens = new_k.shape[0]
        current_len = self.seq_lengths[seq_id]
        new_len = current_len + num_new_tokens

        blocks_needed = (new_len + self.block_size - 1) // self.block_size
        blocks_allocated = len(self.seq_to_blocks[seq_id])

        if blocks_needed > blocks_allocated:
            additional_blocks = blocks_needed - blocks_allocated
            if len(self.free_blocks) < additional_blocks:
                return True

            for _ in range(additional_blocks):
                block_id = self.free_blocks.pop()
                self.seq_to_blocks[seq_id].append(block_id)

        
        for i, (k_token, v_token) in enumerate(zip(new_k, new_v)):
            token_idx = current_len + i
            block_idx = token_idx // self.block_size
            offset = token_idx % self.block_size

            block_id = self.seq_to_blocks[seq_id][block_idx]

            self.blocks[block_id, 0, offset] = k_token
            self.blocks[block_id, 1, offset] = v_token

        self.seq_lengths[seq_id] = new_len
        return True

    def get_kv(self, seq_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Gather K,V for a sequence from its blocks.
        
        Returns:
            K: [seq_len, num_heads, head_dim]
            V: [seq_len, num_heads, head_dim]
        """
        if seq_id not in self.seq_to_blocks:
            raise ValueError(f"Sequence {seq_id} not found")

        seq_len = self.seq_lengths[seq_id]
        block_ids = self.seq_to_blocks[seq_id]

        K_list = []
        V_list = []

        for block_idx, block_id in enumerate(blocks_ids):
            start_token = block_idx * self.block_size
            end_token = min(start_token + self.block_size, seq_len)

            num_tokens = end_token - start_token

            if num_tokens > 0:
                K_list.append(self.blocks[block_id, 0, :num_tokens])
                V_list.append(self.blocks[block_id, 1, :num_tokens])

        K = torch.cat(K_list, dim=0)
        V = torch.cat(V_list, dim=0)

        return K, V

    def free_sequence(self, seq_id: int):
        """Free all blocks associated with a sequence"""
        if seq_id in self.seq_to_blocks:
            for block_id in self.seq_to_blocks[seq_id]:
                self.free_blocks.add(block_id)
            
            del self.seq_to_blocks[seq_id]
            del self.seq_lengths[seq_id]


#### Grouped Query Attention

In [47]:
class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention (GQA) - used in LLaMA-2, Mistral, etc.
    
    Reduces KV cache by sharing K,V across groups of query heads.
    """
    
    def __init__(
        self,
        d_model: int,
        num_query_heads: int,
        num_kv_heads: int,
        dropout: float = 0.0
    ):
        super().__init__()
        
        assert d_model % num_query_heads == 0, "d_model must be divisible by num_query_heads"
        assert num_query_heads % num_kv_heads == 0, "num_query_heads must be divisible by num_kv_heads"
        
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_model // num_query_heads
        
        # Number of query heads per KV head
        self.num_queries_per_kv = num_query_heads // num_kv_heads
        
        # Projections
        self.W_q = nn.Linear(d_model, num_query_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)  # Smaller!
        self.W_v = nn.Linear(d_model, num_kv_heads * self.head_dim, bias=False)  # Smaller!
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            x: [batch, seq_len, d_model]
            attention_mask: [batch, 1, seq_len, total_len]
            kv_cache: (K, V) where K,V are [batch, num_kv_heads, past_len, head_dim]
            use_cache: bool
            
        Returns:
            output: [batch, seq_len, d_model]
            new_cache: (K, V) or None
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)  # [batch, seq_len, num_query_heads * head_dim]
        K = self.W_k(x)  # [batch, seq_len, num_kv_heads * head_dim]
        V = self.W_v(x)  # [batch, seq_len, num_kv_heads * head_dim]
        
        # Reshape Q: [batch, num_query_heads, seq_len, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        
        # Reshape K,V: [batch, num_kv_heads, seq_len, head_dim]
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # Handle cache
        if kv_cache is not None:
            K_cache, V_cache = kv_cache
            K = torch.cat([K_cache, K], dim=2)
            V = torch.cat([V_cache, V], dim=2)
        
        # CRITICAL: Repeat K,V to match number of query heads
        # Each KV head is shared across num_queries_per_kv query heads
        # [batch, num_kv_heads, seq_len, head_dim] â†’ [batch, num_query_heads, seq_len, head_dim]
        K_repeated = K.repeat_interleave(self.num_queries_per_kv, dim=1)
        V_repeated = V.repeat_interleave(self.num_queries_per_kv, dim=1)
        
        # Now K_repeated and V_repeated have same num_heads as Q
        # Compute attention
        scores = Q @ K_repeated.transpose(-2, -1) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = attn_weights @ V_repeated
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        output = self.W_o(attn_output)
        
        # Return cache (store original K,V, NOT repeated versions!)
        new_cache = (K, V) if use_cache else None
        
        return output, new_cache

In [48]:
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention (MQA) - extreme version where num_kv_heads = 1.
    Used in PaLM, StarCoder, Falcon.
    """
    
    def __init__(self, d_model: int, num_query_heads: int, dropout: float = 0.0):
        super().__init__()
        
        assert d_model % num_query_heads == 0
        
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.head_dim = d_model // num_query_heads
        
        # Q has full heads, K,V have only 1 head
        self.W_q = nn.Linear(d_model, num_query_heads * self.head_dim, bias=False)
        self.W_k = nn.Linear(d_model, self.head_dim, bias=False)  # Single head!
        self.W_v = nn.Linear(d_model, self.head_dim, bias=False)  # Single head!
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Args:
            kv_cache: (K, V) where K,V are [batch, 1, past_len, head_dim]
        """
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Q: [batch, num_query_heads, seq_len, head_dim]
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.head_dim).transpose(1, 2)
        
        # K,V: [batch, 1, seq_len, head_dim]
        K = K.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        
        # Handle cache
        if kv_cache is not None:
            K_cache, V_cache = kv_cache
            K = torch.cat([K_cache, K], dim=2)
            V = torch.cat([V_cache, V], dim=2)
        
        # Repeat K,V to match Q heads
        K_repeated = K.expand(-1, self.num_query_heads, -1, -1)
        V_repeated = V.expand(-1, self.num_query_heads, -1, -1)
        
        # Attention
        scores = Q @ K_repeated.transpose(-2, -1) / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = attn_weights @ V_repeated
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        output = self.W_o(attn_output)
        
        new_cache = (K, V) if use_cache else None
        
        return output, new_cache

### Flash Attention

Flash Attention Key Ideas:

Standard Attention Problem:
1. Compute scores: Q @ K^T (materialize NxN matrix)
2. Softmax: softmax(scores) (materialize NxN matrix)
3. Output: attn @ V

Memory: O(N^2) - stores full attention matrix

Flash Attention Solution:
- Tile Q, K, V
- Compute attention in blocks
- Never materialize full NxN matrix
- Memory: O(N) instead of O(N^2)

With KV Cache:
- Flash Attention + KV cache is THE standard for production
- vLLM, TensorRT-LLM all use this combination

Flash Attention + KV Cache Best Practices:

1. **Always use Flash Attention for long sequences (>512 tokens)**
   - 2-4x speedup on attention computation
   - Enables longer contexts (up to 100k+ tokens)

2. **Combine with GQA for maximum efficiency**
   - GQA reduces cache size
   - Flash Attention reduces compute
   - Together: massive throughput gains

3. **Production Stack:**
   - vLLM: PagedAttention + Flash Attention + GQA + INT8 cache
   - TensorRT-LLM: Flash Attention + GQA + FP8 cache
   - SGLang: RadixAttention (prefix caching) + Flash Attention

4. **When to use what:**
   - Batch=1, latency-critical: Flash Attention + FP16 cache
   - High throughput serving: Flash + PagedAttention + INT8 cache
   - Long context (>8k): Flash + GQA + sliding window

#### Continuous Batching Considerations

In [58]:
"""
Continuous Batching (used in vLLM, TensorRT-LLM):

Traditional Batching:
- Wait for N requests
- Process all together
- All finish at different times â†’ wasted compute

Continuous Batching:
- As soon as one request finishes, add new one
- Dynamically change batch composition
- Maximum GPU utilization

KV Cache Challenges:
1. Variable sequence lengths in batch
2. Sequences finishing at different times
3. Need to efficiently add/remove from batch

Solutions:
1. PagedAttention: Easy to swap sequences in/out
2. Padding + masking: Handle variable lengths
3. Separate prefill and decode batches
"""

class ContinuousBatchManager:
    """
    Manages continuous batching with KV cache.
    Simplified version of what vLLM does.
    """
    
    def __init__(self, model, max_batch_size: int, paged_cache: PagedKVCache):
        self.model = model
        self.max_batch_size = max_batch_size
        self.cache = paged_cache
        
        # Active sequences
        self.active_sequences = {}  # seq_id -> sequence info
        
    def add_sequence(self, seq_id: int, prompt: torch.Tensor):
        """Add new sequence to batch"""
        if len(self.active_sequences) >= self.max_batch_size:
            return False  # Batch full
        
        # Allocate cache blocks
        prompt_len = prompt.shape[0]
        blocks_needed = (prompt_len + self.cache.block_size - 1) // self.cache.block_size
        
        if not self.cache.allocate_blocks(seq_id, blocks_needed):
            return False  # Not enough cache memory
        
        self.active_sequences[seq_id] = {
            'prompt': prompt,
            'generated': [],
            'finished': False
        }
        
        return True
    
    def remove_sequence(self, seq_id: int):
        """Remove finished sequence from batch"""
        if seq_id in self.active_sequences:
            self.cache.free_sequence(seq_id)
            del self.active_sequences[seq_id]
    
    def step(self):
        """
        One generation step for all active sequences.
        
        In practice, this would:
        1. Separate prefill vs decode
        2. Batch compatible sequences together
        3. Use PagedAttention to gather K,V
        """
        # Pseudocode for continuous batching step
        pass

print("""
ðŸ’¡ Continuous Batching + KV Cache:

Key Insight: PagedAttention makes continuous batching efficient!

Without PagedAttention:
- Hard to dynamically resize cache tensors
- Memory fragmentation when swapping sequences
- Copying overhead

With PagedAttention:
- Just update block mappings (O(1) operation)
- No memory copying
- Perfect for continuous batching

This is why vLLM achieves 10-20x higher throughput than HuggingFace!
""")


ðŸ’¡ Continuous Batching + KV Cache:

Key Insight: PagedAttention makes continuous batching efficient!

Without PagedAttention:
- Hard to dynamically resize cache tensors
- Memory fragmentation when swapping sequences
- Copying overhead

With PagedAttention:
- Just update block mappings (O(1) operation)
- No memory copying
- Perfect for continuous batching

This is why vLLM achieves 10-20x higher throughput than HuggingFace!

