In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import math
%matplotlib inline

In [None]:
g = torch.Generator().manual_seed(2147483647)

In [None]:
words = open('names.txt', 'r').read().splitlines()

In [None]:
chars = ['.'] + sorted(list(set(''.join(words))))
encode = lambda c: chars.index(c)
decode = lambda i: chars[i]

In [None]:
block_size = 3
vocab_size = len(chars)

In [None]:
def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for c in w + '.':
            ix = encode(c)
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xval, Yval = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [None]:
def cmp(s, dt, t):
    exact = torch.all(dt == t)
    approx = torch.allclose(dt, t.grad)
    max_diff = (torch.abs(dt - t.grad).abs().max().item())
    print(f'{s:15s} exact: {str(exact):5s} approx: {str(approx):5s} max_diff: {max_diff}')


In [None]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1 # using b1 just for fun, it's useless because of BN
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass. Because when everything is zero the
# expression of gradient is simplified.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

4137


In [None]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (n,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

print(Xb.shape, Yb.shape)


torch.Size([32, 3]) torch.Size([32])


In [None]:
# embed the inputs
emb = C[Xb] # (n, block_size, n_embd)
embcat = emb.view(Xb.shape[0], -1) # (n, block_size * n_embd)
# calxulate the linear layer outputs
hprebn = embcat @ W1 + b1 # (n, n_hidden)
# get the mean of the batch
bnmeani = hprebn / hprebn.sum(dim=0, keepdim=True)
# zero center the batch
bndiff = hprebn - bnmeani
# variance is the sum of the squared differences of examples in the batch from the mean, divided by n-1 (Bessel's correction, divide by n-1 instead of n to improve the estimate of the population variance)
bndiff2 = bndiff**2
bnvar = 1/(n-1) * bndiff2.sum(dim=0, keepdim=True)
# standard deviation is the square root of the variance
# get the inverse of the standard deviation summed with a small constant to avoid division by zero
bnvar_inv = (bnvar + 1e-5)**-0.5
# scale the centered batch by the inverse of the standard deviation
bnraw = bndiff * bnvar_inv
# apply learned gain and bias to the scaled batch
hpreact = bngain * bnraw + bnbias
# apply the tanh activation function
h = torch.tanh(hpreact)
# calculate the logits for the output layer
logits = h @ W2 + b2
# for numerical stability, subtract the maximum logit value in each row
logit_maxes = logits.max(dim=1, keepdim=True).values
norm_logits = logits - logit_maxes # for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(dim=1, keepdim=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean() # we first index by the barch index and then for that particular batch row we take the probability of the character oobserved in the dataset

for p in parameters:
    p.grad = None

for t in reversed([emb, embcat, hprebn, bnmeani, bndiff, bndiff2, bnvar, bnvar_inv, bnraw, hpreact, h, logits, logit_maxes, norm_logits, counts, counts_sum, counts_sum_inv, probs, logprobs]):
    t.retain_grad()

loss.backward()

loss


tensor(3.5024, grad_fn=<NegBackward0>)

In [None]:
Yb

tensor([ 8, 14, 15, 22,  0, 19,  9, 14,  5,  1, 20,  3,  8, 14, 12,  0, 11,  0,
        26,  9, 25,  0,  1,  1,  7, 18,  9,  3,  5,  9,  0, 18])

In [36]:
dlogprobs = -1/n * F.one_hot(Yb, vocab_size) # 32 x 27
dprobs = dlogprobs * (probs**-1)
dcounts = dprobs * counts_sum_inv
dcounts_sum_inv = dprobs * counts
dcounts_sum = dcounts_sum_inv * -1 * (counts_sum ** -2)
dcounts += dcounts_sum * counts
dnorm_logits = dcounts * norm_logits.exp()
dlogits = dnorm_logits * 1
dlogits_maxes = dnorm_logits * -1
# dlogits = dlogits_maxes * logits[range(n), logits.argmax(dim=1, keepdim=True)].view((-1,1))

RuntimeError: The size of tensor a (32) must match the size of tensor b (1024) at non-singleton dimension 0

In [33]:
dcounts_sum.shape

torch.Size([32, 27])

torch.Size([32, 27])

In [None]:
# dlogprobs has to be the same size as loss -> 32 x 27
dlogprobs = -(1/n) * torch.zerows

In [35]:
loss

tensor(3.4843, grad_fn=<NegBackward0>)

In [33]:
Yb

tensor([ 8, 14, 15, 22,  0, 19,  9, 14,  5,  1, 20,  3,  8, 14, 12,  0, 11,  0,
        26,  9, 25,  0,  1,  1,  7, 18,  9,  3,  5,  9,  0, 18])

In [32]:
logprobs 

tensor([[-3.1837, -2.7647, -3.8949, -3.4204, -3.5088, -2.3889, -3.4442, -3.3123,
         -4.0456, -3.9576, -3.0628, -3.4465, -3.0465, -3.2209, -2.9837, -3.9360,
         -4.2538, -4.4107, -3.7310, -2.5961, -3.1221, -4.0849, -4.0183, -2.7534,
         -2.7924, -3.7079, -3.6333],
        [-3.3363, -3.0017, -2.3151, -3.2497, -2.7663, -3.9893, -3.7965, -3.6317,
         -3.8491, -3.5038, -3.2904, -3.4172, -3.1890, -2.9364, -2.6960, -2.8940,
         -3.6866, -3.9671, -4.3571, -2.9187, -3.7511, -4.6416, -4.3379, -2.6600,
         -3.7620, -3.6106, -3.4701],
        [-4.3264, -3.6370, -4.2447, -4.2964, -3.3246, -3.2044, -2.9079, -2.8080,
         -2.7094, -4.0626, -3.7935, -3.8263, -3.1316, -2.8195, -3.1986, -3.6188,
         -4.5389, -4.0627, -3.5916, -1.9032, -3.0939, -3.7700, -3.3071, -3.2362,
         -3.2656, -3.6923, -3.5042],
        [-3.6694, -3.8990, -3.2422, -3.4948, -2.5906, -4.0370, -2.7419, -3.5148,
         -3.1858, -3.9286, -3.5002, -3.6154, -3.3074, -2.8598, -2.8353, -2.2577

In [31]:
logprobs[range(n)]

tensor([[-3.1837, -2.7647, -3.8949, -3.4204, -3.5088, -2.3889, -3.4442, -3.3123,
         -4.0456, -3.9576, -3.0628, -3.4465, -3.0465, -3.2209, -2.9837, -3.9360,
         -4.2538, -4.4107, -3.7310, -2.5961, -3.1221, -4.0849, -4.0183, -2.7534,
         -2.7924, -3.7079, -3.6333],
        [-3.3363, -3.0017, -2.3151, -3.2497, -2.7663, -3.9893, -3.7965, -3.6317,
         -3.8491, -3.5038, -3.2904, -3.4172, -3.1890, -2.9364, -2.6960, -2.8940,
         -3.6866, -3.9671, -4.3571, -2.9187, -3.7511, -4.6416, -4.3379, -2.6600,
         -3.7620, -3.6106, -3.4701],
        [-4.3264, -3.6370, -4.2447, -4.2964, -3.3246, -3.2044, -2.9079, -2.8080,
         -2.7094, -4.0626, -3.7935, -3.8263, -3.1316, -2.8195, -3.1986, -3.6188,
         -4.5389, -4.0627, -3.5916, -1.9032, -3.0939, -3.7700, -3.3071, -3.2362,
         -3.2656, -3.6923, -3.5042],
        [-3.6694, -3.8990, -3.2422, -3.4948, -2.5906, -4.0370, -2.7419, -3.5148,
         -3.1858, -3.9286, -3.5002, -3.6154, -3.3074, -2.8598, -2.8353, -2.2577