<a href="https://colab.research.google.com/github/manyachawla22/makemore-char-rnn/blob/main/makemore3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# basics
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

print('torch', torch.__version__)
!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt


torch 2.8.0+cu126
--2025-11-12 15:35:21--  https://raw.githubusercontent.com/karpathy/makemore/master/names.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 228145 (223K) [text/plain]
Saving to: ‘names.txt’


2025-11-12 15:35:22 (11.2 MB/s) - ‘names.txt’ saved [228145/228145]



In [None]:
# load words
words = open('names.txt','r').read().splitlines()
print('num words', len(words))
print('example', words[:6])


num words 32033
example ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte']


In [None]:

chars = sorted(list(set(''.join(words))))
stoi = {ch:i+1 for i,ch in enumerate(chars)}
stoi['.'] = 0   # dot = end of word
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print('vocab size', vocab_size)


vocab size 27


In [None]:
# build dataset: contexts of length 3 -> next char
block_size = 3
def build_dataset(ws):
    X, Y = [], []
    for w in ws:
        ctx = [0]*block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(ctx.copy())
            Y.append(ix)
            ctx = ctx[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte, Yte = build_dataset(words[n2:])
print(Xtr.shape, Ytr.shape)


torch.Size([182625, 3]) torch.Size([182625])


In [None]:
# small helper for gradient comparison (like original)
def compare(name, expected, tensor):
    exact = torch.all(expected == tensor.grad).item()
    approx = torch.allclose(expected, tensor.grad)
    maxdiff = (expected - tensor.grad).abs().max().item()
    print(f"{name:15s} | exact: {str(exact):5s} | approx: {str(approx):5s} | maxdiff: {maxdiff}")


In [None]:
# model params (student tweak: smaller embed dims for speed)
n_embd = 10
n_hidden = 64
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((vocab_size, n_embd), generator=g)
W1 = torch.randn((n_embd*block_size, n_hidden), generator=g) * (5/3)/((n_embd*block_size)**0.5)
b1 = torch.randn(n_hidden, generator=g) * 0.1
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1
b2 = torch.randn(vocab_size, generator=g) * 0.1
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1

parameters = [C, W1, b1, W2, b2, bngain, bnbias]
for p in parameters:
    p.requires_grad = True


In [None]:

batch_size = 32
n = batch_size
ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
Xb, Yb = Xtr[ix], Ytr[ix]


In [None]:

emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)
hprebn = embcat @ W1 + b1
bnmean = hprebn.mean(0, keepdim=True)
bndiff = hprebn - bnmean
bndiff2 = bndiff**2
bnvar = 1/(n-1) * bndiff2.sum(0, keepdim=True)
bnvar_inv = (bnvar + 1e-5)**-0.5
bnraw = bndiff * bnvar_inv
hpreact = bngain * bnraw + bnbias
h = torch.tanh(hpreact)
logits = h @ W2 + b2

# softmax and loss (split for manual backprop)
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdim=True)
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()

# keep grads for a few intermediates so we can compare later
for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits, logit_maxes, logits, h, hpreact, bnraw, bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmean, embcat, emb]:
    t.retain_grad()

loss.backward()
print('loss', loss.item())


loss 3.58122181892395


In [None]:
# manual backprop (step by step) — student comments inline
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Yb] = -1.0 / n   # d/dlogprobs of mean NLL
dprobs = (1.0 / probs) * dlogprobs   # since log -> derivative is 1/probs

# dcounts_sum_inv: how counts_sum_inv affects all probs in row (sum over cols)
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)

# derivative w.r.t. counts (two pieces): direct piece and through counts_sum
dcounts = counts_sum_inv * dprobs
dcounts_sum = (-counts_sum**-2) * dcounts_sum_inv
dcounts += torch.ones_like(counts) * dcounts_sum

# back through exp(norm_logits)
dnorm_logits = counts * dcounts

# combine gradients for logits (dnorm_logits) plus the effect from subtracting max
dlogits = dnorm_logits.clone()
dlogit_maxes = (-dnorm_logits).sum(1, keepdim=True)
dlogits += F.one_hot(logits.max(1).indices, num_classes=logits.shape[1]) * dlogit_maxes

# now go back through tanh, batchnorm, linear layers etc. (same order as forward)
dh = dlogits @ W2.T
dW2 = h.T @ dlogits
db2 = dlogits.sum(0)
dhpreact = (1.0 - h**2) * dh
dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
dbnbias = dhpreact.sum(0, keepdim=True)
dbnraw = bngain * dhpreact
dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)
dbnvar = (-0.5 * (bnvar + 1e-5)**-1.5) * dbnvar_inv
dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar
dbndiff = (2 * bndiff) * dbndiff2
dbndiff += dbnraw * bnvar_inv
dhprebn = dbndiff.clone()
dbnmean = (-dbndiff).sum(0)
dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmean)
dembcat = dhprebn @ W1.T
dW1 = embcat.T @ dhprebn
db1 = dhprebn.sum(0)
demb = dembcat.view(emb.shape)
dC = torch.zeros_like(C)
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k,j]
        dC[ix] += demb[k,j]

# compare a few grads to PyTorch's autograd (sanity check)
compare('logprobs', dlogprobs, logprobs)
compare('probs', dprobs, probs)
compare('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
compare('counts_sum', dcounts_sum, counts_sum)
compare('counts', dcounts, counts)
compare('norm_logits', dnorm_logits, norm_logits)
compare('logit_maxes', dlogit_maxes, logit_maxes)
compare('logits', dlogits, logits)
compare('h', dh, h)
compare('W2', dW2, W2)
compare('b2', db2, b2)
compare('hpreact', dhpreact, hpreact)
compare('bngain', dbngain, bngain)
compare('bnbias', dbnbias, bnbias)
compare('bnraw', dbnraw, bnraw)
compare('bnvar_inv', dbnvar_inv, bnvar_inv)
compare('bnvar', dbnvar, bnvar)
compare('bndiff2', dbndiff2, bndiff2)
compare('bndiff', dbndiff, bndiff)
compare('bnmean', dbnmean, bnmean)
compare('hprebn', dhprebn, hprebn)
compare('embcat', dembcat, embcat)
compare('W1', dW1, W1)
compare('b1', db1, b1)
compare('emb', demb, emb)
compare('C', dC, C)


logprobs        | exact: True  | approx: True  | maxdiff: 0.0
probs           | exact: True  | approx: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approx: True  | maxdiff: 0.0
counts_sum      | exact: True  | approx: True  | maxdiff: 0.0
counts          | exact: True  | approx: True  | maxdiff: 0.0
norm_logits     | exact: True  | approx: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approx: True  | maxdiff: 0.0
logits          | exact: True  | approx: True  | maxdiff: 0.0
h               | exact: True  | approx: True  | maxdiff: 0.0
W2              | exact: True  | approx: True  | maxdiff: 0.0
b2              | exact: True  | approx: True  | maxdiff: 0.0
hpreact         | exact: False | approx: True  | maxdiff: 4.656612873077393e-10
bngain          | exact: False | approx: True  | maxdiff: 1.862645149230957e-09
bnbias          | exact: False | approx: True  | maxdiff: 7.450580596923828e-09
bnraw           | exact: False | approx: True  | maxdiff: 4.656612873077393e-1

In [None]:
# quicker cross-entropy backward (one-liner) - same as PyTorch's shortcut
loss_fast = F.cross_entropy(logits, Yb)
print('loss_fast', loss_fast.item(), 'loss original', loss.item())
dlogits_fast = F.softmax(logits, 1)
dlogits_fast[range(n), Yb] -= 1
dlogits_fast /= n
compare('logits', dlogits_fast, logits)


loss_fast 3.5812220573425293 loss original 3.58122181892395
logits          | exact: False | approx: True  | maxdiff: 6.28642737865448e-09


In [None]:
# training loop (manual grads) - student notes inline
n_embd = 10
n_hidden = 200
g = torch.Generator().manual_seed(2147483647)
C  = torch.randn((vocab_size, n_embd),            generator=g)
W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)
b1 = torch.randn(n_hidden,                        generator=g) * 0.1
W2 = torch.randn((n_hidden, vocab_size),          generator=g) * 0.1
b2 = torch.randn(vocab_size,                      generator=g) * 0.1
bngain = torch.randn((1, n_hidden))*0.1 + 1.0
bnbias = torch.randn((1, n_hidden))*0.1
params = [C, W1, b1, W2, b2, bngain, bnbias]
for p in params:
    p.requires_grad = True

max_steps = 200000
batch_size = 32
n = batch_size
losses = []
with torch.no_grad():
    for i in range(max_steps):
        ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
        Xb, Yb = Xtr[ix], Ytr[ix]
        # forward
        emb = C[Xb]
        embcat = emb.view(emb.shape[0], -1)
        hprebn = embcat @ W1 + b1
        bnmean = hprebn.mean(0, keepdim=True)
        bnvar = hprebn.var(0, keepdim=True, unbiased=True)
        bnvar_inv = (bnvar + 1e-5)**-0.5
        bnraw = (hprebn - bnmean) * bnvar_inv
        hpreact = bngain * bnraw + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        loss = F.cross_entropy(logits, Yb)
        # manual backward (fast expression for CE)
        dlogits = F.softmax(logits, 1)
        dlogits[range(n), Yb] -= 1
        dlogits /= n
        # grads second layer
        dh = dlogits @ W2.T
        dW2 = h.T @ dlogits
        db2 = dlogits.sum(0)
        dhpreact = (1.0 - h**2) * dh
        # batchnorm grads (condensed)
        dbngain = (bnraw * dhpreact).sum(0, keepdim=True)
        dbnbias = dhpreact.sum(0, keepdim=True)
        dhprebn = bngain*bnvar_inv/n * (n*dhpreact - dhpreact.sum(0) - n/(n-1)*bnraw*(dhpreact*bnraw).sum(0))
        # first layer
        dembcat = dhprebn @ W1.T
        dW1 = embcat.T @ dhprebn
        db1 = dhprebn.sum(0)
        demb = dembcat.view(emb.shape)
        dC = torch.zeros_like(C)
        for k in range(Xb.shape[0]):
            for j in range(Xb.shape[1]):
                ix = Xb[k,j]
                dC[ix] += demb[k,j]
        grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]
        # update params (simple SGD)
        lr = 0.1 if i < 100000 else 0.01
        for p, g_ in zip(params, grads):
            p.data += -lr * g_
        if i % 10000 == 0:
            print(f"{i:7d}/{max_steps:7d}: {loss.item():.4f}")
        losses.append(loss.log10().item())


      0/ 200000: 3.7645
  10000/ 200000: 2.1757
  20000/ 200000: 2.3695
  30000/ 200000: 2.4349
  40000/ 200000: 2.0139
  50000/ 200000: 2.3806
  60000/ 200000: 2.3613
  70000/ 200000: 2.0748
  80000/ 200000: 2.3880
  90000/ 200000: 2.1920
 100000/ 200000: 1.9972
 110000/ 200000: 2.3468
 120000/ 200000: 1.9976
 130000/ 200000: 2.4412
 140000/ 200000: 2.2924
 150000/ 200000: 2.2099
 160000/ 200000: 2.0261
 170000/ 200000: 1.8294
 180000/ 200000: 2.0307
 190000/ 200000: 1.8484


In [None]:
# calibrate batchnorm stats (student-style) and eval
with torch.no_grad():
    emb = C[Xtr]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    bnmean = hpreact.mean(0, keepdim=True)
    bnvar = hpreact.var(0, keepdim=True, unbiased=True)

def split_loss(split):
    x,y = {'train':(Xtr,Ytr),'val':(Xdev,Ydev),'test':(Xte,Yte)}[split]
    emb = C[x]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    return F.cross_entropy(logits, y).item()

print('train', split_loss('train'))
print('val', split_loss('val'))


train 2.071941375732422
val 2.1088082790374756


In [None]:
# sample from the model (generate names)
g = torch.Generator().manual_seed(2147483647 + 10)
for _ in range(20):
    out = []
    context = [0]*block_size
    while True:
        emb = C[torch.tensor([context])]
        embcat = emb.view(1, -1)
        hpreact = embcat @ W1 + b1
        hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias
        h = torch.tanh(hpreact)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break
    print(''.join(itos[i] for i in out))


carmahza.
jahleigh.
mri.
reity.
skaelane.
mahnen.
delynn.
jareen.
ner.
kiah.
maiiv.
kaleigh.
ham.
joce.
quint.
saline.
liveni.
waythoniearynix.
kaellinsley.
dae.
