# Bake Data

In [2]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
import os

save_path = './data/MNIST/baked/'
os.makedirs(save_path, exist_ok=True)
file_path = lambda x: os.path.join(save_path, x)

transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
    Lambda(lambda x: torch.flatten(x))])

train_set = MNIST('./data/', train=True, download=True, transform=transform)
test_set = MNIST('./data/', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=len(train_set), shuffle=False)
test_loader = DataLoader(test_set, batch_size=len(test_set), shuffle=False)

train_x, train_y = next(iter(train_loader))
test_x, test_y = next(iter(test_loader))

torch.save(train_x, file_path('train_x.pt'))
torch.save(train_y, file_path('train_y.pt'))
torch.save(test_x, file_path('test_x.pt'))
torch.save(test_y, file_path('test_y.pt'))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 387128928.20it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 102138021.77it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 136077984.31it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 21550371.91it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# Utils

In [1]:
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

class UnitLength(nn.Module):
    """Layer that normalises its inputs to a unit length vector"""
    def forward(self, x):
        return F.normalize(x)
    
class LayerOutputs:
    """
    Iterator that returns the output of each layer in a model, in turn. Model
    must be an iterable of layers.
    
    Example:
        >>> model = nn.Sequential(...) 
        >>> [h.mean() for h in LayerOutputs(model, x)]
    """
    def __init__(self, model, x):
        self.layers = iter(model)
        self.x = x

    def __iter__(self):
        return self

    def __next__(self):
        layer = next(self.layers)
        self.x = layer(self.x)
        return self.x

def visualise_sample(x, title='', sample_index=0):
    img = x[sample_index].cpu().reshape(28, 28)
    plt.figure(figsize = (4, 4))
    plt.title(title)
    plt.imshow(img, cmap="gray")
    plt.show()

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import split
from torch.optim import Adam
from utils import LayerOutputs, UnitLength


ModuleNotFoundError: ignored

In [None]:
def distance2_to_centroids(h, y_true, epsilon=1e-12):
    """
    Calculates the mean squared distance to the centroid of each class. 

    Returns a tensor of shape [n_examples, 10].
    """
    safe_mean = lambda x, dim: x.sum(dim) / (x.shape[dim] + epsilon)
    # TODO: what if class is missing? determine centroids only for classes that are present, and return torch.unique(y_true)
    class_centroids = torch.stack([safe_mean(h[y_true == i],0) for i in range(10)], dim=1) # [n_in, 10]
    x_to_centroids = h.unsqueeze(2) - class_centroids # [n_examples, n_in, 10]
    return x_to_centroids.pow(2).mean(1) # [n_examples, 10]
    
@torch.no_grad()
def predict(model, x, y_true):
    """Predict by finding the class with closest centroid to each example."""
    d = sum(distance2_to_centroids(h, y_true) for h in LayerOutputs(model, x))
    return d.argmin(1) # type: ignore


In [None]:
def centroid_loss(h, y_true, alpha=10, epsilon=1e-12):
    """
    Loss function based on distance^2 to the true centroid vs a nearby centroid.
    
    Achieves an error rate of ~2.0%.
    """

    # Distance from h to centroids of each class
    d2 = distance2_to_centroids(h, y_true)

    # Choose a nearby class, at random, using the inverse distance as a
    # probability distribution
    y_near = torch.multinomial((1 / (d2 + epsilon)), 1).squeeze(1)

    # Smoothed version of triplet loss: max(0, d2_same - d2_near + margin)
    d2_true = d2[range(d2.shape[0]), y_true] # ||anchor - positive||^2
    d2_near = d2[range(d2.shape[0]), y_near] # ||anchor - negative||^2
    return F.silu(alpha * (d2_true - d2_near)).mean()


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# The data is pre-processed, to speed up this script
x_tr = torch.load('./data/MNIST/baked/train_x.pt', device)
y_tr = torch.load('./data/MNIST/baked/train_y.pt', device)
x_te = torch.load('./data/MNIST/baked/test_x.pt', device)
y_te = torch.load('./data/MNIST/baked/test_y.pt', device)


In [None]:
# Define the model
# ----------------
# Must be an iterable of layers. I find it works best if each layer starts with
# a UnitLength() sub-layer.
model = nn.Sequential(
    nn.Sequential(UnitLength(), nn.Linear(784, 500), nn.ReLU()),
    nn.Sequential(UnitLength(), nn.Linear(500, 500), nn.ReLU()),
).to(device)


In [None]:
# Evaluate the model on the training and test set
def print_evaluation(epoch=None):
    global model, x_tr, y_tr, x_te, y_te
    error_rate = lambda x, y: 1.0 - torch.mean((x == y).float()).item()
    prediction_error = lambda x, y: error_rate(predict(model, x, y), y)
    train_error = prediction_error(x_tr, y_tr)
    test_error = prediction_error(x_te, y_te)
    epoch_str = 'init' if epoch is None else f"{epoch:>4d}"
    print(f"[{epoch_str}] Training: {train_error*100:>5.2f}%\tTest: {test_error*100:>5.2f}%")


In [None]:
# Training parameters
torch.manual_seed(42)
loss_fn = centroid_loss
learning_rate = 0.05
optimiser = Adam(model.parameters(), lr=learning_rate)
num_epochs = 120+1
batch_size = 4096


In [None]:
# Train the model
print_evaluation()
for epoch in range(num_epochs):

    # Mini-batch training
    for x, y in zip(split(x_tr, batch_size), split(y_tr, batch_size)):

        # Train layers in turn, using backprop locally only
        for layer in model:
            h = layer(x)
            loss = centroid_loss(h, y)
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            x = h.detach() # no need to forward propagate x again, as direction doesn't change

    # Evaluate the model on the training and test set
    if (epoch + 1) % 5 == 1:
        print_evaluation(epoch)
