In [26]:
# Inspired by martin's implementation: https://github.com/iarata/02456-kan-ntk-project/blob/Phase1Experiments/Experiments/MNISTKANNTK.ipynb
import numpy as np
import pandas as pd
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from lightning.pytorch.loggers import WandbLogger
from introduction_code import GaussianFit, MSELoss_batch
from typing import Union

import sys
sys.path.append('./Convolutional-KANs/kan_convolutional')
from kan import KAN

# Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Setup Randomness -- https://lightning.ai/docs/pytorch/stable/common/trainer.html
L.seed_everything(628, workers=True)

# CUDA Efficiency
torch.set_float32_matmul_precision('high')

# Logging
import wandb
wandb.login(key='2a6309b336d51d2918fc4fb2d51ffef9505c370a')

Device: cpu


Seed set to 628
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\hugom\_netrc


True

In [27]:
# Dataset Setup -- Inspired by Hugo's Dataset Reformatting
# Reformatted, due to odd issues when using NTK on it
class LCDataset(Dataset):
    def __init__(self, dataset, num_classes, limit=-1):
        self.limit = limit
        self.num_classes = num_classes
        if self.limit != -1:
            sub = list(np.random.permutation(np.arange(len(dataset)))[0:self.limit]) # Take a random sample of the first some elements.
            self.dataset = Subset(dataset, sub)
        else:
            self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        y_one_hot = torch.zeros(self.num_classes)
        y_one_hot[y] = 1
        return x, y_one_hot


# Split train_dataset into train and validation sets
from torch.utils.data import random_split
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = MNIST("./temp/", train=True, download=True, transform=transform)
test_dataset = MNIST("./temp/", train=False, download=True, transform=transform)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Add validation dataloader
num_workers = 0
def get_dataloader(split: Union["train", "val", "test", "ntk"]="train", batch_size=64, limit=-1):
    data_loader = None
    if split == "ntk":
        data_loader = DataLoader(
            LCDataset(train_dataset, num_classes=10, limit=500), 
            batch_size=batch_size, num_workers=num_workers
        )
    elif split == "train":
        data_loader = DataLoader(
            LCDataset(train_dataset, num_classes=10, limit=limit), 
            batch_size=batch_size, shuffle=True, num_workers=num_workers
        )
    elif split == "val":
        data_loader = DataLoader(
            LCDataset(val_dataset, num_classes=10, limit=limit), 
            batch_size=batch_size, num_workers=num_workers
        )
    elif split == "test":
        data_loader = DataLoader(
            LCDataset(test_dataset, num_classes=10, limit=500), 
            batch_size=batch_size, num_workers=num_workers
        )
    return data_loader


In [28]:
# Model time
class ClassicKAN(L.LightningModule):
    def __init__(self, num_hidden_layers=1, hidden_dim=64, inp_size=28*28, out_size=10, grid_size=2, spline_order=2):
        super().__init__()
        self.inp_size = inp_size
        layers_hidden = [inp_size] + [hidden_dim] * num_hidden_layers + [out_size]
        self.net = KAN(width=layers_hidden, grid=3, k=3)
    
    def forward(self, x):
        x = x.view(-1, self.inp_size)
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)  # MSE Loss works better for NTK
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test_loss", loss)
        self.log("test_accuracy", accuracy)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [29]:
def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    for it in iter(dataloader):
        x, y = it
        x = x.to(device)
        y = y.to(device)
        sumlength += len(x)
        res += (torch.argmax(model.forward(x), dim=1) == torch.argmax(y, dim=1)).sum()
    model.train()
    return res / sumlength

In [30]:
import torch
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from matplotlib import pyplot as plt

def check_ntk_acc(model, dataloader):
    res = 0.0
    sumlength = 0
    model.eval()
    model.to(device)
    with torch.no_grad():
        for it in iter(dataloader):
            x, y = it
            x = x.to(device)
            y = y.to(device)
            sumlength += len(x)
            res += (torch.argmax(model(x), dim=1) == torch.argmax(y, dim=1)).sum().item()
    model.train()
    return res / sumlength

# Define configuration
config = {
    "epochs": 8,
    "learning_rate": 0.001,
    "batch_size": 64,
    "num_hidden": 2,
    "hidden_dim": 64,
    "limit": 128,
    "grid_size": 4,
    "spline_order": 3
}

# Load data
train_loader = get_dataloader(split="train", batch_size=config["batch_size"], limit=-1)
val_loader = get_dataloader(split="val", batch_size=config["batch_size"], limit=-1)
test_loader = get_dataloader(split="test", batch_size=config["batch_size"])
ntk_loader = get_dataloader(split="ntk", batch_size=config["batch_size"], limit=config["limit"])

# Initialize model
model = ClassicKAN(
    num_hidden_layers=config["num_hidden"],
    hidden_dim=config["hidden_dim"],
    grid_size=config["grid_size"],
    spline_order=config["spline_order"]
).to(device)

# Train and test model
trainer = L.Trainer(
    max_epochs=config["epochs"],
    deterministic=True,
    val_check_interval=1  # Validate every 1 epoch
)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

# Compute NTK accuracy
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
ntk_model = GaussianFit(model=model, device=device, noise_var=0.0)
ntk_model.fit(ntk_loader, optimizer, MSELoss())
ntk_acc = check_ntk_acc(model, test_loader)
print(f"NTK Accuracy: {ntk_acc}")

# Log model predictions
example_batch = next(iter(test_loader))
example_images, example_labels = example_batch
example_images, example_labels = example_images.to(device), example_labels.to(device)

predictions_KAN = model(example_images).argmax(dim=1)
predictions_NTKAN = ntk_model(example_images).argmax(dim=1)

# Display some test examples
for i in range(min(len(example_images), 10)):
    plt.imshow(example_images[i].cpu().numpy().squeeze(), cmap='gray')
    plt.title(f"Pred: {predictions_KAN[i]}, Label: {example_labels[i]}")
    plt.show()

    plt.imshow(example_images[i].cpu().numpy().squeeze(), cmap='gray')
    plt.title(f"Pred NTKAN: {predictions_NTKAN[i]}, Label: {example_labels[i]}")
    plt.show()

checkpoint directory created: ./model
saving model version 0.0


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.

  | Name | Type    | Params | Mode 
-----------------------------------------
0 | net  | MultKAN | 778 K  | train
-----------------------------------------
658 K     Trainable params
119 K     Non-trainable params
778 K     Total params
3.114     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\hugom\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
c:\Users\hugom\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [31]:
inp = [f"({i},{j})" for i in range(28) for j in range(28)]
out = list(range(10))
model.net.plot(beta=100, scale=1, in_vars=inp, out_vars=out)

IndexError: list index out of range

In [None]:
model.net.plot()

IndexError: list index out of range

In [None]:
import numpy as np
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision import transforms
from typing import Union
import matplotlib.pyplot as plt

# Dataset Setup
class LCDataset(Dataset):
    def __init__(self, dataset, num_classes, limit=-1):
        self.limit = limit
        self.num_classes = num_classes
        if self.limit != -1:
            sub = list(np.random.permutation(np.arange(len(dataset)))[0:self.limit])
            self.dataset = Subset(dataset, sub)
        else:
            self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        y_one_hot = torch.zeros(self.num_classes)
        y_one_hot[y] = 1
        return x, y_one_hot

class ClassicKAN(L.LightningModule):
    def __init__(self, num_hidden_layers=1, hidden_dim=64, inp_size=28*28, out_size=10, grid_size=2, spline_order=2):
        super().__init__()
        self.inp_size = inp_size
        self.out_size = out_size
        layers_hidden = [inp_size] + [hidden_dim] * num_hidden_layers + [out_size]
        self.net = KAN(width=layers_hidden, grid=grid_size, k=spline_order)
        
        # Add visualization helpers
        self.example_inputs = None
        self.example_outputs = None
    
    def forward(self, x):
        x = x.view(-1, self.inp_size)
        return self.net(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        
        # Store example inputs/outputs for visualization
        if batch_idx == 0:
            self.example_inputs = x[:5].detach().cpu()
            self.example_outputs = y[:5].detach().cpu()
            
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", accuracy, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_accuracy", accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        v1 = torch.argmax(y_pred, dim=1)
        v2 = torch.argmax(y, dim=1)
        accuracy = torch.sum(torch.eq(v1, v2)) / len(y)
        self.log("test_loss", loss)
        self.log("test_accuracy", accuracy)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)
    
    def visualize_prediction(self, x, y_true=None, y_pred=None):
        """Visualize a single prediction"""
        plt.figure(figsize=(8, 4))
        
        plt.subplot(1, 2 if y_pred is not None else 1, 1)
        plt.imshow(x.view(28, 28).cpu().numpy(), cmap='gray')
        if y_true is not None:
            plt.title(f'True: {torch.argmax(y_true).item()}')
        
        if y_pred is not None:
            plt.subplot(1, 2, 2)
            plt.bar(range(self.out_size), y_pred.cpu().numpy())
            plt.title(f'Prediction: {torch.argmax(y_pred).item()}')
        
        plt.tight_layout()
        plt.show()

def train_and_evaluate(config):
    # Set seeds for reproducibility
    L.seed_everything(628, workers=True)
    torch.set_float32_matmul_precision('high')
    
    # Prepare data
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = MNIST("./temp/", train=True, download=True, transform=transform)
    test_dataset = MNIST("./temp/", train=False, download=True, transform=transform)
    
    # Split train into train/val
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    
    # Create dataloaders
    train_loader = DataLoader(
        LCDataset(train_dataset, num_classes=10, limit=config.get("limit", -1)),
        batch_size=config["batch_size"],
        shuffle=True
    )
    
    val_loader = DataLoader(
        LCDataset(val_dataset, num_classes=10),
        batch_size=config["batch_size"]
    )
    
    test_loader = DataLoader(
        LCDataset(test_dataset, num_classes=10),
        batch_size=config["batch_size"]
    )
    
    # Initialize model
    model = ClassicKAN(
        num_hidden_layers=config["num_hidden"],
        hidden_dim=config["hidden_dim"],
        grid_size=config["grid_size"],
        spline_order=config["spline_order"]
    )
    
    # Train model
    trainer = L.Trainer(
        max_epochs=config["epochs"],
        deterministic=True,
        val_check_interval=1,
        enable_progress_bar=True
    )
    
    trainer.fit(model, train_loader, val_loader)
    test_results = trainer.test(model, test_loader)
    
    return model, test_results

if __name__ == "__main__":
    config = {
        "epochs": 8,
        "batch_size": 64,
        "num_hidden": 2,
        "hidden_dim": 64,
        "limit": 128,
        "grid_size": 4,
        "spline_order": 3
    }
    
    model, results = train_and_evaluate(config)
    print(f"Test Results: {results}")
    
    # Visualize some predictions
    model.eval()
    with torch.no_grad():
        test_loader = DataLoader(
            LCDataset(MNIST("./temp/", train=False), num_classes=10),
            batch_size=5
        )
        batch = next(iter(test_loader))
        x, y = batch
        y_pred = model(x)
        
        for i in range(5):
            model.visualize_prediction(x[i], y[i], y_pred[i])