In [1]:
import random
from random import Random

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()
random.seed(2147483647)
random.shuffle(words)
n = len(words)

n_train = int(0.8 * n)
n_dev = int(0.1 * n)

words_train = words[:n_train]
words_dev = words[n_train:n_train + n_dev]
words_test = words[n_train + n_dev:]

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}
vocab_size = len(stoi)

In [3]:
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approx: {str(app):5s} | maxdiff: {maxdiff}')

In [12]:
def build_xy(words, stoi, block_size=3):
    x, y = [], []
    for w in words[:]:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            context = context[1:] + [ix]
    X, Y = torch.tensor(x), torch.tensor(y)
    return X, Y

def data_nll(x, y, C, W1, W2, B2, bngain, bnbias, bnmean_running, bnstd_running, smoothing=0.05, eps=1e-5):
    emb = C[x]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1
    hpreact = bngain * ((hpreact - bnmean_running) / (bnstd_running ** 2 + eps) ** 0.5) + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 + B2
    return F.cross_entropy(logits, y, label_smoothing=smoothing).item()

def train_step(x, y, C, W1, W2, B2, parameters, bngain, bnbias, bnmean_running, bnstd_running, lr1=0.1, lr2=None, steps=10, batch_size=32):
    lossi = []
    with torch.no_grad():
        for i in range(steps):
            ix = torch.randint(0, x.shape[0], (batch_size,))
            xb, yb = x[ix], y[ix]
            emb = C[xb]
            embcat = emb.view(emb.shape[0], -1)
            hpreact = embcat @ W1
            bnmeani = hpreact.mean(0, keepdim=True)
            bnstdi = hpreact.std(0, keepdim=True)
            bnraw = (hpreact - bnmeani) / bnstdi
            hpreact = bngain * bnraw + bnbias
            bnvar_inv = 1.0 / (bnstdi ** 2 + 1e-5)

            with torch.no_grad():
                bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
                bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi

            h = torch.tanh(hpreact)
            logits = h @ W2 + B2
            loss = F.cross_entropy(logits, yb)
            for p in parameters:
                p.grad = None
            grads = backward_pass(xb, yb, C, emb, embcat, h, W1, W2, batch_size, logits, bnraw, bngain, bnvar_inv)
            for p, g in zip(parameters, grads):
                p.grad = g
            lr = lr1 if lr2 is None or i > 10000 else lr2
            with torch.no_grad():
                for p in parameters:
                    p.data += -lr * p.grad
            lossi.append(loss.item())
            if i % 1000 == 0:
                print(f"Step {i}, Loss: {loss.item()}")
        return lossi, bnmean_running, bnstd_running

def generate(C, W1, W2, B2, bngain, bnbias, bnmean_running, bnstd_running, itos, block_size=3, seed=2147483647, num_samples=5, eps=1e-5):
    g = torch.Generator().manual_seed(seed)
    for _ in range(num_samples):
        out = []
        context = [0] * block_size
        while True:
            ctx = torch.tensor(context)
            emb = C[ctx]
            embcat = emb.view(1, -1)
            hpreact = embcat @ W1
            hpreact = bngain * ((hpreact - bnmean_running) / (bnstd_running ** 2 + eps) ** 0.5) + bnbias
            h = torch.tanh(hpreact)
            logits = h @ W2 + B2
            p = logits.squeeze(0).softmax(dim=0)
            ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
            context = context[1:] + [ix]
            if ix == 0:
                break
            else:
                out.append(ix)
        print(''.join(itos[i] for i in out))

def backward_pass(Xb, Yb, C, emb, embcat, h, W1, W2, n, logits, bnraw, bngain, bnvar_inv):

    dlogits = F.softmax(logits, 1)
    dlogits[range(n), Yb] -= 1
    dlogits /= n

    dh = dlogits @ W2.T
    dW2 = h.T @ dlogits
    dB2 = dlogits.sum(0)

    dhpreact = (1.0 - h ** 2) * dh

    dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
    dbnbias = dhpreact.sum(0, keepdim=True)

    dhprebn = bngain * bnvar_inv / n * (n * dhpreact - dhpreact.sum(0) - n / (n - 1) * bnraw * (dhpreact * bnraw).sum(0))

    dembcat = dhprebn @ W1.T
    dW1 = embcat.T @ dhprebn
    demb = dembcat.view(emb.shape)
    dC = torch.zeros_like(C)
    for i in range(Xb.shape[0]):
        for j in range(Xb.shape[1]):
            ix = Xb[i, j]
            dC[ix] += demb[i, j]

    grads = [dC, dW1, dW2, dB2, dbngain, dbnbias]

    return grads


In [5]:
Xtr, Ytr = build_xy(words_train, stoi)
Xdev, Ydev = build_xy(words_dev, stoi)
Xte, Yte = build_xy(words_test, stoi)

In [6]:
n_emb = 10
n_hidden = 64
block_size = 3

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_emb),                generator=g)
W1 = torch.randn((n_emb * block_size, n_hidden),    generator=g) * (5 / 3) / (n_emb * block_size) ** 0.5
#B1 = torch.randn(n_hidden,                          generator=g) * 0.1
W2 = torch.randn((n_hidden, vocab_size),            generator=g) * 0.1
B2 = torch.randn(vocab_size,                        generator=g) * 0.1

bngain = torch.ones((1, n_hidden)) * 0.1 + 1
bnbias = torch.zeros((1, n_hidden)) * 0.1
bnmean_running = torch.zeros((1, n_hidden))
bnstd_running = torch.ones((1, n_hidden))

parameters = [C, W1, W2, B2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

4073


In [7]:
batch_size = 128
n = batch_size

ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]

In [8]:
lossi, bnmean_running, bnstd_running = train_step(Xtr, Ytr, C, W1, W2, B2, parameters, bngain, bnbias, bnmean_running, bnstd_running, lr1=0.1, lr2=0.01, steps=100000, batch_size=n)

Step 0, Loss: 3.4886906147003174
Step 1000, Loss: 2.693060874938965
Step 2000, Loss: 2.586390256881714
Step 3000, Loss: 2.5295629501342773
Step 4000, Loss: 2.6111199855804443
Step 5000, Loss: 2.3618762493133545
Step 6000, Loss: 2.602396249771118
Step 7000, Loss: 2.5475893020629883
Step 8000, Loss: 2.4084579944610596
Step 9000, Loss: 2.4761176109313965
Step 10000, Loss: 2.5534629821777344
Step 11000, Loss: 2.528238296508789
Step 12000, Loss: 2.4177894592285156
Step 13000, Loss: 2.396249532699585
Step 14000, Loss: 2.351808786392212
Step 15000, Loss: 2.330160617828369
Step 16000, Loss: 2.285564422607422
Step 17000, Loss: 2.1952757835388184
Step 18000, Loss: 2.2243423461914062
Step 19000, Loss: 2.191953182220459
Step 20000, Loss: 2.2538888454437256
Step 21000, Loss: 2.200437068939209
Step 22000, Loss: 2.2638981342315674
Step 23000, Loss: 2.1789591312408447
Step 24000, Loss: 2.121102809906006
Step 25000, Loss: 2.210395097732544
Step 26000, Loss: 2.1049230098724365
Step 27000, Loss: 2.317942

In [9]:
print(data_nll(Xtr, Ytr, C, W1, W2, B2, bngain, bnbias, bnmean_running, bnstd_running))
print(data_nll(Xdev, Ydev, C, W1, W2, B2, bngain, bnbias, bnmean_running, bnstd_running))

2.2942135334014893
2.3270814418792725


In [13]:
generate(C, W1, W2, B2, bngain, bnbias, bnmean_running, bnstd_running, itos)
# Not great samples, but at least not complete gibberish

cexbe
moullurailah
tymarionichana
noluwan
kata


In [None]:
"""
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

#Linear Layer 1
hprebn = embcat @ W1

#BatchNorm Layer
bnmeani = 1 / n * hprebn.sum(0, keepdim=True)
bndiff1 = hprebn - bnmeani
bndiff2 = bndiff1 ** 2
bnvar = 1 / (n - 1) * (bndiff2).sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5) ** -0.5
bnraw = bndiff1 * bnvar_inv
hpreact = bngain * bnraw + bnbias
h = torch.tanh(hpreact)

#Linear Layer 2
logits = h @ W2 + B2

#Cross Entropy Loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

#Backward Pass
for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff1, hprebn, bnmeani, embcat, emb]:
    t.retain_grad()
loss.backward()
loss.item()
"""

In [None]:
"""
counts.shape, probs.shape, counts_sum_inv.shape
"""

In [None]:
"""
# Manually compute all the gradients
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n
dprobs = dlogprobs / probs
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts_sum = (-counts_sum ** -2) * dcounts_sum_inv
dcounts = counts_sum_inv * dprobs + torch.ones_like(counts) * dcounts_sum
dnorm_logits = counts * dcounts
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits = dnorm_logits.clone() + F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
dB2 = dlogits.sum(0)
dhpreact = (1.0 - h ** 2) * dh
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnbias = dhpreact.sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnvar_inv = (bndiff1 * dbnraw).sum(0, keepdim=True)
dbnvar = -0.5 * (bnvar + 1e-5) ** -1.5 * dbnvar_inv
dbndiff2 = 1 / (n - 1) * torch.ones_like(bndiff2) * dbnvar
dbndiff1 = bnvar_inv * dbnraw + 2 * bndiff1 * dbndiff2
dbmeani = (-dbndiff1).sum(0, keepdim=True)
dhprebn = dbndiff1.clone() + 1.0 / n * (torch.ones_like(hprebn) * dbmeani)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[i, j]
        dC[ix] += demb[i, j]

cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('dlogit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('B2', dB2, B2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff1', dbndiff1, bndiff1)
cmp('bnmeani', dbmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('emb', demb, emb)
cmp('C', dC, C)
"""

In [None]:
"""
# More effecient way to compute dlogits manually
dlogits = F.softmax(logits, 1)
dlogits[range(n), Yb] -= 1
dlogits /= n
cmp('logits', dlogits, logits)
"""

In [None]:
"""
# More efficient way to compute dhprebn manually
dhprebn = bngain * bnvar_inv / n * (n * dhpreact - dhpreact.sum(0) - n / (n - 1) * bnraw * (dhpreact * bnraw).sum(0))
cmp('hprebn', dhprebn, hprebn)
"""