<h1 style="text-align: center; font-weight: bold; font-size: 36px;">Character Level MLP - Activations, Gradients and BatchNorm</h1>

# Introduction

Let's create a **MLP** model. Explore training and debugging techniques.

Inspired by Karpathy [Neural Networks: Zero-to-Hero](https://github.com/karpathy/nn-zero-to-hero). 
We are using the same [names.txt](https://github.com/karpathy/makemore/blob/master/names.txt) as in Zero to Hero so we can compare results.

# Imports

In [1]:
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.set_printoptions(precision=4, sci_mode=False)

# PyTorch-ify

In [None]:
class Tokenizer:
    def __init__(self, vocab):
        assert isinstance(vocab, list)
        assert all(isinstance(v, str) for v in vocab)
        assert all(len(v) == 1 for v in vocab)
        self.stoi = {ch: i for i, ch in enumerate(vocab)}
        self.itos = {i: ch for i, ch in enumerate(vocab)}

    def encode(self, text):
        return [self.stoi[s] for s in text]

    def decode(self, sequence):
        if isinstance(sequence, list):
            return ''.join([self.itos[i] for i in sequence])
        elif isinstance(sequence, torch.Tensor):
            assert sequence.ndim in [0, 1]
            if sequence.ndim == 0:
                return self.itos[sequence.item()]  # one char
            else:
                return ''.join([self.itos[i.item()] for i in sequence])
        else:
            raise ValueError(f"Type {type(sequence)} not supported")

class Embedding:
    def __init__(self, num_embeddings, embedding_dim):
        self.weight = torch.randn((num_embeddings, embedding_dim))
    
    def __call__(self, x):
        return self.weight[x]
    
    def parameters(self):
        return [self.weight]
        

class Linear:
    def __init__(self, in_features, out_features, bias=True):
        gain = 1.0 / (in_features**0.5)      # 1 / sqrt(fan_in)
        self.weight = torch.randn((in_features, out_features)) * gain
        self.bias = torch.randn((1, out_features)) * gain if bias else None

    def __call__(self, x):
        out = x @ self.weight
        if self.bias is not None:
            out += self.bias
        self.out = out   # for debug/experiments
        return out
        

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class Tanh:
    def __init__(self):
        pass

    def __call__(self, x):
        return torch.tanh(x)

    def parameters(self):
        return []

class BatchNorm1d:
    def __init__(self, num_features, eps=1e-05, momentum=0.1):
        self.gain = torch.ones((1, num_features))
        self.bias = torch.zeros((1, num_features))
        self.eps = eps

        self.momentum = momentum
        self.mean_running = torch.zeros((1, num_features))
        # TODO: shoudl track var for better eps behavior
        self.std_running = torch.ones((1, num_features))

        self.training = True

    def __call__(self, x):
        if self.training:
            x_mean = torch.mean(x, dim=0, keepdim=True)
            x_std = torch.std(x, dim=0, keepdim=True)
            with torch.no_grad():
                self.mean_running = (1-self.momentum) * self.mean_running + self.momentum * x_mean
                self.std_running = (1-self.momentum) * self.std_running + self.momentum * x_std
        else:
            x_mean = self.mean_running
            x_std = self.std_running
        zx = (x - x_mean) / (x_std + self.eps)
        out = zx * self.gain + self.bias
        return out
    
    def parameters(self):
        return [self.gain, self.bias]

# Build the Dataset

In [2]:
with open('../data/names.txt', 'r') as f:
    names = f.read().splitlines()
print("Num names:", len(names))
print("Example names:", names[:10])
print("Min length:", min(len(name) for name in names))
print("Max length:", max(len(name) for name in names))

Num names: 32033
Example names: ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']
Min length: 2
Max length: 15


In [3]:
# Get vocabulary
letters = sorted(list(set(''.join(names))))
letters = ['.'] + letters
n_vocab = len(letters)
print(letters)

['.', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [5]:
def build_dataset(tok, block_size, names):
    X, Y = [], []  # inputs and targets
    for name in names:
        name = '.'*block_size + name + '.'  # add start/stop tokens '..emma.'
        for i in range(len(name) - block_size):
            X.append(tok.encode(name[i:i+block_size]))
            Y.append(tok.encode(name[i+block_size])[0])  # [0] to keep Y 1d tensor
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

In [6]:
block_size = 3  # context length
tok = Tokenizer(vocab=letters)

random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))

Xtr, Ytr = build_dataset(tok, block_size, names[:n1])
Xval, Yval = build_dataset(tok, block_size, names[n1:n2])
Xtest, Ytest = build_dataset(tok, block_size, names[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [7]:
# Expected initial loss:
expected_initial_loss = -1 * torch.tensor(1/n_vocab).log()
print(expected_initial_loss)

tensor(3.2958)


In [127]:
# Random Init
torch.manual_seed(42)

# Hyperparameters
n_embd = 10
n_hid = 200

embd = Embedding(n_vocab, n_embd)

lin1 = Linear(n_embd*block_size, n_hid, bias=False)
lin1.weight *= 5/3  # Kaiming init
bn1 = BatchNorm1d(n_hid, momentum=0.001)
tanh1 = Tanh()

lin2 = Linear(n_hid, n_vocab)
# TODO: change to standard pytorch
lin2.weight *= (n_hid**0.5)
lin2.weight *= 0.01
lin2.bias *= 0.0     # disable bias on startup


# params = [C, W1, b1, bngain1, bnbias1, W2, b2]
# params = [C, W1, bngain1, bnbias1, W2, b2]
params = [*embd.parameters(), *lin1.parameters(), *bn1.parameters(), *lin2.parameters()]

for p in params:
    p.requires_grad = True

# No gradient calculation
## bnmean1_running = torch.zeros((1, n_hid))
## bnstd1_running = torch.ones((1, n_hid))

In [128]:
iters, losses = [], []

lr_schedule = [0.1]*100000 + [0.01]*100000
num_epochs = len(lr_schedule)
batch_size = 32
i = 0

In [129]:
for _ in range(num_epochs):

    # Model Setup
    bn1.training = True

    # Random mini batch
    batch_indices = torch.randint(0, Xtr.shape[0], (batch_size,))
    x_batch = Xtr[batch_indices]
    y_batch = Ytr[batch_indices]

    # TODO: check all dims
    # TODO: copy forward pass to eval and sampling

    # Forward Pass
    em = embd(x_batch)                         # n_batch, n_seq, n_emb
    embcat = em.view(-1, n_embd*block_size)    # n_batch, n_embd*block_size
    z1 = lin1(embcat)                          # n_batch, n_hid1
    zz = bn1(z1)                               # n_batch, n_hid1
    h1 = tanh1(zz)                             # n_batch, n_hid1
    logits = lin2(h1)                          # n_batch, n_vocab
    loss = F.cross_entropy(logits, y_batch)

    # Backward Pass
    for p in params:
        p.grad = None
    loss.backward()
    
    lr = lr_schedule[i]
    for p in params:
        p.data += -lr * p.grad

    if i % 10000 == 0:
        print(i, loss.item())

    iters.append(i)
    losses.append(loss.item())
    i += 1

    # break

print(i, loss.item())

0 3.331486225128174
10000 2.2780206203460693
20000 2.2366952896118164
30000 2.05989146232605
40000 2.5080156326293945
50000 1.9756540060043335
60000 1.8901540040969849
70000 2.3095462322235107
80000 2.197767496109009
90000 2.1742000579833984
100000 2.2180747985839844
110000 2.2235922813415527
120000 1.857809066772461
130000 2.1419620513916016
140000 2.396003246307373
150000 2.208874225616455
160000 2.0063955783843994
170000 1.9545999765396118
180000 1.9888185262680054
190000 1.7216124534606934
200000 1.9555671215057373


In [102]:
def print_train_check():
    with torch.no_grad():
        my_std = sum(p.std() for p in params)
        my_mean = sum(p.mean() for p in params)
        my_sum = sum(p.sum() for p in params)
        my_max = sum(p.max() for p in params)
        my_min = sum(p.min() for p in params)

    print(f"{my_std=}")
    print(f"{my_mean=}")
    print(f"{my_sum=}")
    print(f"{my_max=}")
    print(f"{my_min=}")

In [103]:
print_train_check()

my_std=tensor(1.3207)
my_mean=tensor(1.0633)
my_sum=tensor(233.5514)
my_max=tensor(4.3551)
my_min=tensor(-2.7899)


In [None]:
# Original:
# my_std=tensor(1.3207)
# my_mean=tensor(1.0633)
# my_sum=tensor(233.5514)
# my_max=tensor(4.3551)
# my_min=tensor(-2.7899)
print_train_check()

my_std=tensor(1.3207)
my_mean=tensor(1.0633)
my_sum=tensor(233.5514)
my_max=tensor(4.3551)
my_min=tensor(-2.7899)


In [None]:
# hidden layer activations
plt.hist(z1.view(-1).tolist(), bins=100)
plt.show()

# hidden layer outputs
plt.hist(h1.view(-1).tolist(), bins=100)
plt.show()

# neurons in tanh flat region
plt.imshow(h1.abs() > 0.99, cmap='gray', interpolation='nearest')
plt.show()

In [None]:
#plt.plot(iters, torch.log(torch.tensor(losses)))
plt.plot(iters, torch.tensor(losses).log10())
plt.show()

In [130]:
@torch.no_grad()
def calc_batch_norm_params_on_train_set():

    # Model Setup
    bn1.training = False

    # Whole Dataset
    x_batch = Xtr

    # Forward Pass
    em = embd(x_batch)                         # n_batch, n_seq, n_emb
    embcat = em.view(-1, n_embd*block_size)    # n_batch, n_embd*block_size
    z1 = lin1(embcat)                          # n_batch, n_hid1

    # Batchnorm
    z1_mean = torch.mean(z1, dim=0, keepdim=True)  # 1, n_hid1
    z1_std = torch.std(z1, dim=0, keepdim=True)    # 1, n_hid1

    return z1_std, z1_mean

bnstd1, bnmean1 = calc_batch_norm_params_on_train_set()

# Compare approaches
print( torch.abs(bnstd1 - bn1.std_running).max() )
print( torch.abs(bnmean1 - bn1.mean_running).max() )

tensor(0.0539)
tensor(0.0266)


In [131]:
@torch.no_grad()
def evaluate(x_batch, y_batch):
    # Model Setup
    bn1.training = False

    # Forward Pass
    em = embd(x_batch)                         # n_batch, n_seq, n_emb
    embcat = em.view(-1, n_embd*block_size)    # n_batch, n_embd*block_size
    z1 = lin1(embcat)                          # n_batch, n_hid1
    zz = bn1(z1)                               # n_batch, n_hid1
    h1 = tanh1(zz)                             # n_batch, n_hid1
    logits = lin2(h1)                          # n_batch, n_vocab
    loss = F.cross_entropy(logits, y_batch)
    return loss.item()

print("train = ", evaluate(Xtr, Ytr))    # ~2.12
print("eval =  ", evaluate(Xval, Yval))  # ~2.15

train =  2.0673458576202393
eval =   2.1093027591705322


In [132]:
# Base:
# train =  2.1214349269866943
# eval =   2.1567015647888184

# Fixed W2*0.01 and b2*0.0 init:
# train =  2.064913511276245
# eval =   2.129284143447876

# Fixed W1*0.2 and b1*0.01 init:
# train =  2.0375447273254395
# eval =   2.104278326034546

# Kaming init
# same as above, we went 0.2 -> 0.3 on W1, so not much difference
# train =  2.0372910499572754
# eval =   2.1173431873321533

# Initial batchnorm (same, no gains expected, NN probably context limited)
# train =  2.067298650741577
# eval =   2.1195051670074463

# Proper batchnorm (running mean/std, no linear layer bias)
# train =  2.0653347969055176
# eval =   2.121156930923462

# Torch-ified batchnorm (classes)
# train =  2.0673458576202393
# eval =   2.1093027591705322

# Sampling

In [133]:
@torch.no_grad()
def sample_name():
    # Model Setup
    bn1.training = False

    context = tok.encode('.'*block_size)
    while True:
        # Construct Batch
        x_batch = torch.tensor(context[-3:]).view(1, -1)   # n_batch=1, n_seq

        # Forward Pass
        em = embd(x_batch)                         # n_batch, n_seq, n_emb
        embcat = em.view(-1, n_embd*block_size)    # n_batch, n_embd*block_size
        z1 = lin1(embcat)                          # n_batch, n_hid1
        zz = bn1(z1)                               # n_batch, n_hid1
        h1 = tanh1(zz)                             # n_batch, n_hid1
        logits = lin2(h1)                          # n_batch, n_vocab

        # Probabilities        
        probs = torch.softmax(logits, dim=1)

        # Sample
        sample = torch.multinomial(probs, 1).item()
        context.append(sample)
        
        # Break
        if sample == 0:  # stop token
            break

    return tok.decode(context)[block_size:]

In [134]:
torch.manual_seed(42)

for i in range(10):
    print(sample_name())

anuelen.
tia.
marian.
dan.
shan.
silaylen.
kemah.
lanie.
epiacelle.
jamiy.
