In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

In [2]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [3]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [4]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [5]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [7]:
batch_size = 32
n = batch_size # a shorter variable also, for convenience
# construct a minibatch
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y

In [8]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb] # embed the characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/n*hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact) # hidden layer
# Linear layer 2
logits = h @ W2 + b2 # output layer
# cross entropy loss (same as F.cross_entropy(logits, Yb)) - note that these are all element-wise operations
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb,C]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3069, grad_fn=<NegBackward0>)

In [14]:
C.shape

torch.Size([27, 10])

In [None]:
emb

tensor([[[-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
         [-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
         [-9.6478e-01, -2.3211e-01, -3.4762e-01,  3.3244e-01, -1.3263e+00,
           1.1224e+00,  5.9641e-01,  4.5846e-01,  5.4011e-02, -1.7400e+00]],

        [[ 1.2815e+00, -6.3182e-01, -1.2464e+00,  6.8305e-01, -3.9455e-01,
           1.4388e-02,  5.7216e-01,  8.6726e-01,  6.3149e-01, -1.2230e+00],
         [ 4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2197e-01,  5.1007e-01,
           1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01],
         [-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
           2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01]],

        [[-5.6533e-01,  5.4281e-01,  1.7549e-01, -2.2901e+00, -7.0928e-01,
          -2.92

: 

In [15]:
embcat

tensor([[-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
         -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
         -9.6478e-01, -2.3211e-01, -3.4762e-01,  3.3244e-01, -1.3263e+00,
          1.1224e+00,  5.9641e-01,  4.5846e-01,  5.4011e-02, -1.7400e+00],
        [ 1.2815e+00, -6.3182e-01, -1.2464e+00,  6.8305e-01, -3.9455e-01,
          1.4388e-02,  5.7216e-01,  8.6726e-01,  6.3149e-01, -1.2230e+00,
          4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2197e-01,  5.1007e-01,
          1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01,
         -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
        [-5.6533e-01,  5.4281e-01,  1.7549e-01, -2.2901e+00, -7.0928e-01,
         -2.9283e-01, -2.1803e+00,  

In [89]:
embcat.shape

torch.Size([32, 30])

In [87]:
embcat

tensor([[-4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
         -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01,
         -9.6478e-01, -2.3211e-01, -3.4762e-01,  3.3244e-01, -1.3263e+00,
          1.1224e+00,  5.9641e-01,  4.5846e-01,  5.4011e-02, -1.7400e+00],
        [ 1.2815e+00, -6.3182e-01, -1.2464e+00,  6.8305e-01, -3.9455e-01,
          1.4388e-02,  5.7216e-01,  8.6726e-01,  6.3149e-01, -1.2230e+00,
          4.6827e-01, -6.5650e-01,  6.1662e-01, -6.2197e-01,  5.1007e-01,
          1.3563e+00,  2.3445e-01, -4.5585e-01, -1.3132e-03, -5.1161e-01,
         -4.7125e-01,  7.8682e-01, -3.2843e-01, -4.3297e-01,  1.3729e+00,
          2.9334e+00,  1.5618e+00, -1.6261e+00,  6.7716e-01, -8.4039e-01],
        [-5.6533e-01,  5.4281e-01,  1.7549e-01, -2.2901e+00, -7.0928e-01,
         -2.9283e-01, -2.1803e+00,  

In [78]:
hprebn.shape

torch.Size([32, 64])

In [79]:
bnraw.shape

torch.Size([32, 64])

In [76]:
hprebn.shape

torch.Size([32, 64])

In [77]:
bnmeani.shape

torch.Size([1, 64])

In [9]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n # set the gradient at the index of Yb to be -1/n
cmp('dlogprobs', dlogprobs, logprobs)

dlogprobs       | exact: True  | approximate: True  | maxdiff: 0.0


In [10]:
dlogprobs.shape

torch.Size([32, 27])

In [11]:
(1/probs).shape

torch.Size([32, 27])

In [12]:
dprops = dlogprobs * (1.0 / probs) # derivative of logprobs w.r.t. probs
cmp('dprops', dprops, probs)

dprops          | exact: True  | approximate: True  | maxdiff: 0.0


In [13]:
dcounts_sum_inv = (dprops * counts).sum(1,keepdim=True) #  need sum to reflec the broaddcasting operations when multiplying counts(nxV) and counts_sum_inv(nx1)
cmp('dcounts', dcounts_sum_inv, counts_sum_inv)

dcounts         | exact: True  | approximate: True  | maxdiff: 0.0


In [14]:
dcounts_sum_inv.shape

torch.Size([32, 1])

In [15]:
dcounts = dprops * counts_sum_inv # counts is used multiple times, so the gradient is not complete.

In [16]:
dcounts_sum = dcounts_sum_inv * (-1/counts_sum**2)
cmp('dcounts_sum', dcounts_sum, counts_sum)

dcounts_sum     | exact: True  | approximate: True  | maxdiff: 0.0


In [17]:
dcounts += dcounts_sum * torch.ones_like(counts)
cmp('dcounts', dcounts, counts)

dcounts         | exact: True  | approximate: True  | maxdiff: 0.0


In [71]:
dcounts.shape

torch.Size([32, 27])

In [18]:
dnorm_logits = dcounts * torch.exp(norm_logits) # derivative of counts w.r.t. norm_logits
cmp('dnorm_logits', dnorm_logits, norm_logits)

dnorm_logits    | exact: True  | approximate: True  | maxdiff: 0.0


In [19]:
dlogit_maxes = -dnorm_logits.sum(1, keepdim=True) # derivative of norm_logits w.r.t. logit_maxes
cmp('dlogit_maxes', dlogit_maxes, logit_maxes)

dlogit_maxes    | exact: True  | approximate: True  | maxdiff: 0.0


In [20]:
dlogit_maxes # should be close to zero, because subtracting a number doesn't change the probability distribution

tensor([[-3.0268e-09],
        [ 3.7253e-09],
        [-1.8626e-09],
        [ 2.3283e-09],
        [-4.6566e-10],
        [ 3.2596e-09],
        [-4.6566e-10],
        [ 4.6566e-09],
        [-0.0000e+00],
        [ 4.1910e-09],
        [ 1.3970e-09],
        [-3.9581e-09],
        [ 5.3551e-09],
        [-0.0000e+00],
        [-2.3283e-09],
        [ 5.5879e-09],
        [ 3.2596e-09],
        [-0.0000e+00],
        [ 2.3283e-09],
        [ 2.7940e-09],
        [ 4.6566e-10],
        [ 6.9849e-10],
        [ 2.3283e-09],
        [ 2.0955e-09],
        [ 3.7253e-09],
        [ 9.3132e-10],
        [-0.0000e+00],
        [-6.0536e-09],
        [ 1.3970e-09],
        [ 4.6566e-09],
        [-5.5879e-09],
        [ 1.6298e-09]], grad_fn=<NegBackward0>)

In [21]:
dlogits = dnorm_logits.clone()
max_indces = F.one_hot(logits.max(1).indices,num_classes=logits.shape[1])# indices of the max logits
dlogits +=  dlogit_maxes * max_indces
cmp('dlogits', dlogits, logits)

dlogits         | exact: True  | approximate: True  | maxdiff: 0.0


In [22]:
dh = dlogits @ W2.T # derivative of logits w.r.t. h
cmp('dh', dh, h)

dh              | exact: True  | approximate: True  | maxdiff: 0.0


In [23]:
dhpreact = dh*(1 - h**2) # derivative of h w.r.t. hpreact
cmp('dhpreact', dhpreact, hpreact)

dhpreact        | exact: True  | approximate: True  | maxdiff: 0.0


In [24]:
dbngain = (bnraw * dhpreact).sum(0, keepdim=True) # derivative of hpreact w.r.t. bngain
cmp('dbngain', dbngain, bngain)

dbngain         | exact: True  | approximate: True  | maxdiff: 0.0


In [25]:
dbnraw = dhpreact * bngain
cmp('dbnraw', dbnraw, bnraw)

dbnraw          | exact: True  | approximate: True  | maxdiff: 0.0


In [26]:
dbnbias = dhpreact.sum(0, keepdim=True) # derivative of hpreact w.r.t. bnbias
cmp('dbnbias', dbnbias, bnbias)

dbnbias         | exact: True  | approximate: True  | maxdiff: 0.0


In [27]:
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim=True) # derivative of bnraw w.r.t. bnvar_inv
cmp('dbnvar_inv', dbnvar_inv, bnvar_inv)

dbnvar_inv      | exact: True  | approximate: True  | maxdiff: 0.0


In [28]:
dbndiff = (dbnraw * bnvar_inv) # derivative of bnvar_inv w.r.t. bndiff


In [29]:
dbnvar = dbnvar_inv * (-0.5*((bnvar + 1e-5)**(-3/2))) # derivative of bnvar_inv w.r.t. bnvar
cmp('dbnvar', dbnvar, bnvar)

dbnvar          | exact: True  | approximate: True  | maxdiff: 0.0


In [30]:
dbndiff2  = dbnvar *  (1.0/(n-1))*torch.ones_like(bndiff2)
cmp('dbndiff', dbndiff2, bndiff2)

dbndiff         | exact: True  | approximate: True  | maxdiff: 0.0


In [31]:
bnvar.shape

torch.Size([1, 64])

In [32]:
bnvar

tensor([[2.9677, 1.2682, 2.1554, 1.4667, 4.6703, 2.9705, 2.0831, 1.4324, 1.3041,
         1.2941, 1.5416, 1.3771, 3.7823, 1.3272, 1.0750, 2.1390, 2.0086, 4.0767,
         0.9473, 2.6632, 5.2006, 1.3070, 2.7925, 2.8935, 1.6161, 2.7019, 1.9697,
         1.4905, 2.7771, 1.8918, 1.4346, 3.3289, 2.6032, 1.7041, 2.5652, 3.1497,
         2.4295, 3.7256, 0.3788, 1.5171, 1.6284, 2.5108, 2.3224, 1.2441, 1.9543,
         2.9438, 1.9290, 2.3941, 1.5941, 2.6831, 2.9968, 1.2193, 3.5117, 2.7399,
         3.2270, 3.6581, 1.3968, 2.2523, 2.1150, 0.7927, 1.5964, 3.3448, 2.2397,
         1.8453]], grad_fn=<MulBackward0>)

In [33]:
bndiff2.shape

torch.Size([32, 64])

In [34]:
dbndiff += dbndiff2* 2 * bndiff # derivative of bndiff2 w.r.t. bndiff
cmp('dbndiff', dbndiff, bndiff)

dbndiff         | exact: True  | approximate: True  | maxdiff: 0.0


In [35]:
dbmeani = -dbndiff.sum(0, keepdim=True) # derivative of bndiff w.r.t. bnmeani
cmp('dbmeani', dbmeani, bnmeani)

dbmeani         | exact: True  | approximate: True  | maxdiff: 0.0


In [36]:
dhprebn = dbndiff.clone()
dhprebn += dbmeani * (1.0 / n)*torch.ones_like(hprebn) # derivative of bndiff w.r.t. hprebn
cmp('dhprebn', dhprebn, hprebn)

dhprebn         | exact: True  | approximate: True  | maxdiff: 0.0


In [37]:
dembcat = dhprebn @ W1.T # derivative of hprebn w.r.t. embcat
cmp('dembcat', dembcat, embcat)

dembcat         | exact: True  | approximate: True  | maxdiff: 0.0


In [38]:
embcat.shape

torch.Size([32, 30])

In [39]:
emb.shape

torch.Size([32, 3, 10])

In [40]:
demb = dembcat.view(emb.shape) # reshape to the original shape
cmp('demb', demb, emb)

demb            | exact: True  | approximate: True  | maxdiff: 0.0


In [41]:
C.shape

torch.Size([27, 10])

In [42]:
dC = torch.zeros_like(C) # derivative of emb w.r.t. C
for i in range(Xb.shape[0]):
  for j in range(Xb.shape[1]):
    ix = Xb[i, j]
    dC[ix] += demb[i, j] # accumulate the gradients for each character

cmp('dC', dC, C)

dC              | exact: True  | approximate: True  | maxdiff: 0.0


Cross Entropy Loss Backward Pass

In [43]:
logits.shape

torch.Size([32, 27])

In [52]:
dlogits = F.softmax(logits,dim=1)/n # derivative of logits w.r.t. logits
dlogits[range(n), Yb] -= 1.0 / n # subtract the gradient at the index of Yb
cmp('dlogits', dlogits, logits)

dlogits         | exact: False | approximate: True  | maxdiff: 5.587935447692871e-09


In [53]:
logits

tensor([[ 7.8443e-01,  9.7966e-01, -5.9479e-01,  4.3666e-01, -5.1079e-01,
          9.5369e-01, -1.6717e-01,  1.6261e-01, -5.1531e-01, -4.2125e-02,
          1.3779e-01,  1.1514e-01,  1.0858e-01, -8.2557e-02,  9.9455e-02,
         -8.5229e-01, -1.1941e+00, -4.3297e-01, -6.1949e-01,  6.1622e-01,
          4.0452e-01, -4.3064e-01, -1.9092e-01,  7.6450e-01,  6.4273e-01,
         -2.6198e-01, -3.4824e-01],
        [ 3.6329e-01,  3.1453e-01,  8.3930e-01,  2.6326e-01, -9.2531e-02,
         -1.7934e-01, -7.4664e-01,  2.0051e-01, -6.5657e-01, -5.6760e-01,
          2.3579e-01,  1.4602e-01,  1.3580e-01, -2.0345e-01,  1.5115e-01,
         -2.7986e-02, -3.5169e-01, -8.3423e-01, -5.7010e-01, -6.9704e-02,
         -7.3184e-01, -6.0282e-01, -9.1704e-01,  4.4497e-01, -5.4146e-01,
         -1.2583e-01, -3.7733e-01],
        [-5.1014e-01, -4.0033e-01, -8.8524e-01, -9.6981e-01, -3.7685e-01,
          2.4569e-01,  6.2260e-01,  6.3863e-01,  6.5615e-01, -8.0189e-02,
         -4.9034e-01,  5.1649e-02,  2.80

In [65]:
F.softmax(logits,dim=1)

tensor([[0.0698, 0.0848, 0.0176, 0.0493, 0.0191, 0.0826, 0.0269, 0.0375, 0.0190,
         0.0305, 0.0365, 0.0357, 0.0355, 0.0293, 0.0352, 0.0136, 0.0096, 0.0207,
         0.0171, 0.0590, 0.0477, 0.0207, 0.0263, 0.0684, 0.0606, 0.0245, 0.0225],
        [0.0570, 0.0543, 0.0917, 0.0515, 0.0361, 0.0331, 0.0188, 0.0484, 0.0205,
         0.0225, 0.0501, 0.0458, 0.0454, 0.0323, 0.0461, 0.0385, 0.0279, 0.0172,
         0.0224, 0.0369, 0.0191, 0.0217, 0.0158, 0.0618, 0.0230, 0.0349, 0.0272],
        [0.0198, 0.0221, 0.0136, 0.0125, 0.0226, 0.0422, 0.0615, 0.0625, 0.0636,
         0.0304, 0.0202, 0.0347, 0.0437, 0.0508, 0.0222, 0.0254, 0.0142, 0.0348,
         0.0275, 0.1071, 0.0662, 0.0361, 0.0410, 0.0392, 0.0374, 0.0195, 0.0293],
        [0.0322, 0.0248, 0.0404, 0.0528, 0.0591, 0.0273, 0.0539, 0.0460, 0.0517,
         0.0175, 0.0399, 0.0257, 0.0309, 0.0427, 0.0550, 0.0634, 0.0272, 0.0264,
         0.0182, 0.0506, 0.0176, 0.0229, 0.0370, 0.0197, 0.0250, 0.0441, 0.0480],
        [0.0164, 0.0153,

In [73]:
dlogits[0].sum()

tensor(3.4925e-09, grad_fn=<SumBackward0>)