In [None]:
import torch
import random

'''
The MLP implemented in this notebook is based on the design from
the paper 'A Neural Probabilistic Language Model' (Bengio et al. 2003)
https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf

Kaiming initialisation is implemented based on 'Delving Deep into Rectifiers'
(Kaiming et al. 2015) https://arxiv.org/pdf/1502.01852.pdf

Batch normalisation is implemented is based on work from the paper 'Batch
Normalization:  Accelerating Deep Network Training by Reducing Internal 
Covariate Shift' (Ioffe et al. 2015) https://arxiv.org/pdf/1502.03167.pdf
'''

In [None]:
words = open('data/names.txt', 'r').read().splitlines()

chars = ['.'] + sorted(list(set(''.join(words))))
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for i, s in enumerate(chars)}
num_chars = len(chars)

In [None]:
# Hyperparameters
block_size = 3 # Context length (characters).
embedding_dims = 10 # Number of dimensions for the embedding space.
batch_size = 32 # Number of examples to process at a time in training.
hidden_layer_size = 200 # Number of neurons in the hidden layer.
init_lr = 0.1 # Initial learning rate.
final_lr = 0.01 # Final learning rate.
max_steps = 200000

def build_dataset(words: list[str]) -> (torch.Tensor, torch.Tensor):

    X, Y = [], []

    for word in words:
        context = [0] * block_size # Padding the context with initial '.' characters.
        for char in word + '.':
            idx = stoi[char]
            X.append(context)
            Y.append(idx)
            context = context[1:] + [idx] # Update context.

    return torch.tensor(X), torch.tensor(Y)

random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

x_train, y_train = build_dataset(words[:n1])
x_val, y_val = build_dataset(words[n1:n2])
x_test, y_test = build_dataset(words[n2:])

# Embedding matrix.
C = torch.randn((num_chars, embedding_dims))
# Hidden layer.
kaiming_init = (5/3) / (block_size * embedding_dims)**0.5
W1 = torch.randn((block_size * embedding_dims, hidden_layer_size)) * kaiming_init
# Output layer.
W2 = torch.randn((hidden_layer_size, num_chars)) * 0.01
b2 = torch.randn(num_chars) * 0

# Batch normalisation params.
bn_gain = torch.ones((1, hidden_layer_size))
bn_bias = torch.zeros((1, hidden_layer_size))
bn_mean_live = torch.zeros((1, hidden_layer_size))
bn_std_live = torch.ones((1, hidden_layer_size))

params = [C, W1, W2, b2, bn_gain, bn_bias]
for param in params:
    param.requires_grad = True

print(f'Number of parameters: {sum(param.nelement() for param in params)}')

In [None]:
# Gradient descent.
for i in range(max_steps):

    # Constructing batches.
    idx = torch.randint(0, x_train.shape[0], (batch_size, ))

    # Forward pass.
    embedding = C[x_train[idx]] # Embed characters into vectors.
    h_pre_act = embedding.view(embedding.size(0), -1) @ W1 # Hidden layer pre-activation.
    # Batch normalisation.
    bn_mean_i = h_pre_act.mean(0, keepdim=True)
    bn_std_i = h_pre_act.std(0, keepdim=True)
    h_pre_act = bn_gain * (h_pre_act - bn_mean_i) / bn_std_i + bn_bias
    with torch.no_grad():
        bn_mean_live = 0.999 * bn_mean_live + 0.001 * bn_mean_i
        bn_std_live = 0.999 * bn_std_live + 0.001 * bn_std_i
    # Non-linearity.
    h = torch.tanh(h_pre_act)
    logits = h @ W2 + b2
    # Calculate the cross entropy loss.
    loss = torch.nn.functional.cross_entropy(logits, y_train[idx])

    # Backward pass.
    for param in params:
        param.grad = None # Set the gradient to zero.
    loss.backward()

    # Stochastic gradient descent update.
    lr = init_lr if i < (max_steps / 2) else final_lr
    for param in params:
        param.data -= lr * param.grad

In [None]:
def split_loss(split: str) -> None:
    ''' Evaluates the model on the chosen split. '''

    x, y = {
        'train': (x_train, y_train),
        'val': (x_val, y_val),
        'test': (x_test, y_test)
    }[split]
    
    # Forward pass.
    embedding = C[x]
    h_pre_act = embedding.view(embedding.size(0), -1) @ W1
    # Batch normalisation.
    h_pre_act = bn_gain * (h_pre_act - bn_mean_live) / bn_std_live + bn_bias
    h = torch.tanh(h_pre_act)
    logits = h @ W2 + b2
    loss = torch.nn.functional.cross_entropy(logits, y)
    print(f'{split.capitalize()} Loss: {loss.data}')

In [None]:
split_loss('test')
split_loss('val')

In [None]:
# Sample from the model.
for _ in range(5):
    out = []
    context = [0] * block_size # Initialise context to '...'

    while True:
        # Forward pass.
        embedding = C[torch.tensor([context])]
        h_pre_act = embedding.view(embedding.size(0), -1) @ W1
        # Batch normalisation.
        h_pre_act = bn_gain * (h_pre_act - bn_mean_live) / bn_std_live + bn_bias
        h = torch.tanh(h_pre_act)
        logits = h @ W2 + b2
        probs = torch.nn.functional.softmax(logits, dim=1)

        # Sample from the distribution.
        idx = torch.multinomial(probs, num_samples=1, replacement=True).item()
        context = context[1:] + [idx] # Shift the context window.
        out.append(itos[idx])
        if idx == 0: # If we sample '.', stop.
            break
    print(''.join(out))