In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import math
%matplotlib inline

In [2]:
g = torch.Generator().manual_seed(2147483647)

In [3]:
words = open('names.txt', 'r').read().splitlines()

In [4]:
chars = ['.'] + sorted(list(set(''.join(words))))
encode = lambda c: chars.index(c)
decode = lambda i: chars[i]

In [5]:
block_size = 3
vocab_size = len(chars)

In [6]:
def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for c in w + '.':
            ix = encode(c)
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xval, Yval = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [7]:
def cmp(s, dt, t):
    exact = torch.all(dt == t)
    approx = torch.allclose(dt, t.grad)
    max_diff = (torch.abs(dt - t.grad).abs().max().item())
    print(f'{s:15s} exact: {str(exact):5s} approx: {str(approx):5s} max_diff: {max_diff}')


In [8]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass. Because when everything is zero the
# expression of gradient is simplified.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [9]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (n,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

print(Xb.shape, Yb.shape)


torch.Size([32, 3]) torch.Size([32])


In [76]:
# embed the inputs
emb = C[Xb] # (n, block_size, n_embd)
embcat = emb.view(Xb.shape[0], -1) # (n, block_size * n_embd)
# calxulate the linear layer outputs
hprebn = embcat @ W1 + b1 # (n, n_hidden)
# get the mean of the batch
bnmeani = hprebn / hprebn.sum(dim=0, keepdim=True)
# zero center the batch
bndiff = hprebn - bnmeani
# variance is the sum of the squared differences of examples in the batch from the mean, divided by n-1 (Bessel's correction, divide by n-1 instead of n to improve the estimate of the population variance)
bndiff2 = bndiff**2
bnvar = 1/(n-1) * bndiff2.sum(dim=0, keepdim=True)
# standard deviation is the square root of the variance
# get the inverse of the standard deviation summed with a small constant to avoid division by zero
bnvar_inv = (bnvar + 1e-5)**-0.5
# scale the centered batch by the inverse of the standard deviation
bnraw = bndiff * bnvar_inv
# apply learned gain and bias to the scaled batch
hpreact = bngain * bnraw + bnbias
# apply the tanh activation function
h = torch.tanh(hpreact)
# calculate the logits for the output layer
logits = h @ W2 + b2
# for numerical stability, subtract the maximum logit value in each row
logit_maxes = logits.max(dim=1, keepdim=True).values
norm_logits = logits - logit_maxes # for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(dim=1, keepdim=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean() # we first index by the barch index and then for that particular batch row we take the probability of the character oobserved in the dataset

for p in parameters:
    p.grad = None

for t in reversed([emb, embcat, hprebn, bnmeani, bndiff, bndiff2, bnvar, bnvar_inv, bnraw, hpreact, h, logits, logit_maxes, norm_logits, counts, counts_sum, counts_sum_inv, probs, logprobs]):
    t.retain_grad()

loss.backward()

loss


tensor(3.4916, grad_fn=<NegBackward0>)

In [94]:
dlogprobs = -1/n * F.one_hot(Yb, vocab_size) # 32 x 27
dprobs = dlogprobs * (probs**-1)
dcounts_sum_inv = dprobs[range(n), Yb].view((n,-1)) * counts[range(n), Yb].view((n,-1))
dcounts_sum = dcounts_sum_inv * -1 * (counts_sum ** -2)
dcounts = dprobs * counts_sum_inv + dcounts_sum
dnorm_logits = dcounts * norm_logits.exp()
dlogit_maxes = torch.zeros((n,1))
dlogits = dnorm_logits
dW2 = h.T @ dlogits
db2 = dlogits.sum(dim=0)
dh = dlogits @ W2.T
dhpreact = dh * (1 - h**2)
dbnraw = dhpreact * bngain
dbngain = dhpreact
dbnbias = dhpreact

In [95]:
cmp("dlogprobs", dlogprobs, logprobs)
cmp("dprobs", dprobs, probs)
cmp("dcounts_sum_inv", dcounts_sum_inv, counts_sum_inv)
cmp("dcounts_sum", dcounts_sum, counts_sum)
cmp("dcounts", dcounts, counts)
cmp("dnorm_logits", dnorm_logits, norm_logits)
cmp("dlogits_maxes", dlogit_maxes, logit_maxes)
cmp("dlogits", dlogits, logits)
cmp("dW2", dW2, W2)
cmp("db2", db2, b2)
cmp("dh", dh, h)
cmp("dhpreact", dhpreact, hpreact)
cmp("dbnraw", dbnraw, bnraw)
cmp("dbngain", dbngain, bngain)
cmp("dbnbias", dbnbias, bnbias)

dlogprobs       exact: tensor(False) approx: True  max_diff: 0.0
dprobs          exact: tensor(False) approx: True  max_diff: 0.0
dcounts_sum_inv exact: tensor(False) approx: True  max_diff: 0.0
dcounts_sum     exact: tensor(False) approx: True  max_diff: 0.0
dcounts         exact: tensor(False) approx: True  max_diff: 0.0
dnorm_logits    exact: tensor(False) approx: True  max_diff: 0.0
dlogits_maxes   exact: tensor(False) approx: True  max_diff: 5.122274160385132e-09
dlogits         exact: tensor(False) approx: True  max_diff: 5.122274160385132e-09
dW2             exact: tensor(False) approx: True  max_diff: 7.450580596923828e-09
db2             exact: tensor(False) approx: True  max_diff: 7.450580596923828e-09
dh              exact: tensor(False) approx: True  max_diff: 1.3969838619232178e-09
dhpreact        exact: tensor(False) approx: True  max_diff: 1.3969838619232178e-09
dbnraw          exact: tensor(False) approx: True  max_diff: 1.3969838619232178e-09
dbngain         exact: ten