In [163]:
import torch

words = open("names.txt", "r").read().splitlines()
chars = sorted(list(set("".join(words))))
stoi = {c: i for i, c in enumerate(chars)}
stoi['.'] = 0
itos = {i: c for c, i in stoi.items()}
block_size = 3


# build the dataset

def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)


import random

random.seed(42)
random.shuffle(words)
n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

Xtrain, Ytrain = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xtest, Ytest = build_dataset(words[n2:])

In [164]:
# utility functions
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f"{s}: {ex} {app} {maxdiff}")

In [187]:
n_embed = 10  # embedding dimension
n_hidden = 64  # the number of the neurons in the hidden layer
vocab_size = len(stoi)  # the size of the vocabulary

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embed), generator=g)

# Layer 1
W1 = torch.randn((n_embed * block_size, n_hidden), generator=g) * (5 / 3) / ((n_embed * block_size) ** 0.5)
b1 = torch.randn((n_hidden,), generator=g) * 0.1

# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * .1
b2 = torch.randn((vocab_size,), generator=g) * 0.1

# BatchNorm parameters
gamma = torch.randn((1, n_hidden), generator=g) * 0.1 + 1.0
beta = torch.randn((1, n_hidden), generator=g) * 0.1

params = [C, W1, b1, W2, b2, gamma, beta]
print([p.shape for p in params])
for p in params:
    p.requires_grad_(True)



[torch.Size([27, 10]), torch.Size([30, 64]), torch.Size([64]), torch.Size([64, 27]), torch.Size([27]), torch.Size([1, 64]), torch.Size([1, 64])]


In [188]:
batch_size = 32
ix = torch.randint(0, len(Xtrain), (batch_size,), generator=g)
Xb, Yb = Xtrain[ix], Ytrain[ix]

In [189]:
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

# Layer 1
hprebn = embcat @ W1 + b1

# BatchNorm
bnmean1 = 1 / batch_size * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmean1
bndiff2 = bndiff ** 2
bnvar = 1 / (batch_size - 1) * bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar - 1e-5) ** (-0.5)
bnraw = bndiff * bnvar_inv
hpreact = gamma * bnraw + beta

# Nonlinearity
h = torch.tanh(hpreact)

# Layer 2
logits = h @ W2 + b2

# Cross-entropy loss
logits_max = logits.max(1, keepdim=True).values
norm_logits = logits - logits_max  # for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(batch_size), Yb].mean()

# Pytorch backward pass
for p in params:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logits_max, logits, h, hpreact, bnraw,gamma,beta,
          bnvar_inv, bnvar, bndiff2, bndiff, bnmean1, hprebn, embcat, emb]:
    t.retain_grad()
loss.backward()
loss

tensor(3.3584, grad_fn=<NegBackward0>)

In [194]:
bndiff2.shape, dbnvar.shape

(torch.Size([32, 64]), torch.Size([1, 64]))

In [204]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Yb] = -1.0 / batch_size
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum ** - 2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum
dnorm_logits = dcounts * norm_logits.exp()
dlogits = dnorm_logits.clone() 
dlogits_max = (-dnorm_logits).sum(1, keepdim=True)
dlogits += torch.nn.functional.one_hot(logits.max(1).indices, num_classes=vocab_size).float() * dlogits_max
dh = dlogits @ W2.t()
dW2 = h.t() @ dlogits
db2 = dlogits.sum(0)
dhpreact =  (1 - h**2) * dh
dgamma = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = gamma * dhpreact
bbeta = dhpreact.sum(0, keepdim=True)
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar = -0.5 * (bnvar - 1e-5) ** (-3/2) * dbnvar_inv
dbndiff2 = 1 / (batch_size - 1) * torch.ones_like(bndiff2) * dbnvar
dbndiff += 2 * bndiff * dbndiff2
dhprebn = dbndiff.clone()
dbnmean1 = -dhprebn.sum(0, keepdim=True)
dhprebn += 1 / batch_size * torch.ones_like(hprebn) * dbnmean1
dembcat = dhprebn @ W1.t()
dW1 = embcat.t() @ dhprebn
db1 = dhprebn.sum(0)
demb = dembcat.view(emb.shape)

dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        dC[Xb[i, j]] += demb[i, j]


cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logits_max', dlogits_max, logits_max)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('dW2', dW2, W2)
cmp('db2', db2, b2)
cmp('dhpreact', dhpreact, hpreact)
cmp('dgamma', dgamma, gamma)
cmp('dbeta', bbeta, beta)
cmp('bnraw', dbnraw, bnraw)
cmp('dbnvar_inv', dbnvar_inv, bnvar_inv)
cmp('dbnvar', dbnvar, bnvar)
cmp('dbndiff2', dbndiff2, bndiff2)
cmp('dbndiff', dbndiff, bndiff)


cmp('dhprebn', dhprebn, hprebn)
cmp('dbnmean1', dbnmean1, bnmean1)
cmp('dembcat', dembcat, embcat)
cmp('dW1', dW1, W1)
cmp('db1', db1, b1)
cmp('demb', demb, emb)
cmp('dC', dC, C)



logprobs: True True 0.0
probs: True True 0.0
counts_sum_inv: True True 0.0
counts_sum: True True 0.0
counts: True True 0.0
norm_logits: True True 0.0
logits_max: True True 0.0
logits: True True 0.0
h: True True 0.0
dW2: True True 0.0
db2: True True 0.0
dhpreact: True True 0.0
dgamma: True True 0.0
dbeta: True True 0.0
bnraw: True True 0.0
dbnvar_inv: True True 0.0
dbnvar: True True 0.0
dbndiff2: True True 0.0
dbndiff: True True 0.0
dhprebn: True True 0.0
dbnmean1: True True 0.0
dembcat: True True 0.0
dW1: True True 0.0
db1: True True 0.0
demb: True True 0.0
dC: True True 0.0


In [169]:
import torch.nn.functional as F

# now 
loss_fast = F.cross_entropy(logits, Yb)
loss_fast

tensor(3.3584, grad_fn=<NllLossBackward0>)