In [None]:
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

### Helper Functions

In [None]:
with open('names.txt', 'r', encoding='utf-8') as f:
    text = f.read().split('\n')

# Unique characters in the text
chars = ['.'] + sorted(list(set(''.join(text))))
vocab_size = len(chars)

# Mapping from characters to integers and vice versa
char_to_int = {c: i for i, c in enumerate(chars)}
int_to_char = {i: c for i, c in enumerate(chars)}

# B - batch size, T - block size, C - embedding dimension (vocab size)

In [None]:
block_size = 3 # Context length for predictions

def build_dataset(text: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
    """Builds the dataset for training the model."""
    X, Y = [], []
    for word in text:
        context = [0] * block_size # Padding the context with initial '.' characters
        for char in word + '.':
            ix = char_to_int[char]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # Update context
    return torch.tensor(X), torch.tensor(Y)

random.shuffle(text)
split = int(len(text) * 0.9) # 90% train, 10% val

x_train, y_train = build_dataset(text[:split])
x_val, y_val = build_dataset(text[split:])

### Building the MLP

In [None]:
n_embd = 2 # Embedding dimension
n_hidden = 200 # Neurons in the hidden layer

C = torch.randn((vocab_size, n_embd)) # Embedding table (B,T) -> (B,T,C)
# Hidden layer
W_hidden = torch.randn((block_size * n_embd, n_hidden))
b_hidden = torch.randn(n_hidden)
# Output layer
W_out = torch.randn((n_hidden, vocab_size))
b_out = torch.randn(vocab_size)

params = [C, W_hidden, b_hidden, W_out, b_out]
for param in params:
    param.requires_grad = True

In [None]:
@torch.no_grad()
def split_loss(split: str) -> None:
    """Evaluate the loss over the given split."""
    x, y = (x_train, y_train) if split == 'train' else (x_val, y_val)
    # Forward pass
    emb = C[x] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1) # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden + b_hidden
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out
    loss = F.cross_entropy(logits, y)
    return loss.data

### Finding a Suitable Learning Rate

In [None]:
batch_size = 32 # Number of samples per batch
max_iters = 1000

lrs_exp = torch.linspace(-3, 0, max_iters) # Linearly decrease the learning rate from 1e-3 to 1e-0
lrs = 10 ** lrs_exp

# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden + b_hidden
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out

    # Cross entropy loss
    loss = F.cross_entropy(logits, yb)
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None # Set the gradient to zero
    loss.backward()

    # Update the parameters
    for param in params:
        param.data -= lrs[i] * param.grad

# Plot the mini-batch loss vs. learning rate
plt.plot(lrs_exp, losses)

The data shows the best learning rate occurs at roughly 1e-1.

### Visualise the Learned Embedding Matrix $C$

In [None]:
lr = 1e-1
max_iters = 100000

# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden + b_hidden
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out

    # Cross entropy loss
    loss = F.cross_entropy(logits, yb)
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None # Set the gradient to zero
    loss.backward()

    # Update the parameters
    for param in params:
        param.data -= lr * param.grad

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')

In [None]:
# Visualise the embedding matrix C for all characters
plt.figure(figsize=(8,8))
plt.scatter(C[:,0].data, C[:,1].data, s=200)
for i in range(C.shape[0]):
    plt.text(C[i,0].item(), C[i,1].item(), int_to_char[i], ha='center', va='center', color='white')
plt.grid('minor')

The model has learned to a basic clustering of the characters (all the vowels are clustered together with similar vector embeddings). The embedding dimension was set to 2 for visualisation purposes. From now on, a higher embedding dimension will be used to improve the performance of the model.

### Analysis of Initial Loss

In [None]:
# Redefining the model with a larger embedding dimension of 10

n_embd = 10 # Embedding dimension
n_hidden = 200 # Neurons in the hidden layer

C = torch.randn((vocab_size, n_embd)) # Embedding table (B,T) -> (B,T,C)
# Hidden layer
W_hidden = torch.randn((block_size * n_embd, n_hidden))
b_hidden = torch.randn(n_hidden)
# Output layer
W_out = torch.randn((n_hidden, vocab_size))
b_out = torch.randn(vocab_size)

params = [C, W_hidden, b_hidden, W_out, b_out]
for param in params:
    param.requires_grad = True

In [None]:
# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden + b_hidden
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out

    # Cross entropy loss
    loss = F.cross_entropy(logits, yb)
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None # Set the gradient to zero
    loss.backward()

    # Update the parameters
    for param in params:
        param.data -= lr * param.grad

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

# Plot the mini-batch loss
plt.plot(losses)

print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')

### Fix: Large Initial Loss

In [None]:
C = torch.randn((vocab_size, n_embd)) # Embedding table (B,T) -> (B,T,C)
# Hidden layer
W_hidden = torch.randn((block_size * n_embd, n_hidden))
b_hidden = torch.randn(n_hidden)
# Output layer
W_out = torch.randn((n_hidden, vocab_size)) * 0.01 # Initialised to small values
b_out = torch.randn(vocab_size) * 0 # Initialised to zero

params = [C, W_hidden, b_hidden, W_out, b_out]
for param in params:
    param.requires_grad = True

Initialising the weights in the output layer to be small random values ensures that the network is not confidently wrong about the output (and so has a high loss). This decreases the initial loss and allows the network to learn more effectively as it does not have to spend the first few training iterations correcting the initial large errors and can instead spend more time optimising the weights to reduce the loss.

In [None]:
# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden + b_hidden
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out

    # Cross entropy loss
    loss = F.cross_entropy(logits, yb)
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None # Set the gradient to zero
    loss.backward()

    # Update the parameters
    for param in params:
        param.data -= lr * param.grad

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

# Plot the mini-batch loss
plt.plot(losses)

print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')

### Analysis of the Neuron Activations

In [None]:
# One iteration of mini-batch gradient descent

ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
xb, yb = x_train[ix], y_train[ix]

# Forward pass
emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
hpreact = emb @ W_hidden + b_hidden
h = torch.tanh(hpreact)
logits = h @ W_out + b_out

# Cross entropy loss
loss = F.cross_entropy(logits, yb)
losses.append(loss.data)

# Backward pass
for param in params:
    param.grad = None # Set the gradient to zero
loss.backward()

# Update the parameters
for param in params:
    param.data -= lr * param.grad

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
axs[0].hist(hpreact.view(-1).tolist(), bins=50)
axs[0].set_title('Pre-Activation')
axs[1].hist(h.view(-1).tolist(), bins=50)
axs[1].set_title('Activation')

The vast majority of activations are either -1 or 1. This means that when back-propagation is performed, the gradients of these activations are 0 and so gradient flow is stopped and the weights further back in the network are not updated.

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(h.abs() > 0.99, cmap='gray')

For any one of the 200 neurons, if an entire column is white, then the neuron is dead. This is where the neuron is not activated by any of the inputs and so the weights are not updated.

### Fix: Tanh Layer too Saturated at Initialisation

In [None]:
# Re-initialise the parameters with small random values for the hidden layer

C = torch.randn((vocab_size, n_embd)) # Embedding table (B,T) -> (B,T,C)
# Hidden layer
W_hidden = torch.randn((block_size * n_embd, n_hidden)) * 0.2
b_hidden = torch.randn(n_hidden) * 0.01
# Output layer
W_out = torch.randn((n_hidden, vocab_size)) * 0.01 # Initialised to small values
b_out = torch.randn(vocab_size) * 0 # Initialised to zero

params = [C, W_hidden, b_hidden, W_out, b_out]
for param in params:
    param.requires_grad = True

# One iteration of mini-batch gradient descent

ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
xb, yb = x_train[ix], y_train[ix]

# Forward pass
emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
hpreact = emb @ W_hidden + b_hidden
h = torch.tanh(hpreact)
logits = h @ W_out + b_out

# Cross entropy loss
loss = F.cross_entropy(logits, yb)
losses.append(loss.data)

# Backward pass
for param in params:
    param.grad = None # Set the gradient to zero
loss.backward()

# Update the parameters
for param in params:
    param.data -= lr * param.grad

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
axs[0].hist(hpreact.view(-1).tolist(), bins=50)
axs[0].set_title('Pre-Activation')
axs[1].hist(h.view(-1).tolist(), bins=50)
axs[1].set_title('Activation')

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(h.abs() > 0.99, cmap='gray')

In [None]:
# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden + b_hidden
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out

    # Cross entropy loss
    loss = F.cross_entropy(logits, yb)
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None # Set the gradient to zero
    loss.backward()

    # Update the parameters
    for param in params:
        param.data -= lr * param.grad

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

# Plot the mini-batch loss
plt.plot(losses)

print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')

### Kaiming Initialisation

[Kaiming initialisation](https://arxiv.org/abs/1502.01852) is a method of initialising the weights in a neural network such that the variance of the activations is the same across all layers.

In [None]:
# Example of the need for Kaiming initialisation
x = torch.randn(batch_size, block_size * n_embd)
w = torch.randn(block_size * n_embd, n_hidden)
y = x @ w
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
axs[0].hist(x.view(-1).tolist(), bins=50, density=True)
axs[0].set_title(f'x (mean={x.mean():.2f}, std={x.std():.2f})')
axs[1].hist(y.view(-1).tolist(), bins=50, density=True)
axs[1].set_title(f'y (mean={y.mean():.2f}, std={y.std():.2f})')

The standard deviation of `y` has expanded after performing `y = x @ w`. This deteroriates the performance of the network as the weights should be normalised to prevent the activations from saturating. Kaiming initialisation is used to prevent this by scaling the weights. With a $\tanh$ activation function, the weights are scaled by $\frac{5}{3} \cdot \frac{1}{\sqrt{n_{in}}}$.

In [None]:
kaiming_init = 1 / (block_size * n_embd) ** 0.5
x = torch.randn(batch_size, block_size * n_embd)
w = torch.randn(block_size * n_embd, n_hidden) * kaiming_init
y = x @ w
fig, axs = plt.subplots(1, 2, figsize=(14, 5))
axs[0].hist(x.view(-1).tolist(), bins=50, density=True)
axs[0].set_title(f'x (mean={x.mean():.2f}, std={x.std():.2f})')
axs[1].hist(y.view(-1).tolist(), bins=50, density=True)
axs[1].set_title(f'y (mean={y.mean():.2f}, std={y.std():.2f})')

In [None]:
C = torch.randn((vocab_size, n_embd)) # Embedding table (B,T) -> (B,T,C)
# Hidden layer
kaiming_init = (5/3) / (block_size * n_embd) ** 0.5 # Ensures the variance of the activations is preserved
W_hidden = torch.randn((block_size * n_embd, n_hidden)) * kaiming_init
b_hidden = torch.randn(n_hidden) * 0.01
# Output layer
W_out = torch.randn((n_hidden, vocab_size)) * 0.01 # Initialised to small values
b_out = torch.randn(vocab_size) * 0 # Initialised to zero

### Batch Normalisation

[Batch normalisation](https://arxiv.org/abs/1502.03167) is used to control the distribution of activations in neural networks. It is common to use batch normalisation layers throughout networks, usually after layers that have multiplications. Batch normalisation has learned parameters `bn_gain` and `bn_bias` controlling the scale and offset of the normalised distributions of the activations. It also has buffers `bn_mean` and `bn_std` which are not trainable via backpropagation and these calculate the mean and the standard deviations of the entire set of inputs. Batch normalisation layers prevent the activations from saturating and allow the network to learn more effectively.

In [None]:
# Batch normalisation params
bn_gain = torch.ones((1, n_hidden)) # Scaling for normalised activations
bn_bias = torch.zeros((1, n_hidden)) # Offset for normalised activations

# Running mean and standard deviation over the training set
bn_mean = torch.zeros((1, n_hidden))
bn_std = torch.ones((1, n_hidden))

params = [C, W_hidden, W_out, b_out, bn_gain, bn_bias]
for param in params:
    param.requires_grad = True

Note that there is no need to have a bias term in the linear layer `b_hidden` as the batch normalization step inherently adjusts the mean of the output to zero. Thus, the bias term gets subtracted out in the normalization step, making it redundant. The `bn_bias` parameter in the batch normalisation layer can be used to replace the bias term.

In [None]:
# Redefine the split_loss function to include batch normalisation

@torch.no_grad()
def split_loss(split: str) -> None:
    """Evaluate the loss over the given split."""
    x, y = (x_train, y_train) if split == 'train' else (x_val, y_val)
    # Forward pass
    emb = C[x] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1) # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden
    # Batch normalisation
    hpreact = bn_gain * (hpreact - bn_mean) / bn_std + bn_bias
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out
    loss = F.cross_entropy(logits, y)
    return loss.data

In [None]:
init_lr = 1e-1 # Initial learning rate
final_lr = 1e-2 # Final learning rate
momentum = 1e-3 # Momentum for the moving average of the gradients

# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    emb = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    hpreact = emb @ W_hidden
    # Batch normalisation
    bn_mean_i = hpreact.mean(0, keepdim=True)
    bn_std_i = hpreact.std(0, keepdim=True)
    hpreact = bn_gain * (hpreact - bn_mean_i) / bn_std_i + bn_bias # Normalise the activations

    # Update the runnning mean and std on the side of the forward pass
    with torch.no_grad():
        bn_mean = (1 - momentum) * bn_mean + momentum * bn_mean_i
        bn_std = (1 - momentum) * bn_std + momentum * bn_std_i
        
    h = torch.tanh(hpreact)
    logits = h @ W_out + b_out

    # Cross entropy loss
    loss = F.cross_entropy(logits, yb)
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None # Set the gradient to zero
    loss.backward()

    # Update the parameters
    lr = init_lr if i < (max_iters / 2) else final_lr
    for param in params:
        param.data -= lr * param.grad

    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

# Plot the mini-batch loss
plt.plot(losses)

print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')

### Full Code

In [None]:
class Linear:
    """Linear layer with an optional bias."""

    def __init__(self, n_in: int, n_out: int, bias: bool = True):
        self.weights = torch.randn((n_in, n_out)) / n_in ** 0.5 # Kaiming initialisation
        self.bias = torch.zeros(n_out) if bias else None

    def __call__(self, x: list[float]):
        self.out = x @ self.weights
        if self.bias is not None:
            self.out += self.bias
        return self.out

    def parameters(self):
        return [self.weights] + ([] if self.bias is None else [self.bias])


class BatchNorm1D:
    """Batch normalisation layer."""

    def __init__(self, n_dims: int, epsilon: float = 1e-5, momentum: float = 1e-1):
        self.epsilon = epsilon # Small value to avoid division by zero
        self.momentum = momentum # Update rate for the running mean and std
        self.training = True
        # Parameters
        self.gamma = torch.ones(n_dims) # Scaling for normalised activations
        self.beta = torch.zeros(n_dims) # Offset for normalised activations
        # Buffers
        self.mean = torch.zeros(n_dims)
        self.var = torch.ones(n_dims)

    def __call__(self, x: list[float]):
        # Forward pass
        if self.training:
            x_mean = x.mean(0, keepdim=True) # Batch mean
            x_var = x.var(0, keepdim=True) # Batch variance
        else:
            x_mean = self.mean
            x_var = self.var
        x_norm = (x - x_mean) / torch.sqrt(x_var + self.epsilon) # Normalise to unit variance
        self.out = self.gamma * x_norm + self.beta
        # Update buffers
        if self.training:
            with torch.no_grad():
                self.mean = (1 - self.momentum) * self.mean + self.momentum * x_mean
                self.var = (1 - self.momentum) * self.var + self.momentum * x_var
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

class Tanh:
    """Tanh activation function."""

    def __call__(self, x: list[float]):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

In [None]:
C = torch.randn((vocab_size, n_embd)) # Embedding matrix

# MLP
layers = [
    Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, vocab_size, bias=False), BatchNorm1D(vocab_size)
]

# Initialisations
with torch.no_grad():
    layers[-1].gamma *= 0.1 # Ensure the output layer is not too confident initially
    for layer in layers[:-1]: # For all other layers, apply gain in Kaiming initialisation
        if isinstance(layer, Linear):
            layer.weights *= 5/3

params = [C] + [param for layer in layers for param in layer.parameters()]
for param in params:
    param.requires_grad = True
    
print(f'Model parameters: {sum(param.nelement() for param in params)}')

### Diagnostic Plots from 1000 Iterations of Training

In [None]:
losses = []
updates = [] # Update magnitude for each parameter in each iteration

# First 1000 iterations of mini-batch gradient descent
for i in range(1000):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    x = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, yb) # Cross entropy loss
    losses.append(loss.data)

    # Backward pass
    for layer in layers:
        layer.out.retain_grad()
    for param in params:
        param.grad = None
    loss.backward()

    # Update the parameters
    lr = init_lr if i < (max_iters / 2) else final_lr
    for param in params:
        param.data -= lr * param.grad

    with torch.no_grad(): # Compute the update magnitude for each parameter
        updates.append([((lr * param.grad).std() / param.data.std()).log10().item() for param in params])


In [None]:
# Visualise histograms of the activations
plt.figure(figsize=(20, 4))
legends = []
for i, layer in enumerate(layers):
    if isinstance(layer, Tanh):
        t = layer.out
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].detach(), hy.detach())
        legends.append(f'Layer {i} ({layer.__class__.__name__}) | Mean: {t.mean():.2f} | Std: {t.std():.2f} | Saturation: {((t.abs() > 0.97).float().mean() * 100):.2f}%')
plt.legend(legends)
plt.title('Activation Distributions')

In [None]:
# Visualise histograms of the gradients of the activations
plt.figure(figsize=(20, 4))
legends = []
for i, layer in enumerate(layers):
    if isinstance(layer, Tanh):
        t = layer.out.grad
        hy, hx = torch.histogram(t, density=True)
        plt.plot(hx[:-1].detach(), hy.detach())
        legends.append(f'Layer {i} ({layer.__class__.__name__}) | Mean: {t.mean():.2e} | Std: {t.std():.2e}')
plt.legend(legends)
plt.title('Activation Distributions')

In [None]:
# Visualise the update magnitude for each parameter in each iteration
plt.figure(figsize=(20, 4))
legends = []
for i, param in enumerate(params):
  if param.ndim == 2: # Only plot the weights and embedding matrix
    plt.plot([updates[j][i] for j in range(len(updates))])
    legends.append(f'Param {i} {tuple(param.shape)}')
plt.plot([0, len(updates)], [-3, -3], 'k') # Ratios should be 1e-3
plt.legend(legends)
plt.title('Update Magnitude for Parameters in Each Iteration')

### Evaluation on Final Model

In [None]:
# Redefine the split_loss function

@torch.no_grad()
def split_loss(split: str) -> None:
    """Evaluate the loss over the given split."""
    x, y = (x_train, y_train) if split == 'train' else (x_val, y_val)
    # Forward pass
    emb = C[x] # Embed characters into vectors (B,T) -> (B,T,C)
    x = emb.view(emb.size(0), -1) # Concatenate the embeddings (B,T,C) -> (B,TC)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, y)
    return loss.data

In [None]:
C = torch.randn((vocab_size, n_embd)) # Embedding matrix

# MLP
layers = [
    Linear(n_embd * block_size, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, n_hidden, bias=False), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden, vocab_size, bias=False), BatchNorm1D(vocab_size)
]

# Initialisations
with torch.no_grad():
    layers[-1].gamma *= 0.1 # Ensure the output layer is not too confident initially
    for layer in layers[:-1]: # For all other layers, apply gain in Kaiming initialisation
        if isinstance(layer, Linear):
            layer.weights *= 5/3

params = [C] + [param for layer in layers for param in layer.parameters()]
for param in params:
    param.requires_grad = True
    
print(f'Model parameters: {sum(param.nelement() for param in params)}')

In [None]:
# Mini-batch gradient descent
losses = []
for i in range(max_iters):

    ix = torch.randint(0, x_train.shape[0], (batch_size, )) # Mini-batches
    xb, yb = x_train[ix], y_train[ix]

    # Forward pass
    emb = C[xb] # Embed characters into vectors (B,T) -> (B,T,C)
    x = emb.view(emb.size(0), -1)  # Concatenate the embeddings (B,T,C) -> (B,TC)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, yb) # Cross entropy loss
    losses.append(loss.data)

    # Backward pass
    for param in params:
        param.grad = None
    loss.backward()

    # Update the parameters
    lr = init_lr if i < (max_iters / 2) else final_lr
    for param in params:
        param.data -= lr * param.grad

    with torch.no_grad(): # Compute the update magnitude for each parameter
        updates.append([((lr * param.grad).std() / param.data.std()).log10().item() for param in params])


    if i % (max_iters // 10) == 0 or i == max_iters - 1:
        print(f'Iteration {i:2d} | Loss (Mini-Batch): {loss.data:.4f}')

# Plot the mini-batch loss
plt.plot(losses)

for layer in layers: # Disable batch normalisation during evaluation
  layer.training = False
print(f'Train loss: {split_loss("train"):.4f} | Val loss: {split_loss("val"):.4f}')


In [None]:
# Sample from the model
for _ in range(5):
    out = []
    context = [0] * block_size # Initialise context to '...'

    while True:
        # Forward pass
        emb = C[torch.tensor([context])] # Embed characters into vectors (B,T) -> (B,T,C)
        x = emb.view(emb.size(0), -1) # Concatenate the embeddings (B,T,C) -> (B,TC)
        for layer in layers:
            x = layer(x)
        probs = F.softmax(x, dim=1)

        # Sample the next character from the distribution for the current character index
        ix = torch.multinomial(probs, num_samples=1, replacement=True).item()
        context = context[1:] + [ix] # Shift the context window
        out.append(int_to_char[ix])
        if ix == 0:
            break # End of word
    print(''.join(out))