In [1]:
import torch
import torch.nn.functional
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:

words = open('names.txt', 'r').read().splitlines()
words[:8]
len(words)

32033

In [3]:
chars = [chr(i) for i in range(ord('a'), ord('z')+1)]
stoi = {s: i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [4]:
X, Y = [], []
for w in words:
    chs = list(f"...{w}.")
    for i in range(len(chs)-3):
        x1, x2, x3, y =  chs[i:i+4]
        # print("".join([x1, x2, x3]), "--->", y)
        X.append((stoi[x1],stoi[x2],stoi[x3]))
        Y.append(stoi[y])

n = len(X)
g = torch.Generator().manual_seed(2147483647)
perm = torch.randperm(n, generator=g)

n_train = int(0.8 * n)
n_dev = int(0.1 * n)

Xtrain, Ytrain = torch.tensor([X[i] for i in perm[:n_train]], dtype=torch.long), torch.tensor([Y[i] for i in perm[:n_train]], dtype=torch.long)
Xdev, Ydev = torch.tensor([X[i] for i in perm[n_train:n_train + n_dev]], dtype=torch.long), torch.tensor([Y[i] for i in perm[n_train:n_train + n_dev]], dtype=torch.long)
Xtest, Ytest = torch.tensor([X[i] for i in perm[n_train + n_dev:]], dtype=torch.long), torch.tensor([Y[i] for i in perm[n_train + n_dev:]], dtype=torch.long)


In [5]:
print(Xtrain.shape, Ytrain.shape)
print(Xdev.shape, Ydev.shape)
print(Xtest.shape, Ytest.shape)

torch.Size([182516, 3]) torch.Size([182516])
torch.Size([22814, 3]) torch.Size([22814])
torch.Size([22816, 3]) torch.Size([22816])


In [None]:
from dataclasses import dataclass, field

@dataclass
class Params:
    seed: int = 2147483647
    C: torch.Tensor = field(init=False)
    W1: torch.Tensor = field(init=False)
    b1: torch.Tensor = field(init=False)
    W2: torch.Tensor = field(init=False)
    b2: torch.Tensor = field(init=False)

    def __post_init__(self):
        g = torch.Generator().manual_seed(self.seed)
        self.C = torch.randn((27, 10), generator=g) * 0.01
        self.W1 = torch.randn((30, 100), generator=g) * 0.01
        self.b1 = torch.zeros(100)
        self.W2 = torch.randn((100, 27), generator=g) * 0.01
        self.b2 = torch.zeros(27)
        for tensor in self.learnable_tensors():
            tensor.requires_grad_(True)

    def learnable_tensors(self):
        return [self.C, self.W1, self.b1, self.W2, self.b2]


import torch.nn.functional as F

def train_params(p: Params, learn_rate=0.1, batch_size=(32,)):
    for i in range(200000):
        ix = torch.randint(Xtrain.shape[0], batch_size)
        emb = p.C[Xtrain[ix]]
        h = torch.tanh(emb.view(emb.shape[0], -1) @ p.W1 + p.b1)
        logits = h @ p.W2 + p.b2
        loss = F.cross_entropy(logits, Ytrain[ix])

        for tensor in p.learnable_tensors():
            tensor.grad = None
        loss.backward()

        lr = learn_rate if i < 100_000 else learn_rate * 1e-1
        for tensor in p.learnable_tensors():
            if tensor.grad is not None:
                tensor.data -= lr * tensor.grad

    return loss.item()


def validation_loss(p: Params):
    with torch.no_grad():
        emb = p.C[Xdev]
        h = torch.tanh(emb.view(emb.shape[0], -1) @ p.W1 + p.b1)
        logits = h @ p.W2 + p.b2
        return F.cross_entropy(logits, Ydev).item()


params = Params()
lr = 0.1
nump = sum(t.nelement() for t in params.learnable_tensors())
print(f"{lr=}, num params={nump}")
loss = train_params(params, learn_rate=lr)
vloss = validation_loss(params)
print(f"train loss={loss:.4f}, dev loss={vloss:.4f}")


lr=0.02, num params=6097


train loss=2.4083, dev loss=2.2388
lr=0.05, num params=6097
train loss=2.4298, dev loss=2.1698
lr=0.1, num params=6097
train loss=1.9910, dev loss=2.1530


In [None]:
block_size = 3
g = torch.Generator().manual_seed(2147483647 + 10)
for _ in range(20):
    out = []
    context = [0] * block_size
    while True:
        emb = params.C[torch.tensor([context])] # (1, block_size, 10)
        h = torch.tanh(emb.view(emb.shape[0], -1) @ params.W1 + params.b1)
        logits = h @ params.W2 + params.b2
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break
    print(''.join(itos[i] for i in out))