In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
words = open('D:/Download/names.txt', 'r').read().splitlines()

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)

block_size = 3

def build_dataset(words):
    X,Y = [],[]
    for w in words:
        context = [0]*block_size
        for ch in w +'.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    # print (X.shape, Y.shape)
    return X,Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])

def cmp (s, dt, t):
    ex = torch.all(dt == t.grad).item() 
    close = torch.allclose(dt, t.grad)
    max_diff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact:{str(ex)} | close:{str(close)} | max diff: {max_diff}')

In [33]:
n_embd = 10
n_hidden = 200
batch_size = 32

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator = g)
W1 = torch.randn((block_size*n_embd, n_hidden), generator = g) * (5/3)/(block_size*n_embd)**0.5
b1 = torch.randn(n_hidden, generator = g) * 0.1
W2 = torch.randn((n_hidden, vocab_size), generator = g) * 0.1
b2 = torch.randn(vocab_size, generator = g) * 0.1
bngain = torch.randn(1, n_hidden, generator = g)*0.1 + 1.0
bnbias = torch.randn(1, n_hidden, generator = g)*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
    p.requires_grad = True
print(sum(p.nelement() for p in parameters))

12297


In [34]:
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator = g)
Xb, Yb = Xtr[ix], Ytr[ix]

emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

hprebn = embcat @ W1 + b1
bnmeani = 1/n * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(n-1) * (bndiff2).sum(0, keepdim=True)
bnvar_inv = (bnvar+1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
h = torch.tanh(hpreact)

logits = h @ W2 + b2
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

for p in parameters:
    p.grad = None
for t in [logprobs, probs, counts_sum_inv, counts_sum, counts, norm_logits, logit_maxes, logits, 
          h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, bnmeani, hprebn, embcat, emb]:
    t.retain_grad()
loss.backward()
loss
# loss = 3.5571

tensor(3.9508, grad_fn=<NegBackward0>)

In [35]:
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0/n
dprobs = dlogprobs.clone() * 1.0/probs.data.detach()
dcounts_sum_inv = (dprobs * counts.data.detach()).sum(1, keepdim = True)
dcounts = dprobs * counts_sum_inv
dcounts_sum = dcounts_sum_inv * -counts_sum**-2
dcounts += dcounts_sum * torch.ones_like(counts)
dnorm_logits = dcounts * norm_logits.exp()
dlogit_maxes = -dnorm_logits.sum(1, keepdim=True)
dlogits = dnorm_logits.clone()
dextra = torch.zeros_like(logits)
dextra[range(n), logits.max(1).indices] = 1
dlogits += dlogit_maxes * dextra
# dlogits += F.one_hot(logits.max(1).indices, num_classes = logits.shape[1])*dlogit_maxes
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = dh * (1.0 - h**2)
dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
dbnbias = dhpreact.sum(0, keepdim=True)
dbnraw = dhpreact * bngain
dbndiff = dbnraw * bnvar_inv
dbnvar_inv = (dbnraw * bndiff).sum(0, keepdim= True)
dbnvar = dbnvar_inv * (-0.5 * (bnvar+1e-5)**-1.5)
dbndiff2 = dbnvar * torch.ones_like(bndiff2) * (1.0/(n-1))
dbndiff += dbndiff2 * 2 * bndiff
dbnmeani = -dbndiff.sum(0, keepdim = True)
dhprebn = dbndiff.clone()
dhprebn += dbnmeani * torch.ones_like(hprebn) / n
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
demb = dembcat.clone().view(emb.shape)
dC = torch.zeros_like(C)
for i in range (Xb.shape[0]):
    for j in range (Xb.shape[1]):
        x = Xb[i, j]
        dC[x] += demb[i,j]


cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

logprobs        | exact:True | close:True | max diff: 0.0
probs           | exact:True | close:True | max diff: 0.0
counts_sum_inv  | exact:True | close:True | max diff: 0.0
counts_sum      | exact:True | close:True | max diff: 0.0
counts          | exact:True | close:True | max diff: 0.0
norm_logits     | exact:True | close:True | max diff: 0.0
logit_maxes     | exact:True | close:True | max diff: 0.0
logits          | exact:True | close:True | max diff: 0.0
h               | exact:True | close:True | max diff: 0.0
W2              | exact:True | close:True | max diff: 0.0
b2              | exact:True | close:True | max diff: 0.0
hpreact         | exact:True | close:True | max diff: 0.0
bngain          | exact:True | close:True | max diff: 0.0
bnbias          | exact:True | close:True | max diff: 0.0
bnraw           | exact:True | close:True | max diff: 0.0
bnvar_inv       | exact:True | close:True | max diff: 0.0
bnvar           | exact:True | close:True | max diff: 0.0
bndiff2       

In [40]:
n_embd = 16
n_hidden = 200
batch_size = 64
n = batch_size
max_steps = 200001
lr = 0.1
lr_decay = 0.03

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator = g)
W1 = torch.randn((block_size*n_embd, n_hidden), generator = g) * (5/3)/(block_size*n_embd)**0.5
b1 = torch.randn(n_hidden, generator = g) * 0.1
W2 = torch.randn((n_hidden, vocab_size), generator = g) * 0.1
b2 = torch.randn(vocab_size, generator = g) * 0.1
bngain = torch.randn(1, n_hidden, generator = g)*0.1 + 1.0
bnbias = torch.randn(1, n_hidden, generator = g)*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
    p.requires_grad = True
print(sum(p.nelement() for p in parameters))

16059


In [23]:
with torch.no_grad():
    for i in range (max_steps):
        ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator = g)
        Xb, Yb = Xtr[ix], Ytr[ix]
        
        emb = C[Xb]
        embcat = emb.view(emb.shape[0], -1)
        
        hprebn = embcat @ W1 + b1
        
        bnmean = hprebn.mean(0, keepdim=True)
        bnvar = hprebn.var(0, keepdim=True)
        bnvar_inv = (bnvar+1e-5)**-0.5
        bnraw = (hprebn-bnmean) * bnvar_inv
        hpreact = bngain * bnraw + bnbias
        
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        loss = F.cross_entropy(logits, Yb)
        
        # for p in parameters:
        #     p.grad = None
        # loss.backward()
        
        dlogits = F.softmax(logits, 1)
        dlogits[range(n), Yb] -= 1
        dlogits/=n
        dh = dlogits @ W2.T
        dW2 = h.T @ dlogits
        db2 = dlogits.sum(0)
        dhpreact = dh * (1-h**2)
        dbngain = (dhpreact * bnraw).sum(0, keepdim= True)
        dbnbias = (dhpreact).sum(0, keepdim= True)
        dhprebn = bngain * bnvar_inv / n * (n*dhpreact - dhpreact.sum(0) - n/(n-1) * bnraw * (bnraw * dhpreact).sum(0))
        dembcat = dhprebn @ W1.T
        dW1 = embcat.T @ dhprebn
        db1 = dhprebn.sum(0)
        demb = dembcat.view(emb.shape)
        dC = torch.zeros_like(C)
        for j in range (emb.shape[0]):
            for k in range (emb.shape[1]):
                x = Xb[j,k]
                dC[x] += demb[j,k]
    
        grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]
    
        lr = lr if i<(max_steps//2) else lr_decay
        for p, grad in zip(parameters, grads):
            p.data += - lr * grad
            # p.data += - lr * p.grad
        if i%10000 == 0:
            print (i, loss)
        # if i==100:
        #     break

0 tensor(1.9306)
10000 tensor(1.9482)
20000 tensor(2.0499)
30000 tensor(1.8415)
40000 tensor(1.8654)
50000 tensor(2.4896)
60000 tensor(2.1975)
70000 tensor(1.9699)
80000 tensor(2.0600)
90000 tensor(1.8972)
100000 tensor(2.1812)
110000 tensor(1.7568)
120000 tensor(1.9225)
130000 tensor(2.2120)
140000 tensor(2.1268)
150000 tensor(1.9954)
160000 tensor(2.2609)
170000 tensor(1.6453)
180000 tensor(1.9065)
190000 tensor(2.0167)
200000 tensor(1.9223)


In [43]:
for p,g in zip(parameters, grads):
    cmp(str(tuple(p.shape)), g, p)

(27, 16)        | exact:False | close:True | max diff: 1.4901161193847656e-08
(48, 200)       | exact:False | close:True | max diff: 1.1175870895385742e-08
(200,)          | exact:False | close:True | max diff: 3.259629011154175e-09
(200, 27)       | exact:False | close:True | max diff: 1.4901161193847656e-08
(27,)           | exact:False | close:True | max diff: 7.450580596923828e-09
(1, 200)        | exact:False | close:True | max diff: 2.7939677238464355e-09
(1, 200)        | exact:False | close:True | max diff: 2.7939677238464355e-09


In [24]:
with torch.no_grad():
    emb = C[Xtr]
    embcat = emb.view(emb.shape[0], -1)
    hprebn = embcat @ W1 + b1
    bnmean = hprebn.mean(0, keepdim = True)
    bnvar = hprebn.var(0, keepdim = True)

In [25]:
@torch.no_grad()
def split_loss(split):
    x, y = {
        'train':(Xtr,Ytr),
        'val':(Xdev,Ydev),
        'test':(Xte,Yte)
    }[split]
    emb = C[x]
    embcat = emb.view(emb.shape[0], -1)
    hprebn = embcat @ W1 + b1
    hpreact = bngain*(hprebn - bnmean)/ (bnvar+1e-5)**0.5 + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 +b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())
split_loss('train')
split_loss('val')

train 2.019171714782715
val 2.08778715133667


In [28]:
g = torch.Generator().manual_seed(2147483647)
for _ in range(20):
    out=[]
    context = [0] * block_size
    while True:
        emb = C[torch.tensor([context])]
        embcat = emb.view(emb.shape[0], -1)
        hprebn = embcat @ W1 + b1 
        hpreact = bngain * (hprebn-bnmean)/(bnvar+1e-5)**-0.5 + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        
        probs = F.softmax (logits, 1)
        ix = torch.multinomial(probs, 1, generator = g).item()
        context = context[1:] + [ix]
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))

chris.
quan.
maximark.
fhampplivitta.
mckendram.
qua.
jamiyah.
jaxstyn.
tri.
mckiella.
jakelsey.
jakodaktusakau.
jammierrick.
sperson.
gwennalupudnson.
blaksh.
pynn.
branvik.
fenne.
fram.


In [27]:
split_loss('test')

test 2.0899453163146973
