# KV caching
> KV Caching is a cornerstone for efficient inference in modern autoregressive Transformer models.

## Understanding KV Caching

### Autoregressive Generation

**What is Autoregressive Generation?**i

Autoregressive models, like Large Language Models (LLMs), generate sequences one token (e.g., a word or sub-word) at a time. Each new token prediction depends on all previously generated tokens. Think of writing a story: to choose the next word, you consider the entire story written so far.


The self-attention mechanism in Transformers is key to their power. When generating the $t$-th token, the model needs to "attend" to all previous tokens $1, \dots, t-1$. Specifically, the **Query (Q)** vector for the current token interacts with the **Key (K)** and **Value (V)** vectors of all tokens in the current context (itself included).

*Q*: Represents the current token asking a question: "What information is relevant to me?"

*K*: Represents all tokens in the context offering their "attributes" or "topics."

V: Represents all tokens in the context offering their "content" or "meaning."

### Core idea of KV Caching
1. Once the Key (K) and Value (V) vectors for a token are computed, cache them.
2. When generating the next token:
    - Compute Q, K, V only for the *new, current* token.
    - Append the new K and V to the *cached* K's and V's from previous tokens.
    - Perform the attention calculation using the current Q and the full (cached + new) K's and V's.

Let $x_t$ be the embedding of the $t$-th token. The Query, Key, and Value vectors are typically linear projections:

$q_t = x_t W_Q$

$k_t = x_t W_K$

$v_t = x_t W_V$

Where $W_Q, W_K, W_V$ are learnable weight matrices.

At generation step $t$:

1. Compute $q_t, k_t, v_t$ for the current token $x_t$.
2. Retrieve cached keys $K_{cache} = [k_1, \dots, k_{t-1}]$ and values $V_{cache} = [v_1, \dots, v_{t-1}]$.
3. Form the full key and value sequences for attention:
$\mathbf{K_{total}} = \text{concat}(K_{cache}, k_t) = [k_1, \dots, k_{t-1}, k_t]$
$\mathbf{V_{total}} = \text{concat}(V_{cache}, v_t) = [v_1, \dots, v_{t-1}, v_t]$
4. The attention output for the $t$-th token is then computed as:
$\text{AttentionOutput}*t = \text{softmax}\left(\frac{q_t \mathbf{K*{total}}^T}{\sqrt{d_k}}\right) \mathbf{V_{total}}$
(where $d_k$ is the dimension of key vectors).
5. The updated cache for the next step will be $(K_{total}, V_{total})$.

### Get Ready

Input tensors (`query`, `key`, `value`) have 4-dimensional shape: `[B, N_h, S, D_h]`
1. B - Batch Size
  - Just the number of independent sequences to process simultaneously.
  - Each sequence in the batch is processed in parallel / does NOT interact with others.
2. `N_h` - Number of attention heads
  - Multi-Head Attention (MHA) allows the model to jointly attend to information from *different* representation subspaces at different positions. Instead of performing one single attention calculation over the full `embed_dim`, the model splits `embed_dim` into `N_h` "heads." Each head has a dimension `D_h` = `embed_dim` / `N_h`.
  - Each head performs its own scaled dot-product attention independently. The outputs of these heads are then typically concatenated and linearly projected to get the final result.
  - hink of it as having `N_h` different "experts" looking at different aspects of the relationships between tokens.



In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math

# Helper function for scaled dot-product attention
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Efficient scaled dot-product attention.
    query: [B, N_h, S_q, D_h]
    key:   [B, N_h, S_k, D_h]
    value: [B, N_h, S_v, D_h] (S_k == S_v)
    mask: typically for causal attention
    """
    matmul_qk = torch.matmul(query, key.transpose(-2, -1))
    d_k = query.size(-1)
    scaled_attention_logits = matmul_qk / math.sqrt(d_k)

    if mask is not None:
        scaled_attention_logits = scaled_attention_logits.masked_fill(mask == 0, -1e9)

    attention_weights = F.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, value)
    return output # [B, N_h, S_q, D_h]

In [11]:
class MultiHeadAttentionNoCache(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x_query, x_kv, causal_mask=False):
        # x_query: [B, S_q, D_model] (e.g., current token(s) for Q)
        # x_kv:    [B, S_kv, D_model] (e.g., full sequence for K, V)
        # For self-attention, x_query and x_kv are often the same.
        # For cross-attention, they can be different.
        # Here, we assume self-attention context where x_query is part of x_kv
        # or represents the queries for the full x_kv context.

        B, S_q, _ = x_query.shape
        _, S_kv, _ = x_kv.shape

        q = self.q_proj(x_query) # [B, S_q, D_model]
        k = self.k_proj(x_kv)    # [B, S_kv, D_model]
        v = self.v_proj(x_kv)    # [B, S_kv, D_model]

        # Reshape and transpose for multi-head: [B, N_h, S, D_h]
        q = q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S_kv, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S_kv, self.num_heads, self.head_dim).transpose(1, 2)

        mask = None
        if causal_mask:
            # Create a causal mask for S_q queries attending to S_kv keys/values
            # This is a general mask; for autoregressive generation, S_q might be 1
            # and S_kv is the total length.
            mask = torch.tril(torch.ones(S_q, S_kv, device=x_query.device)).bool()
            # If S_q != S_kv, this mask needs careful handling based on relative positions.
            # For S_q = 1 (typical generation), this means the query can attend to all S_kv.
            # For S_q = S_kv (prompt processing), this is a standard causal mask.
            if S_q == 1 and S_kv > 1: # single query attending to full history
                 pass # no upper triangular part to mask for a single query
            elif S_q == S_kv: # standard self-attention causal mask
                 mask = torch.tril(torch.ones(S_q, S_kv, device=x_query.device)).view(1, 1, S_q, S_kv)
            else: # More complex cases, let's assume full attention for simplicity here if not square
                  # or handle specific masking outside. For now, just allow all for non-square.
                  pass


        attn_output = scaled_dot_product_attention(q, k, v, mask=mask) # [B, N_h, S_q, D_h]

        # Concatenate heads and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.embed_dim)
        output = self.o_proj(attn_output) # [B, S_q, D_model]

        return output

In [22]:
# --- Example Usage for MultiHeadAttentionNoCache ---

# Parameters
batch_size_l1 = 2
seq_len_l1 = 10      # Sequence length for both query and key/value (self-attention)
embed_dim_l1 = 64
num_heads_l1 = 4
device_l1 = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"\n--- Running Level 1: MultiHeadAttentionNoCache Example ---")
print(f"Using device: {device_l1}")

# Instantiate the module
mha_no_cache_instance = MultiHeadAttentionNoCache(embed_dim_l1, num_heads_l1).to(device_l1)
mha_no_cache_instance.eval() # Set to evaluation mode if not training

# Create dummy input tensor (e.g., embeddings of a sequence)
# For self-attention, x_query and x_kv are typically the same.
input_sequence = torch.randn(batch_size_l1, seq_len_l1, embed_dim_l1, device=device_l1)

# --- Case 1: Self-attention without causal masking (e.g., Encoder block) ---
print(f"\nCase 1: Self-attention, no causal mask")
# Here, x_query and x_kv are the same.
output_self_attn = mha_no_cache_instance(x_query=input_sequence, x_kv=input_sequence, causal_mask=False)
print(f"Input shape: {input_sequence.shape}")
print(f"Output shape (self-attention, no mask): {output_self_attn.shape}")
# Expected output shape: [batch_size_l1, seq_len_l1, embed_dim_l1]

# --- Case 2: Self-attention with causal masking (e.g., Decoder block during training/prompt processing) ---
print(f"\nCase 2: Self-attention, with causal mask")
output_causal_self_attn = mha_no_cache_instance(x_query=input_sequence, x_kv=input_sequence, causal_mask=True)
print(f"Input shape: {input_sequence.shape}")
print(f"Output shape (causal self-attention): {output_causal_self_attn.shape}")
# Expected output shape: [batch_size_l1, seq_len_l1, embed_dim_l1]


# --- Case 3: Simulating a single query token attending to a sequence of KV pairs ---
# This is closer to what happens token-by-token in naive autoregressive generation,
# where the 'query' is the current token and 'kv' is the whole sequence so far.
print(f"\nCase 3: Single query token, attending to full KV sequence (simulating one step of naive generation)")
current_token_embedding = torch.randn(batch_size_l1, 1, embed_dim_l1, device=device_l1) # S_q = 1
full_kv_sequence = torch.randn(batch_size_l1, seq_len_l1, embed_dim_l1, device=device_l1) # S_kv = 10

# In this specific setup for MultiHeadAttentionNoCache, if causal_mask=True and S_q=1,
# the mask generated internally would allow the single query to attend to all S_kv.
# If it were strictly causal based on absolute positions, and the query was "after" S_kv, it would be different.
# But here, the mask is relative to the input S_q and S_kv lengths.
output_single_query = mha_no_cache_instance(x_query=current_token_embedding, x_kv=full_kv_sequence, causal_mask=False) # Causal usually means query can't see "future" keys within x_kv. Here, x_query is separate.
                                                                                                                       # If causal_mask=True, and S_q=1, the generated mask would be [1,1,1,1,....S_kv] for that one query.
print(f"Query shape: {current_token_embedding.shape}")
print(f"KV sequence shape: {full_kv_sequence.shape}")
print(f"Output shape (single query): {output_single_query.shape}")
# Expected output shape: [batch_size_l1, 1, embed_dim_l1]


--- Running Level 1: MultiHeadAttentionNoCache Example ---
Using device: cpu

Case 1: Self-attention, no causal mask
Input shape: torch.Size([2, 10, 64])
Output shape (self-attention, no mask): torch.Size([2, 10, 64])

Case 2: Self-attention, with causal mask
Input shape: torch.Size([2, 10, 64])
Output shape (causal self-attention): torch.Size([2, 10, 64])

Case 3: Single query token, attending to full KV sequence (simulating one step of naive generation)
Query shape: torch.Size([2, 1, 64])
KV sequence shape: torch.Size([2, 10, 64])
Output shape (single query): torch.Size([2, 1, 64])


In [20]:
import torch
test_tensor = torch.randn([2, 3, 4, 5])
test_tensor

tensor([[[[-1.3487,  1.6347,  0.1443,  0.7240,  0.7381],
          [-0.7476, -0.6850,  1.7193,  1.9604,  0.5786],
          [-0.2361, -0.5395,  1.2992, -1.2165,  0.8674],
          [ 0.9757, -0.8619,  1.1845,  0.6143, -0.6978]],

         [[ 0.1655, -1.0265, -1.4400,  1.4185, -1.1353],
          [-0.9438,  1.4995, -2.1421, -1.6823,  1.8388],
          [-0.1459,  1.8395,  0.2912,  0.8873, -1.1434],
          [ 0.2176, -0.6245, -0.1394, -0.4310, -0.8455]],

         [[ 0.9501, -1.0886, -0.2018,  1.3524, -0.9115],
          [ 0.7123,  0.3665,  0.1948,  1.0899,  1.2612],
          [ 0.2031,  0.3736,  0.0554, -0.7537, -0.9879],
          [-1.5952, -1.7867, -0.3117,  0.2865, -1.1933]]],


        [[[-1.5285,  0.9560,  0.9153, -0.3917,  0.3930],
          [-0.6037, -0.3035, -0.4196, -1.4552, -1.7482],
          [ 0.8614, -0.9646, -1.6007, -0.0091, -1.1557],
          [ 0.7585, -1.7625, -1.2060,  0.8237, -0.8793]],

         [[-1.5392,  3.9224,  0.7778,  1.1515,  0.7246],
          [-1.6642, -

In [21]:
torch.matmul(test_tensor, test_tensor.transpose(-2,-1))

tensor([[[[ 5.5808,  1.9831, -0.6166, -2.6242],
          [ 1.9831,  8.1619,  0.8968,  2.6980],
          [-0.6166,  0.8968,  4.2670,  0.4209],
          [-2.6242,  2.6980,  0.4209,  3.9623]],

         [[ 6.4555, -3.0848,  0.2252,  1.2263],
          [-3.0848, 13.9391, -1.3230, -1.6729],
          [ 0.2252, -1.3230,  5.5844, -0.6367],
          [ 1.2263, -1.6729, -0.6367,  1.3574]],

         [[ 4.7883,  0.5628, -0.3438,  1.9674],
          [ 0.5628,  3.4580, -1.7749, -3.0446],
          [-0.3438, -1.7749,  1.7279, -0.0459],
          [ 1.9674, -3.0446, -0.0459,  7.3402]]],


        [[[ 4.3959,  0.1314, -4.1546, -4.6165],
          [ 0.1314,  5.8064,  2.4779,  0.9217],
          [-4.1546,  2.4779,  5.5701,  5.2926],
          [-4.6165,  0.9217,  5.2926,  6.5881]],

         [[20.2104, -0.9571, -3.0663,  0.5464],
          [-0.9571,  7.9613,  1.0569, -0.6435],
          [-3.0663,  1.0569,  4.4052,  0.8783],
          [ 0.5464, -0.6435,  0.8783,  3.4043]],

         [[ 1.0587, -1.0023,