In [1]:
import random
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim

In [2]:
BATCH_SIZE = 32
CONTEXT_SIZE = 64
EMBEDDING_DIMS = 1024
HIDDEN_LAYERS = 5
HIDDEN_DIMS = 1024
EPOCHS = 100000
LR = 0.01

DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
with open("tiny_shakespeare.txt") as f:
    data = f.read()
chars = list(set(data))
VOCAB_SIZE = len(chars)
s_to_i = {s: i for i, s in enumerate(chars)}
i_to_s = {v: k for k, v in s_to_i.items()}

def encode(text):
    return [s_to_i[c] for c in text]

def decode(ids):
    return "".join([i_to_s[i] for i in ids])

data = torch.tensor(encode(data), device=DEV)
train_data = data[:int(len(data) * 0.9)]
val_data = data[int(len(data) * 0.9):]

def get_batch(split):
    data = train_data if split == "train" else val_data
    idxs = [random.randint(0, data.shape[-1] - CONTEXT_SIZE - 1) for _ in range(BATCH_SIZE)]
    x = torch.stack([data[idx: idx + CONTEXT_SIZE] for idx in idxs])
    y = torch.stack([data[idx + 1: idx + 1 + CONTEXT_SIZE] for idx in idxs])
    return x, y

In [4]:
class FCL(nn.Module):
    def __init__(self, fan_in=HIDDEN_DIMS, fan_out=HIDDEN_DIMS):
        super().__init__()
        self.layers = [
            nn.Linear(fan_in, fan_out),
            nn.ReLU(inplace=True)
        ]
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [5]:
class LM(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Embedding(VOCAB_SIZE, EMBEDDING_DIMS),
            FCL(EMBEDDING_DIMS, HIDDEN_DIMS),
            *[FCL() for _ in range(HIDDEN_LAYERS)],
            nn.Linear(HIDDEN_DIMS, VOCAB_SIZE)
        )

    def forward(self, tokens):
        return self.layers(tokens)
    
    def generate(self, context="\n", new_tokens=100):
        context = torch.tensor([encode(context)])
        out = []
        while len(out) < new_tokens:
            logits = net(context)
            probs = F.softmax(logits, dim=-1)[:, -1:, :]
            next = torch.multinomial(probs.view(-1, VOCAB_SIZE), num_samples=1)
            out.append(next.item())
            print(decode([next.item()]), end="")
            context = torch.cat((context, next), dim=-1)
        return decode(out)
    
    
net = LM()

In [None]:
# sampling without training
net.generate("Fir", 100)

In [13]:
@torch.no_grad()
def get_split_losses():
    net.eval()
    losses = {}
    for split in ["train", "val"]:
        x, y = get_batch(split)
        logits = net(x)
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
        losses[split] = loss.item()
    return losses.values()

In [None]:
print("%d trainable parameters" % (sum(p.nelement() for p in net.parameters())))
print("- " * 10)
optimizer = optim.AdamW(net.parameters(), lr=LR)

for e in range(1, EPOCHS + 1):
    if e % (EPOCHS / 25) == 0:
        LR /= 2
        for group in optimizer.param_groups:
            group["lr"] = LR
    x, y, = get_batch("train")
    logits = net(x)
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1] ), y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    t_loss, v_loss = get_split_losses()
    print("%d: lr: %.4f, train loss: %.4f, val loss: %.4f" % (e, LR, t_loss, v_loss))