In [1]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random
%matplotlib inline

In [2]:
with open("data/names.txt") as f:
    names = f.read().splitlines()
random.shuffle(names)

In [3]:
def build_ngram(s, n):
    res = []
    l = ['.'] + list(s) + ['.']
    for i in range(len(l)-n+1):
        res.append(tuple(l[i:i+n]))
    return res

In [4]:
ctoi = {chr(ord('a') + i): i+1 for i in range(26)}
ctoi['.'] = 0

In [5]:
itoc = {v: k for k,v in ctoi.items()}

In [6]:
xs, ys = [], []
for s in names:
    for ngram in build_ngram(s, 3):
        x = ngram[0:2]
        y = ngram[2]
        xs.append((ctoi[x[0]], ctoi[x[1]]))
        ys.append(ctoi[y])
xs, ys = torch.tensor(xs), torch.tensor(ys)

In [7]:
ntrain = int(ys.shape[0] * .8)
ndev = int(ys.shape[0] * .9)

x_train, y_train = xs[:ntrain], ys[:ntrain]
x_dev, y_dev = xs[ntrain:ndev], ys[ntrain:ndev]
x_test, y_test = xs[ndev:], ys[ndev:]

In [8]:
xenc = F.one_hot(x_train, num_classes=27).float()
W = torch.randn((27*2, 27), requires_grad=True)
for k in range(500):
    logits = xenc.view(-1,27*2) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(y_train.shape[0]), y_train].log().mean() + 0.01*(W**2).mean()
    if k % 100 == 0:
        print(loss.item())
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

4.2937092781066895
2.2731430530548096
2.2601845264434814
2.256619930267334
2.2551889419555664


In [36]:
xenc = F.one_hot(x_test, num_classes=27).float()
logits = xenc.view(-1,27*2) @ W
counts = logits.exp()
probs = counts / counts.sum(1, keepdims=True)
loss = -probs[torch.arange(y_test.shape[0]), y_test].log().mean() + 0.01*(W**2).mean()
print(loss.item(), (F.cross_entropy(logits, y_test) + 0.01*(W**2).mean()).item())

2.255495309829712 2.255495309829712


In [10]:
xenc = F.one_hot(x_dev, num_classes=27).float()
logits = xenc.view(-1,27*2) @ W
counts = logits.exp()
probs = counts / counts.sum(1, keepdims=True)
loss = -probs[torch.arange(y_dev.shape[0]), y_dev].log().mean() + 0.01*(W**2).mean()
print(loss.item())

# .01 2.257039785385132
# .001 2.258058786392212
# .1 2.2723448276519775
# .0001 2.2585158348083496

2.2494254112243652


In [28]:
for i in range(5):
    out = ['.', '.']
    ix = 0
    while True:
        xenc_in = F.one_hot(torch.tensor([(ctoi[out[-2]], ctoi[out[-1]])]), num_classes=27).float()
        logits = xenc_in.view(-1, 27*2) @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)
        ix = torch.multinomial(probs, num_samples=1, replacement=True).item()
        out.append(itoc[ix])
        if ix == 0:
            break
    print("".join(out[2:]))

orteliasin.
olishirahan.
uon.
uan.
myrla.
