In [187]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

file_path = Path('data') / 'tinyshakespeare' / 'input.txt'
with open(file_path, 'r') as f:
    text = f.read()

print(f"{len(text) = }")

len(text) = 1115394


In [84]:
chars = sorted(set(text))
print(f"{len(chars) = }")
print(repr(''.join(chars)))

len(chars) = 65
"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"


In [85]:
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: [itos[i] for i in l]
''.join(decode(encode('hello world')))

'hello world'

In [86]:
data = torch.tensor(encode(text))
data.shape

torch.Size([1115394])

In [87]:
idx = int(len(data) * .9)
train_data = data[:idx]
val_data   = data[idx:]

train_data.shape, val_data.shape

(torch.Size([1003854]), torch.Size([111540]))

In [88]:
block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]

print('context -> target')
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(context, '->', target)

context -> target
tensor([18]) -> tensor(47)
tensor([18, 47]) -> tensor(56)
tensor([18, 47, 56]) -> tensor(57)
tensor([18, 47, 56, 57]) -> tensor(58)
tensor([18, 47, 56, 57, 58]) -> tensor(1)
tensor([18, 47, 56, 57, 58,  1]) -> tensor(15)
tensor([18, 47, 56, 57, 58,  1, 15]) -> tensor(47)
tensor([18, 47, 56, 57, 58,  1, 15, 47]) -> tensor(58)


In [97]:
class SlidingCharacterDataset(Dataset):
    def __init__(self, data: torch.tensor, block_size=8):
        assert data.dim() == 1
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - block_size
    
    def __getitem__(self, idx):
        return (
            self.data[idx:idx+self.block_size],
            self.data[idx+1:idx+self.block_size+1]
        )

In [111]:
train_dataset = SlidingCharacterDataset(data=train_data, block_size=8)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
x, y = next(iter(train_dataloader))
x, y

(tensor([[49, 43, 50, 63,  1,  5, 58, 47],
         [58,  0, 33, 52, 51, 43, 56, 47],
         [51,  6,  1, 58, 53,  1, 51, 39],
         [ 1, 58, 46, 47, 57,  1, 39, 58]]),
 tensor([[43, 50, 63,  1,  5, 58, 47, 57],
         [ 0, 33, 52, 51, 43, 56, 47, 58],
         [ 6,  1, 58, 53,  1, 51, 39, 49],
         [58, 46, 47, 57,  1, 39, 58, 58]]))

In [186]:
# the first index of a bigram is used as a context and the second is used as a target

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size = len(chars)):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, x):
        x  = self.token_embedding_table(x)
        return x

def generate(
        model: nn.Module,
        start_token: torch.tensor = torch.zeros(4, 1, dtype=torch.long), # B, T
        max_iter: int = 100
        ):
    sequence = start_token
    for _ in range(max_iter):
        logits = model(sequence[:,-1])
        proba = F.softmax(logits, dim=1)
        pick = torch.multinomial(proba, num_samples=1)
        sequence = torch.cat([sequence, pick], dim=1)
    return sequence

model = BigramLanguageModel()
y_hat = model(x)
print(f"{x.shape     = }")
print(f"{y.shape     = }")
print(f"{y_hat.shape = }")

sequence = generate(model)
[''.join(decode(l)) for l in sequence.tolist()]

x.shape     = torch.Size([4, 8])
y.shape     = torch.Size([4, 8])
y_hat.shape = torch.Size([4, 8, 65])


["\noA!ptlZ:Fa&S\nxoKNY jVHkSpOCsjEonGIXnb$OXYCYwnV'QAKTXC&&GDMRaTwkaPUj$rOK\nO P o\nIlCM.ZjjuDL.IJ\nfqwbAxV",
 "\n':.OfCg'E,dfK!;$\n\nIfpcjKZrYEAo;XKh,bfi;SpTrrNyVlk\nnWUUZGMGRtlEAAYkRxd,rOTwAl\n3gVUC.'OXCglm!3tNU,Hptv",
 "\n\nU!tykWg.k !AwZcj:FH;oUU?XSj&VTXjVR\nkck;'vXPf\nNb\nsKO cYkIqZqpAmxLyU3XhJd\n&$HE$Jcr vs,3?yUGKyK,EJaLaC",
 "\n'KLCg$O'TpT:3?IZGAQv-'gaWAQyG'hghsvOoZYR:FJsrPc&j.PT.CiFcA\nN$oZASfCLtrViFLTP3PYErstrdKRu\nN!CgzVCvuyb"]

In [148]:
F.cross_entropy(y_hat.view(-1, len(chars)), y.view(-1))

tensor(4.7857, grad_fn=<NllLossBackward0>)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
