In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# read in all the word
words = open('/home/nguyen-van-anh/Downloads/names.txt', 'r').read().splitlines()
print(len(words))
print(max(len(w) for w in words))


32033
15


In [3]:
# build the vocab of characters and mapping to/from integers
chars = sorted(list(set(''.join(words))))
chars = ['.'] + chars
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}
vocab_size = len(itos)
print(itos)

{0: '.', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z'}


In [4]:
# build the dataset
block_size = 3
def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for chr in w + '.':
            ix = stoi[chr]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] #crop and append
        
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(f"X.shape: {X.shape}, Y.shape: {Y.shape}")
    return X, Y

import random
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

Xtr,  Ytr  = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte,  Yte  = build_dataset(words[n2:])


X.shape: torch.Size([182667, 3]), Y.shape: torch.Size([182667])
X.shape: torch.Size([22726, 3]), Y.shape: torch.Size([22726])
X.shape: torch.Size([22753, 3]), Y.shape: torch.Size([22753])


In [5]:
# utility function will be used later to compare manual gradients to Pytorch gradients
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item() # exactly equal
    app = torch.allclose(dt, t.grad) # appoximately equal
    maxdiff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')


In [6]:
n_embed  = 10 # the dimensionality of the character embedding vectors
n_hidden = 64 # the number of neurons in the hidden layer of MLP

# Embedded matrix
C = torch.randn((vocab_size, n_embed))

# Layer 1
W1 = torch.randn((n_embed * block_size, n_hidden)) * (5/3)/((n_embed*block_size)**0.5)
b1 = torch.randn(n_hidden                        ) * 0.1 # using b1 just for fun, it's useless because of BN step

#Layer 2
W2 = torch.randn((n_hidden, vocab_size)          ) * 0.1
b2 = torch.randn(vocab_size                      ) * 0.1

# BatchNorm params
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

params = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in params))
for p in params:
    p.requires_grad = True



4137


In [7]:
# Construct mini batch
batch_size = n = 32
ix = torch.randint(0, Xtr.shape[0], (batch_size, ))
Xb, Yb = Xtr[ix], Ytr[ix] # Batch X, Y

In [8]:
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1) # concatenate the embedded vectors in the block size
# Linear layer 1
hprebn = embcat @ W1 + b1 # hidden layer pre-activation
# BatchNorm layer
bnmeani = 1/batch_size * hprebn.sum(0, keepdim=True)
bndiff = hprebn - bnmeani
bndiff2 = bndiff**2
bnvar = 1/(batch_size-1)*bndiff2.sum(0, keepdim=True) # Bessel's correction: (dividing by n - 1, not n)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias

# Non-linearity
h = torch.tanh(hpreact)

# Linear layer 2
logits = h @ W2 + b2

#  Cross entropy loss (same as F.cross_entropy(logits, Yb))
logit_maxes = logits.max(1, keepdim=True).values #view the differentiation as how inputs affect the outputs
norm_logits = logits - logit_maxes # for numerical stability (avoid overflow)
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(batch_size), Yb].mean()

# Pytorch backward pass
for p in params:
    p.grad = None
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv,
          norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv,
          bnvar, bndiff2, bndiff, hprebn, bnmeani, embcat, emb]:
    t.retain_grad() # store gradients for non-leaf tensors
loss.backward()
loss


 

tensor(3.3298, grad_fn=<NegBackward0>)

In [9]:
counts.shape, probs.shape, counts_sum_inv.shape

(torch.Size([32, 27]), torch.Size([32, 27]), torch.Size([32, 1]))

In [10]:
logits.max(1, keepdim=True)

torch.return_types.max(
values=tensor([[0.9956],
        [0.8143],
        [1.1203],
        [0.7209],
        [0.8679],
        [1.0998],
        [1.0154],
        [1.6981],
        [1.4528],
        [0.8336],
        [0.7213],
        [0.9415],
        [0.7016],
        [0.8336],
        [0.8036],
        [0.7016],
        [0.7712],
        [1.5058],
        [0.9388],
        [0.8336],
        [0.7368],
        [1.1038],
        [0.8336],
        [1.1685],
        [1.4102],
        [1.1337],
        [1.2150],
        [0.9996],
        [1.0440],
        [0.6016],
        [1.0178],
        [1.0618]], grad_fn=<MaxBackward0>),
indices=tensor([[25],
        [20],
        [ 4],
        [ 4],
        [13],
        [21],
        [22],
        [25],
        [ 4],
        [25],
        [13],
        [22],
        [ 9],
        [25],
        [ 0],
        [ 9],
        [22],
        [19],
        [ 2],
        [25],
        [10],
        [16],
        [25],
        [19],
        [22],
        [

In [11]:
logits.shape, h.shape, W2.shape, b2.shape

(torch.Size([32, 27]),
 torch.Size([32, 64]),
 torch.Size([64, 27]),
 torch.Size([27]))

In [12]:
hpreact.shape, bngain.shape, bnbias.shape, bnraw.shape, h.dtype

(torch.Size([32, 64]),
 torch.Size([1, 64]),
 torch.Size([1, 64]),
 torch.Size([32, 64]),
 torch.float32)

In [13]:
bndiff.shape, bnvar_inv.shape

(torch.Size([32, 64]), torch.Size([1, 64]))

In [14]:
bnmeani.shape

torch.Size([1, 64])

In [15]:
hprebn.shape, embcat.shape, W1.shape, b1.shape

(torch.Size([32, 64]),
 torch.Size([32, 30]),
 torch.Size([30, 64]),
 torch.Size([64]))

In [16]:
emb.shape, embcat.shape

(torch.Size([32, 3, 10]), torch.Size([32, 30]))

In [17]:
C.shape, Xb.shape

(torch.Size([27, 10]), torch.Size([32, 3]))

In [18]:

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1/n
dprobs = 1/probs * dlogprobs #remember: derivatives have the same size as objects, if not-one dimension is on the right, we can omit the keepdim parameter
dcounts_sum_inv = (dprobs * counts).sum(1, keepdim=True)
dcounts = dprobs * counts_sum_inv
dcounts_sum = dcounts_sum_inv * -(counts_sum**-2) 
dcounts += dcounts_sum * torch.ones_like(counts) 
dnorm_logits = dcounts * counts
dlogit_maxes = -(dnorm_logits * 1).sum(1, keepdim=True)
dlogits = dnorm_logits * 1
dlogits += dlogit_maxes * F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) 
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0) # not-one dimension is on the right, can omit keepdim argument.
dhpreact = dh * (1.0 - h**2)
dbngain = (dhpreact * bnraw).sum(0)
dbnbias = (dhpreact * 1.0).sum(0)
dbnraw = dhpreact * bngain
dbnvar_inv = (dbnraw * bndiff).sum(0)
dbndiff = dbnraw * bnvar_inv
dbnvar = dbnvar_inv * -0.5 * (bnvar + 1e-5)**-1.5 * 1.0
dbndiff2 = dbnvar * torch.ones_like(bndiff2) * 1.0/(n-1)
dbndiff += dbndiff2 * 2.0 * bndiff
dbnmeani = (dbndiff * -1).sum(0)
dhprebn = dbndiff.clone()
dhprebn += dbnmeani * torch.ones_like(hprebn) * 1.0/n 
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for i in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[i,j]    # take the index of character Xb[i,j]
        dC[ix] += demb[i, j, :] # accumulate the derivative of each character in the vocab



cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
cmp('bndiff2', dbndiff2, bndiff2)
cmp('bndiff', dbndiff, bndiff)
cmp('bnmeani', dbnmeani, bnmeani)
cmp('hprebn', dhprebn, hprebn)
cmp('embcat', dembcat, embcat)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('emb', demb, emb)
cmp('C', dC, C)

logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: False | approximate: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bnbias          | exact: False | approximate: True  | maxdiff: 1.862645149230957e-09
bnraw   

In [19]:
loss_fast = F.cross_entropy(logits, Yb)
print(loss_fast.item(), 'diff:', (loss_fast - loss).item())

3.329834461212158 diff: 2.384185791015625e-07


In [20]:
# backward pass

dlogits = F.softmax(logits, 1) # take derivative of loss wrt to logits (softmax function)
dlogits[range(n), Yb] -= 1
dlogits /= n # mean in loss

cmp('logits', dlogits, logits)



logits          | exact: False | approximate: True  | maxdiff: 7.2177499532699585e-09


In [22]:
hppreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias
print('max diff:', (hppreact_fast - hpreact).abs().max())

max diff: tensor(4.7684e-07, grad_fn=<MaxBackward1>)


In [23]:
dhprebn = bngain * bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
cmp('hprebn', dhprebn, hprebn)

hprebn          | exact: False | approximate: True  | maxdiff: 9.313225746154785e-10


In [None]:
# Train a NN: a for loop involves forward pass (include selecting batches), backward pass, and update step.
# init
n_embd = 10
n_hidden = 200

C = torch.randn((vocab_size, n_embd))
# Layer 1
W1 = torch.randn((n_embd*block_size, n_hidden)) * (5/3) / ((n_embd*block_size)**0.5) # He initilization for a better parameter
b1 = torch.randn(n_hidden) * 0.1
# Layer 2
W2 = torch.randn((n_hidden, vocab_size)) * 0.1
b2 = torch.randn(vocab_size) * 0.1
# Batchnorm params
bngain = torch.randn((1, n_hidden)) * 0.1 + 1.0
bnbias = torch.randn((1, n_hidden)) * 0.1

params = [C, W1, b1, W2, b2, bngain, bnbias]
print(sum(p.nelement() for p in params))
for p in params:
    p.requires_grad = True

max_steps = 2000
n = batch_size = 32
lossi = []

for i in range(max_steps):

    # Mini-batch construction
    ix = torch.randint(0, Xtr.shape[0], (batch_size, ) )
    Xb, Yb = Xtr[ix], Ytr[ix]

    # Forward pass
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    # Linear layer
    hprebn = embcat @ W1 + b1
    # Batchnorm layer
    bnmean = hprebn.mean(0, keepdim=True)
    bnvar  = hprebn.var(0, keepdim=True, unbiased=True)
    bnvar_inv = (bnvar + 1e-5) ** -0.5
    bnraw = (hprebn - bnmean) * bnvar_inv
    hpreact = bngain * bnraw + bnbias
    # Non-linearity
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)

    # Backward pass
    for p in params:
        p.grad = None
    loss.backward() # for correctness comparisions
    dC, dW1, db1, dW2, db2, dbngain, dbnbias = None, None, None, None, None, None, None
    # Manual backprop
    dlogits = F.softmax(logits, 1)
    dlogits[range(n), Yb] -= 1
    dlogits /= n
    # 2nd layer backprop
    dh = dlogits @ W2.T
    dW2 = h.T @ dlogits
    db2 = dlogits.sum(0)
    # tanh backprop
    dhpreact = dh * (1.0 - h**2)
    # batchnorm backprop
    dbngain = (dhpreact * bnraw).sum(0, keepdim=True)
    dbnbias = (dhpreact * 1).sum(0, keepdim=True)
    dhprebn = bngain * bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
    # 1st layer
    dembcat = dhprebn @ W1.T
    dW1 = embcat.T @ dhprebn
    db1 = (dhprebn * 1).sum(0)
    # embedding
    demb = dembcat.view(emb.shape)
    dC = torch.zeros_like(C)
    for k in range(Xb.shape[0]):
        for j in range(Xb.shape[1]):
            ix = Xb[k, j]
            dC[ix] += demb[k, j]
    

    
    grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]

    # Update step
    lr = 0.1 if i < 1000 else 0.01 # step learning rate decy
    for p, grad in zip(params, grads):
        p.data += -lr * grad
    

    if i % 100 == 0: # print every once in a while
        print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())







12297
      0/   2000: 3.6841
    100/   2000: 2.5274
    200/   2000: 2.6610
    300/   2000: 2.7603
    400/   2000: 2.5188
    500/   2000: 2.4026
    600/   2000: 2.8043
    700/   2000: 2.1533
    800/   2000: 2.3337
    900/   2000: 2.8883
   1000/   2000: 2.8480
   1100/   2000: 2.1141
   1200/   2000: 2.5449
   1300/   2000: 2.8278
   1400/   2000: 2.6437
   1500/   2000: 2.5289
   1600/   2000: 2.5408
   1700/   2000: 2.5385
   1800/   2000: 2.1418
   1900/   2000: 2.1343


In [72]:
lossi

[0.5663372874259949,
 0.5540695786476135,
 0.5599709749221802,
 0.5487356185913086,
 0.5599564909934998,
 0.5597522854804993,
 0.5349310636520386,
 0.5248026251792908,
 0.5615686178207397,
 0.5380712151527405,
 0.5388144850730896,
 0.5004662871360779,
 0.5586374998092651,
 0.5386630892753601,
 0.5095565915107727,
 0.510406494140625,
 0.4986518323421478,
 0.4894055724143982,
 0.5150602459907532,
 0.47069239616394043,
 0.5046494007110596,
 0.5101689100265503,
 0.4799480736255646,
 0.5152291655540466,
 0.4798921048641205,
 0.4868813455104828,
 0.5008696913719177,
 0.5041429996490479,
 0.5305931568145752,
 0.4770512580871582,
 0.5086319446563721,
 0.485880970954895,
 0.5150389075279236,
 0.4388466775417328,
 0.4994053244590759,
 0.506272554397583,
 0.4663412272930145,
 0.49060940742492676,
 0.4672057628631592,
 0.46189185976982117,
 0.46452686190605164,
 0.5098555684089661,
 0.48640063405036926,
 0.4502503573894501,
 0.45929089188575745,
 0.44344547390937805,
 0.5012738704681396,
 0.460413

In [73]:
for p, g in zip(params, grads):
    cmp(str(tuple(p.shape)), g, p)

(27, 10)        | exact: False | approximate: True  | maxdiff: 1.1175870895385742e-08
(30, 200)       | exact: False | approximate: True  | maxdiff: 7.450580596923828e-09
(200,)          | exact: False | approximate: True  | maxdiff: 9.313225746154785e-09
(200, 27)       | exact: False | approximate: True  | maxdiff: 1.4901161193847656e-08
(27,)           | exact: False | approximate: True  | maxdiff: 7.450580596923828e-09
(1, 200)        | exact: False | approximate: True  | maxdiff: 2.7939677238464355e-09
(1, 200)        | exact: False | approximate: True  | maxdiff: 3.725290298461914e-09


In [74]:
# evaluate train and val loss

@torch.no_grad() # this decorator disables gradient tracking
def split_loss(split):
  x,y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
  }[split]
  emb = C[x] # (N, block_size, n_embd)
  embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)
  hpreact = embcat @ W1 + b1
  hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
  h = torch.tanh(hpreact) 
  logits = h @ W2 + b2 
  loss = F.cross_entropy(logits, y)
  print(split, loss.item())

split_loss('train')
split_loss('val')


train 2.3791940212249756
val 2.386850595474243
