In [54]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import time

from tqdm import tqdm

import numpy as np
    
from pico.utils import load_encoder_hparams_and_params
from pico.gpt2 import generate, generate_kv

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Resources:

1. https://www.dipkumar.dev/posts/gpt-kvcache/
2. https://github.com/jaymody/picoGPT/pull/7/files

You can also control the number of tokens to generate, the model size (one of `["124M", "355M", "774M", "1558M"]`), and the directory to save the models:

expected_completion = ' the most powerful machines on the planet.\n\nThe computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.\n'

In [64]:
n_tokens_to_generate = 40
tokenizer, hparams, params = load_encoder_hparams_and_params(
    model_size = "774M", #"124M", 
    models_dir = "models"
)

Fetching checkpoint: 1.00kb [00:00, 2.96Mb/s]                                                       
Fetching encoder.json: 1.04Mb [00:00, 2.61Mb/s]                                                     
Fetching hparams.json: 1.00kb [00:00, 3.19Mb/s]                                                     
Fetching model.ckpt.data-00000-of-00001: 3.10Gb [04:55, 10.5Mb/s]                                   
Fetching model.ckpt.index: 16.0kb [00:00, 9.58Mb/s]                                                 
Fetching model.ckpt.meta: 1.38Mb [00:00, 3.38Mb/s]                                                  
Fetching vocab.bpe: 457kb [00:00, 1.66Mb/s]                                                         


In [65]:
prompt = "Alan Turing theorized that computers would one day become"
input_ids = tokenizer.encode(prompt)

start_time = time.time()

# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
# decode the ids back into a string
output_text = tokenizer.decode(output_ids)

# Record the end time
end_time = time.time()

# Calculate the elapsed time
elapsed_time = end_time - start_time
print(f"The process took {elapsed_time} seconds to complete.")

output_text

generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:18<00:00,  2.20it/s]

The process took 18.19883704185486 seconds to complete.





' so powerful that they could be used to solve problems that humans could not.\n\nIn the 1950s, Turing was asked to help develop a computer program that could play chess. He was given a'

In [66]:
prompt = "Alan Turing theorized that computers would one day become"
input_ids = tokenizer.encode(prompt)

start_time = time.time()

# generate output ids
output_ids = generate_kv(input_ids, params, hparams["n_head"], n_tokens_to_generate)
# decode the ids back into a string
output_text = tokenizer.decode(output_ids)

# Record the end time
end_time = time.time()

# Calculate the elapsed time
elapsed_time = end_time - start_time
print(f"The process took {elapsed_time} seconds to complete.")

output_text

generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:49<00:00,  1.24s/it]

The process took 49.79473090171814 seconds to complete.





' so powerful that they could be used to solve problems that humans could not.\n\nIn the 1950s, Turing was asked to help develop a computer program that could play chess. He was given a'

### Background

Before we learn kv-cache, lets first understand the non-kv-cache version of the autoregressive generation.

First, a tokenizer converts our text into a list of token_ids:

In [5]:
# encode the input string using the BytePairEncoding tokenizer
input_ids = tokenizer.encode(prompt)
print(len(input_ids))
input_ids

10


[36235, 39141, 18765, 1143, 326, 9061, 561, 530, 1110, 1716]

### Weights

The weights to this LLM are the

1. Word Positional Encoding (wpe)
2. Word Token Embeddings (wte)

```python
print(params.keys()) # dict_keys(['blocks', 'ln_f', 'wpe', 'wte'])
```

By passing `**params` into `logits = gpt2(inputs, **params, n_head=n_head)` we are just passing this dictionary's values, the weights akak parameters, into the function as arguments using the dictionary keys as the names `wte, wpe, blocks, ln_f`

#### Word Positional Encoding

The Word Positional Encoding (wpe) is used to add a vector that represents a position in time, or order in a sequence, to each token embedding. its `print(type(params['wpe']), params['wpe'].shape)` is `<class 'numpy.ndarray'> (1024, 768)` because we have precalculated for you the first 1024 of these positional embeddings, and our embedding size is 768. In doing `wpe[range(len(inputs))]` we have just selected the first `len(inputs)` embeddings

#### Word Token Embedding

The Word Token Embedding is used to map each token_id (input_ids) to its corresponding vector. `print(type(params['wte']), params['wte'].shape)` is `<class 'numpy.ndarray'> (50257, 768)` because our vocab size is 50257 and our embedding size is 768. In doing `wte[inputs]` we have just mapped our token id list of size 10 to a sequence of embeddings shape (10, 768)

#### Transformer Input Embeddings

 the embedings that go into the first of multiple transformer blocks is the element-wise sum of wte and wpe `x = wte[inputs] + wpe[range(len(inputs))]`

In [68]:
print(params['wpe'][range(len(input_ids))])

[[-1.88207198e-02 -1.97418600e-01  4.02672496e-03 ... -4.30437364e-02
   2.82671917e-02  5.44901080e-02]
 [ 2.39594337e-02 -5.37920333e-02 -9.48786438e-02 ...  3.41700129e-02
   1.01718502e-02 -1.55729489e-04]
 [ 4.21607168e-03 -8.47639143e-02  5.45149297e-02 ...  1.97447110e-02
   1.93248559e-02 -2.14238558e-02]
 ...
 [ 2.53077131e-03 -3.17870919e-03  1.17414258e-01 ...  2.00962462e-03
   4.41795774e-03 -6.83258474e-03]
 [-1.23805739e-03 -1.77337788e-03  1.11044556e-01 ... -2.30074697e-03
   4.15364839e-03 -1.04475096e-02]
 [ 4.93714586e-03  2.14576256e-03  1.17781341e-01 ... -2.82027118e-04
   4.07085707e-03 -5.54985739e-03]]


In [69]:
print(params['wte'][input_ids])

[[ 0.04486499 -0.1522257   0.10908855 ...  0.16187134  0.00406003
  -0.01259668]
 [-0.1435177  -0.1303647  -0.00709237 ... -0.26905674 -0.21710931
  -0.27703205]
 [-0.14161602 -0.06058507  0.05428597 ...  0.16568261  0.1750053
   0.08499283]
 ...
 [ 0.00818344  0.03351058  0.03436588 ...  0.15731247  0.06635052
  -0.08678364]
 [-0.1378994  -0.02936367 -0.00255402 ... -0.09662744 -0.07259481
   0.11599892]
 [ 0.06102467 -0.072351    0.01882253 ... -0.24272189  0.23248099
   0.12684126]]


In [70]:
wte = params['wte']
wpe = params['wpe']
x = wte[input_ids] + wpe[range(len(input_ids))]

### Transformer Block

The blocks are a list of repeating transformer blocks `type(params['blocks']) # list` where each block `params['blocks'][0].keys()` consists of ` dict_keys(['attn', 'ln_1', 'ln_2', 'mlp'])`.

```python

def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):  # [n_seq, n_embd] -> [n_seq, n_embd]
    
    # multi-head causal self attention
    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)  # [n_seq, n_embd] -> [n_seq, n_embd]

    # position-wise feed forward network
    x = x + ffn(layer_norm(x, **ln_2), **mlp)  # [n_seq, n_embd] -> [n_seq, n_embd]

    return x
```

#### layer norm

The layer_norm weights `params['blocks'][0]['ln_1'].keys()` consist of a gamma and beta params`dict_keys(['b', 'g'])` which are also called the scale and offset weights because g multiples each element by a factor and be shifts the entire vector `g * x + b` , both  `g` and `b` have the same shape `(768,)`

#### multi-layer-perceptron (mlp) aka feed forward net (ffn) 

This is covered in most basic machine learning classes, so it should suffice that in NumPy, the `@` symbol is used as the matrix multiplication operator, that `ffn` has the same input and output shape and that this is the implementation:

```python

def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
    
def linear(x, w, b):  # [m, in], [in, out], [out] -> [m, out]
    return x @ w + b

def ffn(x, c_fc, c_proj):  # [n_seq, n_embd] -> [n_seq, n_embd]
    # project up
    a = gelu(linear(x, **c_fc))  # [n_seq, n_embd] -> [n_seq, 4*n_embd]
    # project back down
    x = linear(a, **c_proj)  # [n_seq, 4*n_embd] -> [n_seq, n_embd]
    return x
```

Both the layer norm, the ffn and multi headed attention (mha) and the overall transformer block have the same input and output shape

#### causal mask

```python
# causal mask to hide future inputs from being attended to
# [n_seq, n_seq]
causal_mask = (1 - np.tri(3, dtype=x.dtype)) * -1e10  
causal_mask
```

```
array([[-0.e+00, -1.e+10, -1.e+10],
       [-0.e+00, -0.e+00, -1.e+10],
       [-0.e+00, -0.e+00, -0.e+00]], dtype=float32)
```

The very negative values cause these positions to have an attention score of nearly 0 after the row-wise softmax is applied.
Causing no attention weight to be placed on future tokens

```
[[ 0, -1000,   -1000],
 [ 0,     0,   -1000],
 [ 0,     0,       0]]
```
#### Attention (Scaled Dot Product QKV attention)

Here is a moving diagram of Scaled Dot Product QKV attention as a tranformer starts from 1 token and each of the next 3 tokens it generates attends over the previous positions. The grey squares represent the causal mask and though not shown in the diagram, a softmax is applied to `QK^T` before it is matrix multiplied with V to produce A. Watch the diagram evolve, notice that at each step, the upper left square of the `QK^T` matrix is recomputed at every step. Only the bottom row and right column are new. Not only that, notice that because of causal masking, only the bottom row is both new and also need, for matrix multiplication with V to find the next A vector. kv-caching has to do with improving the efficiency for this attention step. 

<img src="samples/QKV_scaled_dot_prod_attn.gif" height = 500 width = 1000 >

```python
# Q, K, V -> A
def attention(Q, K, V, mask): 
    
    # [n_seq_q, n_embd], [n_seq_k, n_embd], [n_seq_k, n_embd], [n_seq_q, n_seq_k] -> [n_seq_q, n_embd]
    
    QK_T = Q @ K.T
    
    A = softmax(QK_T / np.sqrt(Q.shape[-1]) + mask) @ V
    
    return A
```

#### Multi Headed Attention (mha)

multi-headed attention is instead of applying the attention function to Q K V, chopping Q, K V
into multiple segments and applying attention between those corresponding segments, then concatenating the result

```python
# [n_seq, n_embd] -> [n_seq, n_embd]
def mha(x, c_attn, c_proj, n_head):  
    
    # qkv projection
    # [n_seq, n_embd] -> [n_seq, 3*n_embd]
    x = linear(x, **c_attn)  

    # split into qkv
    # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
    qkv = np.split(x, 3, axis=-1)  

    # split into heads
    # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))  

    # causal mask to hide future inputs from being attended to
    # [n_seq, n_seq]
    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10  

    # perform attention over each head
    # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]  
    
    # merge heads
    # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
    x = np.hstack(out_heads)  

    # out projection
    # [n_seq, n_embd] -> [n_seq, n_embd]
    x = linear(x, **c_proj)  

    return x
```

In [102]:
# code run thru of multi-headed attention

ln_1 = params['blocks'][0]['ln_1']
attn = params['blocks'][0]['attn']
n_head = hparams['n_head']
c_attn = attn['c_attn']
c_proj = attn['c_proj']

x_ln = layer_norm(x, **ln_1) # x thanks been layer normed

qkv_proj = linear(x_ln, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
print("qkv_proj.shape",qkv_proj.shape) # (10, 2304), 768 x 3 = 2304
print(qkv_proj[0,:4])

qkv = np.split(qkv_proj, 3, axis=-1) # [n_seq, 3*n_embd] -> List[3, (n_seq, n_embd)]
print("len(qkv),(qkv[0].shape)",len(qkv),(qkv[0].shape)) # list of each head's qkv projection 

# split into heads
# [3, n_seq, n_embd] -> List[3, List[n_head (n_seq, n_embd/n_head)]]
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))  
print("len(qkv_heads), len(qkv_heads[0]), qkv_heads[0][0].shape", len(qkv_heads), len(qkv_heads[0]), qkv_heads[0][0].shape)
print(qkv_heads[0][0][0,:4])

# causal mask to hide future inputs from being attended to
causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10  # [n_seq, n_seq]

# perform attention over each head
# List[3, List[n_head (n_seq, n_embd/n_head)]]] -> List[n_head, (n_seq, n_embd/n_head)]
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] 

print("len(out_heads), out_heads[0].shape", len(out_heads), out_heads[0].shape)

# merge heads
# List[n_head, (n_seq, n_embd/n_head)] -> [n_seq, n_embd]
x_out = np.hstack(out_heads)  # stack horizontally, meaning preserve the

# out projection
# [n_seq, n_embd] -> [n_seq, n_embd]
x_out = linear(x_out, **c_proj)  
print(x_out.shape)

qkv_proj.shape (10, 2304)
[ 0.01042262  0.22827548 -0.7129095  -0.9784453 ]
len(qkv),(qkv[0].shape) 3 (10, 768)
len(qkv_heads), len(qkv_heads[0]), qkv_heads[0][0].shape 3 12 (10, 64)
[ 0.01042262  0.22827548 -0.7129095  -0.9784453 ]
len(out_heads), out_heads[0].shape 12 (10, 64)
(10, 768)


# KV Caching

The moving diagram below compares non-kv-caching on top with kv-caching on the bottom row as the transformer starts with 1 token and generate 3 more autoregressively step wise. With each new step, we only calculate the QKV projection for the most recent token. As we discussed previously, only the bottom row of the `QK^T` attention matrix is used at each step. To compute the next bottom row, you dont actually need the previous Q projections, so thats why this isnt called a QKV cache. You do still need all the previous K projections and V projections, also `QK^T` is no longer a square attention matrix but rather a new sequence_length sized vector at each step that represents the newest V projection's attention on all previous V projections.

<img src="samples/KV_cache.gif">

The benefit is that instead of doing a (n_seq x emb_dim) x (emb_dim x n_seq) -> O(n_seq^2 x emb_dim) of compute, you now are doing
(emb_dim) x (emb_dim x n_seq) -> O(n_seq x emb_dim) of compute. The tradeoff is that we now need to keep a growing Key and Value states in GPU VRAM or CPU RAM. Lastly notice that we dont have to change the scaled dot product QKV `attention()` function, we just need to pass a new shape Q and mask, both vectors instead of matrices, into the same function.

In [45]:
# [n_seq, n_embd] -> [n_seq, n_embd]
def mha(x, c_attn, c_proj, n_head, kv_cache=None):  

    """ with the KV cache strategy, we will only be passing the last token
    into mha, so n_seq = 1 in the qkv projection and split steps
    """
    
    # qkv projection
    # [n_seq = 1, n_embd] -> [n_seq = 1, 3*n_embd]
    x = linear(x, **c_attn)  

    # split into qkv
    # [n_seq = 1, 3*n_embd] -> [3, n_seq = 1, n_embd]
    qkv = np.split(x, 3, axis=-1)  

    if kv_cache:

        # these are all vectors
        new_q, new_k, new_v = qkv  # new_q, new_k, new_v = [1, n_embd]
        
        # append new_k and new_v to the old_k and old_v before multiplying with new_q
        old_k, old_v = kv_cache
        k = np.vstack([old_k, new_k]) # k shaped (n_seq, n_embd), where n_seq = prev_n_seq + 1
        v = np.vstack([old_v, new_v]) # v shaped (n_seq, n_embd), where n_seq = prev_n_seq + 1
        qkv = [new_q, k, v] # new_q is a vector, k and v are matrices

        # if kvcache, we passing a single token as input which need to attend to all previous tokens
        # so we create vector shaped empty mask with all 0s the shape of n_seq
        causal_mask = np.zeros((1, k.shape[0]))
        
    else:
        # create triangular causal mask to hide future inputs from being attended to
        causal_mask = (1 - np.tri(x.shape[0])) * -1e10  # [n_seq, n_seq]

    current_cache = [qkv[1], qkv[2]] # store k and v in the cache

    # split into heads
    # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv))  

    # perform attention over each head
    # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]  
    
    # merge heads
    # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
    x = np.hstack(out_heads)  

    # out projection
    # [n_seq, n_embd] -> [n_seq, n_embd]
    x = linear(x, **c_proj)  

    # we pass the updated_cache along to the next timestep
    return x, current_cache

In [46]:
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head, kv_cache=None):  # [n_seq, n_embd] -> [n_seq, n_embd]

    # multi-head causal self attention
    # [n_seq, n_embd] -> [n_seq, n_embd]
    attn_out, kv_cache_updated = mha(layer_norm(x, **ln_1), **attn, n_head=n_head, kv_cache=kv_cache)
    x = x + attn_out  # [n_seq, n_embd] -> [n_seq, n_embd]

    # position-wise feed forward network
    x = x + ffn(layer_norm(x, **ln_2), **mlp)  # [n_seq, n_embd] -> [n_seq, n_embd]

    return x, kv_cache_updated

### projection to vocab space

it makes alot of sense to use `x @ wte.T` to project your sequence of embeddings back into vocab space because your logits will be proportional to the dot product between transformer output x and the token embedding wte. For example if x was a sequence of embeddings most similar to the mebeddings for "this is a cat", in word embeddings, then you would expect the largest logits to be in the token indices for the word "this", "is", "a" and "cat".

In [47]:
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, kv_cache = None):  # [n_seq] -> [n_seq, n_vocab]
    
    if not kv_cache:
        kv_cache = [None]*len(blocks)
        wpe_out = wpe[range(len(inputs))]
    else:
        wpe_out = wpe[[len(inputs)-1]]
        inputs = [inputs[-1]]

    x = wte[inputs] + wpe_out  # [n_seq] -> [n_seq, n_embd]

    layerwise_kv_cache = []
    for block, kv_cache_block in zip(blocks, kv_cache):
        x, kv_cache_updated = transformer_block(x, **block, n_head=n_head, kv_cache=kv_cache_block)  # [n_seq, n_embd] -> [n_seq, n_embd]

        # TODO: inplace extend new cache instead of re-saving whole layerwise_kv_cache from kv_cache
        layerwise_kv_cache.append(kv_cache_updated)  

    # projection to vocab
    x = layer_norm(x, **ln_f)  # [n_seq, n_embd] -> [n_seq, n_embd]

    logits = x @ wte.T  # [n_seq, n_embd] -> [n_seq, n_vocab]
    
    return logits, layerwise_kv_cache 

In [48]:
def generate(inputs, params, n_head, n_tokens_to_generate):

    kvcache = None
    for _ in tqdm(range(n_tokens_to_generate), "generating"):  # auto-regressive decode loop
        logits, kvcache = gpt2(inputs, **params, n_head=n_head, kv_cache=kvcache)  # model forward pass
        next_id = np.argmax(logits[-1])  # greedy sampling
        inputs.append(int(next_id))  # append prediction to input

    return inputs[len(inputs) - n_tokens_to_generate :]  # only return generated ids

In [49]:
prompt = "Alan Turing theorized that computers would one day become"
input_ids = tokenizer.encode(prompt)
# make sure we are not surpassing the max sequence length of our model
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
expected_completion = ' the most powerful machines on the planet.\n\nThe computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.\n'

In [50]:
start_time = time.time()

# generate output ids
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
# decode the ids back into a string
output_text = tokenizer.decode(output_ids)

# Record the end time
end_time = time.time()

# Calculate the elapsed time
elapsed_time = end_time - start_time
print(f"The process took {elapsed_time} seconds to complete.")

output_text

generating: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:12<00:00,  3.27it/s]

The process took 12.223795890808105 seconds to complete.





' the most powerful machines on the planet.\n\nThe computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.\n'