In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load data
with open("names.txt", "r", encoding="utf-8") as f:
    names = f.read().splitlines()

chars = sorted(list(set("".join(names))))

stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}

In [3]:
# Build the dataset
block_size = 3

def build_dataset(words):
    X, Y = [], []

    for word in words:
        context = [0] * block_size

        for char in word + ".":
            token = stoi[char]
            X.append(context)
            Y.append(token)
            context = context[1:] + [token]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

In [4]:
# Split up training, validation, test sets

import random
random.shuffle(names)
n1 = int(0.8 * len(names))
n2 = int(0.9 * len(names))

Xtr, Ytr = build_dataset(names[:n1])
Xdev, Ydev = build_dataset(names[n1:n2])
Xte, Yte = build_dataset(names[n2:]) 

torch.Size([182480, 3]) torch.Size([182480])
torch.Size([22749, 3]) torch.Size([22749])
torch.Size([22917, 3]) torch.Size([22917])


In [105]:
class Linear:

    def __init__(self, fan_in: int, fan_out: int, bias:bool=True):
        self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5
        self.bias = torch.randn(fan_out) if bias else None

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNorm1d:

    def __init__(self, dim: int, eps: float=1e-5, momentum: float=0.1):
        self.momentum = momentum
        self.eps = eps
        self.training = True
        self.running_mean = torch.ones((1, dim))
        self.running_var = torch.zeros((1, dim))
        self.gain = torch.ones((1, dim))
        self.bias = torch.zeros((1, dim))

    def __call__(self, x: torch.Tensor):
        if self.training:
            x_mean = x.mean(0, keepdim=True)
            x_var = x.var(0, keepdim=True)
        else:
            x_mean = self.running_mean
            x_var = self.running_var

        # Normalize
        xhat = (x - x_mean) / torch.sqrt(x_var + self.eps)
        self.out = self.gain * xhat + self.bias

        if self.training:
            # Running stats
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * x_mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * x_var
        
        return self.gain * (x - x_mean) / x_var + self.bias

    def parameters(self) -> list[torch.Tensor]:
        return [self.gain, self.bias]

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out
        

    def parameters(self):
        return []

In [106]:
vocab_size = len(stoi.keys())
embedding_size = 10
hidden_size = 200

C = torch.randn((vocab_size, embedding_size))
layers = [
    Linear(embedding_size * block_size, hidden_size, bias=False), BatchNorm1d(hidden_size), Tanh(),
    Linear(hidden_size, vocab_size),
]

# Param init
with torch.no_grad():
    layers[-1].weight *= 0.1 # Make the last layer less confident

parameters = [C] + [p for layer in layers for p in layer.parameters()]
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

12097


In [107]:
max_steps = 2000000
minibatch_size = 32

# Training
for i in range(max_steps):
    # Construct minibatch
    minibatch_ixs = torch.randint(0, Xtr.shape[0], (minibatch_size,))
    Xb, Yb = Xtr[minibatch_ixs], Ytr[minibatch_ixs]
    
    # Forward pass
    emb = C[Xb]
    x = emb.view(emb.shape[0], -1)

    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, Yb)

    # Backward
    for p in parameters:
        p.grad = None
    loss.backward()

    # Update
    lr = 0.1 if i < 150000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # Track stats
    losses.append(loss.log10().item())
    steps.append(i)
    if i % 10000 == 0:
        print(f"{i:7d}/{max_steps:7d}: {loss.item():0.4f}")

      0/2000000: 3.6199
  10000/2000000: 2.1675
  20000/2000000: 2.6014
  30000/2000000: 2.1431
  40000/2000000: 2.3884
  50000/2000000: 1.8408
  60000/2000000: 1.8843
  70000/2000000: 2.7445
  80000/2000000: 2.1415
  90000/2000000: 2.5392
 100000/2000000: 2.3680
 110000/2000000: 2.1752
 120000/2000000: 2.1710
 130000/2000000: 1.9797
 140000/2000000: 2.2489
 150000/2000000: 2.4057
 160000/2000000: 2.5833
 170000/2000000: 2.1312
 180000/2000000: 2.2822
 190000/2000000: 1.8204
 200000/2000000: 2.1840
 210000/2000000: 1.9097
 220000/2000000: 2.1071
 230000/2000000: 1.8752
 240000/2000000: 2.1295
 250000/2000000: 1.7789
 260000/2000000: 2.4175
 270000/2000000: 2.2578
 280000/2000000: 2.0341
 290000/2000000: 2.0111
 300000/2000000: 2.1889
 310000/2000000: 2.1699
 320000/2000000: 1.8922
 330000/2000000: 2.3911
 340000/2000000: 1.8432
 350000/2000000: 2.2223


In [None]:
plt.plot(steps, losses)
plt.show()

In [None]:
# Post training
for layer in layers:
    layer.training = False

In [77]:
@torch.no_grad()
def split_loss(X, Y):
    emb = C[Xtr]
    embcat = emb.view(emb.shape[0], -1)
    hpreact = embcat @ W1 + b1
    hpreact = bn_gain * (hpreact - bn_mean_running) / bn_std_running + bn_bias
    h = torch.tanh(hpreact)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Ytr)
    return loss.item()

print(f"Training: {split_loss(Xtr, Ytr):.4f}")
print(f"Dev: {split_loss(Xdev, Ydev):.4f}")

Training: 2.0681
Dev: 2.0681


In [86]:
# Sample
for _ in range(10):
    name = []
    context = [0] * block_size

    while True:
        embs = C[torch.tensor([context])]
        h = torch.tanh(embs.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        token = torch.multinomial(probs, num_samples=1).item()
        context = context[1:] + [token]
        if token == 0:
            break
        else:
            name.append(itos[token])

    print("".join(name))

duza
ffandrandren
cafisseden
maditon
josioksagust
rog
kovijahzav
sahmontavivtangasyn
mic
casslynn
