In [1]:
import random
import torch
import torch.nn.functional as F

In [2]:
words = open('names.txt', 'r').read().splitlines()
random.shuffle(words)
n = len(words)

n_train = int(n * 0.8)
n_dev = int(n * 0.1)

words_train = words[:n_train]
words_dev = words[n_train:n_train+n_dev]
words_test = words[n_train+n_dev:]

chars = sorted(set(''.join(words)))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

In [3]:
def build_xy(words, stoi):
    xs, ys = [], []
    for w in words[:]:
        chs = ['.', '.'] + list(w) + ['.']
        for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
            ix12 = stoi[ch1] * 27 + stoi[ch2]
            xs.append(ix12)
            ys.append(stoi[ch3])
    return torch.tensor(xs), torch.tensor(ys)

def data_nll(xs, ys, W):
    logits = W[xs].float()
    return F.cross_entropy(logits, ys)

def loss_calc(xs, ys, W):
    return data_nll(xs, ys, W) + 0.1 * (W**2).mean()

def train_step(xs, ys, W, lr=1.0, cycle=10):
    for _ in range(cycle):
        loss = loss_calc(xs, ys, W)
        W.grad = None
        loss.backward()
        W.data += -lr * W.grad
    return loss.item()

def generate(W, seed=2147483647, num_samples=5):
    g = torch.Generator().manual_seed(seed)
    for i in range(num_samples):
        out = []
        i1, i2 = 0, 0
        while True:
            pair = i1 * 27 + i2
            logits = W[pair].unsqueeze(0)
            p = logits.softmax(dim=1).squeeze(0)
            i3 = torch.multinomial(p, num_samples=1, generator=g).item()
            if i3 == 0:
                break
            else:
                out.append(itos[i3])
                i1, i2 = i2, i3
        print(''.join(out))

In [4]:
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

xs_tr, ys_tr = build_xy(words_train, stoi)
xs_dev, ys_dev = build_xy(words_dev, stoi)
xs_te, ys_te = build_xy(words_test, stoi)

train_step(xs_tr, ys_tr, W, 1.0, 1000)

3.0602707862854004

In [5]:
print(loss_calc(xs_tr, ys_tr, W).item())
print(loss_calc(xs_dev, ys_dev, W).item())
print(loss_calc(xs_te, ys_te, W).item())

3.0599329471588135
3.059915065765381
3.0769760608673096


In [6]:
generate(W, num_samples=5)

cexzdfzjglkurxycezkwyhhmvlzimjtna
nmlbfvk
ka
da
stexvpbbpwkhrggitmj
