# Backpropagation

| Date | User | Change Type | Remarks |  
| ---- | ---- | ----------- | ------- |
| 22/10/2025   | Martin | Created   | Notebook to learn about backpropagation | 
| 24/10/2025   | Martin | Update   | Started with individual element derivatives | 
| 27/10/2025   | Martin | Update   | Continued with backpropagation | 

# Content

* [Dataset Creation](#dataset-creation)

In [43]:
%load_ext watermark

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark


In [44]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# Dataset Creation

Same functions as previous section

In [45]:
# Read in all the words
words = open('data/names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [46]:
# Build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {v: k+1 for k, v in enumerate(chars)}
stoi['.'] = 0
itos = {v: k for k, v in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


In [47]:
def build_dataset(words):
  block_size = 3
  X, Y = [], []
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix]
    
  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)

  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

X_train, y_train = build_dataset(words[:n1])
X_val, y_val = build_dataset(words[n1:n2])
X_test, y_test = build_dataset(words[n2:])

block_size = 3

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


# Manual Backpropogation

Backpropogating through all of the variables as they are defined in the forward pass

In [48]:
# Function to compare the gradients between manually calculated and torch calculated
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item() # Checks if they are exactly the same
  app = torch.allclose(dt, t.grad) # Checks if they are within a tolerance
  maxdiff = (dt - t.grad).abs().max().item()
  print(f"{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}")

Initialisation of the weights and bias is slightly different by multiplying some small value to each of them. This is because soemtimes initialising with e.g all zeros could mask incorrect implementation of the backward pass

In [49]:
# Initialise weights and biases
n_embd = 10
n_hidden = 64

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g) 
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size) ** 0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1
# Layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

# BatchNorm parameters
bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.ones((1, n_hidden)) * 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
  p.requires_grad = True

print(f"Total number of parameters: {sum(p.nelement() for p in parameters)}")

Total number of parameters: 4137


In [50]:
batch_size = 32
n = batch_size

# Minibatch
ix = torch.randint(0, X_train.shape[0], (batch_size, ), generator=g)
X_batch, y_batch = X_train[ix], y_train[ix]

In [51]:
# Explicit forward pass
emb = C[X_batch]
embcat = emb.view(emb.shape[0], -1)

# Linear layer 1
h_pre_bn = embcat @ W1 + b1 # Hidden layer pre-activation
# BatchNorm layer
bn_mean_i = 1/n * h_pre_bn.sum(0, keepdim=True)
bn_diff = h_pre_bn - bn_mean_i
bn_diff_sq = bn_diff ** 2
bn_var = 1/(n-1) * (bn_diff_sq).sum(0, keepdim=True)
bn_var_inv = (bn_var + 1e-5)**-0.5
bn_raw = bn_diff * bn_var_inv
h_preact = bngain * bn_raw + bnbias
# Non-linearity
h = torch.tanh(h_preact)

# Linear layer 2
logits = h @ W2 + b2
# Cross entropy loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), y_batch].mean()

# Pytorch backward pass
for p in parameters:
  p.grad = None
for t in [
  logprobs, probs, counts, counts_sum, counts_sum_inv,
  norm_logits, logit_maxes, logits, h, h_preact, bn_raw,
  bn_var, bn_var_inv, bn_diff_sq, bn_diff, h_pre_bn, bn_mean_i,
  embcat, emb
]:
  t.retain_grad()
loss.backward()
loss

tensor(3.3603, grad_fn=<NegBackward0>)

Backpropogation here

- Sizes of the derivatives are always the same as their original tensors: Use the size of the tensors to figure out what to do
- From Math: Derivatives (gradients) will always sum their components

<u>Notes of Interpretation</u>

- `probs`: If the probability of the correct class is low, it's boosting the derivative of the log probs to adjust the weights for the correct class
- `norm_logits`: Gradient of `logit_maxes` should be zero (or close to due to floating point precision). This is because it only scales the values to prevent overflow during the exponent in the subsequent step. Since the output is a softmax, the probabilities don't change, therefore, it should not have any impact on the update step i.e 0 gradient

In [None]:
d_logprobs = torch.zeros_like(logprobs)
d_logprobs[range(n), y_batch] = -1.0/n # derivative
cmp('logprobs', d_logprobs, logprobs)

d_probs = (1.0 / probs) * d_logprobs # chain rule 
cmp('probs', d_probs, probs)

d_counts_sum_inv = (counts * d_probs).sum(1, keepdim=True)
cmp('count_sum_inv', d_counts_sum_inv, counts_sum_inv)

d_counts_sum = (-1.0 / counts_sum**2.0) * d_counts_sum_inv
cmp('counts_sum', d_counts_sum, counts_sum)

# d_counts is being used twice
# 1. probs = counts * counts_sum_inv
# 2. counts_sum = counts.sum(...)
d_counts = d_probs * counts_sum_inv + torch.ones_like(counts) * d_counts_sum
cmp('counts', d_counts, counts)

d_norm_logits = norm_logits.exp() * d_counts
cmp('norm_logits', d_norm_logits, norm_logits)

d_logit_maxes = (-d_norm_logits).sum(1, keepdim=True)
cmp('logit_maxes', d_logit_maxes, logit_maxes)

# 1. norm_logits
# 2. logit_maxes
temp = F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * d_logit_maxes # logits_max
d_logits = (d_norm_logits.clone()) + temp
cmp('logits', d_logits, logits)

# Linear layer
d_h = d_logits @ W2.T
d_W2 = h.T @ d_logits
d_b2 = d_logits = d_logits.sum(0)
cmp('h', d_h, h)
cmp('W2', d_W2, W2)
cmp('b2', d_b2, b2)

# Non-linearity
d_h_preact = (1.0 - h**2)  * d_h
cmp('h_preact', d_h_preact, h_preact)

# Batch Normalisation
d_bngain = (bn_raw * d_h_preact).sum(0, keepdim=True)
d_bn_raw = bngain * d_h_preact
d_bnbias = d_h_preact.sum(0, keepdim=True)
cmp('bngain', d_bngain, bngain)
cmp('bn_raw', d_bn_raw, bn_raw)
cmp('bnbias', d_bnbias, bnbias)

d_bn_var_inv = (bn_diff * d_bn_raw).sum(0, keepdim=True)
cmp('bn_var_inv', d_bn_var_inv, bn_var_inv)

# d_bn_var = (-0.5) * (1.0/(bn_var + 1e-5)**(3/2)) * d_bn_var_inv
d_bn_var = (-0.5*(bn_var + 1e-5)**-1.5) * d_bn_var_inv
cmp('bn_var', d_bn_var, bn_var)

d_bn_diff_sq = (1.0/(n-1)) * torch.ones_like(bn_diff_sq) * d_bn_var
cmp('bn_diff_sq', d_bn_diff_sq, bn_diff_sq)

# 1. bn_raw = bn_diff * bn_var_inv
# 2. dn_diff_sq = bn_diff**2
temp = 2 * bn_diff * d_bn_diff_sq
d_bn_diff = bn_var_inv * d_bn_raw + temp
cmp('bn_diff', d_bn_diff, bn_diff)

d_bn_mean_i = (-d_bn_diff).sum(0)
cmp('bn_mean_i', d_bn_mean_i, bn_mean_i)

# 1. bn_diff = h_pre_bn - bn_mean_i
d_h_pre_bn = d_bn_diff.clone()
cmp('h_pre_bn', d_h_pre_bn, h_pre_bn)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
count_sum_inv   | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
h_preact        | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bn_raw          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff:

In [41]:
bn_mean_i.shape, d_bn_diff.shape

(torch.Size([1, 64]), torch.Size([32, 64]))

In [35]:
d_counts.shape

torch.Size([32, 27])

In [33]:
d_probs.shape

torch.Size([32, 27])

In [None]:
%watermark