# GPT with KV Cache — A Beginner's Walkthrough

*Based on Sebastian Raschka's ["Build a Large Language Model From Scratch"](https://www.manning.com/books/build-a-large-language-model-from-scratch) — Chapters 1 through 4.*

This notebook brings together concepts from every chapter so far and adds the **KV Cache** optimization that makes real-world LLM inference fast:

| Chapter | What we learned | Where it shows up here |
|---------|----------------|----------------------|
| **Ch 1** | How LLMs work at a high level | The overall generate → predict → append loop |
| **Ch 2** | Tokenization (BPE) and data loading | `tiktoken` converts text ↔ token IDs |
| **Ch 3** | The self-attention mechanism | `MultiHeadAttention` — now enhanced with a KV cache |
| **Ch 4** | The full GPT architecture | `GPTModel` = embeddings + transformer blocks + output head |

> **How to use this notebook:** Run the cells top-to-bottom (`Shift+Enter`). The model uses random weights (not pretrained), so generated text will be gibberish — that's expected! The goal is to understand the **architecture and caching mechanism**, not to produce meaningful text. Training comes in later chapters.

## So what exactly is a KV Cache?

When a GPT model generates text, it produces **one token at a time**. Without a cache, every time we generate a new token, we have to reprocess the *entire* sequence from scratch — that's a lot of wasted computation!

**The KV Cache stores the Keys and Values** from previous tokens so we never recompute them. Think of it like taking notes during a lecture — instead of re-listening to everything from the start each time you want to write the next sentence, you just look at your notes (the cache) and only process the new information.

### With KV cache (fast):
```
Step 1 (Prefill):  "Hello, I am" → process all 4 tokens, cache their K,V
Step 2 (Decode):   " a"          → process ONLY this 1 token, attend to cached K,V
Step 3 (Decode):   " large"      → process ONLY this 1 token, attend to cached K,V
...each step is O(1) new computation
```

### Without cache (slow):
```
Step 1: "Hello, I am"           → process 4 tokens
Step 2: "Hello, I am a"         → reprocess ALL 5 tokens from scratch
Step 3: "Hello, I am a large"   → reprocess ALL 6 tokens from scratch
...each step gets slower!
```

This is a simplified version of `gpt_with_kv_cache.py` where the cache is **always on** — no toggle flags — so you can focus purely on understanding the mechanism.


In [None]:
import time       # To measure how fast generation is
import tiktoken   # OpenAI's tokenizer — converts text ↔ token IDs
import torch      # PyTorch — the deep learning framework
import torch.nn as nn  # Neural network building blocks (Linear, Embedding, etc.)


## MultiHeadAttention with KV Cache

This is the core of the KV cache mechanism. Here's the key insight:

In standard attention (no cache), **every forward pass** computes Keys, Values, and Queries for **all** tokens in the input. During generation, this means recomputing K,V for tokens we've already seen.

With a KV cache:
- We **only** compute K, V, Q for the **new** token(s) coming in
- We **append** the new K, V to a stored cache
- Queries attend to **all** cached Keys/Values (old + new)

### Visual: What happens at each decode step

```
Cache before:  [K₀, K₁, K₂]     (keys from previous tokens)
New input:     [K₃]              (key from the new token)
Cache after:   [K₀, K₁, K₂, K₃] (append the new key)

Query Q₃ attends to → [K₀, K₁, K₂, K₃]  (full history!)
```

### The "window" concept

The cache has a fixed size (`window_size`). If it fills up, the oldest entries are discarded (shifted out). This is like a sliding window over the conversation history.

```
Window size = 4
Cache: [K₀, K₁, K₂, K₃]  ← full!
New token K₄ arrives...
Cache: [K₁, K₂, K₃, K₄]  ← K₀ was discarded (shifted left)
```


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, window_size=None):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Each head works on a slice of the embedding

        # These are the learned projection matrices that transform input into Q, K, V
        # Think of them as "what am I looking for?" (Q), "what do I contain?" (K), "what do I offer?" (V)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Combines all heads back together
        self.dropout = nn.Dropout(dropout)

        # ========== KV CACHE SETUP ==========
        # window_size = how many tokens the cache can hold at once
        # cache_k, cache_v = the actual storage tensors (initialized to None, created on first use)
        self.window_size = window_size or context_length
        self.register_buffer("cache_k", None, persistent=False)  # persistent=False → not saved to disk
        self.register_buffer("cache_v", None, persistent=False)

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # b = batch size, num_tokens = how many NEW tokens we're processing this call
        # During prefill: num_tokens = length of prompt (e.g., 4 for "Hello, I am")
        # During decode:  num_tokens = 1 (just the latest generated token)

        assert num_tokens <= self.window_size, (
            f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size})."
        )

        # Step 1: Project the NEW input tokens into keys, values, queries
        # These are ONLY for the new tokens — we don't recompute old ones!
        keys_new = self.W_key(x)       # (b, num_tokens, d_out)
        values_new = self.W_value(x)   # (b, num_tokens, d_out)
        queries = self.W_query(x)      # (b, num_tokens, d_out)

        # Reshape to split into multiple heads:
        # (b, num_tokens, d_out) → (b, num_tokens, num_heads, head_dim) → (b, num_heads, num_tokens, head_dim)
        # The transpose puts num_heads before num_tokens so each head can process independently
        keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)

        # ========== STEP 2: APPEND NEW K,V TO THE CACHE ==========

        # First call? Create the cache tensors (pre-allocated to window_size for efficiency)
        if self.cache_k is None or self.cache_k.size(0) != b:
            # Shape: (batch, num_heads, window_size, head_dim)
            # The cache is like a fixed-size array with a write pointer (ptr_cur)
            self.cache_k = torch.zeros(
                b, self.num_heads, self.window_size, self.head_dim, device=x.device
            )
            self.cache_v = torch.zeros_like(self.cache_k)
            self.ptr_cur = 0  # ptr_cur = "next free slot" in the cache
            # Example: ptr_cur=0 means cache is empty
            #          ptr_cur=3 means slots 0,1,2 are filled

        # Would the new tokens overflow the cache? If so, shift left to make room
        # This discards the OLDEST tokens (like a sliding window)
        if self.ptr_cur + num_tokens > self.window_size:
            overflow = self.ptr_cur + num_tokens - self.window_size
            # Shift everything left by 'overflow' positions
            # Before: [tok0, tok1, tok2, tok3, tok4] with overflow=2
            # After:  [tok2, tok3, tok4, ___, ___]   (tok0, tok1 discarded)
            self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
            self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
            self.ptr_cur -= overflow

        # Write the new keys and values into the cache at the current pointer position
        self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
        self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
        self.ptr_cur += num_tokens  # Advance the pointer past what we just wrote

        # ========== STEP 3: READ ALL CACHED K,V FOR ATTENTION ==========
        # Slice from 0 to ptr_cur → this is ALL the keys/values we've seen so far
        keys = self.cache_k[:, :, :self.ptr_cur, :]    # (b, num_heads, total_cached, head_dim)
        values = self.cache_v[:, :, :self.ptr_cur, :]   # (b, num_heads, total_cached, head_dim)

        # ========== STEP 4: COMPUTE ATTENTION SCORES ==========
        # queries shape: (b, num_heads, num_tokens, head_dim)   ← only NEW tokens
        # keys shape:    (b, num_heads, total_cached, head_dim) ← ALL cached tokens
        # Result:        (b, num_heads, num_tokens, total_cached)
        # Each new query gets a score against EVERY cached key
        attn_scores = queries @ keys.transpose(2, 3)

        K = attn_scores.size(-1)  # K = total number of cached tokens

        # ========== STEP 5: APPLY CAUSAL MASK ==========
        # Causal mask ensures a token can only attend to tokens that came BEFORE it
        # (no peeking at the future!)
        #
        # The tricky part with a cache: our queries start at position 'offset' in the sequence,
        # not at position 0. So we need to shift the mask accordingly.
        #
        # Example: cache has [tok0, tok1, tok2] and we're processing [tok3]
        #   offset = 3 (3 tokens were cached before this one)
        #   tok3 (query row 0) can attend to positions 0,1,2,3 → all of them
        #   So causal_mask is all False → no masking needed (tok3 can see everything before it)
        #
        # Example: prefill with [tok0, tok1, tok2] (no prior cache)
        #   offset = 0
        #   tok0 (row 0) can attend to [tok0] only
        #   tok1 (row 1) can attend to [tok0, tok1]
        #   tok2 (row 2) can attend to [tok0, tok1, tok2]
        offset = K - num_tokens
        row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1)  # (num_tokens, 1)
        col_idx = torch.arange(K, device=x.device).unsqueeze(0)           # (1, K)
        causal_mask = (row_idx + offset) < col_idx  # True = "block this position" (can't see the future)

        # Replace masked positions with -infinity so softmax gives them 0 weight
        attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)

        # ========== STEP 6: SOFTMAX + WEIGHTED SUM ==========
        # Scale by sqrt(head_dim) to keep gradients stable, then softmax to get attention weights
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Multiply attention weights by values to get the output
        # (b, num_heads, num_tokens, K) @ (b, num_heads, K, head_dim) → (b, num_heads, num_tokens, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine all heads back into a single vector per token
        # (b, num_tokens, num_heads, head_dim) → (b, num_tokens, d_out)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # Final linear projection

        return context_vec

    def reset_cache(self):
        """Clear the cache — call this before starting a new generation."""
        self.cache_k, self.cache_v = None, None


## Wait — What Does the Cache Actually Save?

A common misconception is that the KV cache skips the attention computation entirely after the first call. **It doesn't.** The dot product (Q × K) and the weighted sum (weights × V) happen **every single call**. Here's what actually changes:

### Without cache (old way) — generating the 5th token

The model reprocesses the **entire** sequence from scratch:

```
Input to model: [tok₀, tok₁, tok₂, tok₃, tok₄]        ← all 5 tokens fed in

W_key(x)   runs on 5 tokens → K₀, K₁, K₂, K₃, K₄     ← RECOMPUTED from scratch
W_value(x) runs on 5 tokens → V₀, V₁, V₂, V₃, V₄     ← RECOMPUTED from scratch
W_query(x) runs on 5 tokens → Q₀, Q₁, Q₂, Q₃, Q₄     ← RECOMPUTED from scratch

Attention: all Q's @ all K's  → (5 × 5) score matrix   ← 25 dot products
Softmax → weighted sum of V's → 5 context vectors

But we only CARE about the last row (Q₄'s output) to predict the next token!
The other 4 rows (Q₀-Q₃) are completely WASTED computation.
```

### With cache — generating the 5th token

The model only processes the **1 new token**:

```
Input to model: [tok₄]                                  ← just 1 token!

W_key(x)   runs on 1 token  → K₄                        ← only the new one
W_value(x) runs on 1 token  → V₄                        ← only the new one
W_query(x) runs on 1 token  → Q₄                        ← only the new one

Cache already has: [K₀, K₁, K₂, K₃] from previous calls
Append K₄, V₄ → cache now: [K₀, K₁, K₂, K₃, K₄]

Attention: Q₄ @ all K's  → (1 × 5) score vector         ← 5 dot products (not 25!)
Softmax → weighted sum of V's → 1 context vector
```

### Side-by-side comparison

```
                            Without Cache     With Cache      Savings
                            ─────────────     ──────────      ───────
W_key linear layer runs:    5 tokens          1 token         4 tokens saved ✓
W_value linear layer runs:  5 tokens          1 token         4 tokens saved ✓
W_query linear layer runs:  5 tokens          1 token         4 tokens saved ✓
Attention score matrix:     5 × 5 = 25        1 × 5 = 5      20 ops saved ✓
Context vectors computed:   5 (4 thrown away)  1 (the one      4 wasted vectors
                                               we need)        avoided ✓
```

### The key insight

The cache saves you from:
1. **Recomputing K and V** for old tokens (read them from cache instead — this is the "KV" in "KV cache")
2. **Computing Q for old tokens** (we don't need their queries — we only care about the new token's query)
3. **Computing attention rows we'd throw away** (without cache, you get a 5×5 matrix but only use the last row)

What the cache does **NOT** skip:
- The new token's Q still does a dot product against **ALL** cached keys — that (1 × total_cached) multiplication happens every decode step
- Softmax, scaling, and the weighted sum of values still happen every step

This is why the line `attn_scores = queries @ keys.transpose(2, 3)` in the code above works identically in both cases — it's always "queries dot-product with all keys." The difference is just whether `queries` has 5 rows or 1 row.

### A helpful analogy

Imagine you're a teacher grading essays:

**Without cache (old way):** Every time a new student submits an essay, you re-read ALL previously submitted essays before reading the new one. You take notes on every essay, but then throw away all notes except the ones for the latest essay.

**With cache:** You keep your notes from all previous essays in a folder (the cache). When a new essay arrives, you only read THAT one essay, add notes to your folder, then compare your new notes against ALL your previous notes to give a grade. You still look through all your notes (the dot product) — but you don't re-read the old essays (no recomputation of K, V).


## Why the Causal Mask is Trickier with a KV Cache

In Chapter 3, we learned about the **causal mask** — it prevents tokens from "peeking" at future tokens during attention. Without the KV cache, building this mask was simple. With the cache, it's not — and understanding why is key to understanding the code.

### Without cache (old way) — simple!

All tokens are processed together, so Q and K have the **same number of rows**. The mask is a square upper-triangular matrix:

```
Processing all at once: [tok₀, tok₁, tok₂, tok₃]

Q has 4 rows, K has 4 rows → attention score matrix is (4 × 4)

Mask (✓ = can attend, ✗ = blocked):

          K₀   K₁   K₂   K₃
   Q₀  [  ✓    ✗    ✗    ✗  ]   ← tok₀ only sees itself
   Q₁  [  ✓    ✓    ✗    ✗  ]   ← tok₁ sees tok₀, tok₁
   Q₂  [  ✓    ✓    ✓    ✗  ]   ← tok₂ sees tok₀, tok₁, tok₂
   Q₃  [  ✓    ✓    ✓    ✓  ]   ← tok₃ sees everything before it

Code: mask = torch.triu(torch.ones(4, 4), diagonal=1)
      ↑ that's it! One line. The matrix is square, diagonal starts at 1.
```

### With cache (during decode) — the matrix isn't even square!

During decode, we have **1 query** (the new token) but the keys include **all cached tokens**. The attention matrix is `(1 × total_cached)`, not square:

```
Cache has: [K₀, K₁, K₂, K₃]   ← 4 tokens from prefill
New token: [tok₄]              ← 1 new token, Q₄ is the only query

Q has 1 row, K has 5 columns → attention score matrix is (1 × 5)

          K₀   K₁   K₂   K₃   K₄
   Q₄  [  ✓    ✓    ✓    ✓    ✓  ]   ← tok₄ can see EVERYTHING (it comes last)

torch.triu DOESN'T WORK here — the matrix isn't square!
You can't build an upper-triangular mask on a (1 × 5) matrix.
```

In this case, the mask is all ✓ (no blocking) because tok₄ is the latest token — it should be able to attend to all previous tokens. Simple enough.

### With cache (during prefill) — offset matters!

It gets more interesting during prefill when we process multiple tokens. Imagine we process 3 tokens, but 2 were already cached:

```
Cache before: [K₀, K₁]          ← 2 tokens already cached
New tokens:   [tok₂, tok₃, tok₄] ← 3 new tokens being prefilled

Q has 3 rows, K has 5 columns → attention score matrix is (3 × 5)

          K₀   K₁   K₂   K₃   K₄
   Q₂  [  ✓    ✓    ✓    ✗    ✗  ]   ← tok₂ sees tok₀-₂, NOT tok₃ or tok₄
   Q₃  [  ✓    ✓    ✓    ✓    ✗  ]   ← tok₃ sees tok₀-₃, NOT tok₄
   Q₄  [  ✓    ✓    ✓    ✓    ✓  ]   ← tok₄ sees everything

Notice: Q₂ is NOT the 0th token in the sequence — it's the 2nd!
The mask diagonal needs to be SHIFTED by 2 (the offset).
```

### The offset trick in code

This is what these three lines in `MultiHeadAttention.forward()` do:

```python
offset = K - num_tokens                                    # = 5 - 3 = 2
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1)  # [0, 1, 2] (query positions)
col_idx = torch.arange(K, device=x.device).unsqueeze(0)           # [0, 1, 2, 3, 4] (key positions)
causal_mask = (row_idx + offset) < col_idx                 # True = blocked
```

Let's trace through it for the example above (3 new queries, 5 total keys, offset=2):

```
row_idx + offset:  [0+2, 1+2, 2+2] = [2, 3, 4]   ← real positions of our queries

Comparison: (row_idx + offset) < col_idx

              col_idx →    0      1      2      3      4
row+offset=2          [  2<0?   2<1?   2<2?   2<3?   2<4?  ]   = [F, F, F, T, T] → block col 3,4 ✓
row+offset=3          [  3<0?   3<1?   3<2?   3<3?   3<4?  ]   = [F, F, F, F, T] → block col 4   ✓
row+offset=4          [  4<0?   4<1?   4<2?   4<3?   4<4?  ]   = [F, F, F, F, F] → block nothing ✓

This gives the exact mask we want:
          K₀   K₁   K₂   K₃   K₄
   Q₂  [  ✓    ✓    ✓    ✗    ✗  ]
   Q₃  [  ✓    ✓    ✓    ✓    ✗  ]
   Q₄  [  ✓    ✓    ✓    ✓    ✓  ]
```

### Why `torch.triu` can't do this

`torch.triu` only works on **square** matrices and always puts the diagonal starting from position (0,0). With the KV cache:
- The matrix is often **not square** (1 query × many keys during decode)
- The diagonal needs to be **shifted** by the number of previously cached tokens
- The offset changes every decode step as the cache grows

That's why the code uses the `row_idx + offset < col_idx` formula instead — it handles all three cases (prefill from empty, prefill with existing cache, single-token decode) with one simple expression.


## Transformer Building Blocks (Chapter 4)

Now that we understand the attention mechanism with its KV cache (the hard part!), the remaining pieces of the transformer are unchanged from Chapter 4. These components process each token's representation **after** attention has mixed information across tokens.

Quick refresher on what each does:

- **LayerNorm**: Normalizes each token's embedding to have mean=0, variance=1. Helps training stability. Think of it as "resetting the scale" before each sub-layer.
- **GELU**: The activation function (like ReLU but smoother). Adds non-linearity so the network can learn complex patterns.
- **FeedForward**: A two-layer MLP that processes each token independently. If attention is "tokens talking to each other", FeedForward is "each token thinking on its own."


In [None]:
class LayerNorm(nn.Module):
    """Normalizes activations across the embedding dimension.
    Makes training more stable by keeping values in a reasonable range."""
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5  # Small constant to avoid division by zero
        self.scale = nn.Parameter(torch.ones(emb_dim))   # Learnable scale (starts at 1)
        self.shift = nn.Parameter(torch.zeros(emb_dim))  # Learnable shift (starts at 0)

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)  # Normalize to mean=0, var=1
        return self.scale * norm_x + self.shift             # Then scale and shift (learned)


class GELU(nn.Module):
    """Gaussian Error Linear Unit — a smooth activation function.
    Unlike ReLU which has a hard cutoff at 0, GELU has a soft curve,
    which helps with gradient flow during training."""
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))


class FeedForward(nn.Module):
    """Two-layer MLP: expand to 4x the embedding size, apply GELU, then project back.
    This is where each token 'thinks' independently (no interaction between tokens here)."""
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),  # Expand: 768 → 3072
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),  # Contract: 3072 → 768
        )

    def forward(self, x):
        return self.layers(x)


## TransformerBlock & GPTModel

### TransformerBlock
Each block follows the same pattern: **Norm → Attention → Add Residual → Norm → FeedForward → Add Residual**

The residual connections ("shortcuts") add the input back to the output. This helps with training deep networks — if a layer can't learn anything useful, the residual lets the signal pass through unchanged.

### GPTModel — the key change for KV cache

Without a cache, position embeddings are simple: token at index 0 gets position 0, index 1 gets position 1, etc.

With a cache, we need to **remember where we left off**. If we prefilled 4 tokens (positions 0-3), then the next generated token should get position 4, not position 0.

That's what `ptr_current_pos` does — it tracks the next position to assign.

```
Prefill "Hello, I am" → positions [0, 1, 2, 3], ptr_current_pos becomes 4
Decode next token     → position  [4],           ptr_current_pos becomes 5
Decode next token     → position  [5],           ptr_current_pos becomes 6
```

Also notice we use `nn.ModuleList` instead of `nn.Sequential` so we can loop through blocks manually — `nn.Sequential` doesn't let us pass extra info between blocks if we ever needed to.


In [None]:
class TransformerBlock(nn.Module):
    """One transformer block = Attention + FeedForward, each with LayerNorm and residual connections.

    The data flow through one block:

        Input x
          │
          ├──────────────┐ (save for residual)
          ▼              │
        LayerNorm        │
          ▼              │
        Attention (KV cached) │
          ▼              │
        Dropout          │
          ▼              │
        + ◄──────────────┘ (add residual — "shortcut connection")
          │
          ├──────────────┐ (save for residual)
          ▼              │
        LayerNorm        │
          ▼              │
        FeedForward      │
          ▼              │
        Dropout          │
          ▼              │
        + ◄──────────────┘ (add residual)
          │
        Output
    """
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"],
            window_size=cfg.get("kv_window_size", cfg["context_length"])
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # --- Attention sub-block with residual ---
        shortcut = x              # Save input for the residual connection
        x = self.norm1(x)         # Normalize before attention (Pre-LayerNorm)
        x = self.att(x)           # Multi-head attention — KV cache handles the rest
        x = self.drop_shortcut(x) # Dropout for regularization
        x = x + shortcut          # Add residual: output = attention(x) + x

        # --- FeedForward sub-block with residual ---
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)            # Each token processes independently through the MLP
        x = self.drop_shortcut(x)
        x = x + shortcut

        return x


In [None]:
class GPTModel(nn.Module):
    """The full GPT model: Token Embeddings + Position Embeddings → N Transformer Blocks → Output Logits.

    This ties together everything from Chapters 2-4 of "Build a Large Language Model From Scratch":
      - Ch 2: Tokenization and data loading (GPTDatasetV1, tiktoken)
      - Ch 3: The attention mechanism (MultiHeadAttention — now with KV cache)
      - Ch 4: The full GPT architecture (this class)

    The KV cache changes two things here:
      1. We use ModuleList (not Sequential) so we can loop through blocks explicitly
      2. We track ptr_current_pos to give each token the correct position embedding,
         even when tokens arrive one at a time across multiple forward() calls
    """
    def __init__(self, cfg):
        super().__init__()

        # Token embedding: converts token IDs (integers) into dense vectors
        # Vocabulary of 50,257 tokens, each mapped to a 768-dim vector
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])

        # Position embedding: encodes WHERE a token is in the sequence
        # Position 0 gets one vector, position 1 gets another, etc.
        # This is how the model knows word order (attention itself is order-agnostic)
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])

        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        # Stack of transformer blocks — GPT-2 124M has 12 of these
        # Each block refines the token representations through attention + feedforward
        self.trf_blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = LayerNorm(cfg["emb_dim"])

        # Output head: projects from embedding space (768) back to vocabulary space (50,257)
        # The logits tell us how likely each token in the vocabulary is to come next
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

        self.kv_window_size = cfg.get("kv_window_size", cfg["context_length"])

        # ptr_current_pos: tracks the absolute position in the sequence
        # This is crucial for the KV cache — when we generate token by token,
        # each new token needs to know its position in the overall sequence
        self.ptr_current_pos = 0

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape

        # Step 1: Convert token IDs to embeddings
        # "Hello" (token 15496) → [0.12, -0.34, 0.56, ...] (768 numbers)
        tok_embeds = self.tok_emb(in_idx)  # (batch, seq_len, emb_dim)

        # Step 2: Create position embeddings starting from ptr_current_pos
        # Without cache: always start from position 0
        # With cache: continue from where we left off
        context_length = self.pos_emb.num_embeddings  # max positions (1024)
        assert self.ptr_current_pos + seq_len <= context_length, (
            f"Position overflow: want position {self.ptr_current_pos + seq_len}, max is {context_length}"
        )

        # Generate position IDs: e.g., if ptr_current_pos=4 and seq_len=1, pos_ids = [4]
        pos_ids = torch.arange(
            self.ptr_current_pos, self.ptr_current_pos + seq_len,
            device=in_idx.device, dtype=torch.long
        )
        self.ptr_current_pos += seq_len  # Advance for the next call

        pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)  # (1, seq_len, emb_dim)

        # Step 3: Combine token + position embeddings
        # The model now knows WHAT each token is AND WHERE it is
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)

        # Step 4: Pass through all transformer blocks
        # Each block: Attention (with KV cache) → FeedForward → residual connections
        for blk in self.trf_blocks:
            x = blk(x)

        # Step 5: Final layer norm + project to vocabulary logits
        x = self.final_norm(x)
        logits = self.out_head(x)  # (batch, seq_len, vocab_size)
        # logits[0, -1, :] = probability scores for the NEXT token after the last input token
        return logits

    def reset_kv_cache(self):
        """Reset all caches and position pointer — call before each new generation."""
        for blk in self.trf_blocks:
            blk.att.reset_cache()
        self.ptr_current_pos = 0


## Text Generation with KV Cache

Generation happens in two distinct phases:

### Phase 1: Prefill
Feed the **entire prompt** through the model in one shot (or in chunks if the prompt is very long). This populates the KV cache with keys and values for every prompt token.

```
Prompt: "Hello, I am"  (4 tokens)
→ model("Hello, I am")
→ Cache now holds K,V for all 4 tokens
→ logits tell us what comes after "am"
```

### Phase 2: Decode (autoregressive loop)
Generate **one token at a time**. Each step:
1. Take the logits from the previous step, pick the most likely next token
2. Feed **only that one token** into the model
3. The model computes Q, K, V for just that token, appends K,V to cache, and attends to the full cached history
4. Get new logits, repeat

```
Step 1: model("_a")      → cache: [Hello, ,, I, am, _a]      → predicts " large"
Step 2: model("_large")  → cache: [Hello, ,, I, am, _a, _large] → predicts " language"
...
```

This is where the KV cache saves time: without it, step 2 would need to process `["Hello", ",", "I", "am", " a", " large"]` — all 6 tokens. With the cache, it only processes `[" large"]` — 1 token — and reads the other 5 from cache.


In [None]:
def generate_text_cached(model, idx, max_new_tokens, context_size=None):
    """Generate tokens one at a time using the KV cache.

    Args:
        model:          The GPT model (with KV cache built in)
        idx:            Starting token IDs, shape (batch, prompt_length)
        max_new_tokens: How many new tokens to generate
        context_size:   Max sequence length (defaults to model's context_length)

    Returns:
        Token IDs tensor with the prompt + all generated tokens appended
    """
    model.eval()  # Turn off dropout (we're generating, not training)

    ctx_len = context_size or model.pos_emb.num_embeddings  # 1024 for GPT-2
    kv_window_size = model.kv_window_size

    with torch.no_grad():  # No gradients needed during generation (saves memory)

        # ============ PHASE 1: PREFILL ============
        # Process the entire prompt to populate the KV cache.
        # If the prompt is longer than kv_window_size, process it in chunks.
        model.reset_kv_cache()  # Start fresh — clear any leftover cache

        input_tokens = idx[:, -ctx_len:]  # Truncate prompt if it exceeds context length
        input_tokens_length = input_tokens.size(1)

        # Feed prompt through the model (populates cache with K,V for all prompt tokens)
        for i in range(0, input_tokens_length, kv_window_size):
            chunk = input_tokens[:, i:i + kv_window_size]
            logits = model(chunk)
            # After this loop, the cache holds K,V for the entire prompt
            # and 'logits' contains predictions for what comes after the last prompt token

        # ============ PHASE 2: DECODE (one token at a time) ============
        # We can generate at most (context_length - prompt_length) new tokens
        # because position embeddings have a fixed size
        max_generable = ctx_len - input_tokens_length
        max_new_tokens = min(max_new_tokens, max_generable)

        for _ in range(max_new_tokens):
            # Pick the most likely next token from the logits
            # logits[:, -1] = predictions for the position after the last token
            # argmax picks the token with the highest score (greedy decoding)
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)  # (batch, 1)

            # Append the new token to our running sequence
            idx = torch.cat([idx, next_idx], dim=1)

            # Feed ONLY the new token into the model
            # The cache already has K,V for all previous tokens —
            # we just need to compute K,V for this one new token and append it
            logits = model(next_idx)

    return idx


## Putting It All Together — Run the Model

Now we create a GPT-2 124M model (random weights — not pretrained) and generate text. The output will be gibberish since we haven't trained the model, but the architecture and KV cache mechanism are exactly the same as the real GPT-2.

### GPT-2 124M Configuration
| Parameter | Value | Meaning |
|-----------|-------|---------|
| vocab_size | 50,257 | Number of unique tokens (words/subwords) the model knows |
| context_length | 1,024 | Maximum sequence length the model can handle |
| emb_dim | 768 | Size of each token's internal representation |
| n_heads | 12 | Number of attention heads (each looks at different patterns) |
| n_layers | 12 | Number of transformer blocks stacked on top of each other |
| drop_rate | 0.1 | 10% dropout for regularization during training |
| kv_window_size | 1,024 | How many tokens the KV cache can hold (= context_length here) |


In [None]:
# GPT-2 124M configuration — same architecture as OpenAI's smallest GPT-2
GPT_CONFIG_124M = {
    "vocab_size": 50257,      # BPE vocabulary size (Ch 2: Byte Pair Encoding tokenizer)
    "context_length": 1024,   # Max sequence length the model can process
    "emb_dim": 768,           # Embedding dimension (each token = 768-dim vector)
    "n_heads": 12,            # Number of attention heads (Ch 3: Multi-Head Attention)
    "n_layers": 12,           # Number of transformer blocks (Ch 4: full GPT architecture)
    "drop_rate": 0.1,         # Dropout rate (10%) for regularization
    "qkv_bias": False,        # No bias in Q, K, V projections (GPT-2 style)
    "kv_window_size": 1024,   # KV cache window = context_length (no sliding window)
}

# Set random seed for reproducibility
torch.manual_seed(123)

# Create the model and move to GPU if available
model = GPTModel(GPT_CONFIG_124M)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  # Switch to evaluation mode (disables dropout)

print(f"Model on: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
# ~124 million parameters — hence "GPT-2 124M"


In [None]:
# Tokenize the input text using tiktoken (Ch 2: BPE tokenizer)
# This converts human-readable text into token IDs that the model understands
start_context = "Hello, I am"

tokenizer = tiktoken.get_encoding("gpt2")  # Use GPT-2's tokenizer
encoded = tokenizer.encode(start_context)   # "Hello, I am" → [15496, 11, 314, 716]
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)  # Add batch dimension

print("Input text:", start_context)
print("Encoded token IDs:", encoded)
print("Tensor shape:", encoded_tensor.shape)  # (1, 4) = batch of 1, prompt of 4 tokens


In [None]:
# Generate text using the KV cache!
# The model has random weights so the output will be gibberish,
# but this demonstrates the full generation pipeline:
#   1. Prefill: "Hello, I am" → cache K,V for 4 tokens
#   2. Decode: generate 200 tokens one at a time, each reading from cache

if torch.cuda.is_available():
    torch.cuda.synchronize()
start = time.time()

token_ids = generate_text_cached(
    model=model,
    idx=encoded_tensor,
    max_new_tokens=200,
)

if torch.cuda.is_available():
    torch.cuda.synchronize()
total_time = time.time() - start

# Decode token IDs back to human-readable text
decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())

print(f"Output length: {len(token_ids[0])} tokens (4 prompt + 200 generated)")
print(f"Time: {total_time:.2f} sec")
print(f"Speed: {int(len(token_ids[0])/total_time)} tokens/sec")
print(f"\nGenerated text:\n{decoded_text}")

if torch.cuda.is_available():
    max_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
    print(f"\nMax GPU memory: {max_mem_gb:.2f} GB")


## Beyond the Code: How the KV Cache Powers Safety in Production LLMs

Everything we built above isn't just an academic exercise — the KV cache is at the heart of how production LLMs like ChatGPT, Claude, and Gemini enforce their safety behavior. Here's the connection.

### System Prompts: Safety Instructions as Tokens

When you chat with an LLM through an API or app, the model doesn't just see your message. Behind the scenes, a **system prompt** (sometimes called a "safety prompt") is prepended before your input. It's just regular text — plain tokens — that tells the model how to behave.

```
What actually enters the model:
┌─────────────────────────────────────────────────────────────────┐
│ [System Prompt Tokens]   [User Prompt Tokens]   [Generation]   │
│                                                                 │
│ "You are a helpful        "Hello, I am"          " a helpful    │
│  assistant. You must                               AI assistant │
│  not produce harmful                               made by..."  │
│  content. Always be                                             │
│  respectful and..."                                             │
│                                                                 │
│  (could be 500-2000+      (your message)         (model output) │
│   tokens!)                                                      │
└─────────────────────────────────────────────────────────────────┘
```

### System Prompts Live in the KV Cache

During the **prefill phase** (the first `for` loop in our `generate_text_cached` function), the system prompt tokens are processed just like any other tokens. Their Keys and Values are computed and written into `cache_k` and `cache_v`.

From that point on, **every single generated token attends to the system prompt's K,V** through the attention mechanism. This is how the model "remembers" its safety instructions throughout an entire conversation — the system prompt isn't magic, it's just tokens sitting in the KV cache that every new token pays attention to.

```python
# What really happens inside generate_text_cached in production:

# Phase 1: Prefill — system prompt + user prompt get cached
full_prompt = system_prompt_tokens + user_message_tokens
logits = model(full_prompt)  # K,V for ALL tokens (including safety instructions) now in cache

# Phase 2: Decode — every generated token attends to the cached system prompt
for _ in range(max_tokens):
    next_token = logits[:, -1].argmax(...)
    logits = model(next_token)  # This 1 token's query attends to ALL cached K,V
                                 # including the system prompt's K,V
                                 # → the model "sees" its safety instructions on every step
```

### The Real-World Cost: System Prompts Eat Your Context Window

This has a direct practical impact. If the system prompt is 1,000 tokens and your context window is 4,096 tokens:

```
Total context window:     4,096 tokens
System prompt:           -1,000 tokens (safety instructions)
Available for you:        3,096 tokens (your conversation + generated output)

That's ~25% of the context used up before you type a single word!
```

This is why companies invest heavily in optimizing this. The KV cache mechanisms we learned above directly enable these real-world optimizations:

### Industry Optimization 1: Prompt Caching

The system prompt is **identical** across millions of requests. So why recompute its K,V every time?

**Prompt caching** (used by Anthropic, OpenAI, and others) pre-computes the system prompt's Keys and Values once and reuses them across requests. The first user pays the prefill cost; subsequent users with the same system prompt skip it entirely.

```
User A: [system prompt K,V computed] + "What is Python?" → response
User B: [system prompt K,V REUSED]   + "Tell me a joke" → response  ← saved ~1000 tokens of compute!
User C: [system prompt K,V REUSED]   + "Write an email" → response  ← saved again!
```

This is exactly the same cache mechanism from our `MultiHeadAttention.forward()` — just shared across users instead of across decode steps.

### Industry Optimization 2: Prefix Sharing

Multiple concurrent conversations that start with the same system prompt can **share the same KV cache memory on the GPU**, rather than duplicating it per request. With thousands of simultaneous users, this saves gigabytes of GPU memory.

```
GPU Memory without prefix sharing:
  User A: [system_KV (copy 1)] + [user_A_KV]     ← 1000 tokens of K,V duplicated
  User B: [system_KV (copy 2)] + [user_B_KV]     ← same 1000 tokens again
  User C: [system_KV (copy 3)] + [user_C_KV]     ← and again...

GPU Memory with prefix sharing:
  Shared: [system_KV (1 copy)]                    ← stored once
  User A:                       + [user_A_KV]     ← only unique K,V per user
  User B:                       + [user_B_KV]
  User C:                       + [user_C_KV]
```

### Key Takeaway

The KV cache we implemented in this notebook is the **same fundamental mechanism** that powers safety, efficiency, and scalability in production LLMs. System prompts aren't a separate safety system — they're just tokens whose Keys and Values live in the cache, influencing every generated token through the attention mechanism we coded in `MultiHeadAttention.forward()`.

Understanding this connection helps explain:
- **Why LLMs sometimes "forget" safety instructions** in very long conversations (the system prompt's influence gets diluted as the cache fills with conversation tokens)
- **Why context window size matters so much** (system prompt + conversation must all fit)
- **Why API providers charge differently for cached vs. uncached tokens** (the compute cost is genuinely different)
- **Why prompt injection attacks are a concern** (adversarial user tokens in the cache compete for attention with the system prompt tokens)