In [1]:
import torch
print(torch.__version__)
print (torch.cuda.is_available())
print (torch.cuda.device_count())

1.13.1+cu117
True
3


This notebook implements a character-level transformer model. The goal is to replicate a lecture to learn how to make a transformer architecture.

This is a decoder-only architecture. There is no sequence-to-sequence functionality, no clasification of text, other NLP tasks... just text generation / pure language modeling.  

Source:
https://www.youtube.com/watch?v=kCc8FmEb1nY

# Dataset loading & exploration

In [2]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
# !mkdir ../data/tinyshakespeare -p
# !mv input.txt ../data/tinyshakespeare/input.txt

In [3]:
# Observe dataset size, first 1000 characters
with open ('../data/tinyshakespeare/input.txt', 'r') as f:
    text = f.read()

print ("Length: ", len(text))
print (text[:1000])

Length:  1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thir

In [4]:
# Print sorted vocabulrary and vocabulary size
vocab = sorted(set(text))
print ("Vocabulary: ", vocab)
print ("Vocabulary size: ", len(vocab))

Vocabulary:  ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocabulary size:  65


# Tokenization, train-test split

In [5]:
# Create tokenizer to encode and decode text by mapping vocabulary to integers
stoi = {c: i for i, c in enumerate(vocab)}
itos = {i: c for i, c in enumerate(vocab)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[c] for c in l])

print(encode(text[:1000]))

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1, 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49, 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50, 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58, 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47, 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42, 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63, 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41, 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63, 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13, 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 

In [6]:
# split dataset into train and validation
cutoff = int(0.9 * len(text))
train_text = text[:cutoff]
val_text = text[cutoff:]

In [7]:
data = torch.tensor(encode(text), dtype=torch.long)
train_data = data[: cutoff]
val_data = data[cutoff:]
(len(train_data), len(val_data))

  data = torch.tensor(encode(text), dtype=torch.long)


(1003854, 111540)

# Loading examples, batches, and creating an oracle

In [8]:
# Show the first input sequence and the expected predictions
block_size = 8
example = encode(train_text[:block_size + 1])
example = torch.tensor(example, dtype=torch.long)
print(example)

for idx in range(block_size):
    print(f'For value {example[idx]} the target is {example[idx+1]}')

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])
For value 18 the target is 47
For value 47 the target is 56
For value 56 the target is 57
For value 57 the target is 58
For value 58 the target is 1
For value 1 the target is 15
For value 15 the target is 47
For value 47 the target is 58


In [9]:
torch.manual_seed(1337)
batch_size = 4

def get_batch(is_train: bool = True):
    data = train_data if is_train else val_data
    ix = torch.randint(0, len(data)-block_size, (batch_size,)) # should this be len(data) - 1? we need to account for yb
    xb = torch.stack([data[i: i+block_size] for i in ix])
    yb = torch.stack([data[i+1: i+1+block_size] for i in ix])
    return xb, yb

xb, yb = get_batch()

for b in range(batch_size):
    for t in range(block_size):
        print(f"When the input is {xb[b][:t+1].tolist()} the target is {yb[b][t]}")

When the input is [24] the target is 43
When the input is [24, 43] the target is 58
When the input is [24, 43, 58] the target is 5
When the input is [24, 43, 58, 5] the target is 57
When the input is [24, 43, 58, 5, 57] the target is 1
When the input is [24, 43, 58, 5, 57, 1] the target is 46
When the input is [24, 43, 58, 5, 57, 1, 46] the target is 43
When the input is [24, 43, 58, 5, 57, 1, 46, 43] the target is 39
When the input is [44] the target is 53
When the input is [44, 53] the target is 56
When the input is [44, 53, 56] the target is 1
When the input is [44, 53, 56, 1] the target is 58
When the input is [44, 53, 56, 1, 58] the target is 46
When the input is [44, 53, 56, 1, 58, 46] the target is 39
When the input is [44, 53, 56, 1, 58, 46, 39] the target is 58
When the input is [44, 53, 56, 1, 58, 46, 39, 58] the target is 1
When the input is [52] the target is 58
When the input is [52, 58] the target is 1
When the input is [52, 58, 1] the target is 58
When the input is [52, 

# Bigram model

In [10]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, x, y= None):
        if y is None:
            loss = None
            pred = self.embed(x)

        else:
            pred = self.embed(x)
            loss = F.cross_entropy(pred.view(-1, self.vocab_size), y.view(-1)) # must convert shapes from (B, T, C) to (B*T, C) channels must be second!
        return pred, loss

    def generate(self, x, max_tokens):
        for _ in range(max_tokens):
            logits = self(x)[0] 
            last = logits[:, -1, :] # get prediction for latest token only
            probs = F.softmax(last, dim=-1)  # normalize logits to prob distr
            val = torch.multinomial(probs, num_samples=1) #sample from distr
            x = torch.cat([x,val], dim=1) # add sampled char to input
        return x


m = BigramModel(len(vocab))
print(m(xb, yb)[0].shape, m(xb,yb)[1])

test = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(test, 100)[0].tolist()))


torch.Size([4, 8, 65]) tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [11]:
optimizer = torch.optim.AdamW(m.parameters(), 3e-4)

In [12]:
batch_size = 32

for epoch in range(10000):
    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    loss.backward()
    optimizer.step()
    if epoch%200 == 0:
        print(loss)


tensor(4.6924, grad_fn=<NllLossBackward0>)
tensor(4.6463, grad_fn=<NllLossBackward0>)
tensor(4.4171, grad_fn=<NllLossBackward0>)
tensor(4.3734, grad_fn=<NllLossBackward0>)
tensor(4.2321, grad_fn=<NllLossBackward0>)
tensor(4.0963, grad_fn=<NllLossBackward0>)
tensor(3.8495, grad_fn=<NllLossBackward0>)
tensor(3.8547, grad_fn=<NllLossBackward0>)
tensor(3.7979, grad_fn=<NllLossBackward0>)
tensor(3.5864, grad_fn=<NllLossBackward0>)
tensor(3.5940, grad_fn=<NllLossBackward0>)
tensor(3.5129, grad_fn=<NllLossBackward0>)
tensor(3.3269, grad_fn=<NllLossBackward0>)
tensor(3.3191, grad_fn=<NllLossBackward0>)
tensor(3.2556, grad_fn=<NllLossBackward0>)
tensor(3.2865, grad_fn=<NllLossBackward0>)
tensor(3.0176, grad_fn=<NllLossBackward0>)
tensor(3.0480, grad_fn=<NllLossBackward0>)
tensor(3.0441, grad_fn=<NllLossBackward0>)
tensor(3.0933, grad_fn=<NllLossBackward0>)
tensor(3.0888, grad_fn=<NllLossBackward0>)
tensor(2.9682, grad_fn=<NllLossBackward0>)
tensor(2.8663, grad_fn=<NllLossBackward0>)
tensor(2.98

In [14]:
print(decode(m.generate(test, 200)[0].tolist()))



A:
GLeco wh hangotLO:ws, oraingor, s ve!
A:


Theleseeserer hee an beeOFonoreme ain cketoty dedo lo'lllI at ta d:
ELIS me turf lal y his d w pe atho oraingre n y t
Enganoreralo anicererupa anse trcor


# Attention with masking

In [20]:
sample = torch.rand((4, 8, 2))
xbow = torch.rand(4,8,2)
print(sample)

for i in range(sample.shape[0]): # iterate on batch
    for j in range(sample.shape[1]): #iterate to tth token
        xprev = sample[i][0:j+1]
        xbow[i][j] = xprev.mean(0)

print(xbow)


tensor([[[3.9964e-01, 8.4944e-01],
         [7.7692e-01, 6.9311e-01],
         [7.9659e-01, 7.0895e-01],
         [9.8861e-01, 5.0518e-02],
         [4.8716e-01, 5.3531e-01],
         [1.7477e-01, 8.2057e-01],
         [1.8916e-02, 4.4383e-01],
         [6.2715e-01, 3.0160e-01]],

        [[3.7760e-01, 9.7373e-01],
         [3.9331e-02, 9.7179e-01],
         [7.6383e-01, 8.0790e-01],
         [7.0879e-02, 6.1775e-01],
         [6.1360e-01, 6.8784e-04],
         [9.5227e-01, 1.4038e-01],
         [6.0251e-02, 8.2354e-01],
         [5.2606e-01, 7.3723e-01]],

        [[6.7681e-02, 9.0671e-02],
         [5.9465e-01, 3.9542e-01],
         [4.4005e-01, 6.1657e-01],
         [5.1777e-01, 2.7781e-01],
         [6.3437e-01, 4.8064e-01],
         [8.9553e-01, 2.4832e-01],
         [8.4402e-01, 5.6527e-01],
         [9.8648e-01, 8.4320e-01]],

        [[6.7831e-01, 2.5030e-01],
         [8.8675e-01, 7.9661e-01],
         [4.5590e-01, 6.3645e-02],
         [2.6560e-01, 9.9764e-02],
         [1.84

In [28]:
a = torch.randint(0, 10, (3,2))
print(a)
print(mask @ a)
# We get a cumulative sum. note how pytorch is height by width

tensor([[1, 5],
        [0, 6],
        [1, 2]])
tensor([[ 1,  5],
        [ 1, 11],
        [ 2, 13]])


In [40]:
mask = torch.tril(torch.ones((8,8), dtype=torch.long))  
mask = mask / mask.sum(dim=1, keepdim=True)
mask


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [43]:
mask @ sample

tensor([[[0.3996, 0.8494],
         [0.5883, 0.7713],
         [0.6577, 0.7505],
         [0.7404, 0.5755],
         [0.6898, 0.5675],
         [0.6040, 0.6096],
         [0.5204, 0.5860],
         [0.5337, 0.5504]],

        [[0.3776, 0.9737],
         [0.2085, 0.9728],
         [0.3936, 0.9178],
         [0.3129, 0.8428],
         [0.3730, 0.6744],
         [0.4696, 0.5854],
         [0.4111, 0.6194],
         [0.4255, 0.6341]],

        [[0.0677, 0.0907],
         [0.3312, 0.2430],
         [0.3675, 0.3676],
         [0.4050, 0.3451],
         [0.4509, 0.3722],
         [0.5250, 0.3516],
         [0.5706, 0.3821],
         [0.6226, 0.4397]],

        [[0.6783, 0.2503],
         [0.7825, 0.5235],
         [0.6737, 0.3702],
         [0.5716, 0.3026],
         [0.4941, 0.2765],
         [0.4382, 0.3798],
         [0.3789, 0.4684],
         [0.3501, 0.4812]]])

In [50]:
val = torch.tril(torch.ones((8,8), dtype=torch.float))
val[val==0] = -torch.inf
val

tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [52]:
torch.softmax(val, dim=-1)


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [61]:
import math
torch.manual_seed(1337)
n_embed = 65 # same as vocab
xb = torch.randn((4, 8, n_embed))

B, T, C = xb.shape
head_size = 32

# Self attention values (B, T, H) (smaller headsize, as attention calc has O(n^2))
query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

q = query(xb)
k = key(xb)
v = value(xb)

# implement attention formula softmax(Q*K/sqrt(H)) * V
affinity = q @ k.transpose(-2, -1) # (B,T,H) * (B,H,T) -> (B,T,T) we see this is context squared in dimensions, not embeddings squared...
affinity = torch.masked_fill(affinity, torch.tril(torch.ones(T,T)) == 0, -torch.inf)/ math.sqrt(head_size)
wei= torch.softmax(affinity, dim=-1)  # B T T
out = wei @ v
out


tensor([[[ 0.4769,  1.1327,  0.7731,  ..., -0.2679,  0.6497,  0.0312],
         [ 0.0106, -0.2966, -0.1085,  ...,  0.4612, -0.0433,  0.5922],
         [-0.1206,  0.1316,  0.0050,  ...,  0.0691, -0.1379,  0.4689],
         ...,
         [-0.1779,  0.3911, -0.0929,  ..., -0.1690, -0.1640,  0.1042],
         [-0.3149,  0.0428, -0.4238,  ...,  0.1710, -0.1707,  0.0250],
         [-0.2148,  0.3741, -0.0986,  ...,  0.0213, -0.0842,  0.2443]],

        [[-1.8004,  0.0076, -0.0238,  ..., -0.7904, -0.1716,  0.1303],
         [-0.9701,  0.0128, -0.1903,  ..., -0.3677, -0.4117,  0.0375],
         [-0.1394,  0.4624, -0.1384,  ...,  0.0072, -0.0533,  0.0413],
         ...,
         [ 0.1976, -0.1131, -0.1375,  ...,  0.1324, -0.5433, -0.0589],
         [ 0.0924, -0.1698, -0.1323,  ..., -0.0724, -0.2501, -0.0958],
         [ 0.1734, -0.1607, -0.1650,  ...,  0.0443, -0.4195, -0.1125]],

        [[ 0.6144,  0.7227, -0.4373,  ..., -0.4756, -0.1982,  0.6854],
         [ 0.2863,  0.2347,  0.0836,  ..., -0

In [62]:
wei

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4166, 0.5834, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3572, 0.4355, 0.2072, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3124, 0.2359, 0.1973, 0.2544, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1460, 0.2074, 0.1991, 0.2169, 0.2305, 0.0000, 0.0000, 0.0000],
         [0.2731, 0.1570, 0.2305, 0.1111, 0.1061, 0.1222, 0.0000, 0.0000],
         [0.1197, 0.0957, 0.0872, 0.0999, 0.2224, 0.3185, 0.0567, 0.0000],
         [0.2045, 0.0996, 0.1021, 0.1003, 0.1049, 0.1006, 0.1552, 0.1328]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6872, 0.3128, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2581, 0.2713, 0.4705, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2696, 0.2137, 0.1674, 0.3493, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2369, 0.1408, 0.1960, 0.2376, 0.1887, 0.0000, 0.0000, 0.0000],
         [0.1553, 0.339