In [17]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [18]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [19]:
# build the vocabulary of chars and mappings to/from ints
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}


In [84]:
# build the dataset

block_size = 3
X, Y = [], []

for w in words:
    context= [0] * block_size
    for ch in w + '.':
        idx = stoi[ch]
        X.append(context)
        Y.append(idx)
        #print(''.join(itos[i] for i in context), '----->', itos[idx])
        context = context[1:] + [idx]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [85]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([228146, 3]), torch.int64, torch.Size([228146]), torch.int64)

In [34]:
C = torch.randn((27, 2))
C

tensor([[-2.4473, -0.0967],
        [ 0.8826, -1.8210],
        [ 0.4507, -0.4967],
        [ 0.6312,  1.0287],
        [ 0.1091, -0.2267],
        [-0.4277,  2.0331],
        [ 0.1640,  1.0193],
        [ 1.9079, -0.3071],
        [ 0.7179,  0.6389],
        [-0.8530,  0.4987],
        [-1.6897, -0.6295],
        [ 1.5163, -0.8672],
        [ 0.9662,  1.1189],
        [ 0.0689, -0.3937],
        [ 0.1939,  1.4098],
        [-0.2188,  0.1498],
        [ 0.3618,  0.1927],
        [ 0.3235, -1.7494],
        [ 1.0883, -0.2664],
        [-0.2719,  0.6883],
        [-1.0290, -0.8363],
        [ 0.0060,  0.1786],
        [ 0.0833, -0.5175],
        [-1.8693, -0.0372],
        [-0.6031,  0.1659],
        [ 0.4414,  0.1703],
        [-0.0422, -0.3846]])

In [39]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [40]:
W1 = torch.rand((6, 100))
b1 = torch.randn(100)

In [52]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [53]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [54]:
logits = h @ W2 + b2

In [55]:
logits.shape

torch.Size([32, 27])

In [57]:
counts = logits.exp()

In [58]:
prob = counts / counts.sum(1, keepdim=True)

In [62]:
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(10.2639)

In [63]:
# put it all together

In [86]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [87]:
sum(p.nelement() for p in parameters)

3481

In [88]:
for p in parameters:
    p.requires_grad = True

In [None]:
for _ in range(1000):
    # forward pass
    emb = C[X]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
    logits = h @ W2 + b2                      # (32, 27)
    loss = F.cross_entropy(logits, Y)
    print(loss.item())
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data += -0.1 * p.grad

print(loss.item())