In [266]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [267]:
words = open('names.txt', 'r').read().splitlines()


In [268]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [269]:
context_size = 3
embedding_size = 2
batch_size = 32
lr = 0.1
X, Y = [], []
for w in words:

    context = [0] * context_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix]  # crop and append

In [270]:
# Count how many items we have in stoi
num_classes = len(stoi)
print(f'{num_classes} characters')  # including the '.' character

27 characters


In [271]:
X = torch.tensor(X)
Y = torch.tensor(Y)
X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [272]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((num_classes, embedding_size), generator=g)
W1 = torch.randn((context_size * embedding_size, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, num_classes), generator=g)
b2 = torch.randn(num_classes, generator=g)
parameters = [C, W1, b1, W2, b2]

In [273]:
for p in parameters:
    p.requires_grad = True

In [274]:
for _ in range(100):

    ix = torch.randint(0, X.shape[0], (batch_size,))   
    # Forward pass
    emb = C[X[ix]] # batch_size, context_size, embedding_size
    h = torch.tanh(emb.view(-1, context_size * embedding_size) @ W1 + b1) # batch_size, 100
    logits = h @ W2 + b2 # batch_size, num_classes
    loss = F.cross_entropy(logits, Y[ix])
    

    # Backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data += -lr * p.grad

print(loss.item())

3.8046977519989014


In [275]:
torch.randint(0, X.shape[0], (32,), generator=g)


tensor([172659, 215477,  19086,   1265,  62236, 217099,  55529,  55796, 145462,
        124008, 129238, 187267, 223419, 150360, 136266, 179474,  56420, 159301,
         31249, 178134,  81560,  76592, 119629,  78313, 106899, 141168,  67469,
        171834, 207471,  43008, 143433, 138411])