In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from tqdm import tqdm

file_path = Path('data') / 'tinyshakespeare' / 'input.txt'
with open(file_path, 'r') as f:
    text = f.read()

print(f"{len(text) = }")

len(text) = 1115394


In [2]:
chars = sorted(set(text))
print(f"{len(chars) = }")
print(repr(''.join(chars)))

len(chars) = 65
"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"


In [3]:
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: [itos[i] for i in l]
''.join(decode(encode('hello world')))

'hello world'

In [4]:
data = torch.tensor(encode(text))
data.shape

torch.Size([1115394])

In [5]:
idx = int(len(data) * .9)
train_data = data[:idx]
val_data   = data[idx:]

train_data.shape, val_data.shape

(torch.Size([1003854]), torch.Size([111540]))

In [6]:
block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]

print('context -> target')
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(context, '->', target)

context -> target
tensor([18]) -> tensor(47)
tensor([18, 47]) -> tensor(56)
tensor([18, 47, 56]) -> tensor(57)
tensor([18, 47, 56, 57]) -> tensor(58)
tensor([18, 47, 56, 57, 58]) -> tensor(1)
tensor([18, 47, 56, 57, 58,  1]) -> tensor(15)
tensor([18, 47, 56, 57, 58,  1, 15]) -> tensor(47)
tensor([18, 47, 56, 57, 58,  1, 15, 47]) -> tensor(58)


In [7]:
class SlidingCharacterDataset(Dataset):
    def __init__(self, data: torch.tensor, block_size=8):
        assert data.dim() == 1
        self.data = data
        self.block_size = block_size
    
    def __len__(self):
        return len(self.data) - block_size
    
    def __getitem__(self, idx):
        return (
            self.data[idx:idx+self.block_size],
            self.data[idx+1:idx+self.block_size+1]
        )

In [8]:
train_dataset = SlidingCharacterDataset(data=train_data, block_size=8)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
x, y = next(iter(train_dataloader))
x, y

(tensor([[ 1, 47, 52, 55, 59, 47, 56, 43],
         [52, 42,  1, 40, 43, 63, 53, 52],
         [41, 50, 53, 59, 42, 57,  1, 39],
         [45,  1, 39, 52, 42,  1, 44, 53]]),
 tensor([[47, 52, 55, 59, 47, 56, 43, 11],
         [42,  1, 40, 43, 63, 53, 52, 42],
         [50, 53, 59, 42, 57,  1, 39, 56],
         [ 1, 39, 52, 42,  1, 44, 53, 56]]))

In [9]:
# the first index of a bigram is used as a context and the second is used as a target

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size = len(chars)):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, x):
        x  = self.token_embedding_table(x)
        return x

def generate(
        model: nn.Module,
        start_token: torch.tensor = torch.zeros(4, 1, dtype=torch.long), # B, T
        max_iter: int = 100
        ):
    sequence = start_token
    for _ in range(max_iter):
        logits = model(sequence[:,-1])
        proba = F.softmax(logits, dim=1)
        pick = torch.multinomial(proba, num_samples=1)
        sequence = torch.cat([sequence, pick], dim=1)
    return sequence

model = BigramLanguageModel()
y_hat = model(x)
print(f"{x.shape     = }")
print(f"{y.shape     = }")
print(f"{y_hat.shape = }")

sequence = generate(model)
[''.join(decode(l)) for l in sequence.tolist()]

x.shape     = torch.Size([4, 8])
y.shape     = torch.Size([4, 8])
y_hat.shape = torch.Size([4, 8, 65])


["\n\nggEU:KWq;Sr!thQ:3n-ayxeAz$xcTJhCwsw3X\n.jGeVnffL-zANYwRG&$ep?Jy mp.kyqIWd'T;n-w\nMhAAQDYprbR3Su3ZNVVX",
 "\nsP hyb3WKh!FdyiWN?$I$NqI.FXKxxjC3lb;soba'SfuG&l,hwtf cWHl!PXadeSJ?Jy,a!zrs'aSG3xmakKkBcZZUixQN?u3aDi",
 "\nwtzsyavgI:eQ'$pDv-cW mmbWaQRrHy\n3X3kduWNFAY-.AB\nMd-ZaZcdN-WciDH.&Aabw-$W:KyLg$pNW'vG3z$lbwq3pgNB:\nIx",
 "\nZoXdt 'qI-wn-pgsIG'HF'zsYZLsFqI-Z---C &tDW&-dn--u33wbmAVY-z&?dBkQVKaDnVfh.uho;ZET'yn-OLPy&'iQPB:j,gs"]

In [10]:
F.cross_entropy(y_hat.view(-1, len(chars)), y.view(-1))

tensor(4.7419, grad_fn=<NllLossBackward0>)

In [31]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

epoch = 5
model.train()
for e in range(epoch):
    running_loss = 0
    for i, (x, y) in enumerate(DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True), start=1):
        optimizer.zero_grad()

        y_hat = model(x)
        loss = F.cross_entropy(y_hat.view(-1, len(chars)), y.view(-1))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print(f"[epoch {e+1:>3}/{epoch}][train loss] {running_loss/i:>10.5f}", end='\r')
    print()

[epoch   1/5][train loss]    2.45248
[epoch   2/5][train loss]    2.45254
[epoch   3/5][train loss]    2.45256
[epoch   4/5][train loss]    2.45257
[epoch   5/5][train loss]    2.45256


In [38]:
sequence = generate(model, start_token=torch.zeros(1, 1, dtype=torch.long), max_iter=200)
print(''.join(decode(sequence[0].tolist())))


Betig ape g dle.
N:
Indeyofo andifoull n
Wivopand.
ULofango y gaviser fa'd hir I.

As chind pu f the ind be, pugind, we
ICHES:
Th fo:
Thal mind luck ono lls t ire; tire,
ENGothant,
Whton werod han y h
