# Building a Generatively Pretrained Transformer (GPT)

Following Andrej Karpathy's [tutorial](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=1044s) and Attention Is All You Need [paper](https://arxiv.org/pdf/1706.03762). 

## Data Processing, Encoder Decoder, Context

In [3]:
# get training text. 
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print(text[:300])
print(f'{len(text)=}')

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us
len(text)=1125396


In [4]:
# all the unique chracters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(f'{vocab_size=}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab_size=65


In [5]:
# create a mapping from characters to integers
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

# encode text
def encode(s):
    return [stoi[c] for c in s]

# decode list of integers
def decode(l):
    return ''.join([itos[i] for i in l])

print(decode(encode("hello")))

hello


In [15]:
# encode text dataset and store as a tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(decode(data[:300].tolist()))

torch.Size([1125396]) torch.int64
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [7]:
# split data into train and validation sets
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [8]:
# we train the transformer on random chunks of the text. 
# this chunk is known as the context length
# we do this in parallel in different batches

batch_size = 4 # sequences processed in parallel
context_size = 8 # maximum context length

def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - context_size, (batch_size,))
    x = torch.stack([data[i:i+context_size] for i in ix])
    y = torch.stack([data[i+1:i+context_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print(f'Input Shape: {xb.shape=}')
print(f'Target Shape: {yb.shape=}')
print(f'Inputs: {xb=}')
print(f'Targets: {yb=}')

print('===============')

for b in range(batch_size):
    for t in range(context_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'When input is {context.tolist()} the target is {target}')

Input Shape: xb.shape=torch.Size([4, 8])
Target Shape: yb.shape=torch.Size([4, 8])
Inputs: xb=tensor([[58,  8,  0, 16, 53,  1, 63, 53],
        [46, 43, 39, 56,  1, 51, 43,  6],
        [ 1, 51, 63,  1, 51, 39, 57, 58],
        [ 1, 39, 57,  1, 58, 46, 53, 59]])
Targets: yb=tensor([[ 8,  0, 16, 53,  1, 63, 53, 59],
        [43, 39, 56,  1, 51, 43,  6,  1],
        [51, 63,  1, 51, 39, 57, 58, 43],
        [39, 57,  1, 58, 46, 53, 59,  1]])
When input is [58] the target is 8
When input is [58, 8] the target is 0
When input is [58, 8, 0] the target is 16
When input is [58, 8, 0, 16] the target is 53
When input is [58, 8, 0, 16, 53] the target is 1
When input is [58, 8, 0, 16, 53, 1] the target is 63
When input is [58, 8, 0, 16, 53, 1, 63] the target is 53
When input is [58, 8, 0, 16, 53, 1, 63, 53] the target is 59
When input is [46] the target is 43
When input is [46, 43] the target is 39
When input is [46, 43, 39] the target is 56
When input is [46, 43, 39, 56] the target is 1
When inp

## Bigram Language Model

We implement a simple bigram language model as a baseline. 

In [9]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# embedding table, essentially a weight matrix
W = nn.Embedding(vocab_size, vocab_size)
print(W(torch.tensor(1)))

# cross entropy loss
# using nn.CrossEntropyLoss() as opposed F.cross_entropy() as latter is a functional

loss = nn.CrossEntropyLoss()
input = torch.randn(1, 4, requires_grad=True)
target = torch.randn(1, 4)
print(loss(input, target))

# torch.cat()
a = torch.arange(10).view(5, 2)
b = torch.arange(10).view(5, 2)
print(a)
print(b)
torch.cat((a, b), dim=1)

tensor([-8.9222e-01, -4.5212e-01,  5.3472e-01, -5.2330e-01, -1.2478e+00,
         2.9198e-01, -2.6053e-02, -1.2701e-01, -1.9188e+00, -9.1911e-03,
         7.3610e-01, -7.9553e-01, -7.4883e-01, -2.4543e-02, -1.0343e+00,
         1.4419e-01,  7.0602e-01, -2.3258e-01,  1.7613e-01,  6.4299e-01,
        -1.1068e+00, -5.7910e-01, -1.4749e+00, -1.4360e+00, -2.3925e-01,
         6.4835e-01,  2.0337e+00, -1.3335e+00,  3.6340e-01,  9.9667e-01,
        -5.9896e-01,  2.7289e-01,  5.0060e-01, -5.3503e-01,  6.5793e-01,
        -1.4281e+00, -1.0557e+00, -3.2065e-01,  5.5331e-01, -2.2441e-01,
         4.9368e-03, -8.5709e-01, -6.2144e-01, -2.9752e+00, -1.0889e+00,
        -1.0538e+00,  1.8081e+00,  7.5744e-01,  1.4311e+00,  1.4998e+00,
         6.5244e-01, -2.6761e-02, -1.3566e+00, -1.6688e-03,  1.4140e+00,
         1.0523e+00, -1.4516e+00,  9.4369e-01,  5.5688e-02, -1.8889e+00,
        -5.4709e-01,  1.7760e+00,  8.6275e-01, -5.1063e-01, -5.2403e-01],
       grad_fn=<EmbeddingBackward0>)
tensor(2.8603

tensor([[0, 1, 0, 1],
        [2, 3, 2, 3],
        [4, 5, 4, 5],
        [6, 7, 6, 7],
        [8, 9, 8, 9]])

In [None]:
class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # (B=batch_size, T=time, C=vocab_size)
        if targets is None:
            loss = None
        else: 
            # idx and targets are both (B, T) tesnors of integers, but we need them to be (B*T)
            B, T, C = logits.shape

            logits = logits.view(B*T, C)
            targets = targets.view(-1)

            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, new_token_length):
        # idx is (B, T) array of indices in current context
        for _ in range(new_token_length):
            # get predictions 
            logits, loss = self(idx)
            # we only want the last time step
            logits = logits[:, -1, :]  #(B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

In [20]:
# initiating simple bigram model with batched inputs and outputs
model = BigramLM(vocab_size)
logits, loss = model(xb, yb)
print(f'{logits.shape=}')
print(f'{loss.item()=}')

logits.shape=torch.Size([32, 65])
loss.item()=4.831140041351318


In [30]:
# running inference on bigram model
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(idx, new_token_length=100)[0].tolist()))



QIhSXXx$SxEerU.a' vEkFM,B OuW
v&QUQk-CRqTGrm;hk,qT
i-Qk.ud-WKHFJoiIdYkt'pM;F:bZzpLSHF&WfaabD&gjAYGrq


Obviously, the output is gibberish. The model is completely random weights. Let's train it. 

In [31]:
# PyTorch Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [69]:
batch_size=32
# simple training loop
for steps in range(10000):
    xb, yb = get_batch('train')
    
    #evaluate loss
    logits, loss = model(xb, yb)

    # zero out gradients to avoid gradient accumulation
    optimizer.zero_grad()

    # backwards pass the loss 
    loss.backward()

    # gradient descend. 
    optimizer.step()
print(loss.item())

2.2513909339904785


In [70]:
# running inference on bigram model
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(idx, new_token_length=300)[0].tolist()))



de'derselishar, s d hatonth ceerind.
Dif!
I ge ce: alend itceamaroancadve ach s
Fo Hoio ang IDUCaime ththat hit u nins vese he ar he mythe.

MENCENus adf ath oushot:
D:

Antyay od thirchiro. that hile o wis CEr lol--S: to w d miod thth wif bey y hare apowisu ssteeemellow t er w's, weered'ere remers 


Getting something much more reasonable after training loop! However, this is just a naive bigram model, where a token's context is just the token before it. Let's level up to **attention** and **transformers**. 

## Self-Attention

**Toy Example**: we want the current token to have all the previous tokens as context. A naive implementation is to average all previous token values and use that as context. 

In [82]:
# Toy Example

# we want the current token to have all the previous tokens as context
# a naive implementation is to average all previous token values and use that as context 

torch.manual_seed(1)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)

# bow = bag of words, meaning an unordered collection of word frequencies
xbow = torch.zeros((B, T, C))
for b in range(B): 
    for t in range(T):
        xprev = x[b, :t+1] #(t, C)
        xbow[b, t] = torch.mean(xprev, dim=0)

The implementation above is highly inefficient. We utilize matrix multiplication to parallelize the process. Specifically, we use lower triangular matrix to calculate running sums to efficiently extract the averages.  

In [80]:
# take lower triangle of square matrix
a = torch.tril(torch.ones(3, 3))
a /= torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'{a=}')
print(f'{b=}')
print(f'{c=}')

# now we get running average of the columns of b in c. 

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[5., 9.],
        [4., 8.],
        [3., 3.]])
c=tensor([[5.0000, 9.0000],
        [4.5000, 8.5000],
        [4.0000, 6.6667]])


Now, apply it onto our running average calculation.

In [89]:
weights = torch.tril(torch.ones(T, T))
print(f'{weights=}')
weights /= weights.sum(dim=1, keepdim=True)
print(f'{weights=}')
xbow2 = weights @ x
# allclose checks if the matrices are similar. 
torch.allclose(xbow, xbow2)

weights=tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
weights=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

A third method, using Softmax, also produces the desired result. 

In [90]:
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril==0, float('-inf'))
print(f'{weights=}')
weights = F.softmax(weights, dim=1)
print(f'{weights=}')
xbow3 = weights @ x
torch.allclose(xbow, xbow3)

weights=tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
weights=tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]

True

We use this third method in self-attention because softmax presents a nice intuitive to understand the natural affinity between elements. Specifically, the weights matrix is an affinity between the current token and previous tokens, with larger values representing more **attention** needed between that pair. 

Importantly, we also clamp the elements so that the current token can only talk to the tokens of the past, and can not talk to the tokens of the future. 

Below, we perform **self-attention**. We introduce three new concepts
- Key: a matrix representing the *value* of the previous tokens.
- Query: a matrix representing what the current token is looking for.
- Value: a matrix representing the information of the current token. 

These are all implemented by simple linear layers. Then, a simple *dot product* between the two represents the affinity between the query of the current token and the keys of the previous tokens. 

In [110]:
# playing around with transpose

a = torch.arange(24).view(2, 3, 4)
print(f'{a.shape=}')
a = a.transpose(0, 1)
print(f'{a.shape=}')
a

a.shape=torch.Size([2, 3, 4])
a.shape=torch.Size([3, 2, 4])


tensor([[[ 0,  1,  2,  3],
         [12, 13, 14, 15]],

        [[ 4,  5,  6,  7],
         [16, 17, 18, 19]],

        [[ 8,  9, 10, 11],
         [20, 21, 22, 23]]])

In [116]:
# self attention
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

# single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, C) --> (B, T, 16)
q = query(x) # (B, T, C) --> (B, T, 16)
weights = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) --> (B, T, T)

# do the weighted average
tril = torch.tril(torch.ones((T, T)))
weights = weights.masked_fill(tril==0, float('-inf'))
weights = F.softmax(weights, dim=-1)
# now we have weights for each batch of data. 
v = value(x)
out = weights @ v

print(f'{weights[0]=}')

weights[0]=tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.1865e-01, 8.8135e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [3.9568e-01, 4.2274e-01, 1.8159e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [2.3539e-01, 5.5080e-01, 1.4647e-01, 6.7341e-02, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.2729e-01, 2.4094e-02, 2.2063e-01, 2.9514e-01, 3.3284e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.2596e-01, 6.3780e-02, 6.7478e-02, 2.8705e-01, 4.0217e-01, 5.3562e-02,
         0.0000e+00, 0.0000e+00],
        [2.1200e-02, 5.7241e-04, 3.3019e-01, 4.5086e-02, 2.1366e-02, 5.0661e-02,
         5.3092e-01, 0.0000e+00],
        [4.2966e-01, 9.4396e-03, 1.4951e-01, 6.9958e-02, 2.1870e-01, 7.7199e-02,
         1.4871e-02, 3.0660e-02]], grad_fn=<SelectBackward0>)


Notes:
1. Attention is a *communication mechanism*. It's like nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them with data-dependent weights. 
2. There is no notion of space. Unlike convolutions which in essence have some spatial encoding. This is why we need postionally encoded tokens, or a positional embedding vector.
3. Examples across batch dimension is processed completely independently and never "communicate" each other. This allows attention to be such an efficient mechanism because it parallels the mechanism. 
4. In an "encoder" attention block, you remove the `torch.tril` line, allowing the tokens to freely communicate with each other. The implementation above is a "decoder" implementation because it has triangular masking, and is usually used in autoregressive settings so that tokens in the future can not talk to the current token. 
5. **Self-attention** just means that the keys and values are produced from the same source. **Cross-attention** means that you are attending to a separate source of tokens to pool information from. 
6. **Scaled-Attention** divides the `weight` matrix by `1/sqrt(head_size)`. This makes so when input $Q, K$ are unit variance, `weight` will be unit variance too and softmax will not saturate too much. 

In [135]:
# Demonstration of scaled-attention
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
weights = q @ k.transpose(-2, -1)
print(f'{k.var()=}, {q.var()=}, {weights.var()=}')
weights = q @ k.transpose(-2, -1) * head_size ** -0.5
print(f'{k.var()=}, {q.var()=}, {weights.var()=}')
# scales back down to unit variance 

# the reason we want to scale this down is because 
print(torch.softmax(torch.tensor([0.1, 0.3, 0.5, 0.2, 0.4]), dim=-1))
print(torch.softmax(torch.tensor([0.1, 0.3, 0.5, 0.2, 0.4])*10, dim=-1))
print(torch.softmax(torch.tensor([0.1, 0.3, 0.5, 0.2, 0.4])*100, dim=-1))
# the larger values get larger, and in the limit this becomes a one-hot embedding.https://job-boards.greenhouse.io/lilasciences/jobs/4031379009 

k.var()=tensor(1.0047), q.var()=tensor(0.9612), weights.var()=tensor(16.4341)
k.var()=tensor(1.0047), q.var()=tensor(0.9612), weights.var()=tensor(1.0271)
tensor([0.1621, 0.1980, 0.2419, 0.1792, 0.2188])
tensor([0.0117, 0.0861, 0.6364, 0.0317, 0.2341])
tensor([4.2482e-18, 2.0611e-09, 9.9995e-01, 9.3572e-14, 4.5398e-05])
