# Becoming a Backprop Ninja
* You need to understand backpropagation, because it is a leaky abstraction
* We already covered backpropagation for the scalar case, by implementing micrograd
* But we need to expand this knowledge to tensors
* We will use the same neural network as in the last lecture, but this time we will implement the backward pass manually

## The Neural Network

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [3]:
# read in all the words
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:8])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


In [4]:
# build the vocabulary of characters and mappings to/from integers
chars = list(sorted(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
vocab_size = len(stoi)
itos = {i:s for s,i in stoi.items()}
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [5]:
# build the dataset
block_size = 3
def build_dataset(words):
    X,Y = [], []
    for w in words:
        context = [0] * 3
        for c in w+'.':
            xi = stoi[c]
            X.append(context)
            Y.append(xi)
            context = context[1:]+[xi]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X,Y

random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.8*len(words))
Xtr,Ytr=build_dataset(words[0:n1])
Xval,Yval=build_dataset(words[n1:n2])
Xtest, Ytest=build_dataset(words[n2:])


* A new utility function is introduced that compares our manual gradient computations with pytorch computed ones

In [6]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [7]:
n_embd = 10
n_hidden = 200

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd)               , generator = g)
W1 = torch.randn((block_size*n_embd, n_hidden)     , generator = g) * 5/3 / ((block_size*n_embd)**0.5) # kaimin initialization to avoid contraction
b1 = torch.randn((n_hidden)                        , generator = g) * 0.1
W2 = torch.randn((n_hidden, vocab_size)            , generator = g) * 0.1  # make less confident
b2 = torch.randn((vocab_size)                      , generator = g) * 0.1  # not zero to unmask gradient errors

bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden)) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

12297


* We will do a single forward pass and for this calculate one batch

In [8]:
batch_size = 32
n = batch_size # shorter name for use in expressions
xi = torch.randint(0, len(words), (batch_size, ), generator = g)
Xb, Yb = Xtr[xi], Ytr[xi]

In [9]:
# forward pass "chunkated" into smaller steps that are possible to backward one at a time
emb = C[Xb]
embcat = emb.view(-1, block_size*n_embd)
# Linear Layer 1
hprebn = embcat @ W1 + b1
# Batchnorm Layer
bnmeani = hprebn.mean(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1) * bndiff2.sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1 instead of n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non Linearity
h = torch.tanh(hpreact)
# Linear layer 2
logits = h @ W2 + b2
# cross entropy loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# PyTorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
          bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
          embcat, emb]:
    t.retain_grad()
loss.backward()
loss



tensor(3.8712, grad_fn=<NegBackward0>)

* Exercise 1: backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one
* We start with dlogprops, which has the following shape

In [10]:
logprobs.shape

torch.Size([32, 27])

* The gradient tensor must have the same shape as we need the element-wise gradient
* So how does ```logprobs``` influence ```loss```?
* ```loss``` is a result of an index operation into ```logprobs``` and a mean calculation of all the resulting values
* The result is then negated
* The indices for each row of ```logprobs``` is taken from the row vector ```Yb```, which are all the correct labels for the 32 samples

In [11]:
print(Yb)
Yb.shape

tensor([ 0, 15,  1,  9, 18, 20,  5, 14,  0, 25,  5,  9,  9,  9, 20,  0,  0,  9,
         1,  2,  1,  0,  9,  1,  1,  1,  9, 14,  3, 12, 14,  5])


torch.Size([32])

* So in a simpler example, where we have three indexes, the loss becomes
  
  $loss=-(a+b+c)/3=-a/3 - b/3 - c/3$
* So deriving by each variable becomes for the example of $a$:
  
  $dloss/da=-1/3$

* Or more generally $-1/n$ for $n$ variables
* But only one number in the 32 rows is used in the $loss$ calculation, so the rest don't influence the $loss$ at all and thus receive a gradient of 0

In [12]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n

In [13]:
cmp('logprobs', dlogprobs, logprobs)

logprobs        | exact True  | approximate: True  | maxdiff: 0.0


* We continue with deriving ```logprobs``` by it's variables
* ```logprobs``` takes as it's variable only ```probs``` and results from applying ```log``` to every entry
* So each entry of ```probs``` changes ```logprobs``` by the proportionality factor of the derivative of ```log```, which is ```1/x```, where ```x``` is the entry
* As ```logprobs``` and ```probs``` have the same dimensions, we just calculate ```1/x``` for each entry
* Also we need to apply the chain rule, by multiplying by the gradients of logprobs to arrive at the derivatives of ```probs``` with respect to ```loss```

In [14]:
dprobs = 1/probs * dlogprobs

In [15]:
cmp('dprobs', dprobs, probs)

dprobs          | exact True  | approximate: True  | maxdiff: 0.0


* Moving on to the constituents of ```probs```, which is simply computed by multiplying the variables ```counts``` and ```counts_sum_inv```
* Applying the derivative for multiplication (example $a$)
  
  $d(counts\_sum\_inv*counts)/dcounts\_sum\_inv=counts$
* So for ```counts_sum_inv```, it will be a gradient of ```counts```
* But we need to take into account that ```counts_sum_inv``` has a different dimensionality than ```probs```

In [16]:
probs.shape, counts_sum_inv.shape

(torch.Size([32, 27]), torch.Size([32, 1]))

* We can see, that the one column of $counts_sum_inv$ will be broadcast into every of the 27 columns of ```probs```
* We learned in the micrograd lecture that if a variable takes part in multiple expressions the gradients of those expression must be summed for that variable
* One element of the ```counts_sum_inv``` column vector is used in 27 multiplications in one row of ```counts```, thus we will sum all the gradients in that row, which is the sum of ```counts```

In [17]:
dcounts_sum_inv= (counts*dprobs).sum(1, keepdim=True)
cmp('dcounts_sum_inv', dcounts_sum_inv, counts_sum_inv)

dcounts_sum_inv | exact True  | approximate: True  | maxdiff: 0.0


* Next up is ```counts```, which appears in the two expressions that result in ```probs``` and ```counts_sum```
* So before we can compute the derivative of ```counts```, we first need to derive ```counts_sum_inv``` wrt ```counts_sum```
* ```counts_sum``` is inverted, so the derivative becomes

  $d(counts\_sum^{-1})/dcounts\_sum=-counts\_sum^{-2}$
  

In [18]:
dcounts_sum = -1*counts_sum**-2 * dcounts_sum_inv
cmp('dcounts_sum', dcounts_sum, counts_sum)

dcounts_sum     | exact True  | approximate: True  | maxdiff: 0.0


* Now we can compute the derivative of ```counts```
* As all variables of ```counts``` are summed row wise in ```counts_sum```, the derivative will be one for each variable

  $d(counts\_sum)/dcounts=1$

* In ```probs```, ```counts``` is multiplied with the broadcasted ```counts_sum_inv```, so

  $d(probs)/dcounts=d(counts * counts\_sum\_inv)/dcounts=counts\_sum\_inv$

In [19]:
dcounts = torch.ones_like(counts) * dcounts_sum
dcounts += counts_sum_inv * dprobs
cmp('dcounts', dcounts, counts)

dcounts         | exact True  | approximate: True  | maxdiff: 0.0


In [20]:
dnorm_logits = counts * dcounts
cmp('dnorm_logits', dnorm_logits, norm_logits)

dnorm_logits    | exact True  | approximate: True  | maxdiff: 0.0


In [21]:
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
cmp('dlogit_maxes', dlogit_maxes, logit_maxes)

dlogit_maxes    | exact True  | approximate: True  | maxdiff: 0.0


In [22]:
dlogits = torch.zeros_like(logits)
dlogits[range(n), logits.max(1).indices] = 1
dlogits *= dlogit_maxes
dlogits += dnorm_logits.clone()
cmp('dlogits', dlogits, logits)

dlogits         | exact True  | approximate: True  | maxdiff: 0.0


In [23]:
dW2 = torch.ones_like(W2)
h.shape

torch.Size([32, 200])

* Now we move on to ```logits = h @ W2 + b2```
* Writing down the whole matrix multiply element by element shows that the partial derivatives of the expression ```h @ W2 + b2``` are also matrix multiplications, in particular

  $dL/dh = dL/dlogits * W2^T$
  
  $dL/dW2 = h^T * dL/dlogits$

* The offset b2 is broadcast across the columns of the result matrix, so each entry in a column of the $dL/dlogits$ matrix will add to the partial derivative of the $b1$ tensor
  
  $dL/db2 = dL/dlogits.sum(0)$

* A shortcut to avoid having to remember these formulas is to look at the shapes of the operations, which have to match up
* For example, the result of $dL/dh$ must match the dimensions of $h$
* It results from a matrix/vector multiplication of the other factor in the product ($W2$ in this example) with $dL/dlogits$

In [24]:
print(f'h: {h.shape}')
print(f'W2: {W2.shape}')
print(f'dlogits: {dlogits.shape}')

h: torch.Size([32, 200])
W2: torch.Size([200, 27])
dlogits: torch.Size([32, 27])


* The only way we arrive at a dimension of ```[32,200]``` is to multiply ```dlogits``` with ```W2``` transposed

In [25]:
dh = dlogits @ W2.T
cmp('h', dh, h)

h               | exact True  | approximate: True  | maxdiff: 0.0


* Same for $dL/dW2$, which must have the same dimensions as ```W2``` and results from a matrix/vector multiplication of ```dlogits``` and ```h```

In [26]:
print(f'W2: {W2.shape}')
print(f'h: {h.shape}')
print(f'dlogits: {dlogits.shape}')

W2: torch.Size([200, 27])
h: torch.Size([32, 200])
dlogits: torch.Size([32, 27])


In [27]:
dW2 = h.T @ dlogits
cmp('dW2', dW2, W2)

dW2             | exact True  | approximate: True  | maxdiff: 0.0


* And finally ```db2```, which is the sum of the columns of ```dlogits```

In [28]:
db2 = dlogits.sum(0, keepdim=True)
cmp('db2', db2, b2)

db2             | exact True  | approximate: True  | maxdiff: 0.0


* Next up is $dh/dhpreact$, which backpropagates through tanh
* One form of the derivative of $tanh$ is $1/cosh^2(x)$

In [29]:
dhpreact = (1. - h*h) * dh
cmp('dhpreact', dhpreact, hpreact)

dhpreact        | exact False | approximate: True  | maxdiff: 4.656612873077393e-10


In [30]:
bngain.shape, bnraw.shape, bnbias.shape

(torch.Size([1, 200]), torch.Size([32, 200]), torch.Size([1, 200]))

In [31]:
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnraw = (bngain * dhpreact)
dbnbias = dhpreact.sum(0, keepdim=True)
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar = -0.5 * (bnvar + 1e-5)**-1.5 * dbnvar_inv
dbndiff2 = 1/(n-1) * torch.ones_like(bndiff2) * dbnvar
dbndiff += 2*bndiff * dbndiff2
dbnmeani = -dbndiff.sum(0, keepdim=True)
dhprebn = (1/hprebn.shape[0])*torch.ones_like(hprebn)*dbnmeani
dhprebn += dbndiff
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0, keepdim=True)
demb = dembcat.view(-1, block_size, n_embd)
dC =  torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]


cmp('dbngain', dbngain, bngain)
cmp('dbnraw', dbnraw, bnraw)
cmp('dbnbias', dbnbias, bnbias)
cmp('dbndiff', dbndiff, bndiff)
cmp('dbnvar_inv', dbnvar_inv, bnvar_inv)
cmp('dbndiff2', dbndiff2, bndiff2)
cmp('dbnmeani', dbnmeani, bnmeani)
cmp('dhprebn', dhprebn, hprebn)
cmp('dembcat', dembcat, embcat)
cmp('dW1', dW1, W1)
cmp('db1', db1, b1)
cmp('dC', dC, C)

dbngain         | exact False | approximate: True  | maxdiff: 2.7939677238464355e-09
dbnraw          | exact False | approximate: True  | maxdiff: 9.313225746154785e-10
dbnbias         | exact False | approximate: True  | maxdiff: 3.725290298461914e-09
dbndiff         | exact False | approximate: True  | maxdiff: 9.313225746154785e-10
dbnvar_inv      | exact False | approximate: True  | maxdiff: 3.026798367500305e-09
dbndiff2        | exact False | approximate: True  | maxdiff: 2.9103830456733704e-11
dbnmeani        | exact False | approximate: True  | maxdiff: 3.725290298461914e-09
dhprebn         | exact False | approximate: True  | maxdiff: 9.313225746154785e-10
dembcat         | exact False | approximate: True  | maxdiff: 3.725290298461914e-09
dW1             | exact False | approximate: True  | maxdiff: 5.122274160385132e-09
db1             | exact False | approximate: True  | maxdiff: 3.725290298461914e-09
dC              | exact False | approximate: True  | maxdiff: 9.3132257461

## Excercise 2: Optimize Cross Entropy Loss Backward Pass
* The Cross Entropy, i.e. the function that calculates the loss from the raw logits, is just a single call in torch

In [32]:
F.cross_entropy(logits, Yb)

tensor(3.8712, grad_fn=<NllLossBackward0>)

* As it is a single function, the forward as well as the backward pass is much faster
* If you write down the cross entropy function and it's derivative analytically, then you arrive at a simple function
* 