In [5]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
words = open('../../names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [156]:
# split data set into blocks

block_size = 3
space_size = 5
num_neurons = 100
batch_size = 64

X, Y = [], []
for w in words:
    context = block_size * [0]
    for c in w + '.':
        X.append(context)
        ix = stoi[c]
        Y.append(ix)
        context = context[1:] + [ix]
X = torch.tensor(X)
Y = torch.tensor(Y)

In [157]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((len(stoi), space_size), generator=g)
W1 = torch.randn((space_size * block_size, num_neurons), generator=g)
b1 = torch.randn(num_neurons, generator=g)
W3 = torch.randn((space_size * block_size, num_neurons), generator=g)
b3 = torch.randn(num_neurons, generator=g)
W2 = torch.randn((num_neurons, len(stoi)), generator=g)
b2 = torch.randn(len(stoi), generator=g)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

In [171]:
step = 0.1/2
for _ in range(2000):
    ix = torch.randint(0, X.shape[0], (batch_size,))
    # Forward model
    emb = C[X[ix]]
    h1 = torch.tanh(emb.view(-1, space_size * block_size) @ W1 + b1)
    h2 = (h1) @ W2 + b2
    # counts = h2.exp()
    # P = counts / counts.sum(1, keepdims=True)
    # l1 = P[torch.arange(Y.shape[0]), Y]
    # loss = -l1.log().mean()
    loss = F.cross_entropy(h2, Y[ix])
    # update model
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data += -step * p.grad
print(loss.item())

2.1752820014953613


In [172]:
# generate samples:
for _ in range(20):
    context = block_size * [0]
    output = []
    while True:
        emb = C[torch.tensor([context])]
        h1 = torch.tanh(emb.view(-1, space_size * block_size) @ W1 + b1)
        h2 = (h1) @ W2 + b2

        P = F.softmax(h2, dim = 1)
        ix = torch.multinomial(P,
                               num_samples=1,
                               generator=g).item()
        output.append(ix)
        context = context[1:] + [ix]
        if ix == 0:
            break
    print(''.join(itos[ix] for ix in output))

laan.
tara.
jorata.
kaylonny.
ghzau.
haalee.
mahdatza.
tah.
jazhiyda.
sadyondie.
kauroa.
seclyzanxa.
eme.
jarawhane.
lein.
dran.
ava.
japmede.
naverlae.
lei.
