# CNN + KAN + MNIST + NTK + Few Shot
Purpose: Fit a CNN + KAN, but with Few Shot to the CIFAR10 dataset, for benchmarking the KAN performance.

Furthermore, the PyTorch Lightning library is used for convenience.

In [1]:
# Imports
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from lightning.pytorch.loggers import CSVLogger
from random import shuffle
import gc

import sys
sys.path.append('../convkans/kan_convolutional')
from KANLinear import *

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(42, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

Seed set to 42


In [2]:
# Dataset Setup -- Inspired by Hugo's Dataset Reformatting
batch_size = 64
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = CIFAR10("./temp/", train=True, download=True, transform=transform)
test_dataset = CIFAR10("./temp/", train=False, download=True, transform=transform)

# Step 1 -- Select which 5 classes to train on at random, and which 5 to test on.
class_train = np.random.choice(np.arange(0, 10), 5, replace=False)
class_test = np.setdiff1d(np.arange(0, 10), class_train)

# Step 2 -- Convert the train and test sets accordingly (individual lists for each class).
temp_train = []
for z in range(10):
    temp_train.append([])
    for x, y in train_dataset:
        if y == z:
            temp_train[z].append((x, y))
    shuffle(temp_train[z])

temp_test = []
for z in range(10):
    temp_test.append([])
    for x, y in test_dataset:
        if y == z:
            temp_test[z].append((x, y))
    shuffle(temp_test[z])

# Step 3 -- Convert them to similarity comparison format.

## Training + Validation Setup

train_limit = 500000 # Limit in total
validation_limit = 20000 # Limit in total

def create_training_and_validation(): # Created function to force garbage collection
    total_dataset2_one = []
    for z in class_train:
        for i in range(len(temp_train[z])):
            for j in range(i + 1, len(temp_train[z])):
                x1, y1 = temp_train[z][i]
                x2, y2 = temp_train[z][j]
                total_dataset2_one.append((x1, x2, np.float32(1)))
    
    max_val = ((len(temp_train[0])**2 - len(temp_train[0]))//2)*(len(class_train)**2 - len(class_train))
    mod_val = max_val // len(total_dataset2_one)
    total_dataset2_zero = []
    count = 0
    for z1 in class_train:
        for z2 in class_train:
            if z1 != z2:
                for i in range(len(temp_train[z1])):
                    for j in range(i + 1, len(temp_train[z2])):
                        x1, y1 = temp_train[z1][i]
                        x2, y2 = temp_train[z2][j]
                        if len(total_dataset2_zero) < len(total_dataset2_one) and count % mod_val == 0:
                            total_dataset2_zero.append((x1, x2, np.float32(0)))
                        count += 1
    
    shuffle(total_dataset2_one)
    shuffle(total_dataset2_zero)
    
    print(f'One List Size: {len(total_dataset2_one)}, Zero List Size: {len(total_dataset2_zero)}')
    
    train_dataset2 = []
    validation_dataset2 = []
    for i in range(train_limit//2):
        train_dataset2.append(total_dataset2_one[i])
        train_dataset2.append(total_dataset2_zero[i])
    
    for i in range(train_limit//2, train_limit//2 + validation_limit//2):
        validation_dataset2.append(total_dataset2_one[i])
        validation_dataset2.append(total_dataset2_zero[i])

    while len(total_dataset2_one) != 0:
        total_dataset2_one.pop()

    while len(total_dataset2_zero) != 0:
        total_dataset2_zero.pop()

    del total_dataset2_one
    del total_dataset2_zero

    return train_dataset2, validation_dataset2

train_dataset2, validation_dataset2 = create_training_and_validation()

gc.collect()

## Testing Setup
one_shot_test = []

for z in class_test:
    one_shot_test.append(temp_test[z][len(temp_test[z]) - 1])

one_shot_query = []
for z in class_test:
    for x, y in temp_test[z]:
        one_shot_query.append((x, y))
    one_shot_query.pop() # Exclude the last value, since it is used in the one_shot_test set.

# Step 4: Convert to Lightning Compatible Datasets
class LCTrainDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x1, x2, y = self.dataset[idx]
        return x1, x2, y

class LCTestDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        return x, y

train_loader = DataLoader(LCTrainDataset(train_dataset2), batch_size=batch_size, shuffle=True, num_workers=10)
validation_loader = DataLoader(LCTrainDataset(validation_dataset2), batch_size=batch_size, shuffle=False, num_workers=10)
test_loader = DataLoader(LCTestDataset(one_shot_query), batch_size=batch_size, num_workers=10)

Files already downloaded and verified
Files already downloaded and verified
One List Size: 62487500, Zero List Size: 62487500


In [9]:
# Model Declaration -- Siamese Network, implementation inspired by the video by Shusen Wang: https://www.youtube.com/watch?v=4S-XDefSjTM
class Embedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=2, padding=0, bias=False),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=0, bias=False),
            nn.MaxPool2d(4),
            nn.Dropout(p=0.5),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(1)
        )

    def forward(self, x):
        return self.cnn(x)

class DenseLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.kan = nn.Sequential(
            nn.Dropout(p=0.5),
            KAN(layers_hidden=[3136, 128], grid_size=2, spline_order=2),
            nn.Dropout(p=0.5),
            KAN(layers_hidden=[128, 1], grid_size=2, spline_order=2)
        )

    def forward(self, x):
        return self.kan(x)

class SiameseNet(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.f = Embedding()
        self.dl = DenseLayer()
        self.out = nn.Sigmoid()
    
    def forward(self, x1, x2):
        x1e = self.f(x1)
        x2e = self.f(x2)
        z = torch.abs(x1e - x2e)
        return torch.squeeze(self.out(self.dl(z)))
    
    def training_step(self, batch, batch_idx):
        x1, x2, y = batch
        y_pred = self(x1, x2)
        loss = F.mse_loss(y_pred, y) # MSE Loss works better for NTK
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x1, x2, y = batch
        y_pred = self(x1, x2)
        loss = F.mse_loss(y_pred, y) # MSE Loss works better for NTK
        self.log("validation_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x1, y = batch
        list_y_pred = np.array([self(x1, torch.unsqueeze(x2, 0).expand(len(x1), -1, -1, -1).to(device)).cpu() for x2, y2 in one_shot_test])
        list_y_pred = torch.Tensor(list_y_pred)
        list_y_pred = torch.argmax(list_y_pred, dim=0)
        list_y_pred = torch.squeeze(list_y_pred)
        list_y_predicted = torch.Tensor(np.array([one_shot_test[x][1] for x in list_y_pred])).to(device)
        accuracy = torch.sum(torch.eq(list_y_predicted, y)) / len(y)
        self.log("accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=5e-4)
        return optimizer

In [10]:
# Train + Test + Results
model = SiameseNet()
trained_model = L.Trainer(max_epochs=10, deterministic=True, logger=CSVLogger("logs", name="CIFAR10FewCNNKANNTK"))
trained_model.fit(model, train_loader, validation_loader)
trained_model.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | f    | Embedding  | 18.9 K | train
1 | dl   | DenseLayer | 2.4 M  | train
2 | out  | Sigmoid    | 0      | train
--------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.713     Total estimated model params size (MB)
24        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                  | 0/? [00:00<?, ?it/s]

Training: |                                         | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

Validation: |                                       | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.20380380749702454}]

In [None]:
# Apply NTK
%run -i './introduction_code_modified.py'

def cross_entropy_loss_batch(y_hat, y):
    return F.cross_entropy(y_hat, y, reduction='none')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ntk_model = GaussianFit(model=model, device=device, noise_var=0.0)
ntk_model.fit(ntk_loader, optimizer, MSELoss_batch) # MSELoss seems to work better for NTK Models

In [None]:
def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        sumlength += len(x)
        res += (torch.argmax(model.forward(x), dim=1) == torch.argmax(y, dim=1)).sum()
    model.train()
    return res / sumlength

In [None]:
print(f'NTK Accuracy: {check_ntk_acc(ntk_model, test_loader)}')

In [None]:
# R: True Value, C: Predicted Value
def make_predict_matrix(model, dataloader):
    res = np.zeros(shape=(10, 10), dtype=int)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        x_arg = torch.argmax(model.forward(x), dim=1)
        y_arg = torch.argmax(y, dim=1)
        for i in range(len(x_arg)):
            res[y_arg[i], x_arg[i]] += 1
    return res

In [None]:
print(make_predict_matrix(ntk_model, test_loader))

In [None]:
print(make_predict_matrix(model, test_loader))