In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
with open("names.txt", "r") as f:
    names = f.read().split("\n")

In [5]:
stoi = {c: idx for c, idx in zip(sorted(list(set("".join(names)))), range(1, 27))}
stoi["."] = 0
itos = {c: i for i, c in stoi.items()}

In [42]:
# build dataset

X, Y = [], []

block_size = 3
for name in names:
    context = [0] * block_size
    for w in name + ".":
        ix = stoi[w]
        X.append(context)
        Y.append(ix)
        # print(f"{[itos[i] for i in context]} -> {itos[ix]}")
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

In [43]:
X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [44]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), requires_grad=True, generator=g)
W1 = torch.randn((6, 100), requires_grad=True, generator=g)
b1 = torch.randn(100, requires_grad=True, generator=g)
W2 = torch.randn((100,27), requires_grad=True, generator=g)
b2 = torch.randn(27, requires_grad=True, generator=g)
params = [C, W1, b1, W2, b2]

In [65]:
n_iter = 10
for _ in range(n_iter):
    # Mini Batch
    ix = torch.randint(0, len(X), (32,)) 
    
    # Forward Pass
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    # F.cross_entropy: Avoid intermediate vectors storage
    # Better backward pass efficiency
    # Avoid overflow for extreme positive values 
    loss = F.cross_entropy(logits, Y[ix])
    print(f"loss: {loss}")
    for p in params:
        p.grad = None
    
    loss.backward()
    for p in params:
        p.data = p.data - 0.1 * p.grad


loss: 2.826667308807373
loss: 2.912942886352539
loss: 2.8861660957336426
loss: 3.537803888320923
loss: 3.755579948425293
loss: 3.3454151153564453
loss: 2.554062843322754
loss: 3.730494976043701
loss: 3.039793014526367
loss: 3.310731887817383


In [40]:
torch.randint(0, len(X), (32,))

tensor([ 34726, 102938, 198036, 172705, 160447,  34541,  48375,  43281,  54324,
         65115, 179563,   5511, 104463, 152653,  69899,  56099, 109265, 179036,
         80475, 185210,  65294, 217302,  50359,  61409, 134505,  70859,   9209,
         64115,  35139,  44922,  88347,  48986])