In [343]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [344]:
words = open('names.txt', 'r').read().splitlines()


In [345]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [346]:
context_size = 5
embedding_size = 3
batch_size = 128
X, Y = [], []
for w in words:

    context = [0] * context_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix]  # crop and append

In [347]:
# Count how many items we have in stoi
num_classes = len(stoi)


In [348]:
X = torch.tensor(X)
Y = torch.tensor(Y)
X.shape, Y.shape

(torch.Size([228146, 5]), torch.Size([228146]))

In [349]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((num_classes, embedding_size), generator=g)
W1 = torch.randn((context_size * embedding_size, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, num_classes), generator=g)
b2 = torch.randn(num_classes, generator=g)
parameters = [C, W1, b1, W2, b2]

In [350]:
for p in parameters:
    p.requires_grad = True

In [351]:
for i in range(10000):

    if i < 5000:
        lr = 0.1
    else:
        lr = 0.01
    

    ix = torch.randint(0, X.shape[0], (batch_size,))   
    # Forward pass
    emb = C[X[ix]] # batch_size, context_size, embedding_size
    h = torch.tanh(emb.view(-1, context_size * embedding_size) @ W1 + b1) # batch_size, 100
    logits = h @ W2 + b2 # batch_size, num_classes
    loss = F.cross_entropy(logits, Y[ix])
    

    # Backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data += -lr * p.grad

print(loss.item())

2.38696026802063
