In [39]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import time
import random
%matplotlib inline


In [40]:
# build the vocabulary of the given dataset.
words = open('names.txt', 'r').read().splitlines()
allWords = sorted(list(set(''.join(words))))
itos = {idx + 1 : alps for idx, alps in enumerate(allWords)}
itos[0] = '.'
stoi = { idx: alps for alps, idx in itos.items()}


In [41]:
# Building a dataset 

block_size = 3  # how many character we want to take in context to predict the next character. 

def build_dataset(words):
    X, Y = [], []
    for word in words:
        context = [0] * block_size
        for alpha in word + '.':
            ix = stoi[alpha]
            X.append(context)
            context = context[1:] + [ix]
            Y.append(ix)

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    
    return X, Y

random.seed(1599)
random.shuffle(words)

n1 = int(len(words) * 0.8)
n2 = int(len(words) * 0.9)

Xtr, Ytr = build_dataset(words[:n1]) # 80% of the data
Xdev, Ydev = build_dataset(words[n1:n2])  # 10% of the data
Xtes, Ytes = build_dataset(words[n2:])  # 10% of the data


In [42]:
# Utility function.
def cmp(s, df, t):
    exact = torch.all(df == t.grad).item()
    apro = torch.allclose(df, t.grad)
    diff = (df - t.grad).abs().max().item()
    print(f"{s:15} ||  exact:{exact} || approximation: {apro} || maxDiff: {diff}")


In [43]:
n_emb = 10 # the dimensionality of the character embedding.
n_hidden = 200 # No of nurons in the hidden layer.
vocab_size = len(itos) # vacubulary size: total number of unique character

gen = torch.Generator().manual_seed(2147483647) # generator of a seed value for reproducibility.
C = torch.randn((vocab_size, n_emb),             generator=gen) # Fra
w1 = torch.randn((block_size * n_emb, n_hidden), generator=gen) * (5/3)/((n_emb * block_size) ** 0.5)
b1 = torch.randn(n_hidden,                       generator=gen) * 0.1
w2 = torch.randn((n_hidden, vocab_size),         generator=gen) * 0.1
b2 = torch.randn(vocab_size,                     generator=gen) * 0.1

bn_gain = torch.ones((1, n_hidden))
bn_bias = torch.zeros((1, n_hidden))

parameters = [C,b1, w1, b2, w2, bn_gain, bn_bias]  # all the parameters in the mode.
sum(p.nelement() for p in parameters)
for par in parameters:
    par.requires_grad = True


In [44]:
batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size, ), generator=gen)
Xb, Yb = Xtr[ix], Ytr[ix]

In [45]:
emb = C[Xb]
embTrans = emb.view(emb.shape[0], -1)
hpredn = embTrans @ w1  + b1

# batch normalization.
bnmeani = 1/n * hpredn.sum(0, keepdim=True)
bndiff = hpredn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1) * bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff * bnvar_inv
hpred = bn_gain * bnraw + bn_bias 

# Non Linearity.
h = torch.tanh(hpred)

# passing it through the second layer.
logits = h @ w2 + b2

# cross entropy loss same as (F.cross_entropy(logits, Yb))
logits_max = logits.max(1, keepdim=True).values

# subtracts from the max value for numerical stability
norm_logits = logits - logits_max
counts = norm_logits.exp()
count_sum = counts.sum(1, keepdim=True)
count_sum_inv = count_sum ** -1
probs =  counts * count_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# backward pass
# making sure all the grad are None.
for param in parameters:
    param.grad = None

for t in [logprobs, probs, counts, count_sum, count_sum_inv,
         norm_logits, logits_max, logits, h, hpred, bnraw, bnvar_inv,
         bnvar, bndiff2, bndiff, hpredn, bnmeani, embTrans, emb]:
    t.retain_grad()
    
loss.backward()
loss


tensor(3.7877, grad_fn=<NegBackward0>)

In [47]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = (1/probs) * dlogprobs
dcount_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = count_sum_inv * dprobs
dcount_sum = (-count_sum**-2) * dcount_sum_inv
dcounts += torch.ones_like(counts) * dcount_sum
dnorm_logits = counts * dcounts
dlogits = dnorm_logits.clone()
dlogits_max = (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogits_max
dh = dlogits @ w2.T
dw2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpred = (1.0 - h ** 2) * dh
dbn_gain =  (bnraw * dhpred).sum(0, keepdim=True)
dbnraw = bn_gain * dhpred
dbn_bias = dhpred.sum(0, keepdim=True)
dbndiff = bnvar_inv * dbnraw
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar =  (-0.5 * (bnvar + 1e-5) ** -1.5) * dbnvar_inv
dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar
dbndiff += (2 * bndiff) * dbndiff2
dhpredn = dbndiff.clone()
dbnmeani = (-dbndiff).sum(0)
dhpredn += 1.0/n * (torch.ones_like(dhpredn) * dbnmeani)
dembTrans = dhpredn @ w1.T
dw1 = embTrans.T @ dhpredn
db1 = dhpredn.sum(0)
demb = dembTrans.view(emb.shape)

dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k, j]
                   


cmp("logprobs", dlogprobs, logprobs)
cmp("dprobs", dprobs, probs)
cmp("dcount_sum_inv", dcount_sum_inv, count_sum_inv)
cmp("dcount_sum", dcount_sum, count_sum)
cmp("dcounts", dcounts, counts)
cmp("dnorm_logits", dnorm_logits, norm_logits)
cmp("dlogits_max", dlogits_max, logits_max)
cmp("dlogits", dlogits, logits)
cmp("dh", dh, h)
cmp("dw2", dw2, w2)
cmp("db2", db2, b2)
cmp("dhpred", dhpred, hpred)
cmp("dbn_gain", dbn_gain, bn_gain)
cmp("dbnraw", dbnraw, bnraw)
cmp("dbn_bias", dbn_bias, bn_bias)
cmp("dbnvar_inv", dbnvar_inv, bnvar_inv)
cmp("dbnvar", dbnvar, bnvar)
cmp("dbndiff2", dbndiff2, bndiff2)
cmp("dbndiff", dbndiff, bndiff)
cmp("dhpredn", dhpredn, hpredn)
cmp("dbnmeani", dbnmeani, bnmeani)
cmp("dembTrans", dembTrans, embTrans)
cmp("dw1", dw1, w1)
cmp("dw1", db1, b1)
cmp("demb", demb, emb)
cmp("C", dC, C)



logprobs        ||  exact:True || approximation: True || maxDiff: 0.0
dprobs          ||  exact:True || approximation: True || maxDiff: 0.0
dcount_sum_inv  ||  exact:True || approximation: True || maxDiff: 0.0
dcount_sum      ||  exact:True || approximation: True || maxDiff: 0.0
dcounts         ||  exact:True || approximation: True || maxDiff: 0.0
dnorm_logits    ||  exact:True || approximation: True || maxDiff: 0.0
dlogits_max     ||  exact:True || approximation: True || maxDiff: 0.0
dlogits         ||  exact:True || approximation: True || maxDiff: 0.0
dh              ||  exact:True || approximation: True || maxDiff: 0.0
dw2             ||  exact:True || approximation: True || maxDiff: 0.0
db2             ||  exact:True || approximation: True || maxDiff: 0.0
dhpred          ||  exact:True || approximation: True || maxDiff: 0.0
dbn_gain        ||  exact:True || approximation: True || maxDiff: 0.0
dbnraw          ||  exact:True || approximation: True || maxDiff: 0.0
dbn_bias        ||  