In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

from helpers.makemore_helpers import make_char2idx_map, build_dataset

words = open(r'data\names.txt','r').read().splitlines()
char2idx = make_char2idx_map(words)
idx2char = {v: k for k, v in char2idx.items()}

block_size = 3

In [3]:
train_count = int(len(words) * 0.8)
n_1 = int(len(words) * 0.9)

x_train, y_train = build_dataset(words[:train_count], block_size, char2idx)
x_val, y_val = build_dataset(words[train_count:n_1], block_size, char2idx)
x_test, y_test = build_dataset(words[n_1:], block_size, char2idx)

total_train = len(x_train)
total_val = len(x_val)
total_test = len(x_test)

print(f"Sizes:", total_train, total_val, total_test)
print("Shapes:", x_train.shape, y_train.shape)

Sizes: 182778 22633 22735
Shapes: torch.Size([182778, 3]) torch.Size([182778])


In [168]:
# Set up parameters
rand_gen = torch.Generator().manual_seed(42)
embedding_size = 10
n_hidden = 200
vocab_size = len(char2idx)

embedding_matrix = torch.randn((vocab_size, 10), generator=rand_gen)
# Because we tanh to remain in the active region
W1 = torch.randn((embedding_size * block_size, n_hidden), generator=rand_gen) *  ( (5/3) / (embedding_size * block_size)**0.5) # this is kaiming initialization
b1 = torch.randn(n_hidden, generator=rand_gen) * 0.001 # Because we tanh to remain in the active region
W2 = torch.randn((n_hidden, vocab_size), generator=rand_gen) * 0.01 # Because we want smaller logits
b2 = torch.randn(vocab_size, generator=rand_gen) * 0.1 # Because we want smaller logits

batch_norm_gain = torch.randn((1,n_hidden)) * 0.1 + 1
batch_norm_bias = torch.randn((1,n_hidden)) * 0.1

mean_running = torch.zeros((1,n_hidden))
std_running = torch.ones((1,n_hidden))

params = [embedding_matrix, W1, W2, b2, b1, batch_norm_gain, batch_norm_bias]
for p in params:
    p.requires_grad = True

In [169]:
batch_size = 32
# construct a minibatch
ix = torch.randint(0, x_train.shape[0], (batch_size,), generator=rand_gen)
xs, ys = x_train[ix], y_train[ix] # batch X,Y

In [170]:
def compare_gradients(name, manual_gradient, target_tensor):
  exact_matches = torch.all(manual_gradient == target_tensor.grad).item()
  appoximate_clossness = torch.allclose(manual_gradient, target_tensor.grad)
  max_diff = (manual_gradient - target_tensor.grad).abs().max().item()
  
  print(f'{name:15s} | exact: {str(exact_matches):5s} | approximate: {str(appoximate_clossness):5s} | maxdiff: {max_diff}')

In [171]:

# Forward pass
# ----------------------------------------

embedding = embedding_matrix[xs] # embed the characters into vectors
emb_concat = embedding.view(embedding.shape[0], -1) # concatenate the vectors

# Linear layer 1
pre_act_1_before_batch_norm = emb_concat @ W1 + b1 # hidden layer pre-activation

# BatchNorm layer
bn_mean_i_th = (1/batch_size) * pre_act_1_before_batch_norm.sum(0, keepdim=True)
bn_diff = pre_act_1_before_batch_norm - bn_mean_i_th                #
bn_diff_sqr = bn_diff**2                                            #
# note: Bessel's correction (dividing by n-1, not n)                # Calculate the variance
bn_variance = 1/(batch_size-1)*(bn_diff_sqr).sum(0, keepdim=True)   #
bn_variance_inv = (bn_variance + 1e-5)**-0.5                        #
bn_raw_out = bn_diff * bn_variance_inv

act_batch_norm = batch_norm_gain * bn_raw_out + batch_norm_bias


# Non-linearity
activation_1 = torch.tanh(act_batch_norm) # hidden layer

# Linear layer 2
logits = activation_1 @ W2 + b2 # output layer

# cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability. Because since we are using exp, we want to the values to be exploding to infinity
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True)
counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
probs = counts * counts_sum_inv # Doing count/count_sum
logprobs = probs.log()
loss = -logprobs[range(batch_size), ys].mean()


# pytorch backward pass
intermediate_variables = [embedding, emb_concat, pre_act_1_before_batch_norm, bn_mean_i_th, bn_diff, bn_diff_sqr, bn_variance, bn_variance_inv, bn_raw_out, act_batch_norm, activation_1, logits, logit_maxes, norm_logits, counts, counts_sum, counts_sum_inv, probs, logprobs, loss]
for v in intermediate_variables:
    v.retain_grad()

for p in params:
    p.grad = None

loss.backward()



In [174]:
# Calculate gradients manually

# logprobs is a matrix of size (batch_size, vocab_size). For us: (32,27)
# logprobs[range(batch_size), ys] is a vector of size (batch_size,). For us: (32,)
# For every row in logprobs, ^ this line plucks out log probs of correct characters(indexed by ys)

# Loss = -logprobs[range(batch_size), ys].mean() 
# loss = -(a+b+c.....)/batch_size ( where a,b,c are the logprobs of correct characters)
# dloss/dlogprobs = -1/batch_size for plucked out logprobs, and 0 for others

d_logprobs = torch.zeros_like(logprobs) # (32,27)
d_logprobs[range(batch_size), ys] = -1/batch_size
compare_gradients('logprobs', d_logprobs, logprobs)


# logprobs = log(probs)
# d_logprobs/d_probs = 1/probs (dlog/dx = 1/x)
# d_loss/d_probs = d_loss/d_logprobs * d_logprobs/d_probs = d_logprobs * 1/probs (chain rule)
d_probs = d_logprobs * (1/probs)
compare_gradients('probs', d_probs, probs)


# counts.shape, counts_sum_inv.shape = (32,27), (32,1)
# c = a * b
# a[3x3] * b[3x1] -> broadcast -> c[3x3] as follows:
# a11 * b1, a12 * b1, a13 * b1
# a21 * b1, a22 * b1, a23 * b1
# a31 * b1, a32 * b1, a33 * b1
#
# probs = counts * counts_sum_inv

# d_probs/d_count_sum_ind = counts (d x*y / dy = x)
# why sum? because we want to sum up the gradients each copy caused due to broadcasting
d_counts_sum_inv = (counts * d_probs).sum(1, keepdims=True)
compare_gradients('counts_sum_inv', d_counts_sum_inv, counts_sum_inv)

# Part 1 of d_count: counts countributes to probs and counts_sum. Two branches
d_counts = counts_sum_inv * d_probs # (? why not summing here)


# d_count_sum/d_count_sum_inv = -counts_sum**-2 (d x^-1 / dx = -1/x**2)
d_count_sum = (-counts_sum**-2) * d_counts_sum_inv
compare_gradients('counts_sum', d_count_sum, counts_sum)

# Differentiate the broadcaste: counts_sum = counts.sum(1, keepdims=True)
# a11 + a12 + a13 = b1   |      [da11, da12 , da13] = [db1, db1, db1]      | 
# a21 + a22 + a23 = b2   | =>   [da21, da22 , da23] = [db2, db2, db2]      | Because d(x+y+z)/dx = 1, because of chain rule you multiply the gradient with 1  
# a31 + a32 + a33 = b3   |      [da31, da32 , da33] = [db3, db3, db3]      |

# Part 2 of d_count:
d_counts += torch.ones_like(counts) * d_count_sum # (32,27)
compare_gradients('counts', d_counts, counts)


# d_count/d_norm_logit = exp(norm_logit) (d e**x / dx = e**x)
d_norm_logits = norm_logits.exp() * d_counts
compare_gradients('norm_logits', d_norm_logits, norm_logits)




# norm_logits = logits - logit_maxes
# norm_logits.shape, logits.shape, logit_maxes.shape = (32,27), (32,27), (32,1) => There is a broadcast
# c11, c12, c13     a11, a12, a13    b1
# c21, c22, c23  =  a21, a22, a23 -  b2
# c31, c32, c33     a31, a32, a33    b3

# c11 = a11 - b1
# d_c11 (local derivative) = (d_c11 / d_a11) + (d_c11 / d_b1) = -b1 + 1
# And we have to do a sum across the columns, because we are summing up the gradients of each copy due to broadcasting

# Branch 1 for d_logits
d_logits = d_norm_logits.clone()
d_logit_maxes = (-d_norm_logits).sum(1, keepdims=True)
compare_gradients('logit_maxes', d_logit_maxes, logit_maxes)


# d_logits += torch.ones_like(logits) * d_logit_maxes # This is equivalent to the line below
d_logits += F.one_hot( logits.max(1).indices, num_classes=logits.shape[1]).float() * d_logit_maxes
compare_gradients('logits', d_logits, logits)



# ------------------------------- I DON'T UNDERSTAND THE FORMULAS GRADIENT IN THIS PART ----------------------------------
# logits = activation_1 @ W2 + b2
# activation_1.shape, W2.shape,d_logits.shape, b2.shape = (32, 200), (200, 27), (32,27) (27,)

# TRICK to get the formula for d_activation_1/d_logits
# - Shape of d_activation has to be same as activation_1
# - d_activation is some kind of matrix multiplication between d_logits and W2 such the shapes work out
# - d_W2 is some kind of matrix multiplication between activation_1 and d_logits such the shapes work out


d_activation_1 = d_logits @ W2.T
compare_gradients('activation_1', d_activation_1, activation_1)

d_W2 = activation_1.T @ d_logits
compare_gradients('W2', d_W2, W2)

d_b2 = d_logits.sum(0, keepdims=True)
compare_gradients('b2', d_b2, b2)

# -------------------------------------------------------------------------------------------------------------------------

# d_act_batch_norm/d_activation_1. We already know d_activation_1/d_logits, which we need to pass throgh the tanh to get d_act_batch_norm
d_act_batch_norm = d_activation_1 * (1 - torch.tanh(act_batch_norm)**2)
compare_gradients('act_batch_norm', d_act_batch_norm, act_batch_norm)



# for the batch norm
d_bn_gain = (bn_raw_out * d_act_batch_norm).sum(0, keepdims=True)
compare_gradients('bn_gain', d_bn_gain, batch_norm_gain)

d_bn_bias = d_act_batch_norm.sum(0, keepdims=True)
compare_gradients('bn_bias', d_bn_bias, batch_norm_bias)

d_bn_raw_out = batch_norm_gain * d_act_batch_norm
compare_gradients('bn_raw_out', d_bn_raw_out, bn_raw_out)



# bn_raw_out = bn_diff * bn_variance_inv
# Part 1 of d_bn_diff
d_bn_diff = bn_variance_inv * d_bn_raw_out

d_bn_variance_inv = (bn_diff * d_bn_raw_out).sum(0, keepdims=True)
compare_gradients('bn_variance_inv', d_bn_variance_inv, bn_variance_inv)    

# bn_variance_inv = (bn_variance + 1e-5)**-0.5 
d_bn_variance = -0.5 * (bn_variance + 1e-5)**-1.5 * d_bn_variance_inv
compare_gradients('bn_variance', d_bn_variance, bn_variance)


# bn_variance = 1/(batch_size-1)*(bn_diff_sqr).sum(0, keepdim=True) 
#        | a11    a12 |                          | b1 |           | (a11+a21)/n |
# 1/n *  | a21    a22 |.sum(0,keepdim=True) ==>  | b2 | * 1/n ==> | (a12+a22)/n | 
#  
d_bn_diff_sqr = (1.0/(batch_size-1))*torch.ones_like(bn_diff_sqr) * d_bn_variance
compare_gradients('bn_diff_sqr', d_bn_diff_sqr, bn_diff_sqr)

# bn_diff_sqr = bn_diff**2
# Part 2 of d_bn_diff
d_bn_diff += 2*bn_diff * d_bn_diff_sqr
compare_gradients('bn_diff', d_bn_diff, bn_diff)

# bn_diff = pre_act_1_before_batch_norm - bn_mean_i_th  (Broadcasting in forward pass <=> Sum of gradients in backward pass)
d_pre_act_1_before_batch_norm =  d_bn_diff.clone()
d_bn_mean_i_th = (-torch.ones_like(bn_diff) * d_bn_diff).sum(0, keepdims=True)
compare_gradients('bn_mean_i_th', d_bn_mean_i_th, bn_mean_i_th)

# bn_mean_i_th = (1/batch_size) * pre_act_1_before_batch_norm.sum(0, keepdim=True)
d_pre_act_1_before_batch_norm += (1/batch_size) * torch.ones_like(pre_act_1_before_batch_norm) * d_bn_mean_i_th
compare_gradients('pre_act_1_before_batch_norm', d_pre_act_1_before_batch_norm, pre_act_1_before_batch_norm)


# pre_act_1_before_batch_norm = emb_concat @ W1 + b1 # hidden layer pre-activation
d_W1 = emb_concat.T @ d_pre_act_1_before_batch_norm
compare_gradients('W1', d_W1, W1)

d_emb_concat = d_pre_act_1_before_batch_norm @ W1.T
compare_gradients('emb_concat', d_emb_concat, emb_concat)

d_b1 = d_pre_act_1_before_batch_norm.sum(0)
compare_gradients('b1', d_b1, b1)


# emb_concat = embedding.view(embedding.shape[0], -1)
d_embedding = d_emb_concat.view(embedding.shape)
compare_gradients('embedding', d_embedding, embedding)


# embedding = embedding_matrix[xs] 
d_embedding_matrix = torch.zeros_like(embedding_matrix)
for i in range(xs.shape[0]):
    for j in range(xs.shape[1]):
        d_embedding_matrix[xs[i,j]] += d_embedding[i,j] 

compare_gradients('embedding_matrix', d_embedding_matrix, embedding_matrix)









logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
activation_1    | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
act_batch_norm  | exact: True  | approximate: True  | maxdiff: 0.0
bn_gain         | exact: True  | approximate: True  | maxdiff: 0.0
bn_bias         | exact: True  | approximate: True  | maxdiff: 0.0
bn_raw_out      | exact: True  | approximate: True  | maxdiff: