In [1]:
# import os
# os.chdir('/root')

In [2]:
# !git clone https://github.com/AntonioTepsich/Convolutional-KANs
# !pip install tqdm pyprof
# !mv 'Convolutional-KANs' Convolutional_KANs
# !cd Convolutional_KANs && mv kan_convolutional .. 
!ls

Convolutional_KANs  __pycache__  efficientKAN.py  kan_convolutional
Few-KAN		    data	 env		  wandb


In [3]:
import os
# export CUDA_VISIBLE_DEVICES=0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import wandb
import torch
import numpy as np
import torchmetrics
import seaborn as sns
import torch.nn as nn
import lightning as L
import torchvision as tv
import torch.functional as F
import matplotlib.pyplot as plt


from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from typing import Optional, Tuple, List, Dict
from lightning.pytorch.loggers import WandbLogger
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision.transforms.functional import InterpolationMode

from efficientKAN import KAN as EfficientKAN
from kan_convolutional.KANConv import KAN_Convolutional_Layer as KANConv


torch.set_float32_matmul_precision('medium')

seed_val = 42
L.seed_everything(seed_val)

Seed set to 42


42

In [4]:
class KANDataModule(L.LightningDataModule):
    """
    A PyTorch Lightning DataModule for handling CIFAR10 and MNIST datasets.
    Provides unified interface and preprocessing for both datasets.
    """
    
    def __init__(
        self,
        data_dir: str = "data",
        dataset_name: str = "cifar10",  # "cifar10" or "mnist"
        batch_size: int = 32,
        num_workers: int = 4,
        val_split: float = 0.2,
        random_seed: int = 42,
        img_size: int = 32,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.dataset_name = dataset_name.lower()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.random_seed = random_seed
        self.img_size = img_size
        
        if self.dataset_name not in ["cifar10", "mnist"]:
            raise ValueError("dataset_name must be either 'cifar10' or 'mnist'")
        
        self.num_classes = 10
        self.channels = 3 if self.dataset_name == "cifar10" else 1
        
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose]:
        """
        Returns train and test transforms for the selected dataset.
        Train transforms include augmentations, test transforms only include normalization.
        """
        if self.dataset_name == "cifar10":
            # CIFAR10 normalization values
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
            
            train_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.RandomCrop(self.img_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
            
            test_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        
        else:
            # MNIST normalization values
            mean = [0.1307]
            std = [0.3081]
            
            train_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
            
            test_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        
        return train_transforms, test_transforms

    def prepare_data(self):
        """
        Downloads the dataset if not already present.
        """
        if self.dataset_name == "cifar10":
            datasets.CIFAR10(self.data_dir, train=True, download=True)
            datasets.CIFAR10(self.data_dir, train=False, download=True)
        else:
            datasets.MNIST(self.data_dir, train=True, download=True)
            datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        """
        Sets up train, validation, and test datasets.
        """
        train_transforms, test_transforms = self._get_transforms()
        
        if stage == "fit" or stage is None:
            if self.dataset_name == "cifar10":
                full_dataset = datasets.CIFAR10(
                    self.data_dir, train=True, transform=train_transforms
                )
            else:
                full_dataset = datasets.MNIST(
                    self.data_dir, train=True, transform=train_transforms
                )
            
            val_length = int(len(full_dataset) * self.val_split)
            train_length = len(full_dataset) - val_length
            
            self.train_dataset, self.val_dataset = random_split(
                full_dataset,
                [train_length, val_length],
                generator=torch.Generator().manual_seed(self.random_seed)
            )
            
            self.val_dataset.dataset.transform = test_transforms
        
        if stage == "test" or stage is None:
            if self.dataset_name == "cifar10":
                self.test_dataset = datasets.CIFAR10(
                    self.data_dir, train=False, transform=test_transforms
                )
            else:
                self.test_dataset = datasets.MNIST(
                    self.data_dir, train=False, transform=test_transforms
                )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

In [5]:
class SiamesePair(Dataset):
    """Dataset wrapper that creates Siamese pairs from a base dataset"""
    def __init__(self, dataset, labels_to_indices: Dict[int, List[int]], same_pair_ratio: float = 0.5):
        self.dataset = dataset
        self.labels_to_indices = labels_to_indices
        self.same_pair_ratio = same_pair_ratio
        self.classes = list(labels_to_indices.keys())
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img1, label1 = self.dataset[idx]
        
        should_get_same_class = np.random.random() < self.same_pair_ratio
        
        if should_get_same_class:
            label2 = label1
            idx2 = np.random.choice(self.labels_to_indices[label1])
            while idx2 == idx: 
                idx2 = np.random.choice(self.labels_to_indices[label1])
        else:
            label2 = np.random.choice([c for c in self.classes if c != label1])
            idx2 = np.random.choice(self.labels_to_indices[label2])
            
        img2, _ = self.dataset[idx2]
        target = torch.tensor(1.0 if label1 == label2 else 0.0, dtype=torch.float32)
        
        return (img1, img2), target


class FewShotDataModule(L.LightningDataModule):
    """
    PyTorch Lightning DataModule for few-shot learning on CIFAR10 and MNIST datasets.
    Provides support for N-shot K-way classification tasks.
    """
    
    def __init__(
        self,
        data_dir: str = "data",
        dataset_name: str = "cifar10",  # "cifar10" or "mnist"
        shots_per_class: int = 5,  # N-shot
        ways: int = 5,  # K-way
        batch_size: int = 32,
        num_workers: int = 4,
        val_split: float = 0.2,
        random_seed: int = 42,
        img_size: int = 32,
        test_split_ratio: float = 0.2,  # Ratio of classes to reserve for testing
    ):
        super().__init__()
        self.data_dir = data_dir
        self.dataset_name = dataset_name.lower()
        self.shots_per_class = shots_per_class
        self.ways = ways
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.random_seed = random_seed
        self.img_size = img_size
        self.test_split_ratio = test_split_ratio
        
        if self.dataset_name not in ["cifar10", "mnist"]:
            raise ValueError("dataset_name must be either 'cifar10' or 'mnist'")
        
        self.num_classes = 10
        self.channels = 3 if self.dataset_name == "cifar10" else 1
        
        self.num_test_classes = int(self.num_classes * self.test_split_ratio)
        self.num_train_classes = self.num_classes - self.num_test_classes

    def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose]:
        """Returns train and test transforms for the selected dataset."""
        if self.dataset_name == "cifar10":
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
        else:
            mean = [0.1307]
            std = [0.3081]
            
        train_transforms = transforms.Compose([
            transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        test_transforms = transforms.Compose([
            transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        return train_transforms, test_transforms

    def prepare_data(self):
        """Downloads the dataset if not already present."""
        if self.dataset_name == "cifar10":
            datasets.CIFAR10(self.data_dir, train=True, download=True)
            datasets.CIFAR10(self.data_dir, train=False, download=True)
        else:
            datasets.MNIST(self.data_dir, train=True, download=True)
            datasets.MNIST(self.data_dir, train=False, download=True)

    def _create_few_shot_dataset(self, dataset, classes: List[int]) -> Tuple[Dataset, Dict[int, List[int]]]:
        """Creates a few-shot dataset by selecting N examples per class."""
        labels_to_indices = {label: [] for label in classes}
        selected_indices = []
        
        for idx, (_, label) in enumerate(dataset):
            if label in classes:
                labels_to_indices[label].append(idx)
        
        for label in classes:
            indices = np.array(labels_to_indices[label])
            selected = indices[np.random.choice(len(indices), 
                                             size=min(self.shots_per_class, len(indices)), 
                                             replace=False)]
            selected_indices.extend(selected.tolist())
            
            labels_to_indices[label] = selected.tolist()
        
        few_shot_dataset = Subset(dataset, selected_indices)
        return few_shot_dataset, labels_to_indices

    def setup(self, stage: Optional[str] = None):
        """Sets up train, validation, and test datasets."""
        train_transforms, test_transforms = self._get_transforms()
        
        if self.dataset_name == "cifar10":
            full_dataset = datasets.CIFAR10(self.data_dir, train=True, transform=train_transforms)
            test_dataset = datasets.CIFAR10(self.data_dir, train=False, transform=test_transforms)
        else:
            full_dataset = datasets.MNIST(self.data_dir, train=True, transform=train_transforms)
            test_dataset = datasets.MNIST(self.data_dir, train=False, transform=test_transforms)
        
        all_classes = np.arange(self.num_classes)
        np.random.shuffle(all_classes)
        self.train_classes = all_classes[:-self.num_test_classes].tolist()
        self.test_classes = all_classes[-self.num_test_classes:].tolist()
        
        if stage == "fit" or stage is None:
            train_dataset, train_labels_to_indices = self._create_few_shot_dataset(
                full_dataset, self.train_classes
            )
            
            train_size = int((1 - self.val_split) * len(train_dataset))
            val_size = len(train_dataset) - train_size
            
            self.train_dataset, self.val_dataset = torch.utils.data.random_split(
                train_dataset,
                [train_size, val_size],
                generator=torch.Generator().manual_seed(self.random_seed)
            )
            
            self.train_dataset = SiamesePair(self.train_dataset, train_labels_to_indices)
            self.val_dataset = SiamesePair(self.val_dataset, train_labels_to_indices)
        
        if stage == "test" or stage is None:
            self.test_dataset, test_labels_to_indices = self._create_few_shot_dataset(
                test_dataset, self.test_classes
            )
            self.test_dataset = SiamesePair(self.test_dataset, test_labels_to_indices)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

In [6]:
class KANModule(L.LightningModule):
    """
    PyTorch Lightning module for KAN Networks with integrated W&B logging
    """
    def __init__(
        self,
        model_name: str = "kan_basic",  # kan_basic, kan_plus, kan_deep
        num_classes: int = 10,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-5,
        channels: int = 3,  # 3 for CIFAR10, 1 for MNIST
        img_size: int = 32,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = self._create_model()
        
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        
        self.val_predictions = []
        self.val_targets = []

    def _create_model(self) -> nn.Module:
        """
        Create the specified KAN architecture
        """
        input_dim = self.hparams.channels * self.hparams.img_size * self.hparams.img_size
        
        if self.hparams.model_name == "kan_basic":
            return KANBasic(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.num_classes,
                dropout=self.hparams.dropout
            )
        elif self.hparams.model_name == "kan_with_CNN":
            return KANwithCNN(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.num_classes,
                num_layers=self.hparams.num_layers,
                dropout=self.hparams.dropout
            )
        elif self.hparams.model_name == "kkan":
            return KKan(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.num_classes,
                num_layers=self.hparams.num_layers,
                dropout=self.hparams.dropout
            )
        else:
            raise ValueError(f"Unknown model name: {self.hparams.model_name}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x = x.view(x.size(0), -1)  
        
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        self.train_acc(preds, y)
        
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc, prog_bar=True)
        
        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x = x.view(x.size(0), -1)  
        
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        self.val_predictions.extend(preds.cpu().numpy())
        self.val_targets.extend(y.cpu().numpy())
        
        self.val_acc(preds, y)
        
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_acc, prog_bar=True)
        
        return loss

    def on_validation_epoch_end(self):
        """
        Create and log visualizations to W&B at the end of validation
        """
        y_pred = np.array(self.val_predictions)
        y_true = np.array(self.val_targets)
        
        cm = confusion_matrix(y_true, y_pred)
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - Epoch {self.current_epoch}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        
        wandb.log({
            "confusion_matrix": wandb.Image(plt),
            "epoch": self.current_epoch
        })
        
        plt.figure(figsize=(10, 6))
        plt.hist(y_pred, bins=self.hparams.num_classes, alpha=0.5, label='Predictions')
        plt.hist(y_true, bins=self.hparams.num_classes, alpha=0.5, label='Ground Truth')
        plt.title(f'Prediction Distribution - Epoch {self.current_epoch}')
        plt.xlabel('Class')
        plt.ylabel('Count')
        plt.legend()
        
        wandb.log({
            "prediction_distribution": wandb.Image(plt),
            "epoch": self.current_epoch
        })
        
        self.val_predictions = []
        self.val_targets = []
        
        plt.close('all')

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x = x.view(x.size(0), -1)
        
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        self.test_acc(preds, y)
        
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc)
        
        return loss

# KAN Model Architectures
class KANBasic(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float = 0.1):
        super().__init__()
        self.model = nn.Sequential(
            EfficientKAN([input_dim, input_dim//2, input_dim//4, 64, num_classes]),
            nn.Dropout(dropout)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class KANwithCNN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, num_layers: int = 3, dropout: float = 0.1):
        super().__init__()
        self.img_size = int(np.sqrt(input_dim))
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
        
        self.conv_output_dim = 64 * (self.img_size // 4) * (self.img_size // 4)
        
        self.classifier = nn.Sequential(
            # nn.Linear(self.conv_output_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, num_classes)
            nn.Dropout(dropout),
            EfficientKAN([self.conv_output_dim, self.conv_output_dim//2, self.conv_output_dim//4, 64, num_classes]),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 1, self.img_size, self.img_size)
        x = self.features(x)
        x = self.classifier(x)
        return x

class KKan(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, num_layers: int = 3, dropout: float = 0.1):
        super().__init__()
        self.img_size = int(np.sqrt(input_dim)) 
        
        self.conv1 = KANConv(
            in_channels=1,
            out_channels=6,
            kernel_size=(3,3),
        )

        self.conv2 = KANConv(
            in_channels=6,
            out_channels=12,
            kernel_size=(3,3),
        )

        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.flat = nn.Flatten()
        
        conv_output_size = ((self.img_size - 2) // 2 - 2) // 2
        self.linear1 = nn.Linear(12 * conv_output_size * conv_output_size, num_classes)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.view(x.size(0), 1, self.img_size, self.img_size)
        
        x = self.conv1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.pool1(x)
        
        x = self.flat(x)
        x = self.dropout(x)
        x = self.linear1(x)
        return x
    #     self.img_size = int(np.sqrt(input_dim // 3))
        
    #     self.conv1 = KANConv(
    #         in_channels=3,  # Changed from 1 to 3 for RGB
    #         out_channels=5,
    #         kernel_size=(3,3),
    #     )

    #     self.conv2 = KANConv(
    #         in_channels=5,
    #         out_channels=5,
    #         kernel_size=(3,3),
    #     )

    #     self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
    #     self.flat = nn.Flatten()
        
    #     # Calculate the size after convolutions and pooling
    #     conv_output_size = ((self.img_size // 2) // 2)  # After two pooling layers
    #     self.linear1 = nn.Linear(180, num_classes)

    # def forward(self, x):
    #     # Reshape the flattened input back to image format
    #     x = x.view(x.size(0), 3, self.img_size, self.img_size)
        
    #     x = self.conv1(x)
    #     x = self.pool1(x)
        
    #     x = self.conv2(x)
    #     x = self.pool1(x)
        
    #     x = self.flat(x)
    #     x = self.linear1(x)
    #     return x

In [7]:
class SiameseKANModule(L.LightningModule):
    """
    PyTorch Lightning module for Siamese Networks with KAN architectures
    """
    def __init__(
        self,
        model_name: str = "kan_basic",  # kan_basic, kan_with_cnn, kkan
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-5,
        channels: int = 3,  # 3 for CIFAR10, 1 for MNIST
        img_size: int = 32,
        hidden_dim: int = 128,
        embedding_dim: int = 64,
        num_layers: int = 3,
        dropout: float = 0.5,
        margin: float = 1.0,  # Margin for contrastive loss
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Initialize Siamese network with specified backbone
        self.backbone = self._create_backbone()
        
        input_dim = self._get_backbone_output_dim()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embedding_dim)
        )
        
        self.train_acc = torchmetrics.Accuracy(task="binary")
        self.val_acc = torchmetrics.Accuracy(task="binary")
        self.test_acc = torchmetrics.Accuracy(task="binary")
        
        self.val_predictions = []
        self.val_targets = []
        
        self.margin = margin

    def _get_backbone_output_dim(self) -> int:
        """Calculate the output dimension of the backbone network"""
        x = torch.randn(1, self.hparams.channels, self.hparams.img_size, self.hparams.img_size)
        with torch.no_grad():
            out = self.backbone(x)
        return out.numel()

    def _create_backbone(self) -> nn.Module:
        """Create the specified KAN architecture as backbone"""
        input_dim = self.hparams.channels * self.hparams.img_size * self.hparams.img_size
        
        if self.hparams.model_name == "kan_basic":
            return KANBasic(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.embedding_dim,
                dropout=self.hparams.dropout
            )
        elif self.hparams.model_name == "kan_with_cnn":
            return KANwithCNN(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.embedding_dim,
                num_layers=self.hparams.num_layers,
                dropout=self.hparams.dropout
            )
        elif self.hparams.model_name == "kkan":
            return KKan(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.embedding_dim,
                num_layers=self.hparams.num_layers,
                dropout=self.hparams.dropout
            )
        else:
            raise ValueError(f"Unknown model name: {self.hparams.model_name}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass for a single image"""
        features = self.backbone(x)
        embeddings = self.projection(features.view(features.size(0), -1))
        return embeddings

    def _shared_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]) -> Dict:
        """Shared step for training, validation and testing"""
        (img1, img2), target = batch
        
        embed1 = self(img1)
        embed2 = self(img2)
        
        distance = F.pairwise_distance(embed1, embed2)
        
        loss = F.margin_ranking_loss(
            distance,
            target.float(),
            torch.ones_like(target.float()) * self.margin,
            margin=self.margin
        )
        
        pred = (distance < self.margin/2).float()
        
        return {
            'loss': loss,
            'distance': distance,
            'preds': pred,
            'targets': target.float()
        }

    def training_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor:
        outputs = self._shared_step(batch)
        
        self.train_acc(outputs['preds'], outputs['targets'])
        
        self.log('train_loss', outputs['loss'], prog_bar=True)
        self.log('train_acc', self.train_acc, prog_bar=True)
        
        return outputs['loss']

    def validation_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor:
        outputs = self._shared_step(batch)
        
        self.val_predictions.extend(outputs['preds'].cpu().numpy())
        self.val_targets.extend(outputs['targets'].cpu().numpy())
        
        self.val_acc(outputs['preds'], outputs['targets'])
        
        self.log('val_loss', outputs['loss'], prog_bar=True)
        self.log('val_acc', self.val_acc, prog_bar=True)
        
        return outputs['loss']

    def on_validation_epoch_end(self):
        """Create and log visualizations to W&B"""
        if len(self.val_predictions) == 0:
            return
            
        y_pred = np.array(self.val_predictions)
        y_true = np.array(self.val_targets)
        
        cm = np.zeros((2, 2))
        for i in range(len(y_pred)):
            cm[int(y_true[i]), int(y_pred[i])] += 1
        
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='g', cmap='Blues')
        plt.title(f'Confusion Matrix - Epoch {self.current_epoch}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        
        wandb.log({
            "confusion_matrix": wandb.Image(plt),
            "epoch": self.current_epoch
        })
        
        plt.figure(figsize=(10, 6))
        plt.hist(y_pred, bins=2, alpha=0.5, label='Predictions')
        plt.hist(y_true, bins=2, alpha=0.5, label='Ground Truth')
        plt.title(f'Prediction Distribution - Epoch {self.current_epoch}')
        plt.xlabel('Class')
        plt.ylabel('Count')
        plt.legend()
        
        wandb.log({
            "prediction_distribution": wandb.Image(plt),
            "epoch": self.current_epoch
        })
        
        self.val_predictions = []
        self.val_targets = []
        
        plt.close('all')

    def test_step(self, batch: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor:
        outputs = self._shared_step(batch)
        
        self.test_acc(outputs['preds'], outputs['targets'])
        
        self.log('test_loss', outputs['loss'])
        self.log('test_acc', self.test_acc)
        
        return outputs['loss']

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

In [None]:
names = ["mnist", "kkan", "kkan_mlp_linear", "v2"]
wandb_logger = WandbLogger(project="Few-KAN", name=".".join(names))

In [None]:
model = KANModule(model_name=names[1], num_classes=10, channels=1, img_size=32, hidden_dim=128)
data_module = KANDataModule(dataset_name=names[0])

trainer = L.Trainer(max_epochs=100, logger=wandb_logger)
trainer.fit(model, data_module)