In [1]:
import torch
import math
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))
print(words[:5])

32033
15
['emma', 'olivia', 'ava', 'isabella', 'sophia']


In [3]:
chars = sorted(list(set(''.join(words))))
stoi = {v:k+1 for k,v in enumerate(chars)}
stoi['.'] = 0
itos = {v:k for k,v in stoi.items()}
vocab_size = len(itos)
print(vocab_size)

27


In [4]:
# build dataset

def build_dataset(block_size, words):
    X, Y = [], []
    for w in words:
        context = block_size * [0]
        for c in w + '.':
            ix = stoi[c]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

block_size = 3
random.seed(42)
random.shuffle(words)
train_len = int(len(words)*0.8)
val_len = int(len(words)*0.9)

X_tr, Y_tr = build_dataset(block_size, words[:train_len])
X_d, Y_d = build_dataset(block_size, words[train_len:val_len])
X_t, Y_t = build_dataset(block_size, words[val_len:])

In [5]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [6]:
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)
# layer 1 
W1 = torch.randn((block_size*n_embd, n_hidden), generator=g) * (5/3) / ((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN

# layer 2
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1

# batchnorm params
bngain = torch.randn((1, n_hidden)) *0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) *0.1 

# Note: I am initializating many of these parameters in non-standard ways
# because sometimes initializating with e.g. all zeros could mask an incorrect
# implementation of the backward pass.

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True


4137


In [7]:
torch.randint(0, X_tr.shape[0], (32,), generator=g)

tensor([120000, 157809,  82137,  69514,  73004,  68734,    286, 123947,  13538,
         42674, 165010,  81021,  59151,  46471,  62456,  64636,  24418, 108817,
        169833, 145683, 168275, 157689,  36258, 142280,  32537, 149713, 149734,
        149517, 165139, 153533,  89661,  20039])

In [8]:
batch_size = 32
n = batch_size
#construct mini batch
ix = torch.randint(0,X_tr.shape[0], (batch_size, ),generator=g)
Xb, Yb = X_tr[ix], Y_tr[ix]  

In [9]:
# forward pass, "chunkated" into smaller steps that are possible to backward one at a time

emb = C[Xb]   # embedding
embcat = emb.view(emb.shape[0], -1)

# linear layer 1
hprebn = embcat @ W1 + b1   # hidden layer pre-activation

# batchnorm layer
bnmeani = hprebn.sum(0, keepdim = True)/n
bndiff = hprebn - bnmeani 
bndiff2 = bndiff ** 2
bnvar = bndiff2.sum(0, keepdim = True) / (n-1)
bnvar_inv = (bnvar+1e-5) ** -0.5    # calculate the inverse and eps = 1e-5
bnraw = bndiff * bnvar_inv  # normalize

hpreact =  bngain * bnraw + bnbias  # scale and shift

# non linearity
h = torch.tanh(hpreact)  

# linear layer 2
logits = h @ W2 + b2

# cross entropy loss
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes # subtract max for numerical stability 
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim = True)
counts_sum_inv = counts_sum ** -1 #if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact..

probs = counts * counts_sum_inv
logprobs = probs.log()
loss = - logprobs[range(n), Yb].mean() # index: plucks out the logprobs of the correct next char


# pytorch backward pass
for p in parameters:
    p.grad = None

for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, 
          norm_logits, logit_maxes, logits, h, hpreact, bnraw,
         bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,
         embcat, emb]:
    t.retain_grad()

loss.backward()
loss


tensor(3.4731, grad_fn=<NegBackward0>)

In [10]:
print(logprobs.shape)

torch.Size([32, 27])


In [20]:
counts.shape,counts_sum_inv.shape

(torch.Size([32, 27]), torch.Size([32, 1]))

In [35]:
logits.shape , logit_maxes.shape

(torch.Size([32, 27]), torch.Size([32, 1]))

In [40]:
logits.shape, h.shape, W2.shape,  b2.shape

(torch.Size([32, 27]),
 torch.Size([32, 64]),
 torch.Size([64, 27]),
 torch.Size([27]))

In [47]:
bngain.shape, bnraw.shape, bnbias.shape

(torch.Size([1, 64]), torch.Size([32, 64]), torch.Size([1, 64]))

In [49]:
bndiff.shape , bnvar_inv.shape

(torch.Size([32, 64]), torch.Size([1, 64]))

In [56]:
bndiff2.sum(0, keepdim = True).shape

torch.Size([1, 64])

In [38]:
hprebn.shape, bnmeani.shape

(torch.Size([32, 64]), torch.Size([1, 64]))

In [None]:
hprebn.shape,  embcat.shape, W1.shape,  b1.shape 
# hprebn = embcat @ W1 + b1 


(torch.Size([32, 64]),
 torch.Size([32, 30]),
 torch.Size([30, 64]),
 torch.Size([64]))

In [53]:
embcat.shape, emb.shape

(torch.Size([32, 30]), torch.Size([32, 3, 10]))

In [54]:
emb.shape, C.shape, Xb.shape

(torch.Size([32, 3, 10]), torch.Size([27, 10]), torch.Size([32, 3]))

Exercise 1: backprop through the whole thing manually, backpropagating through exactly all of the variables as they are defined in the forward pass above, one by one

In [11]:
# ---------------------------- logprobs.grad or dlogprobs ------------------------------------------------

# calculating logprobs.grad manually i.e. dlogprobs
# dlogprobs = derivative of loss w.r.t all the elements of logprobs
# i.e. taking derivate of: loss = - logprobs[range(n), Yb].mean()
# sample eg: loss = -(a+b+c)/3 
# so if derivate is taken w.r.t each element in loss it would be dloss/da = -1/3 i.e. generalized to -1/n

#so the derivative for each element indexed by Yb is -1/n and is stored at that index 

dlogprobs = torch.zeros_like(logprobs)  # create an array of zeros in the shape of logprobs
dlogprobs[range(n), Yb] = -1.0 / n
cmp('logprobs', dlogprobs, logprobs)

#'---------------------------- probs.grad or dprobs ------------------------------------------------------'
# logprobs = probs.log()
# loss = - logprobs[range(n), Yb].mean() 

#chain rule: dloss/dprobs = dloss/dlogprobs * dlogprobs/dprobs where dloss/dlogprobs is already calc above as
# dlogprobs and dloss/dprobs would be 1/probs as d/dxlog(x)=1/x
dprobs = dlogprobs * (1.0/probs)
cmp('probs', dprobs, probs)

#'---------------------------- counts_sum_inv.grad or dcounts_sum_inv -----------------------------------'
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = - logprobs[range(n), Yb].mean() 

# chain rule : dloss/dcounts_sum_inv =  dloss/dprobs * dprobs/dcounts_sum_inv
# counts.shape and counts_sum_inv.shape are different so in the multiplication here:probs = counts * counts_sum_inv
# there is an implicit broadcasting that pytorch will do so backprop needs to consider that and there would
# be a sum so dprobs/dcounts_sum_inv would be 
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
dcounts = counts_sum_inv * dprobs  # dloss/dcounts = dloss/dprobs * dprobs/dcounts
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)

#'---------------------------- counts_sum.grad or dcounts_sum -----------------------------------'
# counts_sum_inv = counts_sum ** -1 

# chain rule : dloss/dcounts_sum =  dloss/dcounts_sum_inv * dcounts_sum_inv/dcounts_sum
dcounts_sum = dcounts_sum_inv * (- counts_sum ** -2) 
cmp('counts_sum', dcounts_sum, counts_sum)

#'---------------------------- counts.grad or dcounts -----------------------------------'
# counts_sum = counts.sum(1, keepdim = True)
# counts_sum_inv = counts_sum ** -1 

# chain rule : dloss/dcounts =  dloss/dcounts_sum * dcounts_sum/dcounts and add to the prev value calc above
dcounts += dcounts_sum * torch.ones((n,1))  # or torch.ones_like(counts)
cmp('counts', dcounts, counts)

#'---------------------------- norm_logits.grad or dnorm_logits -----------------------------------'
# counts = norm_logits.exp() 

# chain rule : dloss/dnorm_logits =  dloss/dcounts * dcounts/dnorm_logits
dnorm_logits = dcounts * norm_logits.exp()  # or dcounts * counts as counts = norm_logits.exp()
cmp('norm_logits', dnorm_logits, norm_logits)

#'---------------------------- logit_maxes.grad or dlogit_maxes -----------------------------------'
# norm_logits = logits - logit_maxes  # check shape to know if there is an implicit broadcasting
# counts = norm_logits.exp()

# chain rule : 
# dloss/dlogit_maxes =  dloss/dnorm_logits * dnorm_logits/dlogit_maxes
# dloss/dlogits =  dloss/dnorm_logits * dnorm_logits/dlogits

dlogit_maxes = (dnorm_logits * (-1)).sum(1, keepdim=True)
dlogits = dnorm_logits.clone()    # dnorm_logits * 1
cmp('logit_maxes', dlogit_maxes, logit_maxes)

#'---------------------------- logits.grad or dlogits -----------------------------------'
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes

# chain rule : dloss/dlogits =  dloss/dlogit_maxes * dlogit_maxes/dlogits
dlogits += dlogit_maxes * F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) # use one hot or
# create using torch.zeros and put 1 as the derivative would be 1 for the max number at the indices returned by max func
cmp('logits', dlogits, logits)

#'---------------------------- logit_maxes.grad or dlogit_maxes -----------------------------------'
# logits = h @ W2 + b2

# chain rule : 
# dloss/dh =  dloss/dlogits * dlogits/dh
# dloss/dW2 =  dloss/dlogits * dlogits/dW2
# dloss/db2 =  dloss/dlogits * dlogits/db2

dh = dlogits @ W2.T 
dW2 = h.T @ dlogits
db2 = (dlogits * 1).sum(0)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

#'---------------------------- hpreact.grad or dhpreact -----------------------------------'
# h = torch.tanh(hpreact) 
# chain rule : dloss/dhpreact = dloss/dh * dh/dhpreact

dhpreact = dh * (1.0-h**2)
cmp('hpreact', dhpreact, hpreact)

#'---------------------------- dbnbias, dbnraw, dbngain -----------------------------------'
# hpreact =  bngain * bnraw + bnbias  # scale and shift
# chain rule : 
# dloss/dbngain = dloss/dhpreact * dhpreact/dbngain
# dloss/dbnraw = dloss/dhpreact * dhpreact/dbnraw
# dloss/dbnbias = dloss/dhpreact * dhpreact/dbnbias

dbngain = (dhpreact * bnraw).sum(0, keepdim=True)  # because of broadcasting
dbnraw =  bngain * dhpreact
dbnbias = (dhpreact).sum(0, keepdim=True) # because of broadcasting
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)

# ----------------------------- dbndiff, dbnvar_inv -----------------------------------
# bnraw = bndiff * bnvar_inv  

# chain rule : 
# dloss/dbndiff = dloss/dbnraw * dbnraw/dbndiff
# dloss/bnvar_inv = dloss/dbnraw * dbnraw/dbnvar_inv

dbndiff = bnvar_inv * dbnraw   
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim = True) # because of broadcasting
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)

# ---------------------------- dbnvar ------------------------------------
# bnvar_inv = (bnvar+1e-5) ** -0.5
# chain rule : dloss/dbnvar =  dloss/dbnvar_inv * dbnvar_inv/dbnvar
dbnvar = (-0.5 * (bnvar+1e-5) ** -1.5) * dbnvar_inv
cmp('bnvar', dbnvar, bnvar)

# ---------------------------- dbndiff2 ------------------------------------
# bnvar = bndiff2.sum(0, keepdim = True) / (n-1)
# chain rule : dloss/dbndiff2 = dloss/dbnvar * dbnvar/dbndiff2
# dbndiff2 = (1.0/(n-1)) * torch.ones_like(bndiff2) * dbnvar # why is this not working?
dbndiff2 = dbnvar / (n-1)
cmp('bndiff2', dbndiff2, bndiff2)

# ---------------------------- bndiff ------------------------------------
# bndiff2 = bndiff ** 2
# chain rule : dloss/dbndiff = dloss/dbndiff2 * dbndiff2/dbndiff
dbndiff += dbndiff2 *  (2 * bndiff)
cmp('bndiff', dbndiff, bndiff)

# ---------------------------- dhprebn, dbnmeani ------------------------------------
# bndiff = hprebn - bnmeani 
# chain rule : 
# dloss/dhprebn = dloss/dbndiff * dbndiff/dhprebn
# dloss/dbnmeani = dloss/dbndiff * dbndiff/dbnmeani
dhprebn = dbndiff.clone()
dbnmeani = (- dbndiff).sum(0)

cmp('bnmeani', dbnmeani, bnmeani)

# ---------------------------- dhprebn ------------------------------------
# bnmeani = hprebn.sum(0, keepdim = True)/n
# chain rule : dloss/dhprebn = dloss/dbnmeani * dbnmeani/dhprebn

dhprebn +=  dbnmeani / n  #(will broadcast)   # or 1.0/n * (torch.ones_like(hprebn) * dbnmeani)
cmp('hprebn', dhprebn, hprebn)

# ---------------------------- dembcat, dW1, db1 ------------------------------------
# hprebn = embcat @ W1 + b1 
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)

# ---------------------------- demb, dC  ------------------------------------------------
# emb = C[Xb]   # embedding
# embcat = emb.view(emb.shape[0], -1)
demb = dembcat.view(emb.shape)  # back to orginal shape, just rerepresent the derivatives into original view
cmp('emb', demb, emb)

# C[Xb] is plucking from the embedding table C for values of Xb
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]

cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff:

Exercise 2: backprop through cross_entropy but all in one go to complete this challenge look at the mathematical expression of the loss, take the derivative, simplify the expression, and just write it out



In [12]:
# forward pass

# before:
# logit_maxes = logits.max(1, keepdim=True).values
# norm_logits = logits - logit_maxes # subtract max for numerical stability
# counts = norm_logits.exp()
# counts_sum = counts.sum(1, keepdims=True)
# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...
# probs = counts * counts_sum_inv
# logprobs = probs.log()
# loss = -logprobs[range(n), Yb].mean()

# now:
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())


# backward pass

# -----------------
#dloss / dlogits = -1 + logits.exp()/logits.exp().sum() if  i=y else logits.exp()/logits().exp().sum() i.e.softmax

dlogits = F.softmax(logits, 1)
dlogits[range(n),Yb] -= 1
dlogits /= n
# -----------------

cmp('logits', dlogits, logits) 

3.4731383323669434 diff: -2.384185791015625e-07
logits          | exact: False | approximate: True  | maxdiff: 7.2177499532699585e-09


Exercise 3: backprop through batchnorm but all in one go to complete this challenge look at the mathematical expression of the output of batchnorm, take the derivative w.r.t. its input, simplify the expression, and just write it out
BatchNorm paper: https://arxiv.org/abs/1502.03167

In [13]:
# forward pass

# before:
# bnmeani = 1/n*hprebn.sum(0, keepdim=True)
# bndiff = hprebn - bnmeani
# bndiff2 = bndiff**2
# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)
# bnvar_inv = (bnvar + 1e-5)**-0.5
# bnraw = bndiff * bnvar_inv
# hpreact = bngain * bnraw + bnbias

# now:
hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias
print('max diff:', (hpreact_fast - hpreact).abs().max())

# backward pass

# before we had:
# dbnraw = bngain * dhpreact
# dbndiff = bnvar_inv * dbnraw
# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv
# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar
# dbndiff += (2*bndiff) * dbndiff2
# dhprebn = dbndiff.clone()
# dbnmeani = (-dbndiff).sum(0)
# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)

# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)
# (you'll also need to use some of the variables from the forward pass up above)

# -----------------
dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
# -----------------

cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10

max diff: tensor(4.7684e-07, grad_fn=<MaxBackward1>)
hprebn          | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10


Exercise 4: putting it all together!

In [22]:

# Train the MLP neural net with your own backward pass

# init
n_embd = 10 # the dimensionality of the character embedding vectors
n_hidden = 200 # the number of neurons in the hidden layer of the MLP

g = torch.Generator().manual_seed(2147483647) # for reproducibility
C  = torch.randn((vocab_size, n_embd),            generator=g)
# Layer 1
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1
# Layer 2
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
# BatchNorm parameters
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

# same optimization as last time
max_steps = 200000
batch_size = 32
n = batch_size # convenience
lossi = []

# use this context manager for efficiency once your backward pass is written (TODO)
with torch.no_grad():
  # kick off optimization
  for i in range(max_steps):

    # minibatch construct
    ix = torch.randint(0, X_tr.shape[0], (batch_size,), generator=g)
    Xb, Yb = X_tr[ix], Y_tr[ix] # batch X,Y

    # forward pass
    emb = C[Xb] # embed the characters into vectors
    embcat = emb.view(emb.shape[0], -1) # concatenate the vectors
    # Linear layer
    hprebn = embcat @ W1 + b1 # hidden layer pre-activation
    # BatchNorm layer
    # -------------------------------------------------------------
    bnmean = hprebn.mean(0, keepdim=True)
    bnvar = hprebn.var(0, keepdim=True, unbiased=True)
    bnvar_inv = (bnvar + 1e-5)**-0.5
    bnraw = (hprebn - bnmean) * bnvar_inv
    hpreact = bngain * bnraw + bnbias
    # -------------------------------------------------------------
    # Non-linearity
    h = torch.tanh(hpreact) # hidden layer
    logits = h @ W2 + b2 # output layer
    loss = F.cross_entropy(logits, Yb) # loss function

    # backward pass
    for p in parameters:
      p.grad = None
    # loss.backward() # use this for correctness comparisons, delete it later!

    # manual backprop! 
    # -----------------
    # cross entropy
    dlogits = F.softmax(logits, 1)
    dlogits[range(n),Yb] -= 1
    dlogits /= n
    # 2nd layer backprop
    dh = dlogits @ W2.T
    dW2 = h.T @ dlogits
    db2 = (dlogits * 1).sum(0)
    # tanh
    dhpreact = dh * (1.0-h**2)
    # batchnorm 
    dbngain = (dhpreact * bnraw).sum(0, keepdim=True)  
    dbnbias = (dhpreact).sum(0, keepdim=True) 
    dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
    # 1st layer
    dembcat = dhprebn @ W1.T
    dW1 = embcat.T @ dhprebn
    db1 = dhprebn.sum(0)
    # embedding
    demb = dembcat.view(emb.shape)  
    dC = torch.zeros_like(C)
    for k in range(Xb.shape[0]):
        for j in range(Xb.shape[1]):
            ix = Xb[k,j]
            dC[ix] += demb[k,j]
    
    grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]
    # -----------------

    # update
    lr = 0.1 if i < 100000 else 0.01 # step learning rate decay
    for p, grad in zip(parameters, grads):
      # p.data += -lr * p.grad # old way (using PyTorch grad from .backward())
      p.data += -lr * grad # new way (manual)

    # track stats
    if i % 10000 == 0: # print every once in a while
      print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())

    # if i >= 100: # TODO: delete early breaking when you're ready to train the full net
    #   break

12297
      0/ 200000: 3.7484
  10000/ 200000: 2.2119
  20000/ 200000: 2.4039
  30000/ 200000: 2.5108
  40000/ 200000: 1.9725
  50000/ 200000: 2.3981
  60000/ 200000: 2.3697
  70000/ 200000: 2.0765
  80000/ 200000: 2.3480
  90000/ 200000: 2.1734
 100000/ 200000: 1.9552
 110000/ 200000: 2.2345
 120000/ 200000: 1.9548
 130000/ 200000: 2.4439
 140000/ 200000: 2.2514
 150000/ 200000: 2.0761
 160000/ 200000: 1.9728
 170000/ 200000: 1.8697
 180000/ 200000: 2.0100
 190000/ 200000: 1.9482


In [None]:
# useful for checking your gradients
# for p,g in zip(parameters, grads):
#   cmp(str(tuple(p.shape)), g, p)

(27, 10)        | exact: False | approximate: True  | maxdiff: 1.862645149230957e-08
(30, 200)       | exact: False | approximate: True  | maxdiff: 5.587935447692871e-09
(200,)          | exact: False | approximate: True  | maxdiff: 4.190951585769653e-09
(200, 27)       | exact: False | approximate: True  | maxdiff: 1.1175870895385742e-08
(27,)           | exact: False | approximate: True  | maxdiff: 7.450580596923828e-09
(1, 200)        | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09
(1, 200)        | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09


In [None]:
# calibrate batch norm parameters at the end of training
with torch.no_grad():
    emb = C[X_tr]
    embcat = emb.view(emb.shape[0],-1)
    hpreact = embcat @ W1 + b1
    # measure mean, std over entire training set
    bnmean = hpreact.mean(0, keepdim=True)
    bnvar = hpreact.var(0, keepdim=True, unbiased = True)

In [26]:
# evaluate train and val loss
@torch.no_grad()
def split_loss(split):
    x, y = {'train': (X_tr, Y_tr),
            'val' : (X_d, Y_d),
            'test' : (X_t, Y_t)}[split]
    
    emb = C[x]
    embcat = emb.view(emb.shape[0],-1)
    hpreact = embcat @ W1 + b1
    hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5) ** -0.5 + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')

train 2.0708296298980713
val 2.106884002685547


In [29]:
g = torch.Generator().manual_seed(2147483647)

for i in range(20):
    out = ''
    context = [0] * block_size
    while True:
        # --------------
        # forward pass
        # embedding
        emb = C[torch.tensor([context])]
        embcat = emb.view(emb.shape[0],-1)
        hpreact = embcat @ W1 + b1
        hpreact = bngain * (hpreact - bnmean) * (bnvar+1e-5)**-0.5 + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        #------------------
        # sample
        probs = F.softmax(logits, dim = 1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out += itos[ix]

        if ix==0:
            break
    print(out)


dex.
mariah.
makilah.
tyha.
malissana.
nella.
kaman.
arreliyah.
jaxson.
mari.
moriella.
kinzie.
darette.
kamside.
eniavion.
rosbut.
huniven.
tahlyn.
kashru.
anesley.
