In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [3]:
with open('./dataset/names.txt', 'r') as f:
    names = f.read().splitlines()
import random
random.seed(108)
random.shuffle(names)

In [4]:
vocab = ['.'] + sorted(set(''.join(names)))
vocab_size = len(vocab)
itos = dict(enumerate(vocab))
stoi = dict((s, i) for i, s in itos.items())

In [5]:
batch = 3
embed_size = 2
l1 = batch * embed_size
l2 = 100
l3 = vocab_size

In [6]:
X, Y = [], []

for name in names:
    content = [0] * batch 
    for ch in name + '.':
        ch = stoi[ch]
        X.append(list(content))
        Y.append(ch)
        content.append(ch)
        content = content[1:]
X = torch.tensor(X)
Y = torch.tensor(Y)

In [8]:
itr = int(0.8 * X.shape[0])
idev = int(0.9 * X.shape[0])
Xtr, Xdev, Xts = X.tensor_split((itr, idev))
Ytr, Ydev, Yts = Y.tensor_split((itr, idev))

In [9]:
g = torch.Generator().manual_seed(123)
C = torch.randn((vocab_size, embed_size), requires_grad=True, generator=g)
W1 = torch.randn((l1, l2), requires_grad=True, generator=g)
b1 = torch.randn((l2,), requires_grad=True, generator=g)
W2 = torch.randn((l2, l3), requires_grad=True, generator=g)
b2 = torch.randn((l3,), requires_grad=True, generator=g)
params = [C, W1, b1, W2, b2]
print(f'Parametes: {sum([p.nelement() for p in params])}')

Parametes: 3481


In [10]:
for e in range(10000):
    b = torch.randint(0, Xtr.shape[0], (int(Xtr.shape[0]*0.01),))
    h = torch.tanh(C[Xtr[b]].view(-1, l1) @ W1 + b1)
    logits = h @ W2 + b2
    nll = nn.functional.cross_entropy(logits, Ytr[b])
    for p in params:
        p.grad = None
    nll.backward()
    for p in params:
        p.data -= 0.1 * p.grad
print(nll)

tensor(2.3987, grad_fn=<NllLossBackward0>)


In [11]:
with torch.no_grad():
    h = torch.tanh(C[Xtr].view(-1, l1) @ W1 + b1)
    logits = h @ W2 + b2
    exp = logits.exp()
    probs = exp/exp.sum(dim=1, keepdim=True)
    nll = -probs[torch.arange(Xtr.shape[0]), Ytr].log().mean()
    print(nll)

tensor(2.4067)


In [12]:
with torch.no_grad():
    h = torch.tanh(C[Xdev].view(-1, l1) @ W1 + b1)
    logits = h @ W2 + b2
    exp = logits.exp()
    probs = exp/exp.sum(dim=1, keepdim=True)
    nll = -probs[torch.arange(Xdev.shape[0]), Ydev].log().mean()
    print(nll)

tensor(2.4180)


In [19]:
out = [0] * batch
idx = out[-batch:]

while True:
    h = torch.tanh(C[idx].view(-1, l1) @ W1 + b1)
    logits = h @ W2 + b2
    exp = logits.exp()
    probs = exp/exp.sum(dim=1, keepdim=True)
    pred = torch.multinomial(probs, 1, replacement=True, generator=g).item()
    idx.append(pred)
    out.append(pred)
    idx = idx[1:]
    if pred == 0:
        break
print(''.join(itos[i] for i in out if i))

jaol
