In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-05-22 15:24:26--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-05-22 15:24:28 (963 KB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [4]:
# read data
with open('input.txt','r',encoding='UTF-8') as f:
    text = f.read()

In [5]:
# build vocab, encoder and decoder

chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for i,c in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda v: ''.join(itos[vv] for vv in v)

print(encode('hii there'))
print(decode(encode('hii there')))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65
[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [6]:
# tokenize the dataset
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[0:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [7]:
# separate data into train, validation and test datasets

n = int(len(data) * 0.9)
train_data = data[0:n]
validation_data = data[n:]

In [8]:
block_size = 8

x = train_data[:block_size]
y = train_data[1:block_size + 1]

for i in range(block_size):
    context = x[:i+1]
    target = y[i]

    print(f"When the context is {context}, the target is {target}")

When the context is tensor([18]), the target is 47
When the context is tensor([18, 47]), the target is 56
When the context is tensor([18, 47, 56]), the target is 57
When the context is tensor([18, 47, 56, 57]), the target is 58
When the context is tensor([18, 47, 56, 57, 58]), the target is 1
When the context is tensor([18, 47, 56, 57, 58,  1]), the target is 15
When the context is tensor([18, 47, 56, 57, 58,  1, 15]), the target is 47
When the context is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is 58


In [9]:
torch.manual_seed(1337)

batch_size = 4
block_size = 8

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else validation_data
    ix = torch.randint(0,len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)
print('---------------')

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b,:t+1]
        target = yb[b,t]
        print(f"When the context is {context}, the target is {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
---------------
When the context is tensor([24]), the target is 43
When the context is tensor([24, 43]), the target is 58
When the context is tensor([24, 43, 58]), the target is 5
When the context is tensor([24, 43, 58,  5]), the target is 57
When the context is tensor([24, 43, 58,  5, 57]), the target is 1
When the context is tensor([24, 43, 58,  5, 57,  1]), the target is 46
When the context is tensor([24, 43, 58,  5, 57,  1, 46]), the target is 43
When the context is tensor([24, 43, 58,  5, 57,  1, 46, 43]), the target is 39
When the context is tensor([44]), the target is 53
When the context is te

In [16]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size) -> None:
        super().__init__()

        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targerts are both of shape (batch_size, block_size) tensor of integers
        logits = self.token_embedding_table(idx) # shape: (batch_size, block_size, vocab_size) or (B, T, C)

        if targets is None:
            loss = None
        else:
            # reshaping logits so that we can compute the loss
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            # predicting the loss through cross entropy
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    # Sample from the distribution through the model 
    def generate(self, idx, max_new_tokens=100):
        with torch.no_grad():
            for t in range(max_new_tokens):
                # get the predictions
                logits, loss = self(idx)
                # Use only the last step
                logits = logits[:, -1, :] # for all batches, take the last token
                # get the probability distribution
                probs = F.softmax(logits, dim=1)
                # sample from the distribution
                sample = torch.multinomial(probs, num_samples=1)
                # append to the sequence
                idx = torch.cat((idx, sample.view(-1, 1)), dim=1)
        return idx

    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

idx = torch.zeros((1,1), dtype=torch.long)
tks = m.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(tks))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [29]:
# Pytorch optimizer
optimizer = torch.optim.Adam(m.parameters(), lr=0.01)

batch_size = 32
for steps in range(1000):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    print(loss.item())

2.3436813354492188
2.390263795852661
2.534330368041992
2.4038283824920654
2.603792905807495
2.4936015605926514
2.4469120502471924
2.4720141887664795
2.50858473777771
2.563898801803589
2.516644239425659
2.534618377685547
2.450871706008911
2.5475900173187256
2.5122485160827637
2.3961195945739746
2.5305325984954834
2.547184705734253
2.5595877170562744
2.5234274864196777
2.5520293712615967
2.41968035697937
2.5986216068267822
2.527085542678833
2.6083054542541504
2.469792127609253
2.4987616539001465
2.483426332473755
2.6692516803741455
2.5700106620788574
2.568403482437134
2.635690689086914
2.3716225624084473
2.553765296936035
2.524486780166626
2.4647772312164307
2.5398011207580566
2.530543327331543
2.4723949432373047
2.443235158920288
2.4367730617523193
2.369100332260132
2.494701862335205
2.363583564758301
2.43719744682312
2.6231653690338135
2.4482808113098145
2.380394220352173
2.5506653785705566
2.4480056762695312
2.4132583141326904
2.436281681060791
2.447112560272217
2.566087007522583
2.46

In [31]:
print(decode(m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))



Buceckierelsaru t IAREEroutersay.
AD:
Se m ane! ctcobsad, s IAse h mexadame, makisoung, hall-hithin p wate st--
TUCLIfun met th hire onsthiapu cour chorlothay Mabe t VOLUKERUS:

ABEY:
Ans s har ug y'd it trr,
I lay:
EOMy athaveanghur amexeousss:
F in

Buru.
Selum's
IA:
AUESTrisin. thed thmyo as ketoret?
Yovethethellend:
GAnd fo,
CKI:
DULird m wotintoupenow an,
Un. houf,
HELOLALod bery s. angupre; l We t lute, ts m re, drer f s thatheewisto burinoura s,
Tono u t acou s, ingh cak r deftilselend wn


### The mathematical trick in self-attention

In [9]:
import torch
torch.manual_seed(1337)
B,T,C = 4,8,2
x =  torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [17]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range (T):
        xprev = x[b,:t+1]
        xbow[b,t] = torch.mean(xprev, 0)

x[0], xbow[0]

(tensor([[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [20]:
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
print(wei)
xbow2 = wei @ x # wei initially is (T,T) and x is (B,T,C), Pytorch will add a dimension to wei to make it (1,T,T) and broadcast it to (B,T,T)
torch.allclose(xbow, xbow2)
x[0], xbow2[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


(tensor([[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [31]:
import torch
from torch.nn import functional as F

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)
xbow[0], xbow3[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
B,T,C = 4,8,32
x =  torch.randn(B,T,C)

# single head to perform self-attention
head_size = 16 # H
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # B,T,H
q = query(x) # B,T,H
v = value(x) # B,T,H
wei = q @ k.transpose(-2,-1)

tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

#out = wei @ x
oy = wei @ v
out.shape

torch.Size([4, 8, 32])

Notes:


* Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
* There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
Each example across batch dimension is of course processed completely independently and never "talk" to each other
* In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
* "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
* "Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below


In [24]:
a = torch.randn(3,3)
print(a)
print(a[:,-2:])

tensor([[-1.0036,  0.0850,  1.0277],
        [-0.3999,  0.8997,  0.2513],
        [ 0.0527,  2.4278, -1.3424]])
tensor([[ 0.0850,  1.0277],
        [ 0.8997,  0.2513],
        [ 2.4278, -1.3424]])
