In [None]:
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

The model in this notebook is based on a simplified version of the [WaveNet](https://arxiv.org/abs/1609.03499) architecture.

### Helper Functions

In [None]:
with open('names.txt', 'r', encoding='utf-8') as f:
    text = f.read().split('\n')

# Unique characters in the text
chars = ['.'] + sorted(list(set(''.join(text))))
vocab_size = len(chars)

# Mapping from characters to integers and vice versa
char_to_int = {c: i for i, c in enumerate(chars)}
int_to_char = {i: c for i, c in enumerate(chars)}

# B - batch size, T - block size, C - embedding dimension (vocab size)

In [None]:
block_size = 8 # Context length for predictions

def build_dataset(text: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
    """Builds the dataset for training the model."""
    X, Y = [], []
    for word in text:
        context = [0] * block_size # Padding the context with initial '.' characters
        for char in word + '.':
            ix = char_to_int[char]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # Update context
    return torch.tensor(X), torch.tensor(Y)

random.shuffle(text)
split = int(len(text) * 0.9) # 90% train, 10% val

x_train, y_train = build_dataset(text[:split])
x_val, y_val = build_dataset(text[split:])

### WaveNet Model

In [None]:
class Linear:
    """Linear layer with an optional bias."""

    def __init__(self, n_in: int, n_out: int, bias: bool = True):
        self.weights = torch.randn((n_in, n_out)) / n_in ** 0.5 # Kaiming initialisation
        self.bias = torch.zeros(n_out) if bias else None

    def __call__(self, x: list[float]):
        self.out = x @ self.weights
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weights] + ([] if self.bias is None else [self.bias])


class BatchNorm1D:
    """Batch normalisation layer."""

    def __init__(self, n_dims: int, epsilon: float = 1e-5, momentum: float = 1e-1):
        self.epsilon = epsilon # Small value to avoid division by zero
        self.momentum = momentum # Update rate for the running mean and std
        self.training = True
        # Parameters
        self.gamma = torch.ones(n_dims) # Scaling for normalised activations
        self.beta = torch.zeros(n_dims) # Offset for normalised activations
        # Buffers
        self.mean = torch.zeros(n_dims)
        self.var = torch.ones(n_dims)

    def __call__(self, x: list[float]):
        # Forward pass
        if self.training:
            if x.ndim == 2:
                dim = 0
            elif x.ndim == 3:
                dim = (0, 1)
            x_mean = x.mean(dim, keepdim=True) # Batch mean
            x_var = x.var(dim, keepdim=True) # Batch variance
        else:
            x_mean = self.mean
            x_var = self.var
        x_norm = (x - x_mean) / torch.sqrt(x_var + self.epsilon) # Normalise to unit variance
        self.out = self.gamma * x_norm + self.beta
        # Update buffers
        if self.training:
            with torch.no_grad():
                self.mean = (1 - self.momentum) * self.mean + self.momentum * x_mean
                self.var = (1 - self.momentum) * self.var + self.momentum * x_var
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

class Tanh:
    """Tanh activation function."""

    def __call__(self, x: list[float]):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

class Embedding:
    """Embedding layer."""

    def __init__(self, n_embd: int, n_dims: int):
        self.weights = torch.randn((n_embd, n_dims))

    def __call__(self, x: list[float]):
        self.out = self.weights[x] # Embed characters into vectors (B,T) -> (B,T,C)
        return self.out

    def parameters(self):
        return [self.weights]
    
class FlattenConsecutive:
    """Flattens consecutive elements."""

    def __init__(self, n):
        self.n = n

    def __call__(self, x: list[float]):
        x = x.view(x.size(0), x.size(1) // self.n, x.size(2) * self.n) # Concatenate the embeddings (B,T,C) -> (B,T//n,C*n)
        if x.size(1) == 1: # Remove T dimension if equal to 1
            x = x.squeeze(1)
        self.out = x
        return self.out

    def parameters(self):
        return []

class Sequential:
    """Sequential model."""

    def __init__(self, layers: list):
        self.layers = layers

    def __call__(self, x: list[float]):
        for layer in self.layers:
            x = layer(x)
        self.out = x
        return self.out

    def parameters(self):
        return [param for layer in self.layers for param in layer.parameters()]

In [None]:
n_embd = 24 # Embedding dimension
n_hidden = 128 # Neurons in the hidden layer

# WaveNet
model = Sequential([
    Embedding(vocab_size, n_embd),
    FlattenConsecutive(2), Linear(n_embd * 2, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    FlattenConsecutive(2), Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    FlattenConsecutive(2), Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, vocab_size, bias=False)
])

# Initialisations
with torch.no_grad():
    model.layers[-1].weights *= 0.1 # Ensure the output layer is not too confident initially

for param in model.parameters():
    param.requires_grad = True
    
print(f'Model parameters: {sum(param.nelement() for param in model.parameters())}')

In [None]:
@torch.no_grad()
def split_loss(split: str) -> None:
    """Evaluate the loss over the given split."""
    x, y = (x_train, y_train) if split == 'train' else (x_val, y_val)
    # Forward pass
    logits = model(x)
    loss = F.cross_entropy(logits, y)
    return loss.data

In [None]:
# Forward a batch of 4 examples for layer by layer inspection
ix = torch.randint(0, x_train.shape[0], (4, ))
xb, yb = x_train[ix], y_train[ix]
logits = model(xb)

for layer in model.layers:
    print(f'{layer.__class__.__name__}: {tuple(layer.out.size())}')

In [None]:
init_lr = 1e-1 # Initial learning rate
final_lr = 1e-2 # Final learning rate
lr_switch = 3 / 4 # Ratio of iterations to switch from initial to final learning rate
batch_size = 32 # Number of samples per batch
max_iters = 100000

# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    logits = model(xb)
    loss = F.cross_entropy(logits, yb) # Cross entropy loss
    losses.append(loss.data)

    # Backward pass
    for param in model.parameters():
        param.grad = None
    loss.backward()

    # Update the parameters
    lr = init_lr if i < (max_iters * lr_switch) else final_lr
    for param in model.parameters():
        param.data -= lr * param.grad

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

# Plot the loss, averaged over 100 iterations
plt.plot(torch.tensor(losses).view(-1, max_iters // 100).mean(dim=1))

for layer in model.layers: # Disable batch normalisation during evaluation
  layer.training = False
print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')


In [None]:
# Sample from the model
for _ in range(5):
    out = []
    context = [0] * block_size # Initialise context to '...'

    while True:
        # Forward pass
        logits = model(torch.tensor([context]))
        probs = F.softmax(logits, dim=1)
        # Sample the next character from the distribution for the current character index
        ix = torch.multinomial(probs, num_samples=1, replacement=True).item()
        context = context[1:] + [ix] # Shift the context window
        out.append(int_to_char[ix])
        if ix == 0:
            break # End of word
    print(''.join(out))