In [1]:
import torch
from pathlib import Path

torch.manual_seed(1337)

<torch._C.Generator at 0x10f2ca790>

In [2]:
REPO_ROOT = Path.cwd()
if (REPO_ROOT / "data").exists() is False and (REPO_ROOT.parent / "data").exists():
    REPO_ROOT = REPO_ROOT.parent

data_path = REPO_ROOT / "data" / "names.txt"
words = data_path.read_text(encoding="utf-8").splitlines()

print("repo root:", REPO_ROOT)
print("num words:", len(words))
print("first 5 words:", words[:5])

repo root: /Users/home/Developer/github/makemore-notes
num words: 32033
first 5 words: ['emma', 'olivia', 'ava', 'isabella', 'sophia']


In [3]:
chars = sorted(set("".join(words)))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}
vocab_size = len(itos)

print("vocab_size:", vocab_size)
print("itos sample:", list(itos.items())[:10])

vocab_size: 27
itos sample: [(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e'), (6, 'f'), (7, 'g'), (8, 'h'), (9, 'i'), (10, 'j')]


In [4]:
block_size = 3  # number of characters used as context

def build_dataset(words, block_size):
    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + ".":
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

X, Y = build_dataset(words, block_size)

print("X shape:", X.shape)
print("Y shape:", Y.shape)

X shape: torch.Size([228146, 3])
Y shape: torch.Size([228146])


In [5]:
assert X.shape[0] == Y.shape[0]
assert X.shape[1] == block_size
assert X.dtype == torch.int64
assert Y.dtype == torch.int64
assert Y.min() >= 0 and Y.max() < vocab_size

In [6]:
def decode_context(ctx):
    return "".join(itos[int(i)] for i in ctx)

for i in [0, 1, 2, 3, 4, 20, 100]:
    print(
        f"{i:>5} | X = {decode_context(X[i])!r} -> Y = {itos[int(Y[i])]!r}"
    )

    0 | X = '...' -> Y = 'e'
    1 | X = '..e' -> Y = 'm'
    2 | X = '.em' -> Y = 'm'
    3 | X = 'emm' -> Y = 'a'
    4 | X = 'mma' -> Y = '.'
   20 | X = 'sab' -> Y = 'e'
  100 | X = 'lla' -> Y = '.'


In [7]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("device:", device)

X = X.to(device)
Y = Y.to(device)

device: mps


In [8]:
g = torch.Generator(device=device).manual_seed(1337)

n_embed = 10     # embedding dimension
n_hidden = 200   # hidden layer width

In [9]:
# Embedding table: (vocab_size, n_embed)
C = torch.randn((vocab_size, n_embed), generator=g, device=device)

# First layer: takes block_size*n_embed -> n_hidden
W1 = torch.randn((block_size * n_embed, n_hidden), generator=g, device=device) * (5/3) / (block_size * n_embed) ** 0.5
b1 = torch.randn((n_hidden,), generator=g, device=device) * 0.01

# Output layer: n_hidden -> vocab_size
W2 = torch.randn((n_hidden, vocab_size), generator=g, device=device) * 0.01
b2 = torch.zeros((vocab_size,), device=device)

parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad_(True)

sum(p.nelement() for p in parameters)

11897

In [10]:
# minibatch for sanity (faster than full dataset)
batch_size = 32
ix = torch.randint(0, X.shape[0], (batch_size,), generator=g, device=device)

Xb, Yb = X[ix], Y[ix]  # (B, block_size), (B,)

# 1) embedding lookup: (B, block_size, n_embed)
emb = C[Xb]

# 2) concatenate embeddings: (B, block_size*n_embed)
emb_cat = emb.view(emb.shape[0], -1)

# 3) hidden layer
h_pre = emb_cat @ W1 + b1          # (B, n_hidden)
h = torch.tanh(h_pre)              # nonlinearity

# 4) logits
logits = h @ W2 + b2               # (B, vocab_size)

print("Xb:", Xb.shape)
print("emb:", emb.shape)
print("emb_cat:", emb_cat.shape)
print("h:", h.shape)
print("logits:", logits.shape)

Xb: torch.Size([32, 3])
emb: torch.Size([32, 3, 10])
emb_cat: torch.Size([32, 30])
h: torch.Size([32, 200])
logits: torch.Size([32, 27])


In [11]:
import torch.nn.functional as F

loss = F.cross_entropy(logits, Yb)
loss

tensor(3.3152, device='mps:0', grad_fn=<NllLossBackward0>)

In [12]:
assert logits.shape == (batch_size, vocab_size)
assert Yb.shape == (batch_size,)
assert torch.isfinite(loss).item() is True