In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
torch.device('mps') # use GPU if available
%matplotlib inline

In [3]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()

print(text[:100])


First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print('vocab size:', vocab_size)
print(''.join(chars))

vocab size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [7]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda x: [stoi[ch] for ch in x]
decode = lambda x: ''.join([itos[ch] for ch in x])

print(encode("dupa"))
print(decode(encode("dupa")))

[42, 59, 54, 39]
dupa


In [8]:
data = torch.tensor(encode(text), dtype=torch.long)

In [9]:
n = int(len(data) * 0.9)
train_data = data[:n]
valid_data = data[n:]

In [10]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [14]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'When the context is {decode(context.tolist())}, the target is {decode([target.tolist()])}')


When the context is F, the target is i
When the context is Fi, the target is r
When the context is Fir, the target is s
When the context is Firs, the target is t
When the context is First, the target is  
When the context is First , the target is C
When the context is First C, the target is i
When the context is First Ci, the target is t


In [16]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else valid_data
    ix = torch.randint(len(data) - block_size , (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print(xb.shape, yb.shape)

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b,:t+1]
        target = yb[b,t]
        print(f'When the context is {decode(context.tolist())}, the target is {decode([target.tolist()])}')

torch.Size([4, 8]) torch.Size([4, 8])
When the context is L, the target is e
When the context is Le, the target is t
When the context is Let, the target is '
When the context is Let', the target is s
When the context is Let's, the target is  
When the context is Let's , the target is h
When the context is Let's h, the target is e
When the context is Let's he, the target is a
When the context is f, the target is o
When the context is fo, the target is r
When the context is for, the target is  
When the context is for , the target is t
When the context is for t, the target is h
When the context is for th, the target is a
When the context is for tha, the target is t
When the context is for that, the target is  
When the context is n, the target is t
When the context is nt, the target is  
When the context is nt , the target is t
When the context is nt t, the target is h
When the context is nt th, the target is a
When the context is nt tha, the target is t
When the context is nt that, the 

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # (batch_size, block_size, vocab_size) (B, T, C) (batch time channel)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # get the predictions
            logits , loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax
            probs = F.softmax(logits, dim=1) # B, C
            # sample from the distribution
            idx_next = torch.multinomial(probs, 1) # B, 1
            # add the new token to the sequence
            idx = torch.cat([idx, idx_next], dim=1) # B, T+1
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, 100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [34]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [37]:
batch_size = 32
for steps in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 10 == 0:
        print(f'Loss at step {steps} is {loss.item()}')

Loss at step 0 is 3.6924993991851807
Loss at step 10 is 3.703143358230591
Loss at step 20 is 3.5420799255371094
Loss at step 30 is 3.6378514766693115
Loss at step 40 is 3.654754638671875
Loss at step 50 is 3.638540267944336
Loss at step 60 is 3.634798288345337
Loss at step 70 is 3.546691417694092
Loss at step 80 is 3.6222987174987793
Loss at step 90 is 3.597193956375122
Loss at step 100 is 3.6769022941589355
Loss at step 110 is 3.6126725673675537
Loss at step 120 is 3.5891876220703125
Loss at step 130 is 3.53181791305542
Loss at step 140 is 3.469095468521118
Loss at step 150 is 3.425056219100952
Loss at step 160 is 3.67372989654541
Loss at step 170 is 3.5732293128967285
Loss at step 180 is 3.454685926437378
Loss at step 190 is 3.475454092025757
Loss at step 200 is 3.4926655292510986
Loss at step 210 is 3.506186008453369
Loss at step 220 is 3.451472759246826
Loss at step 230 is 3.4598655700683594
Loss at step 240 is 3.476881742477417
Loss at step 250 is 3.480024576187134
Loss at step 26

In [39]:
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, 300)[0].tolist()))


Werorreltcod w f heatug t my't care!
Adle, cld f sthen:
AThemy TI my ne fu sthe wet:

HERDaienothichevee t sinom me d:
BOMNod he ak, he we m, eengmor t
Wesivesp, d
WWhen
Ththithart overt musthe hes hogie thake:
Sn bend swe buriman ithea
Clovert it ck;
Nos y ts bag,

EWhed st sus d we aves! br:
O:
Th


In [40]:
# mathematical trick in self-attention

torch.manual_seed(1337)
B,T,C = 4, 8, 2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [41]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b,t] = torch.mean(xprev, dim=0)

In [47]:
wei = torch.tril(torch.ones((T,T)))
wei = wei / torch.sum(wei, dim=1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) -> (B, T, C)

torch.allclose(xbow, xbow2)

True

In [42]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [43]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [46]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [48]:
# version 3: use softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow3, xbow)

True

In [62]:
# version 4: self-attention
torch.manual_seed(1337)
B,T,C = 4, 8, 32 # batch, time, channels
x = torch.randn(B,T,C)

# single head of self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, C) -> (B, T, head_size)
q = query(x) # (B, T, C) -> (B, T, head_size)
wei = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T) 

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x) # (B, T, C) -> (B, T, head_size)

out = wei @ v

out.shape

torch.Size([4, 8, 32])

In [63]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [64]:
wei

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
         [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
         [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
         [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1687, 0.8313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2477, 0.0514, 0.7008, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4410, 0.0957, 0.3747, 0.0887, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0069, 0.0456, 0.0300, 0.7748, 0.1427, 0.0000, 0.0000, 0.0000],
         [0.0660, 0.089