### Becoming a Backprop Ninja!

Starting witht the code from `NN_v3` (The NN model with BatchNormalization), the aim is to replace `loss.backward()` with manual backpropagation calculations.

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random

%matplotlib inline

In [2]:
BLOCK_SIZE = 3
N_EMBED = 10
N_HIDDEN = 64
N_EPOCHS = 1000
BATCH_SIZE = 32
g = torch.Generator().manual_seed(2147483647)

In [3]:
def createWordsMapping(filename = 'names.txt'):
  words = open(filename, 'r').read().splitlines()
  chars = sorted(list(set(''.join(words))))
  stoi = {s:i+1 for i,s in enumerate(chars)}
  stoi['.'] = 0
  itos = {i:s for s,i in stoi.items()}
  n_vocab = len(stoi)
  return words, stoi, itos, n_vocab

def buildDataset(words, block_size):
  X, Y = [], []
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix]
  X = torch.tensor(X)
  Y = torch.tensor(Y)
  return X,Y

def buildDatasets(words, block_size):
  random.seed(42)
  random.shuffle(words)
  
  n1 = int(0.8 * len(words))
  n2 = int(0.9 * len(words))
  
  Xtr, Ytr = buildDataset(words[:n1], block_size)
  Xdev, Ydev = buildDataset(words[n1:n2], block_size)
  Xte, Yte = buildDataset(words[n2:], block_size)

  return Xtr, Ytr, Xdev, Ydev, Xte, Yte

In [4]:
words, stoi, itos, n_vocab = createWordsMapping()
Xtr, Ytr, Xdev, Ydev, Xte, Yte = buildDatasets(words, BLOCK_SIZE)
n_embed, n_hidden, block_size, n_epochs, batch_size = N_EMBED, N_HIDDEN, BLOCK_SIZE, N_EPOCHS, BATCH_SIZE
X, Y = Xtr, Ytr

In [5]:
# Utility function to compare gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
C = torch.randn((n_vocab, n_embed), generator=g)

W1 = torch.randn((n_embed * block_size, n_hidden), generator=g) * ((5/3)/((n_embed*block_size) ** 0.5))
b1 = torch.randn(n_hidden, generator=g) * 0.1
W2 = torch.randn((n_hidden, n_vocab), generator=g) * 0.1
b2 = torch.randn(n_vocab, generator=g) * 0.1

bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden)) * 0.1

bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))

parameters = [C, W1, b1, bngain, bnbias, W2, b2]
static_parameters = [bnmean_running, bnstd_running]

for p in parameters:
  p.requires_grad = True

print(f'Total Parameters: {sum(p.nelement() for p in parameters)}')

Total Parameters: 4137


In [7]:
# Minibatch construct
ix = torch.randint(0, X.shape[0], (batch_size,), generator=g)
X_batch, Y_batch = X[ix], Y[ix]

In [8]:
# Forward Pass

# Get Embeddings
emb = parameters[0][X_batch] # embed characters into vectors
embcat = emb.view(emb.shape[0], -1) # concatentae the vectors

# Linear Layer 1
hprebn = embcat @ W1 + b1 # hidden layer preactivation

# Batch Normalization Layer
bnmeani = 1/batch_size * hprebn.mean(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvar = 1/(batch_size-1) * (bndiff2).sum(0, keepdim=True) # Bessel's Correction: divide by (n-1), not n
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non-Linearity
h = torch.tanh(hpreact)
logits = h @ W2 + b2

# Cross Entropy Loss (Same as  F.cross_entropy(logits, Y_batch))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for neumerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = - logprobs[range(batch_size), Y_batch].mean()

# Backward Pass
for p in parameters:
  p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
  t.retain_grad()
loss.backward()
loss

tensor(3.4533, grad_fn=<NegBackward0>)

In [9]:
print(logprobs.shape)
print(logprobs[range(batch_size), Y_batch])
# loss = - (a+b+c+...batch_size numbers)/batch_size
# dloss/da = -1/batch_size

torch.Size([32, 27])
tensor([-3.8284, -2.7227, -3.3022, -3.3303, -4.3977, -2.9712, -3.5001, -3.8357,
        -3.4755, -4.3485, -3.1524, -1.7809, -2.9787, -2.6291, -2.8781, -3.6123,
        -3.8557, -3.4175, -3.5072, -3.6878, -2.9730, -3.4440, -4.4328, -4.2365,
        -3.5293, -3.0859, -3.2361, -4.3650, -3.1878, -3.8061, -3.7368, -3.2603],
       grad_fn=<IndexBackward0>)


In [13]:
dlogprobs = torch.zeros_like(logprobs)

dlogprobs[range(batch_size), Y_batch] = -1.0/batch_size
dprobs = (1.0/probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
