This is not the main notebook in this challenge. Start with `understand-engine.ipynb`.

### Understand KV cache and add support to `my_gpt.py`

I can see I won't get very far with the engine without kv cache support in GPT. I left out that code earlier. I'm not sure if it's better to first understand the `KVCache` class in [engine.py](https://github.com/karpathy/nanochat/blob/master/nanochat/engine.py) or first add kv cache support in `my_gpt.py`. There's a lot of subtle stuff going on in both places.

Actually let me first just think about the concept. In the naive generate, we put in say 3 tokens and we get out 3 distributions over our vocab, we look only at the last one, take say the highest probability one, and that gives us our new token.

Now we put those 4 tokens in, etc.

But when working on those 4 tokens, all the calculations for the first 3 will be exactly the same as before.

We want to avoid repeating the calculations, but we do need all that information to be available during the "mixing" part in self-attention because the value in that 4th position is going to be a mix of values from the first 3. (Being a little loose with language here.)

So maybe what we do is stick only that new 4th token into the machine but pass a cache so that scaled_dot_product_attention() does the right thing. Let's see if I can construct a simple example of this.

In [1]:
import torch
import torch.nn.functional as F
B = 1 # batch size
T = 4 # sequence length
H = 1 # heads
C = 5 # channels per head
q = torch.randn((B, H, T, C)) # keep it like in the real thing even though doing B = 1 and H = 1 here
k = torch.randn((B, H, T, C))
v = torch.randn((B, H, T, C))
output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
output

tensor([[[[ 0.1555,  0.3554,  0.4374, -1.2441, -1.1757],
          [ 0.1047,  0.3228,  0.1903, -1.0041, -0.8421],
          [ 0.4303,  0.9916,  0.2892, -0.8870, -0.5080],
          [ 0.1842,  0.5891, -0.1506, -0.2712,  0.0022]]]])

In [2]:
# and let's say the only thing we really need is that last row
output[:,:,-1:,:]

tensor([[[[ 0.1842,  0.5891, -0.1506, -0.2712,  0.0022]]]])

In [3]:
# can we get that from k, v and q[:,:,-1:,:]?
q_last_token_only = q[:,:,-1:,:]
q_last_token_only

tensor([[[[ 0.5312, -0.4794,  0.2350,  0.0690, -0.6563]]]])

In [4]:
F.scaled_dot_product_attention(q_last_token_only, k, v, is_causal=True)

tensor([[[[ 0.1555,  0.3554,  0.4374, -1.2441, -1.1757]]]])

In [5]:
F.scaled_dot_product_attention(q_last_token_only, k, v, is_causal=False)

tensor([[[[ 0.1842,  0.5891, -0.1506, -0.2712,  0.0022]]]])

Yes? But why only if we pass is_causal=False?

I bet the attention mask that gets built in F.scaled_dot_product_attention isn't right for this situation. Let's see:

In [6]:
import math
# copying from https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value, attn_bias

In [7]:
# first repeat what we did above but using this version, should match
output, attn_bias = scaled_dot_product_attention(q, k, v, is_causal=True)
output

tensor([[[[ 0.1555,  0.3554,  0.4374, -1.2441, -1.1757],
          [ 0.1047,  0.3228,  0.1903, -1.0041, -0.8421],
          [ 0.4303,  0.9916,  0.2892, -0.8870, -0.5080],
          [ 0.1842,  0.5891, -0.1506, -0.2712,  0.0022]]]])

In [8]:
attn_bias

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [9]:
output, attn_bias = scaled_dot_product_attention(q_last_token_only, k, v, is_causal=True)
output # expect NOT to match the last row, just like above

tensor([[[[ 0.1555,  0.3554,  0.4374, -1.2441, -1.1757]]]])

In [10]:
attn_bias

tensor([[0., -inf, -inf, -inf]])

Yes, it assumes we're only allowed to look at the first "token."

So for a single token in q, passing is_causal=False works because it's the same as no mask which is right for the last position. However, it looks like the code in `CausalSelfAttention` constructs its own mask, and that's probably to support a q with multiple tokens.

When would that come up in inference? When you first start you likely have a prompt which means many tokens, but why would you have a KV cache? And later I'm starting to think it goes one token at a time.

Come back to that. Let's see how the mask is constructed and make sure it's the same as no mask when q has length 1.

In [11]:
# Let's start with q having the two final "tokens"
q_last_two_tokens = q[:,:,-2:,:]
# copying code from CausalSelfAttention.forward
Tq = q_last_two_tokens.size(2) # 2
Tk = k.size(2) # 4
attn_mask = torch.zeros((Tq,Tk), dtype=torch.bool)
prefix_len = Tk - Tq # 2
attn_mask[:, :prefix_len] = True    # so whatever is in the prefix we're always allowed to see
attn_mask

tensor([[ True,  True, False, False],
        [ True,  True, False, False]])

In [12]:
# and now we need a triangle for the remaining Tq x Tq
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool))
attn_mask

tensor([[ True,  True,  True, False],
        [ True,  True,  True,  True]])

In [13]:
scaled_dot_product_attention(q_last_two_tokens, k, v, attn_mask=attn_mask)

(tensor([[[[ 0.4303,  0.9916,  0.2892, -0.8870, -0.5080],
           [ 0.1842,  0.5891, -0.1506, -0.2712,  0.0022]]]]),
 tensor([[0., 0., 0., -inf],
         [0., 0., 0., 0.]]))

^ Yes, this matches the last two rows computed the "normal" way above

In [14]:
# and just to be sure, compute the attn_mask when prefix is 3
Tq = q_last_token_only.size(2) # 1
Tk = k.size(2) # 4
attn_mask = torch.zeros((Tq,Tk), dtype=torch.bool)
prefix_len = Tk - Tq # 3
attn_mask[:, :prefix_len] = True
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool))
attn_mask

tensor([[True, True, True, True]])

^ Yes, that seems right

In [15]:
scaled_dot_product_attention(q_last_token_only, k, v, attn_mask=attn_mask)

(tensor([[[[ 0.1842,  0.5891, -0.1506, -0.2712,  0.0022]]]]),
 tensor([[0., 0., 0., 0.]]))

^ and yes, that matches the last row

(Later as I started to copy the code I saw that when Tq is 1 it does in fact skip forming a mask and call sdpa with causal=False)

So if this is the right idea, then I should see somewhere in engine generate() where it passes only the latest token rather than all previous tokens. (Or maybe latest token(s) if it's constructing multiple sequences in parallel say to do a beam search like I talk about [here](https://towardsdatascience.com/tracing-the-transformer-in-diagrams-95dbeb68160c/) and was used in the original 2017 paper.) Let's see...

In the main generation loop in `engine.py` it does this:

```
logits = self.model.forward(ids, kv_cache=kv_cache_decode)
```

But what is `ids`?

```
...some code...

token_column = [] # contains the next token id along each row

...some code...

ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
```

So it does seem like that's what's happening. Each time we send in sequences of length 1. I can come back and confirm  / understand better once I have the engine code implemented.

### Add KV cache support to GPT

I get the concept enough that I should be able to follow and hand copy the kv cache support stuff to `my_gpt.py`

```
@@ -48,7 +48,6 @@ class CausalSelfAttention(nn.Module):
         self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
 
     def forward(self, x, cos_sin, kv_cache=None):
-        assert kv_cache is None # add support for this later
         B, T, C = x.size()
 
         q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
@@ -63,13 +62,36 @@ class CausalSelfAttention(nn.Module):
 
         q, k, v = q.transpose(2,1), k.transpose(2,1), v.transpose(2,1) # (B,T,H,D) -> (B,H,T,D)
 
+        # apply KV cache: insert current k,v into cache and get the full view so far
+        if kv_cache is not None:
+            k, v = kv_cache.insert_kv(self.layer_idx, k, v)
+        Tq = q.size(2) # number of queries in this forward pass (I think will usually be 1)
+        Tk = k.size(2) # number of keys/values in total (in the cache + in this forward pass)
+
         # code related to KV cache goes here
 
         # will understand and add code for GQA later
         assert self.n_head == self.n_kv_head
         enable_gqa = self.n_head != self.n_kv_head # always false for now
 
-        y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
+        # read notes in challenge-20-understand-engine/add-kv-cache-support-to-gpt.ipynb
+        if kv_cache is None or Tq == Tk:
+            y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
+        elif Tq == 1:
+            # believe this is the common case during inference after the initial prompt is processed
+            y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
+        else:
+            attn_mask = torch.zeros((Tq,Tk), dtype=torch.bool, device=q.device) # True = keep
+            prefix_len = Tk - Tq # 2
+            if prefix_len > 0: # he says can't be negative but could be zero but don't think can be 0 due to above
+                attn_mask[:, :prefix_len] = True    # so whatever is in the prefix is allowed
+            attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
+            # A "square" of Trues on the left and a "triangle" of Trues on the right, like
+            # if Tk = 4 and Tq = 2 we'll end up with
+            # True True True False
+            # True True True True
+            y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
+
         y = y.transpose(1,2).contiguous().view(B, T, -1)
         y = self.c_proj(y)
         return y
@@ -193,7 +215,6 @@ class GPT(nn.Module):
 
 
     def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
-        assert kv_cache is None # for now
         
         B, T = idx.size()
 
@@ -201,7 +222,8 @@ class GPT(nn.Module):
         assert idx.device == self.cos.device
         assert self.cos.dtype == torch.bfloat16
 
-        T0 = 0 # TODO T0 = 0 if kv_cache is None else kv_cache.get_pos()
+        # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
+        T0 = 0 if kv_cache is None else kv_cache.get_pos()
         cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
```

Let's try

In [1]:
import os
import sys
sys.path.append('../my_nanochat')
import torch
import torch.nn.functional as F
from my_nanochat.my_common import get_base_dir
from my_nanochat.my_checkpoint_manager import build_model

In [2]:
checkpoint_dir = os.path.join(get_base_dir(), "base_checkpoints", "d4")
model, tokenizer, meta_data = build_model(checkpoint_dir, step=10, device=torch.get_default_device(), phase="eval")

Building model with config: {'sequence_len': 128, 'vocab_size': 65537, 'n_layer': 4, 'n_head': 2, 'n_kv_head': 2, 'n_embd': 256}


In [3]:
prompt_tokens = tokenizer.encode('Hello', prepend=tokenizer.get_bos_token_id())
prompt_tokens

[65536, 28466]

In [4]:
class StupidKVCache:
    
    def __init__(self):
        self.cache = {}
        self.pos = 0
    
    def insert_kv(self, layer_idx, k, v):
        print(f"about to cache layer: {layer_idx}, k.shape: {k.shape}")
        k_key = f"layer_{layer_idx}_k"
        v_key = f"layer_{layer_idx}_v"
        cached_k = self.cache.get(k_key)
        cached_v = self.cache.get(v_key)
        if cached_k is not None:
            k = torch.concat((cached_k, k), dim=2)
            v = torch.concat((cached_v, v), dim=2)
        self.cache[k_key] = k
        self.cache[v_key] = v
        self.pos = k.shape[2]
        print(f"about to return k with shape {k.shape}")
        return k, v

    def get_pos(self):
        return self.pos

kv_cache = StupidKVCache()

In [5]:
logits = model.forward(torch.tensor([prompt_tokens]), kv_cache=kv_cache)
logits

about to cache layer: 0, k.shape: torch.Size([1, 2, 2, 128])
about to return k with shape torch.Size([1, 2, 2, 128])
about to cache layer: 1, k.shape: torch.Size([1, 2, 2, 128])
about to return k with shape torch.Size([1, 2, 2, 128])
about to cache layer: 2, k.shape: torch.Size([1, 2, 2, 128])
about to return k with shape torch.Size([1, 2, 2, 128])
about to cache layer: 3, k.shape: torch.Size([1, 2, 2, 128])
about to return k with shape torch.Size([1, 2, 2, 128])


tensor([[[-1.4920, -1.4920, -1.4920,  ..., -1.4920, -1.4920, -1.4920],
         [-1.6174, -1.6174, -1.6174,  ..., -1.6174, -1.6174, -1.6174]]],
       grad_fn=<MulBackward0>)

In [6]:
next_token = torch.argmax(logits[0,-1]).item(); next_token

668

In [7]:
logits = model.forward(torch.tensor([[next_token]]), kv_cache=kv_cache)
logits

about to cache layer: 0, k.shape: torch.Size([1, 2, 1, 128])
about to return k with shape torch.Size([1, 2, 3, 128])
about to cache layer: 1, k.shape: torch.Size([1, 2, 1, 128])
about to return k with shape torch.Size([1, 2, 3, 128])
about to cache layer: 2, k.shape: torch.Size([1, 2, 1, 128])
about to return k with shape torch.Size([1, 2, 3, 128])
about to cache layer: 3, k.shape: torch.Size([1, 2, 1, 128])
about to return k with shape torch.Size([1, 2, 3, 128])


tensor([[[-2.4710, -2.4710, -2.4710,  ..., -2.4710, -2.4710, -2.4710]]],
       grad_fn=<MulBackward0>)

In [8]:
logits = model.forward(torch.tensor([prompt_tokens + [next_token]]), kv_cache=None)
logits

tensor([[[-1.4920, -1.4920, -1.4920,  ..., -1.4920, -1.4920, -1.4920],
         [-1.6174, -1.6174, -1.6174,  ..., -1.6174, -1.6174, -1.6174],
         [-2.4710, -2.4710, -2.4710,  ..., -2.4710, -2.4710, -2.4710]]],
       grad_fn=<MulBackward0>)

### "Copy" actual `KVCache` class

I'll hand copy the KVCache class from [engine.py](https://github.com/karpathy/nanochat/blob/master/nanochat/engine.py) to  `my_engine.py` (will be first thing in the file)

I'm guessing KVCache will be similar to the StupidKVCache I made above but more efficient, but there could be other things going on too.

In [9]:
# He uses this bit operation "trick" to round up to nearest multiple of 1024
(4234 + 1023) & ~1023

5120

Try it

In [4]:
from my_nanochat.my_engine import KVCache

In [5]:
kv_cache = KVCache(
    batch_size=1,
    num_heads=meta_data['model_config']['n_head'],
    seq_len=100,
    head_dim=meta_data['model_config']['n_embd'] // meta_data['model_config']['n_head'],
    num_layers=meta_data['model_config']['n_layer'])

In [6]:
kv_cache.kv_shape

(4, 2, 1, 2, 100, 128)

In [7]:
logits = model.forward(torch.tensor([prompt_tokens]), kv_cache=kv_cache)
logits

tensor([[[-1.4920, -1.4920, -1.4920,  ..., -1.4920, -1.4920, -1.4920],
         [-1.6174, -1.6174, -1.6174,  ..., -1.6174, -1.6174, -1.6174]]],
       grad_fn=<MulBackward0>)

In [8]:
next_token = torch.argmax(logits[0,-1]).item(); next_token

668

In [9]:
logits = model.forward(torch.tensor([[next_token]]), kv_cache=kv_cache)
logits

tensor([[[-2.4710, -2.4710, -2.4710,  ..., -2.4710, -2.4710, -2.4710]]],
       grad_fn=<MulBackward0>)

In [10]:
logits = model.forward(torch.tensor([prompt_tokens + [next_token]]), kv_cache=None)
logits

tensor([[[-1.4920, -1.4920, -1.4920,  ..., -1.4920, -1.4920, -1.4920],
         [-1.6174, -1.6174, -1.6174,  ..., -1.6174, -1.6174, -1.6174],
         [-2.4710, -2.4710, -2.4710,  ..., -2.4710, -2.4710, -2.4710]]],
       grad_fn=<MulBackward0>)