In [None]:
from pathlib import Path
from time import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import trange

from kgi import apply_kgi_to_model

In [None]:
# attempt to enable LaTeX rendering
# change to `False` if you get an error during plotting (latex not installed)
plt.rcParams['text.usetex'] = True

# Data

In [None]:
def batchify(data, labels, batch_size):
    """ manually batchify data """
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], labels[i:i + batch_size]


def load_dataset_to_gpu(batch_size=256, device="cuda"):
    """ load dataset in CPU, process, and move to GPU """
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = torchvision.datasets.FashionMNIST(root='./datasets', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.FashionMNIST(root='./datasets', train=False, download=True, transform=transform)
    train_data = torch.stack([train_dataset[i][0] for i in range(len(train_dataset))]).to(device)
    train_labels = torch.tensor([train_dataset[i][1] for i in range(len(train_dataset))], dtype=torch.long).to(device)
    test_data = torch.stack([test_dataset[i][0] for i in range(len(test_dataset))]).to(device)
    test_labels = torch.tensor([test_dataset[i][1] for i in range(len(test_dataset))], dtype=torch.long).to(device)
    train_batches = list(batchify(train_data, train_labels, batch_size))
    test_batches = list(batchify(test_data, test_labels, batch_size))
    return train_batches, test_batches


device_ = "cuda" if torch.cuda.is_available() else "cpu"
train_set, test_set = load_dataset_to_gpu(256, device=device_)

# Training

In [None]:
class FashionMNISTMLP(nn.Module):
    def __init__(self):
        super(FashionMNISTMLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)  # Input size is 28x28, output size is 256
        self.fc2 = nn.Linear(256, 128)  # Hidden layer with 128 units
        self.fc3 = nn.Linear(128, 10)  # Output layer for 10 classes

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the image
        x = torch.relu(self.fc1(x))  # First hidden layer
        x = torch.relu(self.fc2(x))  # Second hidden layer
        x = self.fc3(x)  # Output layer
        return x

In [None]:
def train(kgi, seed=0, num_epochs=10000, save_every=20, progress_bar=True):
    torch.manual_seed(seed)

    # model
    model = FashionMNISTMLP().to(device_)
    if kgi:
        apply_kgi_to_model(model, knot_low=[-0.8, 0., 0.], knot_high=0.8,
                           perturb_factor=0.2, kgi_by_bias=False)
    model.train()

    # loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # train loop
    loss_history = []
    loop = trange(num_epochs, desc="Training Epochs", disable=not progress_bar)
    for epoch in loop:
        running_loss = 0.0
        for images, labels in train_set:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        if (epoch + 1) % save_every == 0:
            epoch_loss = running_loss / len(train_set)
            loss_history.append(epoch_loss)
            loop.set_postfix({"Epoch Loss": f"{epoch_loss:.2e}"})

    # evaluation
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in test_set:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()  # noqa

    accuracy = 100 * correct / total
    if progress_bar:
        print(f'Final Accuracy on test set: {accuracy:.2f}%')
    return loss_history, accuracy

In [None]:
# train all models
seeds = list(range(10))  # use `seeds = [0]` for fast test
epochs = 10000  # use a smaller one for fast test
out_dir = Path("results/mnist_paper")

out_dir.mkdir(exist_ok=True, parents=True)
for seed_ in seeds:
    for kgi_ in [False, True]:
        name_ = f"{seed_}_{kgi_}"
        if not (out_dir / name_).exists():
            t0 = time()
            hist_, acc = train(kgi_, seed_, epochs, progress_bar=False)
            np.savetxt(out_dir / name_, hist_, header=f"{acc}")
            print(f"{name_} trained in {(time() - t0) / 60:.1f} min, loss={hist_[-1]:.2e}, acc={acc:.2f}%")
        else:
            print(f"{name_} exists")

# Analysis

### Metrics

In [None]:
def print_metrics(kgi):
    losses = []
    speeds = []
    accs = []
    for seed in seeds:
        # read history
        name = f"{seed}_{kgi}"
        hist = np.loadtxt(out_dir / name)
        # use average of last 100 epochs (5 * 20) as final loss
        final_loss = hist[-5:].mean()
        losses.append(final_loss)
        # AUC for convergence speed
        speeds.append(1. / (hist.sum() - final_loss))
        # read relative error
        with open(out_dir / name) as fs:
            acc_str = fs.readline()
        accs.append(float(acc_str[1:]))
    losses = np.array(losses)
    speeds = np.array(speeds)
    accs = np.array(accs)
    print("KGI" if kgi else "No KGI")
    # print in latex format
    print(f"Loss: {losses.mean():.1e} \pm {losses.std():.1e}")
    print(f"Speed: {speeds.mean():.1e} \pm {speeds.std():.1e}")
    print(f"Accuracy: {accs.mean():.1f}\% \pm {accs.std():.1f}\%")


print_metrics(False)
print_metrics(True)

### Loss history

In [None]:
def moving_ave(series, window_size):
    """ Moving average to smooth the loss history a little bit """
    return np.convolve(series, np.ones(window_size) / window_size, mode='valid')


_, ax = plt.subplots(figsize=(5 / 1.5, 4 / 1.5), dpi=200)
seed_ = 0
hist_def = np.loadtxt(out_dir / f"{seed_}_{False}")
hist_kgi = np.loadtxt(out_dir / f"{seed_}_{True}")
ax.plot(np.arange(0, epochs // 20), moving_ave(hist_def, 1), label="No KGI", lw=1, c='b')
ax.plot(np.arange(0, epochs // 20), moving_ave(hist_kgi, 1), label="KGI", lw=1, c='r')
ax.set_xticks([0, 2000 // 20, 4000 // 20, 6000 // 20, 8000 // 20, 10000 // 20],
              [0, 2, 4, 6, 8, 10])
ax.set_ylabel("Cross entropy loss")
ax.set_xlabel("Epoch ($10^3$)")
ax.set_yscale("log")
ax.legend(ncol=2, handlelength=.8, columnspacing=.5, handletextpad=.4)
# not saving figure here because we will plot all histories together 
plt.show()