In [35]:
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()
random.shuffle(words)
n = len(words)

n_train = int(0.8 * n)
n_dev = int(0.1 * n)

words_train = words[:n_train]
words_dev = words[n_train:n_train + n_dev]
words_test = words[n_train + n_dev:]

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

In [3]:
def build_xy(words, stoi, block_size=3):
    x, y = [], []
    for w in words[:]:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            context = context[1:] + [ix]
    X, Y = torch.tensor(x), torch.tensor(y)
    return X, Y

def data_nll(x, y, C, W1, B1, W2, B2, block_size=3, smoothing=0):
    emb = C[x]
    H = torch.tanh(emb.view(-1, C.size(1) * block_size) @ W1 + B1)
    logits = H @ W2 + B2
    return F.cross_entropy(logits, y, label_smoothing=smoothing).item()

def train_step(x, y, C, W1, B1, W2, B2, parameters, block_size=3, lr=0.1, cycle=10):
    for i in range(cycle):
        ix = torch.randint(0, x.shape[0], (32,))
        emb = C[x[ix]]
        H = torch.tanh(emb.view(-1, C.size(1) * block_size) @ W1 + B1)
        logits = H @ W2 + B2
        loss = F.cross_entropy(logits, y[ix])
        for p in parameters:
            p.grad = None
        loss.backward()
        for p in parameters:
            p.data += -lr * p.grad
    return loss.item()

def generate(C, W1, B1, W2, B2, itos, block_size=3, seed=2147483647, num_samples=5):
    g = torch.Generator().manual_seed(seed)
    for _ in range(num_samples):
        out = []
        context = [0] * block_size
        while True:
            emb = C[torch.tensor(context)]
            h = torch.tanh(emb.view(-1, C.size(1) * block_size) @ W1 + B1)
            logits = h @ W2 + B2
            probs = F.softmax(logits, dim=1)
            ix = torch.multinomial(probs, num_samples=1, generator=g).item()
            context = context[1:] + [ix]
            if ix == 0:
                break
            else:
                out.append(ix)
        print(''.join(itos[i] for i in out))


In [4]:
x_tr, y_tr = build_xy(words_train, stoi)
x_dev, y_dev = build_xy(words_dev, stoi)
x_te, y_te = build_xy(words_test, stoi)

In [5]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 10), generator=g)
W1 = torch.randn((30, 500), generator=g)
B1 = torch.randn(500, generator=g)
W2 = torch.randn((500, 27), generator=g)
B2 = torch.randn(27, generator=g)
parameters = [C, W1, B1, W2, B2]
for p in parameters:
    p.requires_grad = True

In [33]:
print(train_step(x_tr, y_tr, C, W1, B1, W2, B2, parameters, 3, 0.0005, 4000000))
print(train_step(x_dev, y_dev, C, W1, B1, W2, B2, parameters, 3, 0.0005, 80000))

print(data_nll(x_tr, y_tr, C, W1, B1, W2, B2))
print(data_nll(x_dev, y_dev, C, W1, B1, W2, B2))

2.1739916801452637
1.683990478515625
2.0133633613586426
2.0003249645233154


In [36]:
print(data_nll(x_te, y_te, C, W1, B1, W2, B2))
# Final Test Loss Value = 2.1557

2.155707836151123


In [37]:
generate(C, W1, B1, W2, B2, itos)

dex
maleah
makila
kayden
maira
