<h1 style="text-align: center; font-weight: bold; font-size: 36px;">Building a WaveNet</h1>

# Introduction

Let's create a **MLP** model. Explore training and debugging techniques.

Inspired by Karpathy [Neural Networks: Zero-to-Hero](https://github.com/karpathy/nn-zero-to-hero). 
We are using the same [names.txt](https://github.com/karpathy/makemore/blob/master/names.txt) as in Zero to Hero so we can compare results.

# Imports

In [None]:
import time
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# PyTorch-ify

In [None]:
class Tokenizer:
    def __init__(self, vocab):
        assert isinstance(vocab, list)
        assert all(isinstance(v, str) for v in vocab)
        assert all(len(v) == 1 for v in vocab)
        self.stoi = {ch: i for i, ch in enumerate(vocab)}
        self.itos = {i: ch for i, ch in enumerate(vocab)}

    def encode(self, text):
        return [self.stoi[s] for s in text]

    def decode(self, sequence):
        if isinstance(sequence, list):
            return ''.join([self.itos[i] for i in sequence])
        elif isinstance(sequence, torch.Tensor):
            assert sequence.ndim in [0, 1]
            if sequence.ndim == 0:
                return self.itos[sequence.item()]  # one char
            else:
                return ''.join([self.itos[i.item()] for i in sequence])
        else:
            raise ValueError(f"Type {type(sequence)} not supported")

# Build the Dataset

In [None]:
with open('../data/names.txt', 'r') as f:
    names = f.read().splitlines()
print("Num names:", len(names))
print("Example names:", names[:10])
print("Min length:", min(len(name) for name in names))
print("Max length:", max(len(name) for name in names))

In [None]:
# Get vocabulary
letters = sorted(list(set(''.join(names))))
letters = ['.'] + letters
n_vocab = len(letters)
print(letters)

In [None]:
def build_dataset(tok, block_size, names):
    X, Y = [], []  # inputs and targets
    for name in names:
        name = '.'*block_size + name + '.'  # add start/stop tokens '..emma.'
        for i in range(len(name) - block_size):
            X.append(tok.encode(name[i:i+block_size]))
            Y.append(tok.encode(name[i+block_size])[0])  # [0] to keep Y 1d tensor
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

In [None]:
block_size = 3  # context length
tok = Tokenizer(vocab=letters)

random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))
n2 = int(0.9*len(names))

Xtr, Ytr = build_dataset(tok, block_size, names[:n1])
Xval, Yval = build_dataset(tok, block_size, names[n1:n2])
Xtest, Ytest = build_dataset(tok, block_size, names[n2:])

# Build the Model

In [None]:
class Embedding:
    def __init__(self, num_embeddings, embedding_dim):
        self.weight = torch.randn((num_embeddings, embedding_dim))
    
    def __call__(self, x):
        self.out = self.weight[x]
        return self.out
    
    def parameters(self):
        return [self.weight]

class Linear:
    def __init__(self, in_features, out_features, bias=True):
        self.weight = torch.randn((in_features, out_features)) / in_features**0.5
        self.bias = torch.zeros(out_features) if bias else None

    def __call__(self, x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out
        
    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNorm1d:
    def __init__(self, dim, eps=1e-05, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True

        self.gain = torch.ones(dim)
        self.bias = torch.zeros(dim)
        
        self.mean_running = torch.zeros(dim)
        self.var_running = torch.ones(dim)

    def __call__(self, x):
        if self.training:
            x_mean = torch.mean(x, dim=0, keepdim=True)
            x_var = torch.var(x, dim=0, keepdim=True)
            with torch.no_grad():
                self.mean_running = (1-self.momentum) * self.mean_running + self.momentum * x_mean
                self.var_running = (1-self.momentum) * self.var_running + self.momentum * x_var
        else:
            x_mean = self.mean_running
            x_var = self.var_running
        zx = (x - x_mean) / (x_var + self.eps)**0.5  # sqrt
        self.out = zx * self.gain + self.bias
        return self.out
    
    def parameters(self):
        return [self.gain, self.bias]

class Tanh:
    def __init__(self):
        pass

    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []
    
class Flatten:
    def __call__(self, x):
        self.out = x.view(x.shape[0], -1)
        return self.out

    def parameters(self):
        return []

class Sequential:
    def __init__(self, layers):
        self.layers = layers
    
    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        self.out = x
        return self.out
    
    def parameters(self):
        return [p for l in self.layers for p in l.parameters()]

    def train(self):
        for l in self.layers: l.training = True

    def eval(self):
        for l in self.layers: l.training = False
    

In [None]:
# Expected initial loss:
expected_initial_loss = -1 * torch.tensor(1/n_vocab).log()
print(expected_initial_loss)

In [None]:
# Experiments:
# - play with kaming init: 5/3
# - play with weight init in Linear: / (in_features**0.5)

# Random Init
torch.manual_seed(42)

# Hyperparameters
n_embd = 10
n_hidden = 200

# Model


model = Sequential([
    Embedding(n_vocab, n_embd),
    Flatten(),
    Linear(n_embd*block_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden, n_vocab),
])

# Parameter Init
with torch.no_grad():
    # for layer in layers[:-1]:
    #     if isinstance(layer, Linear):
    #         layer.weight *= 5/3  # Kaiming init
    model.layers[-1].weight *= 0.1     # last layer less confident for uniform softmax

# Gather Params
params = model.parameters()

# Enable Grad
for p in params:
    p.requires_grad = True

# Total Num Params
print(sum(p.nelement() for p in params))

In [None]:
iters, losses, ud = [], [], []

lr_schedule = [0.1]*150_000 + [0.01]*50_000
num_epochs = len(lr_schedule)
batch_size = 32
i = 0

In [None]:
time_start = time.time()
for _ in range(num_epochs):

    # Set Mode
    model.train()

    # Mini Batch
    batch_indices = torch.randint(0, Xtr.shape[0], (batch_size,))
    x_batch, y_batch = Xtr[batch_indices], Ytr[batch_indices]

    # Forward
    logits = model(x_batch)
    loss = F.cross_entropy(logits, y_batch)

    # Backward
    for p in params:
        p.grad = None
    loss.backward()
    
    # Update
    lr = lr_schedule[i]
    for p in params:
        p.data += -lr * p.grad

    # Stats
    if i % 10000 == 0:
        time_taken = time.time() - time_start
        time_start = time.time()
        print(f"{time_taken:.2f}  {i}  {loss.item()}")
    iters.append(i)
    losses.append(loss.item())
    with torch.no_grad():
        ud.append([(lr*p.grad.std()/p.data.std()).log10().item() for p in params])

    # Break
    if i % 10_000 == 0:
        break

    i += 1

print(i, loss.item())

In [None]:
plt.plot(torch.tensor(losses).view(-1, 1000).mean(dim=-1))
plt.show()

In [None]:
@torch.no_grad()
def evaluate(x_batch, y_batch):
    model.eval()
    logits = model(x_batch)
    loss = F.cross_entropy(logits, y_batch)
    return loss.item()

print("train = ", evaluate(Xtr, Ytr))    # ~2.12
print("eval =  ", evaluate(Xval, Yval))  # ~2.15

In [None]:
@torch.no_grad()
def sample_name():
    model.eval()    
    context = tok.encode('.'*block_size)
    while True:
        x = torch.tensor(context[-3:]).view(1, -1)   # n_batch=1, n_seq
        # Forward Pass
        logits = model(x)
        probs = torch.softmax(logits, dim=1)
        # Sample
        sample = torch.multinomial(probs, 1).item()
        context.append(sample)
        # Break
        if sample == 0:  # stop token
            break

    return tok.decode(context)[block_size:]

In [None]:
torch.manual_seed(42)

for i in range(10):
    print(sample_name())