In [44]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
%matplotlib inline

In [2]:
with open("data/names.txt") as f:
    names = f.read().splitlines()

In [30]:
def build_ngram(s, n):
    res = []
    l = ['.'] + list(s) + ['.']
    for i in range(len(l)-n+1):
        res.append(tuple(l[i:i+n]))
    return res

In [4]:
ctoi = {chr(ord('a') + i): i+1 for i in range(26)}
ctoi['.'] = 0

In [5]:
itoc = {v: k for k,v in ctoi.items()}

In [151]:
xs, ys = [], []
for s in names:
    for ngram in build_ngram(s, 3):
        x = ngram[0:2]
        y = ngram[2]
        xs.append(ctoi[x[0]]*27 + ctoi[x[1]])
        ys.append(ctoi[y])
xs, ys = torch.tensor(xs), torch.tensor(ys)
num = xs.nelement()

In [152]:
xenc = F.one_hot(xs, num_classes=27*27).float()

In [155]:
W = torch.randn((27*27, 27), requires_grad=True)

In [156]:
for k in range(100):
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    print(loss.item())
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

3.830918550491333
3.746774196624756
3.668682336807251
3.595921754837036
3.528170108795166
3.4652700424194336
3.4070959091186523
3.3534679412841797
3.3041205406188965
3.2587060928344727
3.2168314456939697
3.1781082153320312
3.142191171646118
3.108795404434204
3.077687978744507
3.048676013946533
3.021582841873169
2.9962422847747803
2.9724936485290527
2.9501850605010986
2.9291799068450928
2.9093568325042725
2.890610694885254
2.872851610183716
2.8560001850128174
2.8399882316589355
2.8247530460357666
2.8102400302886963
2.796398401260376
2.7831811904907227
2.770545482635498
2.758451461791992
2.7468626499176025
2.7357449531555176
2.725066661834717
2.714799404144287
2.704915761947632
2.6953916549682617
2.6862051486968994
2.677335262298584
2.6687631607055664
2.6604719161987305
2.6524460315704346
2.644669771194458
2.6371309757232666
2.6298162937164307
2.622715950012207
2.6158182621002197
2.6091136932373047
2.602593183517456
2.5962488651275635
2.5900728702545166
2.584057331085205
2.57819652557373

In [191]:
for i in range(5):
    out = []
    ix = 0
    while True:
        if len(out) > 1:
            ix = ctoi[out[-2]]*27 + ctoi[out[-1]]
        else:
            ix = 0
        xenc_in = F.one_hot(torch.tensor([ix]), num_classes=27*27).float()
        logits = xenc_in @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)
        ix = torch.multinomial(probs, num_samples=1, replacement=True).item()
        out.append(itoc[ix])
        if ix == 0:
            break
    print("".join(out))

jximqpadewedvwkufbdlzztlwqlltmkthanijaomkybmon.
wjah.
mfqtdily.
fxw.
vxff.
