In [215]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import math

In [17]:
with open("text.txt","r") as file:
    text = file.read()

In [35]:
itos = {index: char for index, char in enumerate(vocab_size)}

In [33]:
vocab_size = sorted(set(text))

In [37]:
stoi = {index: char for char, index in enumerate(vocab_size)}

In [50]:
encode = lambda x: [stoi[s] for s in x]

In [380]:
decode = lambda x: [itos[s] for s in x]

In [381]:
encoded = encode("hii there")
encoded

[46, 47, 47, 1, 58, 46, 43, 56, 43]

In [382]:
decoded = decode(encoded)
decoded

['h', 'i', 'i', ' ', 't', 'h', 'e', 'r', 'e']

In [59]:
data = torch.tensor(encode(text), dtype=torch.long)

In [63]:
train_data = data[:int(0.9 * len(data))]

In [64]:
val_data = data[int(0.9 * len(data)):]

In [65]:
assert len(data) == len(train_data) + len(val_data)

In [66]:
block_size = 8

In [68]:
x = train_data[:block_size]
x

tensor([18, 47, 56, 57, 58,  1, 15, 47])

In [70]:
y = train_data[block_size]
y

tensor(58)

In [93]:
block_size = 8
batch_size = 4

In [94]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7fd3cbd4fcf0>

In [95]:
#We have x as an input and we have the target number as y

In [96]:
for i in range(1, len(x) - 1):
    inp = x[:i]
    out = x[i+1]
    print(f"inp is: {inp}; out is: {out}")

inp is: tensor([18]); out is: 56
inp is: tensor([18, 47]); out is: 57
inp is: tensor([18, 47, 56]); out is: 58
inp is: tensor([18, 47, 56, 57]); out is: 1
inp is: tensor([18, 47, 56, 57, 58]); out is: 15
inp is: tensor([18, 47, 56, 57, 58,  1]); out is: 47


In [154]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    index = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([train_data[i:i+block_size] for i in index])
    y = torch.stack([train_data[i+1:i+block_size+1] for i in index])
    return x,y 

In [155]:
index = torch.randint(len(train_data) - block_size, (batch_size, ))
index

tensor([894151, 111683, 537766, 851008])

In [156]:
x = torch.stack([train_data[i:i+block_size] for i in index])

In [157]:
y = torch.stack([train_data[i+1:i+block_size+1] for i in index])

In [158]:
x

tensor([[ 1, 39, 52, 42,  1, 61, 56, 53],
        [46, 58,  1, 58, 53,  1, 46, 39],
        [39, 58,  1, 57, 43, 43, 57,  1],
        [39, 52, 42,  1, 58, 46, 43,  0]])

In [159]:
y

tensor([[39, 52, 42,  1, 61, 56, 53, 52],
        [58,  1, 58, 53,  1, 46, 39, 60],
        [58,  1, 57, 43, 43, 57,  1, 47],
        [52, 42,  1, 58, 46, 43,  0, 60]])

In [160]:
xb, yb = get_batch("train")

In [365]:
class Model(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, index, targets=None):
        output = self.embedding_layer(index)
        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(output.view(batch_size * block_size, -1), yb.view(-1))
        
        return output, loss
    
    def generate(self, index, max_new_token):
        for i in range(max_new_token):
            output, loss = self(index)
            output = output[:, -1, :].softmax(-1)
            index_next = torch.multinomial(output, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index

In [366]:
model = Model(len(vocab_size))

In [367]:
logits, loss = model(xb, yb)

In [368]:
logits.shape

torch.Size([4, 8, 65])

In [373]:
-math.log(1/len(vocab_size)) #So the model is very convoluted at the moment
#The lowest error should be 4.1, but it's higher

4.174387269895637

In [386]:
output = decode(model.generate(index = torch.zeros((1, 1), dtype=torch.long), max_new_token=100)[0].tolist())
output = "".join(output)
output

"\n!Nreh;fSYtuvtuqaAkSH\ngrTVUg'atuUzCN, efnlUHcRxoYhNZU;utDEEFFXUCzuXKGaLCFhp!T?g;aULdLR.syWAn,LGGeeYEq"