# ConvKAN + MNIST
Purpose: Fit a Convolutional KAN to the CIFAR-10 dataset, for benchmarking the KAN performance.

Furthermore, the PyTorch Lightning library is used for convenience.

I've copied some parts from the KAN experimentation and MNISTMLP, due to there being some degree of overlap between the two.

In [1]:
# Imports
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from lightning.pytorch.loggers import CSVLogger

import sys
sys.path.append('../convkans/kan_convolutional')
from KANLinear import *
from KANConv import KAN_Convolutional_Layer

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(42, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

Seed set to 42


In [2]:
# Dataset Setup -- Inspired by Hugo's Dataset Reformatting
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST("./temp/", train=True, download=True, transform=transform)
test_dataset = MNIST("./temp/", train=False, download=True, transform=transform)

# Reformatted, due to odd issues when using NTK on it
class LCDataset(Dataset): # Lightning Compatible Dataset
    def __init__(self, dataset, num_classes, limit=-1):
        self.limit = limit
        self.num_classes = num_classes
        if self.limit != -1:
            sub = list(np.random.permutation(np.arange(len(train_dataset)))[0:self.limit]) # Take a random sample of the first some elements.
            self.dataset = Subset(dataset, sub)
        else:
            self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        y_one_hot = torch.zeros(self.num_classes)
        y_one_hot[y] = 1
        return x, y_one_hot

batch_size = 64
ntk_loader = DataLoader(LCDataset(train_dataset, num_classes=10, limit=100), batch_size=batch_size, num_workers=10)
train_loader = DataLoader(LCDataset(train_dataset, num_classes=10), batch_size=batch_size, shuffle=True, num_workers=10)
test_loader = DataLoader(LCDataset(test_dataset, num_classes=10), batch_size=batch_size, num_workers=10)

In [3]:
# Model Declaration
class ConvKAN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.convkan = nn.Sequential(
            KAN_Convolutional_Layer(in_channels=1, out_channels=16, kernel_size=(2, 2), grid_size=2, spline_order=2, device=device),
            nn.MaxPool2d(4),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Flatten(1)
        )

        self.dl = nn.Sequential(
            KAN(layers_hidden=[576, 10], grid_size=2, spline_order=2)
        )
    
    def forward(self, x):
        temp = self.convkan(x)
        #print(len(temp[0]))
        return self.dl(temp)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y) # MSE Loss works better for NTK
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y) # MSE Loss works better for NTK
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test loss (Cross Entropy Loss)", loss)
        self.log("accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
        return optimizer

In [4]:
# Train + Test + Results
model = ConvKAN()
trained_model = L.Trainer(max_epochs=20, deterministic=True, logger=CSVLogger("logs", name="MNISTConvKANNTK"))
trained_model.fit(model, train_loader)
trained_model.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | convkan | Sequential | 416    | train
1 | dl      | Sequential | 34.6 K | train
-----------------------------------------------
35.0 K    Trainable params
0         Non-trainable params
35.0 K    Total params
0.140     Total estimated model params size (MB)
60        Modules in train mode
0         Modules in eval mode


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

[{'test loss (Cross Entropy Loss)': 0.05913151428103447,
  'accuracy': 0.9812999963760376}]

In [5]:
# Apply NTK
%run -i './introduction_code_modified.py'

def cross_entropy_loss_batch(y_hat, y):
    return F.cross_entropy(y_hat, y, reduction='none')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ntk_model = GaussianFit(model=model, device=device, noise_var=0.0)
ntk_model.fit(ntk_loader, optimizer, MSELoss_batch) # MSELoss seems to work better for NTK Models

In [6]:
def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        sumlength += len(x)
        res += (torch.argmax(model.forward(x), dim=1) == torch.argmax(y, dim=1)).sum()
    model.train()
    return res / sumlength

In [7]:
print(f'NTK Accuracy: {check_ntk_acc(ntk_model, test_loader)}')

NTK Accuracy: 0.12559999525547028


In [8]:
# R: True Value, C: Predicted Value
def make_predict_matrix(model, dataloader):
    res = np.zeros(shape=(10, 10), dtype=int)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        x_arg = torch.argmax(model.forward(x), dim=1)
        y_arg = torch.argmax(y, dim=1)
        for i in range(len(x_arg)):
            res[y_arg[i], x_arg[i]] += 1
    return res

In [9]:
print(make_predict_matrix(ntk_model, test_loader))

[[321  17 125  93 130  31  18  87  91  67]
 [ 23 724  61 181  33  36  34  36   1   6]
 [ 14  48 713  58  92  38  16  16  16  21]
 [ 25  80 115 564  27  46  58  58  22  15]
 [ 99  87  62  74 427  56  44  29  83  21]
 [ 22  26  42  41  48 526  21  37  77  52]
 [ 26  51  56  84  53  48 394 131  22  93]
 [ 25 112  47 171   9  18   5 565  24  52]
 [ 25 102  89  29  19  96  18  88 467  41]
 [ 26 150  82  63  34  46  35  41  32 500]]


In [10]:
print(make_predict_matrix(model, test_loader))

[[ 971    0    0    0    0    1    2    1    4    1]
 [   0 1129    3    1    0    1    0    0    1    0]
 [   1    2 1020    2    1    0    1    1    4    0]
 [   0    0    4  994    0    5    0    2    5    0]
 [   1    1    2    0  970    0    0    2    2    4]
 [   2    0    1   13    0  866    3    0    6    1]
 [   5    2    1    0    2    3  941    0    4    0]
 [   0    4   19    3    1    0    0  994    3    4]
 [   2    0    2    1    1    1    0    4  960    3]
 [   0    3    0    5   11    2    0   10    6  972]]
