In [63]:
from typing import Final
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [64]:
# Load the dataset file and prepare the vocabulary.

words: Final[list[str]] = open("names.txt", 'r').read().splitlines()
chars: Final[list[str]] = sorted(list(set(''.join(words))))
stoi: Final[dict[str, int]] = {char: i + 1 for i, char in enumerate(chars)}
stoi['.'] = 0
itos: Final[dict[int, str]] = {i: char for char, i in stoi.items()}
char_cnt: Final[int] = len(stoi)

In [65]:
# Create the dataset and the corresponding labels.

def create_dataset(words: list[str], block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
    X: list[list[int]] = []
    Y: list[int] = []
    for word in words:
        context: list[int] = [0] * block_size
        for ix in word + '.':
            X.append(context)
            Y.append(stoi[ix])
            context = context[1:] + [stoi[ix]]
    X_t = torch.tensor(X)
    Y_t = torch.tensor(Y)
    print(X_t.shape, Y_t.shape)
    return torch.tensor(X), torch.tensor(Y)

device: Final[torch.device] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [66]:
# Prepare the dataset for training, validation, and testing.

import random
random.seed(42)
random.shuffle(words)

block_size: Final[int] = 3 # context length
train_set_ratio: float = 0.8
valid_set_ratio: float = 0.1
test_set_ratio: float = 0.1

n1: int = int(len(words) * train_set_ratio)
n2: int = int(len(words) * (train_set_ratio + valid_set_ratio))

X_train, Y_train = create_dataset(words[:n1], block_size)
X_val, Y_val = create_dataset(words[n1:n2], block_size)
X_test, Y_test = create_dataset(words[n2:], block_size)

X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_val = X_val.to(device)
Y_val = Y_val.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)


torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [67]:
# Define gradient checking functions.
def cmp_grad(p_name: str, dp: torch.Tensor, p: torch.Tensor) -> None:
    assert p.grad is not None, f"Gradient for {p_name} is None"
    exact_eq: Final[bool] = bool(torch.all(dp == p.grad).item())
    apprx_eq: Final[bool] = torch.allclose(dp, p.grad)
    max_diff: Final[float] = (dp - p.grad).abs().max().item()
    print(f"{p_name:18s} | shape equal: {dp.shape == p.grad.shape} | "
          f"exact equal: {str(exact_eq):5s} | "
          f"approximate equal: {str(apprx_eq):5s} | max_diff: {max_diff}")


In [68]:
# With batch normalization, we don't need to initialize the bias, but we
# do it for calculating the gradient manually.
# zero bias can mask the incorrect calculation of the gradient, so we
# initialize it to a small value so that we can still check its value.

num_neurons: Final[int] = 300
embed_dim: Final[int] = 30

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((char_cnt, embed_dim), generator=g).to(device)
# Layer 1
W1 = torch.randn((block_size * embed_dim, num_neurons), generator=g).to(device) * (5 / 3) / ((embed_dim * block_size) ** 0.5)
b1 = torch.randn(num_neurons, generator=g).to(device) * 0.1
# Layer 2
W2 = torch.randn((num_neurons, char_cnt), generator=g).to(device) * 0.1
b2 = torch.randn(char_cnt, generator=g).to(device) * 0.1

# Batch normalization parameters
bn_gain = torch.randn((1, num_neurons), generator=g).to(device) * 0.1 + 1.0
bn_bias = torch.randn((1, num_neurons), generator=g).to(device) * 0.1
bn_mean_running = torch.zeros((1, num_neurons)).to(device)
bn_std_running = torch.ones((1, num_neurons)).to(device)

parameters = [C, W1, b1, W2, b2, bn_gain, bn_bias]
for p in parameters:
    p.requires_grad = True
print(f"parameters: {sum(p.nelement() for p in parameters)}")

@torch.no_grad()
def calculate_loss(X_t: torch.Tensor, Y_t: torch.Tensor) -> torch.Tensor:
    emb = C[X_t]
    embcat = emb.view(-1, block_size * embed_dim)
    h_preact = embcat @ W1 + b1
    # Use the running mean/std.
    h_preact_norm = (h_preact - bn_mean_running) / bn_std_running
    h_preact = h_preact_norm * bn_gain + bn_bias
    h = torch.tanh(h_preact)
    logits = b2 + h @ W2
    return F.cross_entropy(logits, Y_t)

epochs: Final[int] = 200000
loss = torch.tensor(1000.0)
mini_batch_size: Final[int] = 128

lossi: list[float] = []
stepi = []


parameters: 36837


In [69]:
# Mini-batch
ix = torch.randint(0, X_train.shape[0], (mini_batch_size,))
X_t_mini = X_train[ix]
Y_t_mini = Y_train[ix]

# Forward pass:

# Get the embedding for the mini-batch.
emb = C[X_t_mini]
embcat = emb.view(-1, block_size * embed_dim)

# Layer 1 starts.
h_pre_bn = embcat @ W1 + b1

# Batch normalization
# dim=0 means that the rows are the ones to be eliminated:
# adding all batch elements together.
bn_mean_i = 1 / mini_batch_size * h_pre_bn.sum(dim=0, keepdim=True)
bn_diff = h_pre_bn - bn_mean_i
bn_diff_sq = bn_diff * bn_diff
bn_var_i = 1 / (mini_batch_size - 1) * bn_diff_sq.sum(dim=0, keepdim=True)
bn_std_i = (bn_var_i + 1e-5)**(0.5)
bn_std_i_inverse = (bn_var_i + 1e-5)**(-0.5)
bn_raw = bn_diff * bn_std_i_inverse
h_pre_act = bn_raw * bn_gain + bn_bias

# Nonlinear activation
h = torch.tanh(h_pre_act)
# Layer 1 ends.

# Layer 2 starts.
logits = h @ W2 + b2

# Calculate cross entropy loss: loss = F.cross_entropy(logits, Y_t_mini)
logit_maxes = logits.max(dim=1, keepdim=True).values
normalized_logits = logits - logit_maxes # for numerical stability
counts = normalized_logits.exp()
counts_sum = counts.sum(dim=1, keepdim=True)
counts_sum_inverse = counts_sum**(-1)
probs = counts * counts_sum_inverse
log_probs = probs.log()
loss = -log_probs[torch.arange(mini_batch_size), Y_t_mini].mean()
# Layer 2 ends.

with torch.no_grad():
    bn_mean_running = 0.999 * bn_mean_running + 0.001 * bn_mean_i
    bn_std_running = 0.999 * bn_std_running + 0.001 * bn_std_i

# Backward pass
for p in parameters:
    p.grad = None

# Retain the gradient for all the intermediate variables
for t in [
    log_probs, probs, counts_sum_inverse, counts_sum, counts,
    normalized_logits, logit_maxes, logits, h, h_pre_act, bn_raw,
    bn_std_i_inverse, bn_std_i, bn_var_i, bn_diff_sq, bn_diff, bn_mean_i,
    h_pre_bn, embcat, emb
]:
    t.retain_grad()
loss.backward()
print(f"loss: {loss.item():.4f}")

loss: 3.8407


In [100]:
bn_raw.shape, bn_gain.shape, bn_bias.shape

(torch.Size([128, 300]), torch.Size([1, 300]), torch.Size([1, 300]))

In [71]:
log_probs.shape, probs.shape
# loss = -(a + b + c) / 3
# dloss/da = -1 / n

(torch.Size([128, 27]), torch.Size([128, 27]))

In [72]:
counts.shape, counts_sum_inverse.shape
# c = a * b, but with tensors:
# a[3x3] * b[3x1] ->
# a11 * b11 + a12 * b21 + a13 * b31
# a12 * b12 + a22 * b22 + a32 * b32
# a13 * b13 + a23 * b23 + a33 * b33
# c[3x3]

(torch.Size([128, 27]), torch.Size([128, 1]))

In [73]:
counts.shape, counts_sum.shape
# a11 a12 a13 -> b1 (=sum(a11, a12, a13))
# a21 a22 a23 -> b2 (=sum(a21, a22, a23))
# a31 a32 a33 -> b3 (=sum(a31, a32, a33))

(torch.Size([128, 27]), torch.Size([128, 1]))

In [74]:
normalized_logits.shape

torch.Size([128, 27])

In [75]:
bn_gain.shape, bn_bias.shape

(torch.Size([1, 300]), torch.Size([1, 300]))

In [76]:
bn_diff.shape, bn_std_i_inverse.shape

(torch.Size([128, 300]), torch.Size([1, 300]))

In [77]:
bn_std_i_inverse.shape, bn_var_i.shape


(torch.Size([1, 300]), torch.Size([1, 300]))

In [78]:
# bn_var_i = 1 / (mini_batch_size - 1) * bn_diff_sq.sum(dim=0, keepdim=True)

bn_diff_sq.shape, bn_var_i.shape

# a11 a12 a13
# a21 a22 a23
# ->
# b1  b2  b3
# b1 = 1 / (mini_batch_size - 1) * (a11 + a21)
# b2 = 1 / (mini_batch_size - 1) * (a12 + a22)
# b3 = 1 / (mini_batch_size - 1) * (a13 + a23)

(torch.Size([128, 300]), torch.Size([1, 300]))

In [79]:
bn_diff.shape, h_pre_bn.shape, bn_mean_i.shape

(torch.Size([128, 300]), torch.Size([128, 300]), torch.Size([1, 300]))

In [None]:
bn_mean_i.shape, h_pre_bn.shape


(torch.Size([1, 300]), torch.Size([128, 300]))

In [97]:
# embcat = emb.view(-1, block_size * embed_dim)
embcat.shape, emb.shape

(torch.Size([128, 90]), torch.Size([128, 3, 30]))

In [101]:
# emb = C[X_t_mini]
print(emb.shape, C.shape, X_t_mini.shape)
print(X_t_mini[:5])

torch.Size([128, 3, 30]) torch.Size([27, 30]) torch.Size([128, 3])
tensor([[19,  7,  1],
        [ 1, 13,  5],
        [ 0,  0, 18],
        [ 0, 12,  5],
        [ 1,  8,  9]], device='cuda:0')


In [102]:
# Exercise 1: backprop through the whole thing manually,
# backpropagating through exactly all of the variables
# as they are defined in the forward pass above, one by one

d_log_probs = torch.zeros_like(log_probs)
d_log_probs[torch.arange(mini_batch_size), Y_t_mini] = -1 / mini_batch_size

d_probs = (1.0 / probs) * d_log_probs

d_counts_sum_inverse = (counts * d_probs).sum(dim=1, keepdim=True)

# d_counts part1
d_counts = counts_sum_inverse * d_probs

d_counts_sum = (-1) * (counts_sum**(-2)) * d_counts_sum_inverse

# d_counts part2
# This can also be simplified to: d_counts += d_counts_sum
# Because `d_counts`'s shape is already the same as `counts`.
# We will do broadcasting. `torch.ones_like(counts)` is unnecessary.
d_counts += torch.ones_like(counts) * d_counts_sum

# `normalized_logits.exp()` is already `counts`.
d_normalized_logits = counts * d_counts

# d_logits part1
d_logits = d_normalized_logits.clone()

d_logit_maxes = -d_logits.sum(dim=1, keepdim=True)

# d_logits part2
d_logits += F.one_hot(logits.argmax(dim=1), num_classes=logits.shape[1]) * d_logit_maxes

d_b2 = d_logits.sum(dim=0)
d_W2 = h.T @ d_logits
d_h = d_logits @ W2.T

d_h_pre_act = (1.0 - h**2) * d_h

d_bn_raw = bn_gain * d_h_pre_act # Broadcast happens here.
d_bn_gain = (bn_raw * d_h_pre_act).sum(dim=0, keepdim=True)
d_bn_bias = d_h_pre_act.sum(dim=0, keepdim=True)

d_bn_diff = bn_std_i_inverse * d_bn_raw # Broadcast happens here.
d_bn_std_i_inverse = (bn_diff * d_bn_raw).sum(dim=0, keepdim=True)

d_bn_var_i = (-(0.5) * (bn_var_i + 1e-5)**(-1.5)) * d_bn_std_i_inverse

d_bn_diff_sq = (1.0 / (mini_batch_size - 1)) * torch.ones_like(bn_diff_sq) * d_bn_var_i

d_bn_diff += 2.0 * bn_diff * d_bn_diff_sq

d_bn_mean_i = (-1.0) * d_bn_diff.sum(dim=0, keepdim=True)
d_h_pre_bn = d_bn_diff.clone()

d_h_pre_bn += (1.0 / mini_batch_size) * torch.ones_like(h_pre_bn) * d_bn_mean_i

d_embcat = d_h_pre_bn @ W1.T
d_W1 = embcat.T @ d_h_pre_bn
d_b1 = d_h_pre_bn.sum(dim=0)

d_emb = d_embcat.view(emb.shape)

d_C = torch.zeros_like(C)
for i in range(X_t_mini.shape[0]):
    for j in range(X_t_mini.shape[1]):
        k = X_t_mini[i, j]
        d_C[k] += d_emb[i, j]

# and checking the gradients with cmp_grad.
print("Gradient check:")
cmp_grad("log_probs", d_log_probs, log_probs)
cmp_grad("probs", d_probs, probs)
cmp_grad("counts_sum_inverse", d_counts_sum_inverse, counts_sum_inverse)
cmp_grad("counts_sum", d_counts_sum, counts_sum)
cmp_grad("counts", d_counts, counts)
cmp_grad("normalized_logits", d_normalized_logits, normalized_logits)
cmp_grad("logit_maxes", d_logit_maxes, logit_maxes)
cmp_grad("logits", d_logits, logits)
cmp_grad("b2", d_b2, b2)
cmp_grad("W2", d_W2, W2)
cmp_grad("h", d_h, h)
cmp_grad("h_pre_act", d_h_pre_act, h_pre_act)
cmp_grad("bn_raw", d_bn_raw, bn_raw)
cmp_grad("bn_gain", d_bn_gain, bn_gain)
cmp_grad("bn_bias", d_bn_bias, bn_bias)
cmp_grad("bn_std_i_inverse", d_bn_std_i_inverse, bn_std_i_inverse)
cmp_grad("bn_var_i", d_bn_var_i, bn_var_i)
cmp_grad("bn_diff_sq", d_bn_diff_sq, bn_diff_sq)
cmp_grad("bn_diff", d_bn_diff, bn_diff)
cmp_grad("bn_mean_i", d_bn_mean_i, bn_mean_i)
cmp_grad("h_pre_bn", d_h_pre_bn, h_pre_bn)
cmp_grad("embcat", d_embcat, embcat)
cmp_grad("W1", d_W1, W1)
cmp_grad("b1", d_b1, b1)
cmp_grad("emb", d_emb, emb)
cmp_grad("C", d_C, C)


Gradient check:
log_probs          | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
probs              | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
counts_sum_inverse | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
counts_sum         | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
counts             | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
normalized_logits  | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
logit_maxes        | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
logits             | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
b2                 | shape equal: True | exact equal: True  | approximate equal: True  | max_diff: 0.0
W2                 | shape equal: True | exact equal: Tru

In [82]:
d_normalized_logits.shape, logits.shape, logit_maxes.shape

(torch.Size([128, 27]), torch.Size([128, 27]), torch.Size([128, 1]))

In [83]:
print(f"d_logits: {d_logits.shape}\nh: {h.shape}\nW2: {W2.shape}\nb2: {b2.shape}")
# Given logits = h @ W2 + b2
# We know that:
# d_h ([128, 300]) must be calculated with logits([128, 27]) and W2([300, 27])
# So we need to get [128, 300] from [128, 27] and [300, 27]
# which is [128, 27] @ [27, 300]
# which is [128, 27] @ [300, 27].T
# so
# d_h = d_logits @ W2.T

# We know that:
# d_W2 ([300, 27]) must be calculated with h([128, 300]) and d_logits([128, 27])
# So we need to get [300, 27] from [128, 300] and [128, 27]
# which is [128, 300].T @ [128, 27]
# so
# d_W2 = h.T @ d_logits

# We know that:
# d_b2 ([27]) must be calculated with d_logits([128, 27])
# So we need to get [27] from [128, 27]
# which is [128, 27].sum(dim=0)
# so
# d_b2 = d_logits.sum(dim=0)

d_logits: torch.Size([128, 27])
h: torch.Size([128, 300])
W2: torch.Size([300, 27])
b2: torch.Size([27])


![Derivative Calculation](Matrix_Derivative.jpg)

In [84]:
d_h_pre_act.shape

torch.Size([128, 300])

In [89]:
d_h_pre_bn.shape, embcat.shape, W1.shape, b1.shape


(torch.Size([128, 300]),
 torch.Size([128, 90]),
 torch.Size([90, 300]),
 torch.Size([300]))

In [92]:
d_bn_diff.shape, h_pre_bn.shape, bn_mean_i.shape


(torch.Size([128, 300]), torch.Size([128, 300]), torch.Size([1, 300]))

In [95]:
d_bn_mean_i.shape, bn_mean_i.grad.shape

(torch.Size([300]), torch.Size([1, 300]))