# Bake Data

In [38]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
import os

save_path = './data/MNIST/baked/'
os.makedirs(save_path, exist_ok=True)
file_path = lambda x: os.path.join(save_path, x)

transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
    Lambda(lambda x: torch.flatten(x))])

train_set = MNIST('./data/', train=True, download=True, transform=transform)
test_set = MNIST('./data/', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=len(train_set), shuffle=False)
test_loader = DataLoader(test_set, batch_size=len(test_set), shuffle=False)

train_x, train_y = next(iter(train_loader))
test_x, test_y = next(iter(test_loader))

torch.save(train_x, file_path('train_x.pt'))
torch.save(train_y, file_path('train_y.pt'))
torch.save(test_x, file_path('test_x.pt'))
torch.save(test_y, file_path('test_y.pt'))

# Utils

In [39]:
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

class UnitLength(nn.Module):
    """Layer that normalises its inputs to a unit length vector"""
    def forward(self, x):
        return F.normalize(x)
    
class LayerOutputs:
    """
    Iterator that returns the output of each layer in a model, in turn. Model
    must be an iterable of layers.
    
    Example:
        >>> model = nn.Sequential(...) 
        >>> [h.mean() for h in LayerOutputs(model, x)]
    """
    def __init__(self, model, x):
        self.layers = iter(model)
        self.x = x

    def __iter__(self):
        return self

    def __next__(self):
        layer = next(self.layers)
        self.x = layer(self.x)
        return self.x

def visualise_sample(x, title='', sample_index=0):
    img = x[sample_index].cpu().reshape(28, 28)
    plt.figure(figsize = (4, 4))
    plt.title(title)
    plt.imshow(img, cmap="gray")
    plt.show()

# Train

In [40]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import split
from torch.optim import Adam
#from utils import LayerOutputs, UnitLength

In [48]:
def class_centroids(h, y_true):
    """
    Calculates the the centroid of each class present in y.

    Returns tensors of shape [n_in, n_classes_present].
    """
    class_labels, y_idx = torch.unique(y, return_inverse=True) # [n_classes_present], [n_examples]
    centroids = torch.stack([h[y == label].mean(0) for label in class_labels], dim=1) # [n_in, n_classes_present]
    return centroids, y_idx

def d(a, b, dim=1):
    return (a - b).pow(2).mean(dim)

def distance_to_centroids(h, y_true, d=d):
    """
    Calculates the mean squared distance to the centroid of every class. 

    Returns a tensor of shape [n_examples, n_in].
    """
    centroids, _ = class_centroids(h, y_true)
    return d(h.unsqueeze(2), centroids) # [n_examples, n_classes_present]
    
@torch.no_grad()
def predict(model, x, y_true):
    """Predict by finding the class with closest centroid to each example."""
    << this won't work, as y_true will be different every time
    d = sum(distance_to_centroids(h, y_true) for h in LayerOutputs(model, x))
    return d.argmin(1) # type: ignore


In [42]:
def centroid_loss(h, y_true, alpha=10, epsilon=1e-12):
    """
    Loss function based on distance^2 to the true centroid vs a nearby centroid.
    
    Achieves an error rate of ~2.0%.
    """

    # Distance from h to centroids of each class
    d2 = distance_to_centroids(h, y_true)

    # Choose a nearby class, at random, using the inverse distance as a
    # probability distribution
    y_near = torch.multinomial((1 / (d2 + epsilon)), 1).squeeze(1)

    # Smoothed version of triplet loss: max(0, d2_same - d2_near + margin)
    d2_true = d2[range(d2.shape[0]), y_true] # ||anchor - positive||^2
    d2_near = d2[range(d2.shape[0]), y_near] # ||anchor - negative||^2
    return F.silu(alpha * (d2_true - d2_near)).mean()


In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# The data is pre-processed, to speed up this script
x_tr = torch.load('./data/MNIST/baked/train_x.pt', device)
y_tr = torch.load('./data/MNIST/baked/train_y.pt', device)
x_te = torch.load('./data/MNIST/baked/test_x.pt', device)
y_te = torch.load('./data/MNIST/baked/test_y.pt', device)


Using device: cuda


In [44]:
# Define the model
# ----------------
# Must be an iterable of layers. I find it works best if each layer starts with
# a UnitLength() sub-layer.
model = nn.Sequential(
    nn.Sequential(UnitLength(), nn.Linear(784, 500), nn.ReLU()),
    nn.Sequential(UnitLength(), nn.Linear(500, 500), nn.ReLU()),
).to(device)


In [45]:
# Evaluate the model on the training and test set
def print_evaluation(epoch=None):
    global model, x_tr, y_tr, x_te, y_te
    error_rate = lambda x, y: 1.0 - torch.mean((x == y).float()).item()
    prediction_error = lambda x, y: error_rate(predict(model, x, y), y)
    train_error = prediction_error(x_tr, y_tr)
    test_error = prediction_error(x_te, y_te)
    epoch_str = 'init' if epoch is None else f"{epoch:>4d}"
    print(f"[{epoch_str}] Training: {train_error*100:>5.2f}%\tTest: {test_error*100:>5.2f}%")


In [46]:
# Training parameters
torch.manual_seed(42)
loss_fn = centroid_loss
learning_rate = 0.05
optimiser = Adam(model.parameters(), lr=learning_rate)
num_epochs = 120+1
batch_size = 4096


In [49]:
# Train the model
print_evaluation()
for epoch in range(num_epochs):

    # Mini-batch training
    for x, y in zip(split(x_tr, batch_size), split(y_tr, batch_size)):

        # Train layers in turn, using backprop locally only
        for layer in model:
            h = layer(x)
            loss = centroid_loss(h, y)
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            x = h.detach() # no need to forward propagate x again, as direction doesn't change

    # Evaluate the model on the training and test set
    if (epoch + 1) % 5 == 1:
        print_evaluation(epoch)


[init] Training: 19.16%	Test: 17.31%
[   0] Training: 11.63%	Test: 10.89%
[   5] Training:  5.20%	Test:  5.44%
[  10] Training:  3.78%	Test:  3.97%
[  15] Training:  2.89%	Test:  3.31%
[  20] Training:  2.53%	Test:  2.97%
[  25] Training:  2.25%	Test:  2.76%
[  30] Training:  2.06%	Test:  2.71%
[  35] Training:  1.89%	Test:  2.51%
[  40] Training:  1.72%	Test:  2.51%
[  45] Training:  1.56%	Test:  2.37%
[  50] Training:  1.48%	Test:  2.42%
[  55] Training:  1.43%	Test:  2.38%
[  60] Training:  1.30%	Test:  2.19%
[  65] Training:  1.24%	Test:  2.15%
[  70] Training:  1.16%	Test:  2.20%
[  75] Training:  1.15%	Test:  2.18%
[  80] Training:  1.08%	Test:  2.20%
[  85] Training:  1.04%	Test:  2.03%
[  90] Training:  1.01%	Test:  2.23%
[  95] Training:  0.96%	Test:  2.08%
[ 100] Training:  0.89%	Test:  2.02%
[ 105] Training:  0.86%	Test:  2.05%
[ 110] Training:  0.82%	Test:  2.08%
[ 115] Training:  0.77%	Test:  2.02%
[ 120] Training:  0.79%	Test:  2.09%
