# MLP + MNIST + NTK
Purpose: Fit a FFNN to the MNIST dataset, for benchmarking the KAN performance.

Furthermore, the PyTorch Lightning library is used for convenience.

I've copied some parts from the KAN experimentation, due to there being some degree of overlap between the two.

In [1]:
# Imports
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from lightning.pytorch.loggers import CSVLogger

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(42, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

Seed set to 42


In [2]:
# Dataset Setup
train_dataset = MNIST("./temp/", train=True, download=True)
test_dataset = MNIST("./temp/", train=False, download=True)

class LCDataset(Dataset): # Lightning Compatible Dataset
    def __init__(self, dataset, limit=-1):
        self.limit = limit
        if self.limit != -1:
            self.sub = list(np.random.permutation(np.arange(len(train_dataset)))[0:self.limit]) # Take a random sample of the first some elements.
            self.data = dataset.data[self.sub, :, :].view(-1, 28*28).type(torch.float32)
            self.target = torch.tensor((pd.get_dummies(pd.Series(dataset.targets[self.sub].numpy())).map(lambda x: 1 if x == True else 0)).values).type(torch.float32)
        else:
            self.data = dataset.data.view(-1, 28*28).type(torch.float32)
            self.target = torch.tensor((pd.get_dummies(pd.Series(dataset.targets.numpy())).map(lambda x: 1 if x == True else 0)).values).type(torch.float32)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

train_loader = DataLoader(LCDataset(train_dataset, limit=2750), batch_size=64, shuffle=True, num_workers=10)
test_loader = DataLoader(LCDataset(test_dataset), batch_size=64, num_workers=10)

In [3]:
# Model Declaration
class FFNNMLP(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 240),
            nn.ReLU(),
            nn.Linear(240, 60),
            nn.ReLU(),
            nn.Linear(60, 10)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test loss (cross entropy)", loss)
        self.log("accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [4]:
# Train + Test + Results
model = FFNNMLP()
trained_model = L.Trainer(max_epochs=100, deterministic=True, logger=CSVLogger("logs", name="MNISTMLPNTK"), log_every_n_steps=25)
trained_model.fit(model, train_loader)
trained_model.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | net  | Sequential | 203 K  | train
--------------------------------------------
203 K     Trainable params
0         Non-trainable params
203 K     Total params
0.814     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

[{'test loss (cross entropy)': 0.48374447226524353,
  'accuracy': 0.9089999794960022}]

In [5]:
# Apply NTK
%run -i 'misc/introduction_code_modded_KAN_NTK.py'

def cross_entropy_loss_batch(y_hat, y):
    return F.cross_entropy(y_hat, y, reduction='none')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ntk_model = GaussianFit(model=model, device=device, noise_var=0.0)
ntk_model.fit(train_loader, optimizer, cross_entropy_loss_batch)

In [6]:
def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        sumlength += len(x)
        res += (torch.argmax(model.forward(x), dim=1) == torch.argmax(y, dim=1)).sum()
    model.train()
    return res / sumlength

In [7]:
print(f'NTK Accuracy: {check_ntk_acc(ntk_model, test_loader)}')

NTK Accuracy: 0.18639999628067017


In [8]:
# R: True Value, C: Predicted Value
def make_predict_matrix(model, dataloader):
    res = np.zeros(shape=(10, 10), dtype=int)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        x_arg = torch.argmax(model.forward(x), dim=1)
        y_arg = torch.argmax(y, dim=1)
        for i in range(len(x_arg)):
            res[y_arg[i], x_arg[i]] += 1
    return res

In [9]:
print(make_predict_matrix(ntk_model, test_loader))

[[175  65  91 123  85 133  81  98  54  75]
 [ 36 351 107  97 100 102  76 118  74  74]
 [ 62 102 203 109 103 101  91 114  92  55]
 [ 41 120 139 173 119 125  57  80  90  66]
 [ 51  69  88 119 178 120  96 102  72  87]
 [ 35  93  80 123 109 161  66  73  85  67]
 [ 59  56 132 114 142 125 121  71  72  66]
 [ 56  74 114 139 138  77  61 199  69 101]
 [ 58  81  98  96 152 101  69  97 137  85]
 [ 56  73 102  96 138  85  89 122  82 166]]


In [10]:
print(make_predict_matrix(model, test_loader))

[[ 938    0    2    1    3   10    7    2   13    4]
 [   0 1113    5    3    1    1    4    1    7    0]
 [  15    3  908   16   11    7   14   27   26    5]
 [   8    2   24  870    0   34    8   19   36    9]
 [   5    3    4    0  880    0   11    2    4   73]
 [   5    4    7   25    7  779   28    5   26    6]
 [  15    3    9    0    8    7  914    0    2    0]
 [   2   12   14    6   12    4    4  950    1   23]
 [  15    6   15   17   11   14   14   25  833   24]
 [   5    6    3    9   30    9    2   32    8  905]]
