In [1]:
import torch
import torch.nn.functional as F
import importlib
import matplotlib.pyplot as plt
import random
mm = importlib.import_module("makemore-1")

In [2]:
def make_dataset_split(words, B):
  xs = []
  ys = []
  start_idx = mm.CHAR_INDICES['.']
  
  for word in words:
    # initial context_array = [0, 0, 0]
    context_array = [start_idx] * B
    xs.append(context_array.copy())
    word = f'{word}.'

    # loop invariant:
    #   - xs has a sequence of inputs already processed (possibly empty),
    #     followed by the next input to be processed
    #   - xs = [x_1, ..., x_{k-1}, x_k]^T
    #   - ys = [y_1, ..., y_{k-1}]^T has a sequence of outputs, one for each
    #     of the inputs already processed.
    for ch in word:
      ch_idx = mm.CHAR_INDICES[ch]
      ys.append(ch_idx)
      # assuming training data contains no "."'s and has properly been filtered out,
      # then the only way ch_idx == 0 is if we're at the end
      if ch_idx != 0:
        context_array.pop(0)
        context_array.append(ch_idx)
        xs.append(context_array.copy())

  X = torch.tensor(xs)
  Y = torch.tensor(ys)
  print(X.shape, Y.shape)
  return (X, Y)

# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [3]:
words = mm.load_words_from_file('names.txt')

# 80/10/10 split
B = 3
M = len(words)
random.seed(12345)
random.shuffle(words)
X_tr, Y_tr = make_dataset_split(words[:int(0.8 * M)], B)
X_val, Y_val = make_dataset_split(words[int(0.8 * M):int(0.9 * M)], B)
X_tst, Y_tst = make_dataset_split(words[int(0.9 * M):], B)
print(X_tr.shape[0] + X_val.shape[0] + X_tst.shape[0])

torch.Size([182512, 3]) torch.Size([182512])
torch.Size([22864, 3]) torch.Size([22864])
torch.Size([22770, 3]) torch.Size([22770])
228146


In [4]:
# linear -> batch norm -> tanh -> linear

size_vocab = 27
size_embed = 10
size_hidden = 100

RAND_SEED = 1729
gen = torch.Generator().manual_seed(RAND_SEED)

C = torch.randn((size_vocab, size_embed), generator=gen)
W1 = torch.randn((B*size_embed, size_hidden), generator=gen) * 0.1
b1 = torch.randn((size_hidden,), generator=gen) * 0.1
W2 = torch.randn((size_hidden, size_vocab), generator=gen) * 0.1
b2 = torch.randn((size_vocab,), generator=gen) * 0.1
bn_bias = torch.randn((1, size_hidden), generator=gen) * 0.1
bn_gain = torch.randn((1, size_hidden), generator=gen) * 0.1 + 1

parameters = [C, W1, b1, W2, b2, bn_bias, bn_gain]
for p in parameters:
  p.requires_grad = True
  
print(f"total # of params: {sum([p.nelement() for p in parameters])}")

total # of params: 6297


In [5]:
size_batch = 32
N = size_batch
b_idxs = torch.randint(0, X_tr.shape[0], (size_batch,), generator=gen)
X_b, Y_b = X_tr[b_idxs], Y_tr[b_idxs]

In [7]:
embeddings = C[X_b].view(X_b.shape[0], -1)
l1_act = embeddings @ W1 + b1

bn_eps = 1e-5
l1_mean = l1_act.mean(0, keepdim=True)
bn_diff = l1_act - l1_mean
bn_diff_sq = bn_diff**2
# unbiased sample variance
bn_var = 1./(size_batch - 1) * bn_diff_sq.sum(0, keepdim=True)
bn_var_inv = (bn_var + bn_eps)**-0.5
bn_act = bn_gain * (bn_diff * bn_var_inv) + bn_bias

nl_act = torch.tanh(bn_act)

logits = nl_act @ W2 + b2

logits_max = logits.max(1, keepdim=True).values
counts = torch.exp(logits - logits_max)
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv

neg_log_probs = -torch.log(probs)
loss = neg_log_probs[range(size_batch), Y_b].mean()


for p in parameters:
  p.grad = None

for v in [neg_log_probs, probs, counts_sum_inv, counts_sum, counts, logits_max, logits, nl_act, bn_act,
         bn_var_inv, bn_var, bn_diff_sq, bn_diff, l1_mean, l1_act, embeddings]:
  v.retain_grad()
loss.backward()
loss


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


tensor(3.4346, grad_fn=<MeanBackward0>)

In [17]:
# Exercise 1
d_neg_log_probs = torch.zeros((size_batch, size_vocab))
d_neg_log_probs[range(size_batch), Y_b] = 1./size_batch

d_probs = d_neg_log_probs * (-1. / probs)
d_counts_sum_inv = (d_probs * counts).sum(1, keepdim=True)
d_counts_sum = d_counts_sum_inv * (- counts_sum**-2)
d_counts = d_probs * counts_sum_inv + torch.ones((size_batch, size_vocab)) * d_counts_sum

cmp('d_neg_log_probs', d_neg_log_probs, neg_log_probs)
cmp('d_probs', d_probs, probs)
cmp('d_counts_sum_inv', d_counts_sum_inv, counts_sum_inv)
cmp('d_counts_sum', d_counts_sum, counts_sum)
cmp('d_counts', d_counts, counts)

d_neg_log_probs | exact: True  | approximate: True  | maxdiff: 0.0
d_probs         | exact: True  | approximate: True  | maxdiff: 0.0
d_counts_sum_inv | exact: True  | approximate: True  | maxdiff: 0.0
d_counts_sum    | exact: True  | approximate: True  | maxdiff: 0.0
d_counts        | exact: True  | approximate: True  | maxdiff: 0.0


# debug

In [None]:
dprobs = (1.0 / probs) * dlogprobs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum