In [1]:
import torch
import json
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [14]:
VOCAB_SIZE = 27
SOURCE = "names.txt"
SEP = "."
text = open(SOURCE, encoding="utf-8").read()
s_to_i = {chr(97 + i): i + 1 for i in range(VOCAB_SIZE - 1)}
s_to_i[SEP] = 0
i_to_s = {v: k for k, v in s_to_i.items()}

In [15]:
# basic tokenizer, encodes text into tokens(int) and
# decodes tokens(int) back to text

def encode(text):
    return [s_to_i[char] for char in text]

def decode(tokens):
    if isinstance(tokens, int):
        return i_to_s[tokens]
    return "".join([i_to_s[token] for token in tokens])

In [17]:
Xs, Ys = [], []
H_LAYERS, H_DIM = 2, VOCAB_SIZE * 2

with open(SOURCE, encoding="utf-8") as f:
    for name in f.read().splitlines():
        name = [SEP] + list(name) + [SEP]
        for c1, c2 in zip(name, name[1:]):
            Xs.append(s_to_i[c1])
            Ys.append(s_to_i[c2])

Xs, Ys = torch.tensor(Xs), torch.tensor(Ys)
# basic embedding, converts a token into a one hot vector
# of size VOCAB_SIZE(i.e 27)
x_emb = F.one_hot(Xs, num_classes=VOCAB_SIZE).float()
g = torch.Generator().manual_seed(1337)
W0 = torch.randn((VOCAB_SIZE, H_DIM), generator=g, requires_grad=True)
Ws = [torch.randn((H_DIM, H_DIM), generator=g) for _ in range(H_LAYERS)]
W1 = torch.randn((H_DIM, VOCAB_SIZE), generator=g, requires_grad=True)
losses = []

In [None]:
EPOCHS, LR, ALPHA = 200, 2, 0.01

for e in range(EPOCHS):
    for w in [W0] + Ws + [W1]:
        x = x @ w
    logits = x @ W1
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    reg = torch.tensor([(w ** 2).mean() for w in [W1] + Ws + W1])
    loss = -probs[torch.arange(Xs.nelement()), Ys].log().mean()
    loss.backward()
