## Imports and Downloads

Companion notebook to the [Zero To Hero](https://karpathy.ai/zero-to-hero.html) video on GPT.

In [85]:
!pip install torch

Defaulting to user installation because normal site-packages is not writeable


In [86]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

<torch._C.Generator at 0x1182fe5b0>

## Hyperparameters

In [174]:
batch_size = 32
block_size = 8 # context_window size
max_new_tokens = 100 # number of new tokens to predict
embedding_dim = 32
num_attention_heads = 4

eval_interval = 300
eval_iters = 200
training_iters = 5000

learning_rate = 1e-3
dropout_rate = 0.0
device = "cuda" if torch.cuda.is_available() else "cpu"


## Data Processing

**Download text dataset and read it in**

In [88]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
# run this in command line
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

zsh:1: command not found: wget


In [89]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [90]:
# Number of characters in the dataset  
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [91]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [92]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("All characters in dataset:", ''.join(chars))
print("Vocab size:", vocab_size)

All characters in dataset: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65


**Tokenize the dataset**

In [93]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

# Create a function that performs the encoding and decoding
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [94]:
# let's now encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

**Split data into train/test**

In [95]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

**Define context window size**

In [96]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [97]:
# Illustration of why we use +1 for block size, gives 8 examples in each context

x = train_data[:block_size] # input
y = train_data[1:block_size+1] # next character to predict (target)
for t in range(block_size):
    context = x[:t+1]   
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


**Batch data**

In [98]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?


# generate a small batch of data of inputs x and targets y
def get_batch(split):
    # choose between train and val splits
    data = train_data if split == 'train' else val_data
    # choose random index to sample sequence from, and do this batch_size times
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # get inputs and targets
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print("shape:", xb.shape)
print("data:", xb)
print('targets:')
print("shape:", yb.shape)
print("data:", yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
shape: torch.Size([4, 8])
data: tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
shape: torch.Size([4, 8])
data: tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target

In [99]:
print("Transformer input:", xb)

Transformer input: tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


## Bigram Model

<div>
<img src="https://devopedia.org/images/article/219/7356.1569499094.png" width="500"/>
</div>

**Create Model**

In [100]:
class BigramLanguageModel(nn.Module):

    # B = batch size (=4)
    # T = context window size (=8)
    # C = vocab size (=65)

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx: (B, T)
        # targets: (B, T)

        # get idx'th row of embedding table and return vector of size vocab_size  
        # interpretation: vector is scores for the next character in the sequnece 
        logits = self.token_embedding_table(idx) 

        # logits: (B,T, C) 
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            # reshape logits and targets
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # negative log likelihood loss
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # generate max_new_tokens new tokens, starting from idx
        # idx: (B, T) (a batch of sequences of tokens)

        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step (i.e. prediction for next token)
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) = single prediction for each sequence
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

Notes:

1: in generate, not using a NN, just using embedding as a look up table so dimensionality of input can change i.e. we can keep incrementing the output and feeding it into the input to predict the next token

2: This isn't necessary since we're using a bigram model, so only the previous letter is actually being used to predict next one, but it is a more general and flexible implementation 


In [101]:
# Instantiate the model
model = BigramLanguageModel(vocab_size)
model = model.to(device)

logits, loss = model(xb, yb)

In [102]:
# Check model output and loss
print("Logits:")
print(logits.shape)
print(loss)

Logits:
torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)


In [103]:
idx = torch.zeros((1, 1), dtype=torch.long)
print("Start sequence with new line characters:", idx)

print("Prediction from model:", decode(model.generate(idx=idx, max_new_tokens=max_new_tokens)[0].tolist()))

Start sequence with new line characters: tensor([[0]])
Prediction from model: 
lfJeukRuaRJKXAYtXzfJ:HEPiu--sDioi;ILCo3pHNTmDwJsfheKRxZCFs
lZJ XQc?:s:HEzEnXalEPklcPU cL'DpdLCafBheH


**Train Model**

In [104]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [110]:
# function to estimate loss with model in eval mode for checkmarking
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            _, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [111]:
for iter in range(training_iters): # increase number of steps for good results... 

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"iter {iter} train loss: {losses['train']:.2f} val loss: {losses['val']:.2f}")
    
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

iter 0 train loss: 3.19 val loss: 3.25
iter 300 train loss: 3.13 val loss: 3.14
iter 600 train loss: 3.06 val loss: 3.07
iter 900 train loss: 2.98 val loss: 3.02
iter 1200 train loss: 2.95 val loss: 2.95
iter 1500 train loss: 2.89 val loss: 2.90
iter 1800 train loss: 2.85 val loss: 2.87
iter 2100 train loss: 2.82 val loss: 2.83
iter 2400 train loss: 2.76 val loss: 2.80
iter 2700 train loss: 2.76 val loss: 2.77


In [107]:
# Repeat predictions from earlier with trained model

idx = torch.zeros((1, 1), dtype=torch.long)
print("Start sequence with new line characters:", idx)

print("Prediction from model:", decode(model.generate(idx=idx, max_new_tokens=max_new_tokens)[0].tolist()))

Start sequence with new line characters: tensor([[0]])
Prediction from model: 
d:Xn;B?Klespy OMiaj,-H:HMIZWt
wal;eroM.e-IIgs
br,!M.xxy p?-TORtukQSVgQNqRf,&$qq?Kjt.DunlGuZEFHlB?!Lr


## The mathematical trick in self-attention

Toy example batch:

In [112]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time (tokens), channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Aim: Want tokens to communicate with the preceding tokens (and ignore future tokens).
Naive approach: for context of token with its history, simply use the average of all of the previous tokens up until that point.

In [115]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))

# For each batch
for b in range(B):
    # For each time step
    for t in range(T):
        # Get all previous time steps
        xprev = x[b,:t+1] # (t,C)
        # Average over previous tokens
        xbow[b,t] = torch.mean(xprev, 0)


In [118]:
print("Original context vector:", x[0])
print("Bag of words averaged vector:", xbow[0])
print("Each row is the average of the rows above it")

Original context vector: tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
Bag of words averaged vector: tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
Each row is the average of the rows above it


Problem: inefficient implementation (many loops). Solution: use matrix multiplication to speed this up.

In [117]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3)) # returns lower triangular matrix of ones (upper=zeros)
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Hence can now use this technique to get average of tokens in an efficient manner:

In [120]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

True

However, there is a third option for doing this, but with a more intuitive interpretation.

In [124]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T)) # lower triangular matrix of ones, only used for masking

# wei can be seen as a matrix that encapsulates the affinities between different tokens 
# For averaging, can be just zeros
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) #i.e. lower triangle of zeros, upper of -inf
# -inf can be interpreted as a mask for "don't attend to this token", i.e. for future tokens
wei = F.softmax(wei, dim=-1) # softmax is a normalisation operation, so rows sum to 1
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

## Updating the Bigram Model

Let's update the bigram model to have:

* a linear layer
* an arbitrary embedding dimension
* a positional encoding

In [136]:
class BigramLanguageModel(nn.Module):

    # B = batch size (=4)
    # T = context window size (=8)

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embedding_table = nn.Embedding(block_size, embedding_dim)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx, targets=None):
        # idx: (B, T)
        # targets: (B, T)

        B, T = idx.shape

        # get idx'th row of embedding table and return vector of size vocab_size  
        # interpretation: vector is scores for the next character in the sequnece 
        token_embed = self.token_embedding_table(idx) # (B, T, embedding_dim)

        # get positional embedding for each token in the sequence, i.e. the position of the token in the sequence
        pos_embed = self.positional_embedding_table(torch.arange(T, device=idx.device)) # (T, embedding_dim)

        logits = self.lm_head(token_embed + pos_embed) # (B, T, vocab_size)

        # logits: (B,T, vocab_size) 
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            # reshape logits and targets
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # negative log likelihood loss
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # generate max_new_tokens new tokens, starting from idx
        # idx: (B, T) (a batch of sequences of tokens)

        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step (i.e. prediction for next token)
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) = single prediction for each sequence
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [137]:
# Instantiate the model
model = BigramLanguageModel()
model = model.to(device)

for iter in range(training_iters): # increase number of steps for good results... 

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"iter {iter} train loss: {losses['train']:.2f} val loss: {losses['val']:.2f}")
    
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

iter 0 train loss: 4.38 val loss: 4.37
iter 300 train loss: 4.38 val loss: 4.38
iter 600 train loss: 4.38 val loss: 4.38
iter 900 train loss: 4.38 val loss: 4.37
iter 1200 train loss: 4.38 val loss: 4.38
iter 1500 train loss: 4.38 val loss: 4.38
iter 1800 train loss: 4.38 val loss: 4.37
iter 2100 train loss: 4.37 val loss: 4.38
iter 2400 train loss: 4.38 val loss: 4.38
iter 2700 train loss: 4.38 val loss: 4.37


## One Self-Attention Head

Now, let's use self-attention so that the communication between tokens is more advanced than the average of previous tokens. Method: self-attention. Roughly speaking, in one self-attention head, for a single token have different values of:

* Q = query vector: "what am I looking for?"
* K = key vector: "what do I contain?"
* V = value vector: "if you find me interesting, here's what I will communicate to you"

Each token will emit these two vectors to communicate with the other preceding tokens. To calculate affinities between tokens, calculate the dot product of each token's query vector with all other tokens' key vectors. These dot products become the weight vector that we calculate softmax over (and then predict next token from).

Intuition: if key and query vectors are well aligned then dot product will be high and will learn more about that token.

In [145]:
# One self-attention head:
# Input: (B, T, embedding_dim)
# Output: (B, T, head_size)

torch.manual_seed(1337)
B,T,embedding_dim = 4,8,32 # batch, time, embedding dimension
x = torch.randn(B,T,embedding_dim)
head_size = 16


# Create linear layers for keys, queries and values
key = nn.Linear(embedding_dim, head_size, bias=False)
query = nn.Linear(embedding_dim, head_size, bias=False)
value = nn.Linear(embedding_dim, head_size, bias=False)


# Weights now come from dot product of keys and queries for each token
k = key(x)   # (B, T, head_size)
q = query(x) # (B, T, head_size)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)


# Can now mask over future tokens and softmax as before
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)


# Rather than outputting the weights, we use the value vector to encode information about the token
# Can be thought of as "private" information about the token
# "If token is interesting, value vector communicates information about the token"
v = value(x)
out = wei @ v

print("input shape: ", x.shape)
print("output shape: ", out.shape)

input shape:  torch.Size([4, 8, 32])
output shape:  torch.Size([4, 8, 16])


Summary of what's happening:
* K * Q creates a vector for each token calculates how similar each token's query is to the other tokens' keys. Tokens which are more similar, i.e. deemed more interesting, will be weighted higher.
* -inf masks over the future tokens (shouldn't attend to future tokens)
* softmax creates a probability distribution over the values 
* Multiply the output by value vector V (i.e. have a weighted sum of all of the value vectors for each token)

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights. For auto-regressive scenarios, nodes are words in sequence and each node points to nodes that precede it.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling. In other words, don't always have to restrict the attention mechanism to look at previous tokens. E.g. in sentiment analysis (rather than next word prediction) may attend fully to the whole sequence.
- "self-attention" just means that the keys and values are produced from the same source as queries (i.e. from the input x). In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This ensures that input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. If variance of `wei` was large, then softmax would exagerrate differences and hence would converge to one hot vectors. Illustration below

In [146]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)

# Need to normalise the weights by the square root of the head size for scaled-attention
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [147]:
k.var()

tensor(1.0449)

In [148]:
q.var()

tensor(1.0700)

In [149]:
wei.var()

tensor(1.0918)

In [150]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [151]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

## Multi-Head Self-Attention

Can now bring an attention head into the language model. Changes:
* pass embedding through self-attention block before linear layer
* need to update the generate function to have constant input size (can no longer have input > context window size)

In [167]:
class Head(nn.Module):
    """ one head of self-attention """

    # B: batch size
    # T: sequence length
    # C: head size

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)

        # Create the lower triangular matrix for masking out the future tokens
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B,T,C = x.shape

        # Calculate keys and queries
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)

        # Compute attention scores ("affinities") and normalise by head_size
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        
        # Mask over future tokens
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)

        # Perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

# out: batch, time, head_size

In [168]:
class BigramLanguageModel(nn.Module):

    # B = batch size (=4)
    # T = context window size (=8)

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embedding_table = nn.Embedding(block_size, embedding_dim)
        self.sa_head = Head(embedding_dim)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx, targets=None):
        # idx: (B, T)
        # targets: (B, T)

        B, T = idx.shape

        # get idx'th row of embedding table and return vector of size vocab_size  
        # interpretation: vector is scores for the next character in the sequnece 
        token_embed = self.token_embedding_table(idx) # (B, T, embedding_dim)

        # get positional embedding for each token in the sequence, i.e. the position of the token in the sequence
        pos_embed = self.positional_embedding_table(torch.arange(T, device=idx.device)) # (T, embedding_dim)

        x = token_embed + pos_embed # (B, T, embedding_dim)
        x = self.sa_head(x) # (B, T, head_size=embedding_dim)

        logits = self.lm_head(x) # (B, T, vocab_size)
                
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            # reshape logits and targets
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # negative log likelihood loss
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # generate max_new_tokens new tokens, starting from idx
        # idx: (B, T) (a batch of sequences of tokens)

        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:] 
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step (i.e. prediction for next token)
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) = single prediction for each sequence
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [169]:
# Instantiate the model
model = BigramLanguageModel()
model = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(training_iters): # increase number of steps for good results... 

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"iter {iter} train loss: {losses['train']:.2f} val loss: {losses['val']:.2f}")
    
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

iter 0 train loss: 4.24 val loss: 4.24
iter 300 train loss: 2.89 val loss: 2.90
iter 600 train loss: 2.63 val loss: 2.62
iter 900 train loss: 2.54 val loss: 2.55
iter 1200 train loss: 2.50 val loss: 2.50
iter 1500 train loss: 2.47 val loss: 2.48
iter 1800 train loss: 2.45 val loss: 2.45
iter 2100 train loss: 2.44 val loss: 2.44
iter 2400 train loss: 2.44 val loss: 2.44
iter 2700 train loss: 2.43 val loss: 2.43
iter 3000 train loss: 2.42 val loss: 2.43
iter 3300 train loss: 2.42 val loss: 2.43
iter 3600 train loss: 2.40 val loss: 2.41
iter 3900 train loss: 2.40 val loss: 2.41
iter 4200 train loss: 2.39 val loss: 2.40
iter 4500 train loss: 2.39 val loss: 2.40
iter 4800 train loss: 2.39 val loss: 2.39


This is an improvement on before, so attention is helping, but can still improve! Solution: multiple attention heads.

Implementation:
* Use multiple heads and concatenate outputs

In [175]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return out

In [176]:
class BigramLanguageModel(nn.Module):

    # B = batch size (=4)
    # T = context window size (=8)

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embedding_table = nn.Embedding(block_size, embedding_dim)

        # Need to make embedding dim smaller when using multiple heads (for cost)
        self.sa_heads = MultiHeadAttention(num_attention_heads, embedding_dim//num_attention_heads)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx, targets=None):
        # idx: (B, T)
        # targets: (B, T)

        B, T = idx.shape

        # get idx'th row of embedding table and return vector of size vocab_size  
        # interpretation: vector is scores for the next character in the sequnece 
        token_embed = self.token_embedding_table(idx) # (B, T, embedding_dim)

        # get positional embedding for each token in the sequence, i.e. the position of the token in the sequence
        pos_embed = self.positional_embedding_table(torch.arange(T, device=idx.device)) # (T, embedding_dim)

        x = token_embed + pos_embed # (B, T, embedding_dim)
        x = self.sa_heads(x) # (B, T, head_size=embedding_dim)

        logits = self.lm_head(x) # (B, T, vocab_size)
                
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            # reshape logits and targets
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # negative log likelihood loss
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # generate max_new_tokens new tokens, starting from idx
        # idx: (B, T) (a batch of sequences of tokens)

        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:] 
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step (i.e. prediction for next token)
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) = single prediction for each sequence
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [177]:
# Instantiate the model
model = BigramLanguageModel()
model = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(training_iters): # increase number of steps for good results... 

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"iter {iter} train loss: {losses['train']:.2f} val loss: {losses['val']:.2f}")
    
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

iter 0 train loss: 4.22 val loss: 4.22
iter 300 train loss: 2.81 val loss: 2.83
iter 600 train loss: 2.61 val loss: 2.61
iter 900 train loss: 2.50 val loss: 2.52
iter 1200 train loss: 2.45 val loss: 2.45
iter 1500 train loss: 2.40 val loss: 2.42
iter 1800 train loss: 2.38 val loss: 2.38
iter 2100 train loss: 2.37 val loss: 2.36
iter 2400 train loss: 2.34 val loss: 2.35
iter 2700 train loss: 2.33 val loss: 2.34
iter 3000 train loss: 2.31 val loss: 2.33
iter 3300 train loss: 2.31 val loss: 2.31
iter 3600 train loss: 2.29 val loss: 2.30
iter 3900 train loss: 2.28 val loss: 2.29
iter 4200 train loss: 2.27 val loss: 2.30
iter 4500 train loss: 2.26 val loss: 2.29
iter 4800 train loss: 2.25 val loss: 2.29


Result: adding more attention heads improves the predictive performance of the models!

## Other Components of the Transformer

There are a few extra final components of the transformer:

* Feed forward layers (take outputs of multi-head self-attention and communicate between outputs)
* Blocks: One block intersperces communication and computation: first computes MultiHead self-attention and then communicates outputs using feedforward block


In [191]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, embedding_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

In [197]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, embedding_dim, n_head):
        # embedding_dim: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = embedding_dim // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(embedding_dim)
        self.ffw = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = self.sa(x)
        x = self.ffwd(x)
        return x

However, adding multiple blocks of multi-head self-attention and feedforward layers can create a deep network that may be prone to overfitting. 

To remedy this need two regularisation techniques:
* Residual connections
* Layer norm
* Drop out


In [202]:
## Add residual connections to different components

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, embedding_dim, n_head):
        # embedding_dim: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = embedding_dim // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(embedding_dim)

    def forward(self, x):
        x = x + self.sa(x)  # residual connection
        x = x + self.ffwd(x)  # residual connection
        return x


# Also add projection layers to project outputs of blocks back for residual connection

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj  = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out) # projection layer
        return out


# Also in AIAYN, FFN has 4 * embedding_dim hidden units
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, embedding_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim), # projection layer
        )

    def forward(self, x):
        return self.net(x)

Recap: batchnorm ensures each column has 0 mean and unit std (i.e. across batch).
Layernorm: normalise rows instead of columns (i.e. features/embedding dims).

In [207]:
class LayerNorm1d: 
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [208]:
x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs

(tensor(0.1469), tensor(0.8803))

In [209]:
x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features

(tensor(-3.5763e-09), tensor(1.0000))

Note: in AIAYN paper, layernorm was after attention. Nowadays it's more common practice to apply LayerNorm before transformations.

In [210]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, embedding_dim, n_head):
        # embedding_dim: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = embedding_dim // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))  # layer norm
        x = x + self.ffwd(self.ln2(x))  # layer norm
        return x

Also usually add LN to end output of all MHA blocks (before final linear layer)

## Putting it all together (Final Code)

Imports

In [185]:
import torch
import torch.nn as nn
from torch.nn import functional as F

Hyperparameters

In [211]:

# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
training_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
embedding_dim = 384
n_head = 6
n_layer = 6
dropout_rate = 0.2
# ------------


Data loading

In [212]:
torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

Model Components:

In [213]:
class Head(nn.Module):
    """ one head of self-attention """

    # B: batch size
    # T: sequence length
    # C: head size

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)

        # Create the lower triangular matrix for masking out the future tokens
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        B,T,C = x.shape

        # Calculate keys and queries
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)

        # Compute attention scores ("affinities") and normalise by head_size
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        
        # Mask over future tokens
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # Perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, embedding_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout_rate),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, embedding_dim, n_head):
        # embedding_dim: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = embedding_dim // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(embedding_dim)
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

Final Model:

In [214]:
class LM(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding_table = nn.Embedding(block_size, embedding_dim)
        self.blocks = nn.Sequential(*[Block(embedding_dim, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(embedding_dim) # final layer norm
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

Training and Predictions

In [215]:
model = LM()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(training_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == training_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))

10.788929 M parameters
step 0: train loss 4.2849, val loss 4.2823


## Final Notes

* Transformer architecture has encoder and decoder modules. We have just implemented the decoder modules (as characterised by mask which prevents attention with future tokens)
* Missing the encoder parts and cross-attention mechanisms
* AIAYN comes from a machine translation setting where aim is to encode one language and translate auto-regressively into another language. Encoder projects input language into one vector. Decoder then auto-regressively translates the sentence, beginning with the < START > token and conditioned on the encoder output (via cross-attention). Here, queries are coming from target language decoder but the keys and values are coming from the input language encoder. Hence decoder is generating new tokens based on history of decoder and the encoder.

In [None]:
# French to English translation example:

# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>
