In [40]:
import random
import torch
import torch.nn.functional as F

In [42]:
words = open('makemore-p1/names.txt', 'r').read().splitlines()
random.seed(2147483647)
random.shuffle(words)
n = len(words)

n_train = int(n * 0.8)
n_dev = int(n * 0.1)

words_train = words[:n_train]
words_dev = words[n_train:n_train+n_dev]
words_test = words[n_train+n_dev:]

chars = sorted(set(''.join(words)))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

In [43]:
def build_xy(words, stoi):
    xs, ys = [], []
    for w in words[:]:
        chs = ['.'] + list(w) + ['.']
        for ch1, ch2 in zip(chs, chs[1:]):
            xs.append(stoi[ch1])
            ys.append(stoi[ch2])
    return torch.tensor(xs), torch.tensor(ys)

def data_nll(xs, ys, W):
    logits = W[xs].float()
    return F.cross_entropy(logits, ys)

def loss_calc(xs, ys, W):
    return data_nll(xs, ys, W) + 0.1 * (W**2).mean()

def train_step(xs, ys, W, lr=1.0, cycle=10):
    for _ in range(cycle):
        loss = loss_calc(xs, ys, W)
        W.grad = None
        loss.backward()
        W.data += -lr * W.grad
    return loss.item()

def generate(W, seed=2147483647, num_samples=5):
    g = torch.Generator().manual_seed(seed)
    for i in range(num_samples):
        out = []
        ix = 0
        while True:
            logits = W[ix].unsqueeze(0)
            p = logits.softmax(dim=1).squeeze(0)
            ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
            if ix == 0:
                break
            else:
                out.append(itos[ix])
        print(''.join(out))

In [44]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

xs_tr, ys_tr = build_xy(words_train, stoi)
xs_dev, ys_dev = build_xy(words_dev, stoi)
xs_te, ys_te = build_xy(words_test, stoi)

train_step(xs_tr, ys_tr, W, 1.0, 1000)

2.6525280475616455

In [45]:
train_nll = loss_calc(xs_tr, ys_tr, W)
dev_nll = loss_calc(xs_dev, ys_dev, W)
test_nll = loss_calc(xs_te, ys_te, W)

print(f"Train NLL: {train_nll}")
print(f"Dev NLL: {dev_nll}")
print(f"Test NLL: {test_nll}")

Train NLL: 2.65242075920105
Dev NLL: 2.647392511367798
Test NLL: 2.6453683376312256


In [46]:
generate(W, 2147483647, 5)

cexza
mogllurailezktyha
kllimittain
llayn
ka
