# CNN + KAN + MNIST + NTK
Purpose: Fit a CNN + KAN to the CIFAR10 dataset, for benchmarking the KAN performance.

Furthermore, the PyTorch Lightning library is used for convenience.

In [1]:
# Imports
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from lightning.pytorch.loggers import CSVLogger

import sys
sys.path.append('../convkans/kan_convolutional')
from KANLinear import *

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(42, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

Seed set to 42


In [2]:
# Dataset Setup -- Inspired by Hugo's Dataset Reformatting
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = CIFAR10("./temp/", train=True, download=True, transform=transform)
test_dataset = CIFAR10("./temp/", train=False, download=True, transform=transform)

# Reformatted, due to odd issues when using NTK on it
class LCDataset(Dataset): # Lightning Compatible Dataset
    def __init__(self, dataset, num_classes, limit=-1):
        self.limit = limit
        self.num_classes = num_classes
        if self.limit != -1:
            sub = list(np.random.permutation(np.arange(len(train_dataset)))[0:self.limit]) # Take a random sample of the first some elements.
            self.dataset = Subset(dataset, sub)
        else:
            self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        y_one_hot = torch.zeros(self.num_classes)
        y_one_hot[y] = 1
        return x, y_one_hot

batch_size = 64
ntk_loader = DataLoader(LCDataset(train_dataset, num_classes=10, limit=100), batch_size=batch_size, num_workers=10)
train_loader = DataLoader(LCDataset(train_dataset, num_classes=10), batch_size=batch_size, shuffle=True, num_workers=10)
test_loader = DataLoader(LCDataset(test_dataset, num_classes=10), batch_size=batch_size, num_workers=10)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Model Declaration
class CNNandKAN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=2, padding=0, bias=False),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=0, bias=False),
            nn.MaxPool2d(4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(1)
        )
        
        self.kan = nn.Sequential(
            KAN(layers_hidden=[3136, 512, 10], grid_size=2, spline_order=2)
        )
    
    def forward(self, x):
        x = self.cnn(x)
        #print(x.size())
        return self.kan(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y) # MSE Loss works better for NTK
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y) # MSE Loss works better for NTK
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test loss (MSE Loss)", loss)
        self.log("accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [4]:
# Train + Test + Results
model = CNNandKAN()
trained_model = L.Trainer(max_epochs=20, deterministic=True, logger=CSVLogger("logs", name="CIFAR10CNNKANNTK"))
trained_model.fit(model, train_loader)
trained_model.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | cnn  | Sequential | 18.9 K | train
1 | kan  | Sequential | 9.7 M  | train
--------------------------------------------
9.7 M     Trainable params
0         Non-trainable params
9.7 M     Total params
38.734    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

[{'test loss (MSE Loss)': 0.04555017128586769, 'accuracy': 0.7488999962806702}]

In [5]:
# Apply NTK
%run -i './introduction_code_modified.py'

def cross_entropy_loss_batch(y_hat, y):
    return F.cross_entropy(y_hat, y, reduction='none')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ntk_model = GaussianFit(model=model, device=device, noise_var=0.0)
ntk_model.fit(ntk_loader, optimizer, MSELoss_batch) # MSELoss seems to work better for NTK Models

In [6]:
def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        sumlength += len(x)
        res += (torch.argmax(model.forward(x), dim=1) == torch.argmax(y, dim=1)).sum()
    model.train()
    return res / sumlength

In [7]:
print(f'NTK Accuracy: {check_ntk_acc(ntk_model, test_loader)}')

NTK Accuracy: 0.7484999895095825


In [8]:
# R: True Value, C: Predicted Value
def make_predict_matrix(model, dataloader):
    res = np.zeros(shape=(10, 10), dtype=int)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        x_arg = torch.argmax(model.forward(x), dim=1)
        y_arg = torch.argmax(y, dim=1)
        for i in range(len(x_arg)):
            res[y_arg[i], x_arg[i]] += 1
    return res

In [9]:
print(make_predict_matrix(ntk_model, test_loader))

[[802  16  39  25  18   9   5   6  57  23]
 [ 14 845   8  11   5   5   8   4  20  80]
 [ 52   3 635  66  88  59  44  30  12  11]
 [ 28   8  62 572  53 162  58  27  18  12]
 [ 22   4  69  72 684  29  41  59  15   5]
 [ 12   1  48 159  39 664  14  49   7   7]
 [  9   2  38  47  41  29 819   7   8   0]
 [ 19   5  19  42  50  54   8 785   3  15]
 [ 52  22  14  13   6   7   3   6 847  30]
 [ 31  48   6  22   5   7   3   9  26 843]]


In [10]:
print(make_predict_matrix(model, test_loader))

[[808  16  40  26  16   9   5   6  50  24]
 [ 15 841   8  12   4   5   8   4  19  84]
 [ 51   2 643  68  80  59  41  32  13  11]
 [ 28   9  64 578  52 161  54  26  16  12]
 [ 24   4  74  75 676  27  41  59  15   5]
 [ 11   1  50 160  34 664  15  52   6   7]
 [  9   2  40  55  38  28 812   8   8   0]
 [ 19   4  21  41  45  59   8 787   3  13]
 [ 59  23  15  13   6   7   3   6 840  28]
 [ 33  47   6  23   5   7   3   9  26 841]]
