In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import math
from torch.nn import ModuleList

In [2]:
with open("text.txt","r") as file:
    text = file.read()

In [3]:
vocab_size = sorted(set(text))

In [4]:
itos = {index: char for index, char in enumerate(vocab_size)}

In [5]:
vocab_size = sorted(set(text))

In [6]:
stoi = {index: char for char, index in enumerate(vocab_size)}

In [7]:
encode = lambda x: [stoi[s] for s in x]

In [8]:
decode = lambda x: [itos[s] for s in x]

In [9]:
encoded = encode("hii there")
encoded

[46, 47, 47, 1, 58, 46, 43, 56, 43]

In [10]:
decoded = decode(encoded)
decoded

['h', 'i', 'i', ' ', 't', 'h', 'e', 'r', 'e']

In [11]:
data = torch.tensor(encode(text), dtype=torch.long)

In [12]:
train_data = data[:int(0.9 * len(data))]

In [13]:
val_data = data[int(0.9 * len(data)):]

In [14]:
assert len(data) == len(train_data) + len(val_data)

In [15]:
block_size = 8

In [16]:
x = train_data[:block_size]
x

tensor([18, 47, 56, 57, 58,  1, 15, 47])

In [17]:
y = train_data[block_size]
y

tensor(58)

In [18]:
block_size = 8
batch_size = 4

In [19]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7fef7df63c90>

In [20]:
#We have x as an input and we have the target number as y

In [21]:
for i in range(1, len(x) - 1):
    inp = x[:i]
    out = x[i+1]
    print(f"inp is: {inp}; out is: {out}")

inp is: tensor([18]); out is: 56
inp is: tensor([18, 47]); out is: 57
inp is: tensor([18, 47, 56]); out is: 58
inp is: tensor([18, 47, 56, 57]); out is: 1
inp is: tensor([18, 47, 56, 57, 58]); out is: 15
inp is: tensor([18, 47, 56, 57, 58,  1]); out is: 47


In [22]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    index = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([train_data[i:i+block_size] for i in index])
    y = torch.stack([train_data[i+1:i+block_size+1] for i in index])
    return x,y 

In [23]:
index = torch.randint(len(train_data) - block_size, (batch_size, ))
index

tensor([ 74928, 231851, 934226, 560077])

In [24]:
x = torch.stack([train_data[i:i+block_size] for i in index])

In [25]:
y = torch.stack([train_data[i+1:i+block_size+1] for i in index])

In [26]:
x

tensor([[56,  6,  0, 24, 43, 58,  1, 61],
        [39, 47, 51,  1, 58, 46, 39, 58],
        [52, 45,  1, 58, 53,  1, 57, 39],
        [43, 47, 52, 45,  1, 46, 53, 50]])

In [27]:
y

tensor([[ 6,  0, 24, 43, 58,  1, 61, 46],
        [47, 51,  1, 58, 46, 39, 58,  1],
        [45,  1, 58, 53,  1, 57, 39, 63],
        [47, 52, 45,  1, 46, 53, 50, 47]])

In [28]:
xb, yb = get_batch("train")

In [29]:
class Model(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, index, targets=None):
        output = self.embedding_layer(index)
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(output.view(batch_size * block_size, -1), yb.view(-1))
        
        return output, loss
    
    def generate(self, index, max_new_token):
        for i in range(max_new_token):
            output, loss = self(index)
            output = output[:, -1, :].softmax(-1)
            index_next = torch.multinomial(output, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index

In [30]:
model = Model(len(vocab_size))

In [31]:
logits, loss = model(xb, yb)

In [32]:
logits.shape

torch.Size([4, 8, 65])

In [33]:
-math.log(1/len(vocab_size)) #So the model is very convoluted at the moment
#The lowest error should be 4.1, but it's higher

4.174387269895637

In [34]:
output = decode(model.generate(index = torch.zeros((1, 1), dtype=torch.long), max_new_token=100)[0].tolist())
output = "".join(output)
output

"\n&rEnLTjLDJIcLVR'JIHDTHdhsV\nv\nwxh,nhUYZzAEOZHpgo3q3ZYZes$zuGw,;eMk QqACRfCLgxiW3.O!zDLgA YsTb!dHb!;pK"

In [35]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [36]:
epochs = 10000

In [37]:
def train(epochs, optimizer):
    for i in range(epochs):
        x, y = get_batch("train")
        output, loss = model(x, y)
        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            print(f"{i}, loss: {loss}")

In [38]:
train(epochs, optimizer)

0, loss: 4.3336710929870605
1000, loss: 4.190601825714111
2000, loss: 3.6422181129455566
3000, loss: 3.4271457195281982
4000, loss: 3.514732837677002
5000, loss: 3.316194534301758
6000, loss: 3.0088391304016113
7000, loss: 3.1258537769317627
8000, loss: 3.1442017555236816
9000, loss: 3.0557634830474854


In [39]:
output = decode(model.generate(index = torch.zeros((1, 1), dtype=torch.long), max_new_token=100)[0].tolist())
output = "".join(output)
output

'\n\n\nrmr hotk\neIL;,\n\ntfTBBZYLst: rtGhBx&Lmkko\nugTII\nee;f \n  Let ii:Lgos n n \n\ne  r;QzQPrmo Lht somLKn I'

In [40]:
a = torch.tril(torch.ones(3,3))
a

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [41]:
a = a / torch.sum(a, 1, keepdims=True)

In [42]:
b = torch.randint(0,10,(3,3)).float()
b

tensor([[3., 4., 1.],
        [9., 3., 4.],
        [7., 1., 6.]])

In [43]:
result = a @ b
result

tensor([[3.0000, 4.0000, 1.0000],
        [6.0000, 3.5000, 2.5000],
        [6.3333, 2.6667, 3.6667]])

In [44]:
T = 8

In [45]:
tril = torch.tril(torch.ones(4,8))

In [46]:
wei = torch.zeros((4,8))

In [47]:
wei = wei.masked_fill(tril == 0, float("-inf"))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf]])

In [48]:
wei = F.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000]])

In [49]:
B,T,C = 4,8,32 # batch, time, channels

In [50]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size)
value = nn.Linear(C, head_size)

In [51]:
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 32])

In [52]:
k = key(x)
q = query(x)
v = value(x)

In [53]:
q.shape, k.shape

(torch.Size([4, 8, 16]), torch.Size([4, 8, 16]))

In [54]:
query_key = q @ k.transpose(-1,-2) * head_size**-0.5

In [55]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5
k.var(), q.var(), v.var()

(tensor(1.0508), tensor(1.0666), tensor(0.3416, grad_fn=<VarBackward0>))

In [56]:
query_key.shape

torch.Size([4, 8, 8])

In [57]:
query_key.std(), query_key.mean()

(tensor(0.3160, grad_fn=<StdBackward0>),
 tensor(0.0101, grad_fn=<MeanBackward0>))

In [58]:
d_k = q.size(-1)
d_k

16

In [59]:
math.sqrt(d_k)

4.0

In [60]:
query_key_scaled = query_key / torch.sqrt(torch.tensor(d_k))

In [61]:
query_key_scaled.mean(), query_key_scaled.std()

(tensor(0.0025, grad_fn=<MeanBackward0>),
 tensor(0.0790, grad_fn=<StdBackward0>))

In [62]:
tril = torch.tril(torch.ones(T,T))

In [63]:
wei = tril.masked_fill(tril == 0, float("-inf"))

In [64]:
wei = wei.softmax(-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [65]:
wei.shape

torch.Size([8, 8])

In [66]:
v.shape

torch.Size([4, 8, 16])

In [67]:
v.permute(0, 2, 1).shape

torch.Size([4, 16, 8])

In [68]:
result = wei @ v
result.shape
#This is the output of the attention mechanism for one head

torch.Size([4, 8, 16])

In [69]:
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.4

In [70]:
class Head(nn.Module):
    def __init__(self, head_size):
        super(Head, self).__init__()
        self.key = nn.Linear(n_embd, head_size)
        self.value = nn.Linear(n_embd, head_size)
        self.query = nn.Linear(n_embd, head_size)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        
        wei = q @ k.transpose(-1,-2) * head_size**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei.softmax(-1)
        result = wei @ v
        return result

In [71]:
head = Head(head_size)

In [72]:
xb, yb = get_batch("train")
xb.shape

torch.Size([16, 32])

In [73]:
a = torch.randn(8,32,64)
b = torch.randn(32,64)

In [74]:
a_b = (a + b)
a_b.shape

torch.Size([8, 32, 64])

In [75]:
emb = nn.Embedding(32,64)

In [76]:
a = torch.randint(10, (32,64))

In [77]:
a_emb = emb(a)
a_emb.shape

torch.Size([32, 64, 64])

In [78]:
xb.shape, y.shape

(torch.Size([16, 32]), torch.Size([4, 8]))

In [79]:
class MultiHeadAttention(nn.Module):
    def __init__(self, head_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.heads = ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [80]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.net(x)

In [81]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super(Block, self).__init__()
        head_size = n_embd // n_head
        self.mha = MultiHeadAttention(head_size, n_head)
        self.feed_forward = FeedForward(n_embd)
        self.layer_norm1 = nn.LayerNorm(n_embd)
        self.layer_norm2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        x = x + self.mha(self.layer_norm1(x))
        x = x + self.feed_forward(self.layer_norm2(x))
        return x

In [82]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.token_embeddings = nn.Embedding(len(vocab_size), n_embd)
        self.position_embeddings = nn.Embedding(block_size, n_embd)
        
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.layer_norm = nn.LayerNorm(len(vocab_size)-1)
        self.feed_forward = nn.Linear(n_embd, len(vocab_size))
    
    def forward(self, idx, target=None):
        B, T = idx.shape
        tok_embd = self.token_embeddings(idx)
        position_embeddings = self.position_embeddings(torch.arange(T))
        
        x = tok_embd + position_embeddings
        x = self.blocks(x)
        x = self.layer_norm(x)
        logits = self.feed_forward(x)
        
        if target is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(batch_size * block_size, -1), target.view(-1))
        return logits, loss
    
    def generate(self, index, max_new_token):
        for i in range(max_new_token):
            idx_cond = index[:, -block_size:]
            output, loss = self(idx_cond)
            output = output[:, -1, :].softmax(-1)
            index_next = torch.multinomial(output, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index

In [83]:
model = Model()

In [84]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [85]:
for i in range(max_iters):
    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i % 500 == 0:
        print(f"iteration: {i}| loss: {loss}")

iteration: 0| loss: 4.350693225860596
iteration: 500| loss: 2.384089946746826
iteration: 1000| loss: 2.271000623703003
iteration: 1500| loss: 2.08839750289917
iteration: 2000| loss: 1.972941279411316
iteration: 2500| loss: 1.947158694267273
iteration: 3000| loss: 1.9004379510879517
iteration: 3500| loss: 1.864082932472229
iteration: 4000| loss: 1.8159171342849731
iteration: 4500| loss: 1.9053051471710205
iteration: 5000| loss: 1.9047762155532837
iteration: 5500| loss: 1.7336891889572144
iteration: 6000| loss: 1.9031013250350952
iteration: 6500| loss: 1.7895954847335815
iteration: 7000| loss: 1.7797973155975342
iteration: 7500| loss: 1.6494497060775757
iteration: 8000| loss: 1.711073398590088
iteration: 8500| loss: 1.7032816410064697
iteration: 9000| loss: 1.660027027130127
iteration: 9500| loss: 1.6481990814208984


In [86]:
output = decode(model.generate(index = torch.zeros((1, 1), dtype=torch.long), max_new_token=2000)[0].tolist())
output = "".join(output)

In [87]:
content = "".join(l for l in output.splitlines() if l)

In [88]:
content

"You. Tirdan, comy be asheek leavelsTo thou hard in for asons'sbeindervy, shall our usper: 't my moralt,Beggelanst in emy appoply tenter that me within me,Which is a side majest: have! Relorsle, I up,No slep hanraten hele; have day.MENENE:You by with man, to foe? mastillont, hasts mon my Badied on my heare brother.To belight all stastaling my blet'sMy lord not day come tourve to dobe pat inpatianter; thou the, livent, but he's forge!No, in death sake him subjel'r: the by slaungeth, shall bethston her.Thy Seafferwerver:'Thon sloy shall speak hearst!VIR HORDWARD II:As say wharge, go turn, day, nothin reving of he battle imabe! bod less bouthde, so nemer wear map would look eyor king it contory.CORTIOLALUS:Fame is netsenteloust are ladtly,Sto blary letterable moss will all'tt, and thou an lookip sents, by lifest an of frommen!Gol, have to as geol, when no me.O me One my stilf ours.Tgrieves porens, gold mean a paleBetwell thy pet I sury was in gaes.DY:Second! A merve oge?ANNO:Make sail tru