# miniFrost

I am hoping to build a mini GPT model alogn the lines of Karpathy's nanoGPT tutorial [here](https://www.youtube.com/watch?v=kCc8FmEb1nY). Instead of Shakespeare I will attempt to do Robert Frost's poems because they are more evocative to me.

In [1]:
from google.colab import files
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
!pip install tiktoken
import tiktoken
import math

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tiktoken
  Downloading tiktoken-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m41.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.3.3


Read the data from the file and parse it. I am gonna print the first 5 lines.

In [2]:
poem_collection = pd.read_csv("robert_frost_collection.csv")
print(poem_collection.head())

                                        Name   \
0                                         NaN   
1        Stopping by Woods on a Snowy Evening   
2                                Fire and Ice   
3                            The Aim was Song   
4  The Need of Being Versed in Country Things   

                                             Content     Collection  \
0                                                NaN            NaN   
1  Whose woods these are I think I know.   \nHis ...  New Hampshire   
2  Some say the world will end in fire,\nSome say...  New Hampshire   
3  Before man came to blow it right\nThe wind onc...  New Hampshire   
4  The house had gone to bring again\nTo the midn...  New Hampshire   

   Year of Publication  
0                  NaN  
1               1923.0  
2               1923.0  
3               1923.0  
4               1923.0  


In [3]:
all_text = ""

# Clean the NaN value
poem_collection = poem_collection.drop(labels=0, axis=0)
for i, poem in enumerate(poem_collection['Content']):
  all_text = "\n".join([all_text, poem])

print("Length of the text: ", len(all_text))

fileout = open("poems.txt", "w")
fileout.write(all_text)
fileout.close()

Length of the text:  221343


In [4]:
# Extract all the unique characters that are in the text
uniq_chars = sorted(list(set(all_text)))
print(''.join(uniq_chars))

VOCAB_SIZE = len(uniq_chars)


 !"'(),-.123458:;>?ABCDEFGHIJKLMNOPQRSTUVWY_abcdefghijklmnopqrstuvwxyz­·æèê–—‘’“”…


### Encoding and Decoding Functions
These functions will be used to encode and decode the string to a list of integers. I will be using OpenAI's tiktoken library that uses sub-words. It will be shorter than encoding every possible character as its ASCII value.

In [5]:
enc = tiktoken.get_encoding("gpt2")

STOI = { ch:i for i,ch in enumerate(uniq_chars)}
ITOS = { i:ch for i,ch in enumerate(uniq_chars)}
encode = lambda s: [STOI[c] for c in s]
decode = lambda l: ''.join([ITOS[i] for i in l])

### Build a Pytorch Tensor from the Encoded Text

In [6]:
encoded_text = encode(all_text)
data = torch.tensor(encoded_text)
print(data.shape, data.dtype)
print(data[:500])

torch.Size([221343]) torch.int64
tensor([ 0, 43, 53, 60, 64, 50,  2, 68, 60, 60, 49, 64,  2, 65, 53, 50, 64, 50,
         2, 46, 63, 50,  2, 29,  2, 65, 53, 54, 59, 56,  2, 29,  2, 56, 59, 60,
        68, 10,  2,  2,  2,  0, 28, 54, 64,  2, 53, 60, 66, 64, 50,  2, 54, 64,
         2, 54, 59,  2, 65, 53, 50,  2, 67, 54, 57, 57, 46, 52, 50,  2, 65, 53,
        60, 66, 52, 53, 18,  2,  2,  2,  0, 28, 50,  2, 68, 54, 57, 57,  2, 59,
        60, 65,  2, 64, 50, 50,  2, 58, 50,  2, 64, 65, 60, 61, 61, 54, 59, 52,
         2, 53, 50, 63, 50,  2,  2,  2,  0, 40, 60,  2, 68, 46, 65, 48, 53,  2,
        53, 54, 64,  2, 68, 60, 60, 49, 64,  2, 51, 54, 57, 57,  2, 66, 61,  2,
        68, 54, 65, 53,  2, 64, 59, 60, 68, 10,  2,  2,  2,  0, 33, 70,  2, 57,
        54, 65, 65, 57, 50,  2, 53, 60, 63, 64, 50,  2, 58, 66, 64, 65,  2, 65,
        53, 54, 59, 56,  2, 54, 65,  2, 62, 66, 50, 50, 63,  2,  2,  2,  0, 40,
        60,  2, 64, 65, 60, 61,  2, 68, 54, 65, 53, 60, 66, 65,  2, 46,  2, 51,
       

### Training and Testing Split

At this point, we will have to decide on the training-testing split for the model. The tutorial says a 90-10 split should be a good enough one. 

**Remember that we can alter this later on and see how "accurately" it can generate text as per our needs.**


In [7]:
TRAINING_PORTION = 0.9
n = math.ceil(TRAINING_PORTION * len(data))

training_data = data[:n]
testing_data = data[n:]

### Context and Target

Now for a transformer, we need to chunk data in batches and feed it in with a context and the target output that the context "implies".
This is how the model learns. It sees all the context for that batch and sees all the targets and accordingly learns to predict. 

In [8]:
BLOCK_SIZE = 256

# An example here shows the context and target in actions 
context = training_data[:BLOCK_SIZE]
target = training_data[1:BLOCK_SIZE + 1]
for i in range(len(context)):
    print("The context is ", context[:i+1], " and the target is ", target[i])

The context is  tensor([0])  and the target is  tensor(43)
The context is  tensor([ 0, 43])  and the target is  tensor(53)
The context is  tensor([ 0, 43, 53])  and the target is  tensor(60)
The context is  tensor([ 0, 43, 53, 60])  and the target is  tensor(64)
The context is  tensor([ 0, 43, 53, 60, 64])  and the target is  tensor(50)
The context is  tensor([ 0, 43, 53, 60, 64, 50])  and the target is  tensor(2)
The context is  tensor([ 0, 43, 53, 60, 64, 50,  2])  and the target is  tensor(68)
The context is  tensor([ 0, 43, 53, 60, 64, 50,  2, 68])  and the target is  tensor(60)
The context is  tensor([ 0, 43, 53, 60, 64, 50,  2, 68, 60])  and the target is  tensor(60)
The context is  tensor([ 0, 43, 53, 60, 64, 50,  2, 68, 60, 60])  and the target is  tensor(49)
The context is  tensor([ 0, 43, 53, 60, 64, 50,  2, 68, 60, 60, 49])  and the target is  tensor(64)
The context is  tensor([ 0, 43, 53, 60, 64, 50,  2, 68, 60, 60, 49, 64])  and the target is  tensor(2)
The context is  ten

### Getting Batches 

Now, what we want is to sample random batches from the text, get their context and their target and then build a stack out of them. Since our batch size is 8, we will have 8 columns.
We will set the no. of rows in the stack to 4. Pytorch will parallelize this process and *that's what makes transformers so good. The power of efficiency.*

**Extracting Batches**
The function `get_batch` will be used to either extract 4 blocks of size 8 and put them onto a stack togther. 2 [4x8] stacks will be returned. One being the context and the other being the target. 

In [9]:
BATCH_SIZE = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(1337)

# get_batch will extract from either training or testing depending 
# on the value of the split_type ('train' or 'test')
def get_batch(split_type):
  data = training_data if split_type == "train" else testing_data

  # ix essentially says: find $batch_size (4) random offsets and then 
  # extract $block_size (8) length list after and including it.
  ix = torch.randint((len(data) - BLOCK_SIZE), (BATCH_SIZE, ))

  # Assemble the stacks: context (cx), target (tg)
  cx = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
  tg = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
  x, y = cx.to(DEVICE), tg.to(DEVICE)
  return x,y

In [10]:
# Sampling the 
xd, yd = get_batch('train')

print(xd.shape)
print(xd)
print("---")
print(yd.shape)
print(yd)

for batch in range(BATCH_SIZE):
    for time in range(BLOCK_SIZE):
        context = xd[batch, :time+1]
        target = xd[batch, time]
        print(f'when the input is {context.tolist()} the expected out is {target.tolist()}')
    print()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
when the input is [46, 65, 3, 0, 28, 50, 80, 64, 2, 48, 60, 58, 54, 59, 52, 2, 65, 60, 68, 46, 63, 49, 64, 2, 66, 64, 10, 2, 30, 60, 50, 57, 8, 2, 52, 60, 2, 54, 59, 9, 61, 57, 50, 46, 64, 50, 10, 0, 28, 46, 63, 56, 3, 9, 29, 2, 49, 60, 59, 80, 65, 2, 53, 50, 46, 63, 2, 53, 54, 58, 2, 59, 60, 68, 10, 2, 22, 66, 65, 2, 61, 57, 50, 46, 64, 50, 2, 52, 60, 2, 54, 59, 10, 80, 0, 79, 29, 59, 2, 65, 53, 50, 2, 51, 54, 63, 64, 65, 2, 61, 57, 46, 48, 50, 2, 70, 60, 66, 2, 48, 46, 59, 80, 65, 2, 58, 46, 56, 50, 2, 58, 50, 2, 47, 50, 57, 54, 50, 67, 50, 2, 54, 65, 80, 64, 77, 79, 0, 79, 29, 65, 2, 54, 64, 9, 60, 63, 2, 64, 60, 58, 50, 60, 59, 50, 2, 50, 57, 64, 50, 2, 53, 50, 80, 64, 2, 64, 50] the expected out is 50
when the input is [46, 65, 3, 0, 28, 50, 80, 64, 2, 48, 60, 58, 54, 59, 52, 2, 65, 60, 68, 46, 63, 49, 64, 2, 66, 64, 10, 2, 30, 60, 50, 57, 8, 2, 52, 60, 2, 54, 59, 9, 61, 57, 50, 46, 64, 50, 10, 0, 28, 46, 63, 56, 3, 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
when the input is [60, 66, 64, 50, 51, 66, 57, 2, 60, 51, 2, 53, 66, 59, 52, 63, 70, 2, 58, 50, 59, 2, 65, 60, 2, 51, 50, 50, 49, 0, 29, 2, 52, 66, 50, 64, 64, 2, 70, 60, 66, 80, 49, 2, 51, 54, 59, 49, 83, 10, 2, 29, 65, 2, 64, 50, 50, 58, 64, 2, 65, 60, 2, 58, 50, 0, 29, 2] the expected out is 2
when the input is [60, 66, 64, 50, 51, 66, 57, 2, 60, 51, 2, 53, 66, 59, 52, 63, 70, 2, 58, 50, 59, 2, 65, 60, 2, 51, 50, 50, 49, 0, 29, 2, 52, 66, 50, 64, 64, 2, 70, 60, 66, 80, 49, 2, 51, 54, 59, 49, 83, 10, 2, 29, 65, 2, 64, 50, 50, 58, 64, 2, 65, 60, 2, 58, 50, 0, 29, 2, 48] the expected out is 48
when the input is [60, 66, 64, 50, 51, 66, 57, 2, 60, 51, 2, 53, 66, 59, 52, 63, 70, 2, 58, 50, 59, 2, 65, 60, 2, 51, 50, 50, 49, 0, 29, 2, 52, 66, 50, 64, 64, 2, 70, 60, 66, 80, 49, 2, 51, 54, 59, 49, 83, 10, 2, 29, 65, 2, 64, 50, 50, 58, 64, 2, 65, 60, 2, 58, 50, 0, 29, 2, 48, 46] the expected out is 46
when the input is [60, 66, 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [11]:
EVAL_ITERS = 200

@torch.no_grad()
def estimate_loss(model_est):
    out = {}
    model_est.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split)
            logits, loss = model_est(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model_est.train()
    return out

## Bigram Language Model

The simplest language model you can find. It literally just does word prediction based on the last word. Read more about n-gram models [here](https://towardsdatascience.com/introduction-to-language-models-n-gram-e323081503d9)

In [12]:
N_EMBED = 384
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [13]:
model = BigramLanguageModel(VOCAB_SIZE)
m = model.to(DEVICE)
logits, loss = m(xd, yd)
print(logits.shape)
print(loss)

torch.Size([16384, 84])
tensor(4.7790, grad_fn=<NllLossBackward0>)


Now that we have a generation function, we can try generating some data. Of course it will be random data because we haven't trained our model but it will be useful. 

In [14]:
idx = torch.zeros((BLOCK_SIZE,BATCH_SIZE), dtype = torch.long)
print(idx)
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

BRb''”!"A)C5rK·wa"JE)FfezS’nU’qUD–P’LE·Rx
wYdI.,;YTdRkHNoEr1Ld_)ê(aC"Ax?
hè…—NE?—Kc_…nYsf5Y
’ld"M‘OPt,b;o-uLcMFF
Uc2Cb”VsI
A‘O!U'b·RUwj
V1.,zps”JLIQêhBkHw‘mQzY–z­_NC',“ME
Uæk"­v8m8pz(vC


## Optimization 

We can now start training the model and prompting it to minimizing the loss function. We will use the AdamW optimzer from PyTorch. The learning rate ca be set to much higher for 

In [15]:
# create an optimizer
LEARNING_RATE=3e-4

optimizer = torch.optim.AdamW(m.parameters(), lr=LEARNING_RATE)

Now in batches we can train the model to reduce the loss. 

In [28]:
MAX_ITERS = 3000
EVAL_INTERVAL = 100
for iter in range(MAX_ITERS):

    # every once in a while evaluate the loss on train and val sets
    if iter % EVAL_INTERVAL == 0:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

step 0: train loss 3.6277, val loss 3.6433
step 100: train loss 3.5993, val loss 3.6125
step 200: train loss 3.5692, val loss 3.5843
step 300: train loss 3.5407, val loss 3.5559
step 400: train loss 3.5127, val loss 3.5261
step 500: train loss 3.4855, val loss 3.4995
step 600: train loss 3.4582, val loss 3.4717
step 700: train loss 3.4312, val loss 3.4457
step 800: train loss 3.4044, val loss 3.4176
step 900: train loss 3.3796, val loss 3.3921
step 1000: train loss 3.3537, val loss 3.3669
step 1100: train loss 3.3308, val loss 3.3425
step 1200: train loss 3.3053, val loss 3.3187
step 1300: train loss 3.2809, val loss 3.2945
step 1400: train loss 3.2586, val loss 3.2713
step 1500: train loss 3.2358, val loss 3.2488
step 1600: train loss 3.2133, val loss 3.2255
step 1700: train loss 3.1927, val loss 3.2037
step 1800: train loss 3.1723, val loss 3.1826
step 1900: train loss 3.1516, val loss 3.1620
step 2000: train loss 3.1298, val loss 3.1413
step 2100: train loss 3.1112, val loss 3.1209


Now lets try outputing again

In [17]:
idx = torch.zeros((1,1), dtype = torch.long)
sequence_gen = m.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(sequence_gen))


‘“èE”­Do‘.AYo  z?’d,’?”
_rnAn_   yMrMitvxM‘WL’l1æ—airæ
8seu_oKd’m1æ_w4:xêmi(T‘apydDOAkNYenU·5è…2"DJ—


## Self-Attention and Transformer Networks

Now we want the batches to talk to each other. So if I am at batch i, i want to gain context and loss information of the predications of **all the previous batches** because I am trying to improve the predictions for this one.

I **cannot look at all the predictions in the future batches** because I want to predict the future. 

In [18]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch = 4, time = 8 and channel = 2
x = torch.randn(B,T,C)
x.shape
print(x)

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])


Now lets calculate the mean logits and use its values for each subsequent time column.

In [19]:
weights = torch.tril(torch.ones(T,T))
print(weights)
weights = weights / weights.sum(1, keepdim=True)
print(weights)

# This is valid matrix multiplication becasue PyTorch will automatically convert
# the weights (T x T) matrix to a (B x T x T) so it can multiply with a 
# (B x T x C) matrix 
xbow_2 = weights @ x
print(xbow_2[0])

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.14

In [20]:
tril = torch.tril(torch.ones(T,T))
weights = torch.zeros(T, T)

# This essentially says that token in the future cannot communicate with a token
# in the past. We can't have a future token interacting with the past for the 
# reasons mentioned before. 
weights = weights.masked_fill(tril==0, float('-inf'))
weights = F.softmax(weights, dim=-1)

print(weights)
xbow3 = weights @ x
print(xbow3)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],

## Self-Attention 

The value v stores "here's what I will communicate to you if there is a key that satsifies my query" for a single head. 

In [21]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# A single Head performs self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
weights =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T,T))
weights = weights.masked_fill(tril==0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = value(x)
out = weights @ v

print(out.shape)
tril

torch.Size([4, 8, 16])


tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [22]:
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Look at the 0.2391 in the last row of the above matrix. It is the 8th token. It knows its position via the position embedding table and the its value. Then it makes a query - like im looking for <> characters. 
Every node gets to emit a key and the query and key that dot product the highest indicate that they match well. 

# Final Build 

This build accounts for the Multiple Heads of Attention, Residual Connections and the LayerNorm but applied before the FeedForward step. 

In [35]:
N_HEAD = 6
N_LAYER = 6
DROPOUT = 0.2

class LayerNorm1d: # (used to be BatchNorm1d)
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]


class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(N_EMBED, head_size, bias=False)
        self.query = nn.Linear(N_EMBED, head_size, bias=False)
        self.value = nn.Linear(N_EMBED, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))

        # self.dropout = nn.Dropout(DROPOUT)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        # wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, N_EMBED)
        # self.dropout = nn.Dropout(DROPOUT)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, N_EMBED)
        self.position_embedding_table = nn.Embedding(BLOCK_SIZE, N_EMBED)
        self.blocks = nn.Sequential(*[Block(N_EMBED, n_head=N_HEAD) for _ in range(N_LAYER)])
        self.ln_f = nn.LayerNorm(N_EMBED) # final layer norm
        self.lm_head = nn.Linear(N_EMBED, VOCAB_SIZE)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=DEVICE)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -BLOCK_SIZE:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [36]:
xd, yd = get_batch('train')
sa_model = GPTLanguageModel()
sa_m = sa_model.to(DEVICE)
logits, loss = sa_m(xd, yd)
print(logits.shape)
print(loss)

torch.Size([16384, 84])
tensor(4.5747, grad_fn=<NllLossBackward0>)


Define the same optimizer again


In [37]:
sa_optimizer = torch.optim.AdamW(sa_model.parameters(), lr=LEARNING_RATE)

In [38]:
@torch.no_grad()
def sa_estimate_loss():
    out = {}
    sa_model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split)
            logits, loss = sa_model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    sa_model.train()
    return out

In [None]:
MAX_ITERS_SA = 5100
for iter in range(MAX_ITERS_SA):

    # every once in a while evaluate the loss on train and val sets
    if iter % EVAL_INTERVAL == 0 or iter == MAX_ITERS_SA - 1:
        losses = sa_estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = sa_model(xb, yb)
    sa_optimizer.zero_grad(set_to_none=True)
    loss.backward()
    sa_optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode(sa_m.generate(context, max_new_tokens=500)[0].tolist()))