In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [4]:
chars = list('abcdefghijklmnopqrstuvwxyz')
stoi = { s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = { i+1:s for i, s in enumerate(chars)}
itos[0] = '.'
print(stoi)
print(itos)

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [5]:
vocab_size = 27
block_size = 3

def build_dataset(words):
    X, Y = [], []

    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)    
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [6]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [7]:
n_embd = 10
n_hidden = 200

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)
# C has shape 27 * 10
# C[X[ix]]
# X[ix] has shape 32 * 10
 
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5 / 3) * (1 / ((n_embd * block_size)**0.5))
b1 = torch.randn(n_hidden, generator=g) * 0.1

W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

bngain = torch.ones((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.zeros((1, n_hidden))* 0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
    p.requires_grad = True
print(sum(p.nelement() for p in parameters))

12297


In [8]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size, ), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [9]:
emb = C[Xb] # 32 * 3 * 10
embcat = emb.view(emb.shape[0], -1)
# Linear layer 1
hprebn = embcat @ W1 + b1
bnmeani = (1 / n) * (hprebn.sum(0, keepdim = True))
bndiff = hprebn - bnmeani
bndiff2 = bndiff ** 2
bnvar = (1 / (n - 1)) * (bndiff2.sum(0, keepdim = True))
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
# Non-linearity
h = torch.tanh(hpreact)

# Linear layer 2
logits = h @ W2 + b2
# cross entropy loss: max log likelihood, min negative log likelyhood
logit_maxes = logits.max(1, keepdim=True).values 
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim = True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log() # 27 * 32
loss = -logprobs[range(n), Yb].mean()


# Pytorch backward pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts_sum, counts_sum_inv, counts, norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, bnmeani, hprebn, embcat, emb]:
    t.retain_grad()
loss.backward()
loss
# B

tensor(3.8279, grad_fn=<NegBackward0>)

In [10]:
# Backprob thrugh the whole thing manually

dloss = 1.0
dlogprobs = torch.zeros((logprobs.shape[0], logprobs.shape[1]))
for i in range(n):
    dlogprobs[i][Yb[i]] =  -1.0 / batch_size
# print(dlogprobs[0])
# print(logprobs.grad[0])
# cmp('dlogprobs', dlogprobs, logprobs)

# dprobs / dloss = (dprobs / dlogprops) * (dlogprops / dloss)
# dprobs has shape same a probs i.e. 32 * 27
dprobs = (1.0 / probs) * dlogprobs
# print(dprobs[0])
# print(probs.grad[0])
cmp('dprobs', dprobs, probs)

dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True) # shape: 32 * 1
cmp('dcounts_sum_inv', dcounts_sum_inv, counts_sum_inv)


dcounts_sum =  (-1.0 / (counts_sum)**2) * dcounts_sum_inv # has shape 32 x 1
cmp('dcounts_sum', dcounts_sum, counts_sum)

dcounts = counts_sum_inv * dprobs
dcounts += torch.ones_like(counts) * dcounts_sum
cmp('dcounts', dcounts, counts)

# dcounts = (counts_sum_inv + counts * (-1.0 / (counts_sum ** 2))) # * dprobs
# print(dcounts[1])
# print(counts.grad[1])
# cmp('dcounts', dcounts, counts)

dnorm_logits = norm_logits.exp() * dcounts
cmp('dnorm_logits', dnorm_logits, norm_logits)

dlogit_maxes = (-1.0 * dnorm_logits).sum(1, keepdim=True) # only for logits which were max shape: 32 * 1
cmp('dlogit_maxes', dlogit_maxes, logit_maxes)

dlogits = dnorm_logits
for i in range(logits.shape[0]):
    for j in range(logits.shape[1]):
        if logits[i][j] == logit_maxes[i]:
            dlogits[i][j] += dlogit_maxes[i][0]
cmp('dlogits', dlogits, logits)

db2 = dlogits.sum(0, keepdim=True)
cmp('db2', db2, b2)


# h has shape 32 * 200
# W2 has shape 200 * 27
# dlogits has shape 32 * 27
# print(W2.shape)
# print(dlogits.shape)
dh = torch.transpose(W2 @ torch.transpose(dlogits, 0, 1), 0, 1)
cmp('dh', dh, h)

dW2 = torch.transpose(h, 0, 1) @ dlogits
cmp('dW2', dW2, W2)


dhpreact = (1 - h**2) * dh
cmp('dhpreact', dhpreact, hpreact)

# emb = C[Xb] # 32 * 3 * 10
# embcat = emb.view(emb.shape[0], -1)
# # Linear layer 1
# hprebn = embcat @ W1 + b1
# bnmeani = (1 / n) * (hprebn.sum(0, keepdim = True))
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff ** 2
# bnvar = (1 / (n - 1)) * (bndiff2.sum(0, keepdim = True))
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias


dbnbias = dhpreact.sum(0, keepdim = True)
cmp('dbnbias', dbnbias, bnbias)

dbngain = (bnraw * dhpreact).sum(0, keepdim = True)
cmp('dbngain', dbngain, bngain)

dbnraw = bngain * dhpreact
cmp('dbnraw', dbnraw, bnraw)

dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim = True)
cmp('dbnvar_inv', dbnvar_inv, bnvar_inv)

# bnvar has shape 1 * 200
dbnvar = -0.5 * (bnvar + 1e-5)**-1.5 * dbnvar_inv
cmp('dbnvar', dbnvar, bnvar)

# bnvar = (1 / (n - 1)) * (bndiff2.sum(0, keepdim = True))
# bndiff2 has shape 32 * 200
dbndiff2 = torch.ones_like(bndiff2) * (1 / (n - 1)) * dbnvar
# print(bndiff2.shape)
# print(dbndiff2.shape)
cmp('dbndiff2', dbndiff2, bndiff2)

dbndiff = 2 * bndiff * dbndiff2
dbndiff += bnvar_inv * dbnraw
cmp('dbndiff', dbndiff, bndiff)

# bndiff has shape 32 * 200
# bnmeani has shape 1 * 200
dbnmeani = (-1.0 * dbndiff).sum(0, keepdim = True) 
cmp('dbnmeani', dbnmeani, bnmeani)

# emb = C[Xb] # 32 * 3 * 10
# embcat = emb.view(emb.shape[0], -1)
# hprebn = embcat @ W1 + b1
# bnmeani = (1 / n) * (hprebn.sum(0, keepdim = True))
# bndiff = hprebn - bnmeani

# hprebn has shape 32 * 200
dhprebn = dbndiff
# bnmeani has shape 1 * 200
dhprebn += (1 / n) * dbnmeani
cmp('dhprebn', dhprebn, hprebn)

db1 = dhprebn.sum(0, keepdim = True)
cmp('db1', db1, b1)

dembcat = dhprebn @ torch.transpose(W1, 0, 1)
cmp('dembcat', dembcat, embcat)

dW1 = torch.transpose(embcat, 0, 1) @ dhprebn
cmp('dW1', dW1, W1)

demb = dembcat.view(emb.shape[0], emb.shape[1], -1)
cmp('demb', demb, emb)

# emb = C[Xb]
# Xb has shape 32 * 3
# C has shape 27 * 10
# emb has shape 32 * 3 * 10
dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        dC[Xb[i][j]] += demb[i][j]
cmp('dC', dC, C)


dprobs          | exact: True  | approximate: True  | maxdiff: 0.0
dcounts_sum_inv | exact: True  | approximate: True  | maxdiff: 0.0
dcounts_sum     | exact: True  | approximate: True  | maxdiff: 0.0
dcounts         | exact: True  | approximate: True  | maxdiff: 0.0
dnorm_logits    | exact: True  | approximate: True  | maxdiff: 0.0
dlogit_maxes    | exact: True  | approximate: True  | maxdiff: 0.0
dlogits         | exact: True  | approximate: True  | maxdiff: 0.0
db2             | exact: True  | approximate: True  | maxdiff: 0.0
dh              | exact: True  | approximate: True  | maxdiff: 0.0
dW2             | exact: True  | approximate: True  | maxdiff: 0.0
dhpreact        | exact: True  | approximate: True  | maxdiff: 0.0
dbnbias         | exact: True  | approximate: True  | maxdiff: 0.0
dbngain         | exact: True  | approximate: True  | maxdiff: 0.0
dbnraw          | exact: True  | approximate: True  | maxdiff: 0.0
dbnvar_inv      | exact: True  | approximate: True  | maxdiff:

In [33]:
foo = torch.zeros_like(logits)
foo[range(logits.shape[0]), Yb] = 1 / logits.shape[0]

In [50]:
bar = logits.sum(1, keepdim = True) - logits + (1.0 - vocab_size) * logits.max(1, keepdim = True).values
bar *= foo
# print(bar)
A = logits.sum(1, keepdim = True)
print(A)

tensor([[-1.1675],
        [ 0.6853],
        [11.1944],
        [-4.7283],
        [ 3.7751],
        [ 2.2485],
        [-4.2482],
        [-8.2952],
        [ 3.9967],
        [ 7.8760],
        [-5.9808],
        [-0.3015],
        [ 4.5658],
        [-6.4188],
        [ 1.1241],
        [ 6.4003],
        [ 8.2901],
        [ 6.5472],
        [ 3.4046],
        [-8.2952],
        [ 9.7022],
        [-6.9753],
        [ 2.8430],
        [ 1.8309],
        [ 3.0741],
        [-6.2279],
        [ 1.8356],
        [-0.4410],
        [-0.8160],
        [-8.2952],
        [-4.2253],
        [-8.2952]], grad_fn=<SumBackward1>)


In [44]:
cmp('dlogits', bar, logits)

dlogits         | exact: False | approximate: False | maxdiff: 2.1872243881225586


In [29]:
logits.grad

tensor([[ 3.5640e-04, -2.4939e-02,  5.1969e-04,  1.0260e-03,  1.8752e-03,
          5.2618e-04,  3.0502e-04,  4.0244e-04,  1.0106e-03,  6.9478e-04,
          1.8658e-03,  2.1820e-04,  6.5081e-04,  1.6950e-04,  6.3564e-04,
          6.1022e-04,  7.7567e-04,  1.7753e-04,  6.4878e-05,  5.0351e-04,
          4.4877e-03,  1.3322e-03,  2.3320e-03,  4.2050e-04,  1.2719e-04,
          1.8821e-03,  1.9692e-03],
        [ 1.6959e-03,  1.2657e-03,  7.0448e-04,  1.1192e-03,  1.0531e-03,
          3.9933e-04,  2.5606e-04,  7.6596e-04,  2.4770e-03,  1.1053e-03,
          7.0484e-04,  1.4427e-04, -3.0693e-02,  1.7573e-03,  2.8811e-04,
          1.2277e-03,  6.5333e-04,  7.4730e-04,  2.7627e-03,  1.6487e-03,
          1.3067e-03,  5.1192e-04,  1.3522e-03,  1.7902e-03,  1.8281e-03,
          2.0353e-04,  2.9241e-03],
        [-3.0514e-02,  1.0445e-03,  2.6628e-03,  4.6614e-04,  7.1869e-04,
          8.5898e-04,  1.3593e-03,  1.1446e-03,  1.4555e-03,  2.3716e-03,
          1.9871e-03,  3.0276e-04,  4.11

In [64]:
dlogits = probs.clone()
dlogits[range(probs.shape[0]), Yb] -= 1.0
# for i in range(probs.shape[0]):
#     dlogits[i][Yb[i]] -= 1.0
dlogits /= probs.shape[0]

In [65]:
cmp('dlogits', dlogits, logits)

dlogits         | exact: False | approximate: True  | maxdiff: 5.122274160385132e-09


In [66]:
# bnmeani = (1 / n) * (hprebn.sum(0, keepdim = True))
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff ** 2
# bnvar = (1 / (n - 1)) * (bndiff2.sum(0, keepdim = True))
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv

dhprebn = (1.0 - 1.0 / h.shape[1]) * (h.shape[0] - 1) * bnraw.grad

In [67]:
cmp('dhprebn', dhprebn, hprebn)

dhprebn         | exact: False | approximate: False | maxdiff: 0.3773874044418335


In [70]:
dhprebn = ((1.0 - 1.0 / vocab_size) * bnvar_inv) * bngain * hpreact.grad 
cmp('dhprebn', dhprebn, hprebn)

dhprebn         | exact: False | approximate: False | maxdiff: 0.002149573527276516


In [75]:
m = h.shape[0]
n = h.shape[1]
dhprebn = ((bngain * bnvar_inv) / m) * (m * dhpreact - dhpreact.sum(0) - (m/(m-1)) * bnraw * (dhpreact * bnraw).sum(0))

In [76]:
cmp('dhprebn', dhprebn, hprebn)

dhprebn         | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10
