# KAN + MNIST
Purpose: Fit a KAN to the MNIST dataset, for benchmarking the KAN performance.

Furthermore, the PyTorch Lightning library is used for convenience.

In [1]:
# Imports
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from lightning.pytorch.loggers import CSVLogger
from kan import *

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(42, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

Seed set to 42


In [2]:
# Dataset Setup
train_dataset = MNIST("./temp/", train=True, download=True)
test_dataset = MNIST("./temp/", train=False, download=True)

class LCDataset(Dataset): # Lightning Compatible Dataset
    def __init__(self, dataset):
        self.data = dataset.data.view(-1, 28*28).type(torch.float32)
        self.target = torch.tensor((pd.get_dummies(pd.Series(dataset.targets.numpy())).map(lambda x: 1 if x == True else 0)).values).type(torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

train_loader = DataLoader(LCDataset(train_dataset), batch_size=64, shuffle=True, num_workers=10)
test_loader = DataLoader(LCDataset(test_dataset), batch_size=64, num_workers=10)

In [3]:
# Model Declaration
class ClassicKAN(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = KANLayer(in_dim=28*28, out_dim=240, num=2, k=2)
        self.l2 = KANLayer(in_dim=240, out_dim=60, num=2, k=2)
        self.l3 = KANLayer(in_dim=60, out_dim=10, num=2, k=2)
    
    def forward(self, x):
        x1, _, _, _ = self.l1(x)
        x2, _, _, _ = self.l2(x1)
        x3, _, _, _ = self.l3(x2)
        return x3
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test loss (cross entropy)", loss)
        self.log("accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [4]:
# Train + Test + Results
model = ClassicKAN()
trained_model = L.Trainer(max_epochs=100, deterministic=True, logger=CSVLogger("logs", name="MNISTKAN"))
trained_model.fit(model, train_loader)
trained_model.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type     | Params | Mode 
------------------------------------------
0 | l1   | KANLayer | 1.3 M  | train
1 | l2   | KANLayer | 102 K  | train
2 | l3   | KANLayer | 4.6 K  | train
------------------------------------------
1.2 M     Trainable params
210 K     Non-trainable params
1.4 M     Total params
5.719     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode


Training: |                                               | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

[{'test loss (cross entropy)': 0.15136770904064178,
  'accuracy': 0.972000002861023}]