# MLP + MNIST + NTK
Purpose: Fit a FFNN to the MNIST dataset, for benchmarking the KAN performance.

Furthermore, the PyTorch Lightning library is used for convenience.

I've copied some parts from the KAN experimentation, due to there being some degree of overlap between the two.

In [1]:
# Imports
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from lightning.pytorch.loggers import CSVLogger

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(42, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

Seed set to 42


In [2]:
# Dataset Setup -- Inspired by Hugo's Dataset Reformatting
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST("./temp/", train=True, download=True, transform=transform)
test_dataset = MNIST("./temp/", train=False, download=True, transform=transform)

# Reformatted, due to odd issues when using NTK on it
class LCDataset(Dataset): # Lightning Compatible Dataset
    def __init__(self, dataset, num_classes, limit=-1):
        self.limit = limit
        self.num_classes = num_classes
        if self.limit != -1:
            sub = list(np.random.permutation(np.arange(len(train_dataset)))[0:self.limit]) # Take a random sample of the first some elements.
            self.dataset = Subset(dataset, sub)
        else:
            self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        y_one_hot = torch.zeros(self.num_classes)
        y_one_hot[y] = 1
        return x, y_one_hot

batch_size = 64
ntk_loader = DataLoader(LCDataset(train_dataset, num_classes=10, limit=2750), batch_size=batch_size, num_workers=10)
train_loader = DataLoader(LCDataset(train_dataset, num_classes=10), batch_size=batch_size, shuffle=True, num_workers=10)
test_loader = DataLoader(LCDataset(test_dataset, num_classes=10), batch_size=batch_size, num_workers=10)

In [3]:
# Model Declaration
class FFNNMLP(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 240),
            nn.ReLU(),
            nn.Linear(240, 60),
            nn.ReLU(),
            nn.Linear(60, 10)
        )
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y) # MSE Loss works better for NTK
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y) # MSE Loss works better for NTK
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test loss (cross entropy)", loss)
        self.log("accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [4]:
# Train + Test + Results
model = FFNNMLP()
trained_model = L.Trainer(max_epochs=10, deterministic=True, logger=CSVLogger("logs", name="MNISTMLPNTK"), log_every_n_steps=25)
trained_model.fit(model, train_loader)
trained_model.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | net  | Sequential | 203 K  | train
--------------------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.814     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

[{'test loss (cross entropy)': 0.005980664398521185,
  'accuracy': 0.9779000282287598}]

In [5]:
# Apply NTK
%run -i 'misc/introduction_code_modded_KAN_NTK.py'

def cross_entropy_loss_batch(y_hat, y):
    return F.cross_entropy(y_hat, y, reduction='none')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ntk_model = GaussianFit(model=model, device=device, noise_var=0.0)
ntk_model.fit(ntk_loader, optimizer, MSELoss_batch) # MSELoss seems to work better for NTK Models

In [6]:
def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        sumlength += len(x)
        res += (torch.argmax(model.forward(x), dim=1) == torch.argmax(y, dim=1)).sum()
    model.train()
    return res / sumlength

In [7]:
print(f'NTK Accuracy: {check_ntk_acc(ntk_model, test_loader)}')

NTK Accuracy: 0.9768999814987183


In [8]:
# R: True Value, C: Predicted Value
def make_predict_matrix(model, dataloader):
    res = np.zeros(shape=(10, 10), dtype=int)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        x_arg = torch.argmax(model.forward(x), dim=1)
        y_arg = torch.argmax(y, dim=1)
        for i in range(len(x_arg)):
            res[y_arg[i], x_arg[i]] += 1
    return res

In [9]:
print(make_predict_matrix(ntk_model, test_loader))

[[ 972    0    1    0    0    2    2    1    2    0]
 [   0 1126    2    2    0    1    2    1    1    0]
 [   3    3 1002    3    2    1    1   12    4    1]
 [   0    0    3  987    0    3    0    9    3    5]
 [   2    1    2    0  958    0    5    3    1   10]
 [   3    0    0    5    2  872    4    2    2    2]
 [   7    3    0    0    4    4  937    0    3    0]
 [   0    6    9    1    1    1    0 1004    2    4]
 [   5    1    3    6    4    3    1    4  943    4]
 [   4    5    0    5   15    2    0    9    1  968]]


In [10]:
print(make_predict_matrix(model, test_loader))

[[ 971    0    2    0    0    0    3    2    2    0]
 [   0 1126    2    2    0    1    2    0    2    0]
 [   3    2 1008    2    3    0    1    8    4    1]
 [   0    0    3  992    0    3    0    6    3    3]
 [   1    1    3    0  962    0    4    0    1   10]
 [   2    0    0    6    2  871    4    2    3    2]
 [   5    3    0    0    4    5  937    0    4    0]
 [   0    7   14    1    1    1    0  996    2    6]
 [   3    1    2    5    4    3    2    3  947    4]
 [   4    6    1    7    9    2    0    7    4  969]]
