In [1]:
# import os
# os.chdir('/root')

In [2]:
# !git clone https://github.com/AntonioTepsich/Convolutional-KANs
# !pip install tqdm pyprof
# !mv 'Convolutional-KANs' Convolutional_KANs
# !cd Convolutional_KANs && mv kan_convolutional .. 
!ls

Convolutional_KANs  __pycache__  efficientKAN.py  kan_convolutional
Few-KAN		    data	 env		  wandb


In [3]:
import os
# export CUDA_VISIBLE_DEVICES=0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["WANDB_NOTEBOOK_NAME"] = "siam-kan.ipynb"

import wandb
import torch
import numpy as np
import torchmetrics
import seaborn as sns
import torch.nn as nn
import lightning as L
import torchvision as tv
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from sklearn.manifold import TSNE
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
from typing import Optional, Tuple, List, Dict, defaultdict
from lightning.pytorch.loggers import WandbLogger
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import InterpolationMode
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from torch.utils.data import Dataset, DataLoader, random_split, Subset


from efficientKAN import KAN as EfficientKAN
from kan_convolutional.KANConv import KAN_Convolutional_Layer as KANConv


torch.set_float32_matmul_precision('medium')

seed_val = 42
L.seed_everything(seed_val)

Seed set to 42


42

In [4]:
class KANDataModule(L.LightningDataModule):
    """
    A PyTorch Lightning DataModule for handling CIFAR10 and MNIST datasets.
    Provides unified interface and preprocessing for both datasets.
    """
    
    def __init__(
        self,
        data_dir: str = "data",
        dataset_name: str = "cifar10", 
        batch_size: int = 32,
        num_workers: int = 4,
        val_split: float = 0.2,
        random_seed: int = 42,
        img_size: int = 32,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.dataset_name = dataset_name.lower()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.random_seed = random_seed
        self.img_size = img_size
        
        if self.dataset_name not in ["cifar10", "mnist"]:
            raise ValueError("dataset_name must be either 'cifar10' or 'mnist'")
        
        self.num_classes = 10
        self.channels = 3 if self.dataset_name == "cifar10" else 1
        
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def _get_transforms(self) -> Tuple[transforms.Compose, transforms.Compose]:
        """
        Returns train and test transforms for the selected dataset.
        Train transforms include augmentations, test transforms only include normalization.
        """
        if self.dataset_name == "cifar10":
            # CIFAR10 normalization values
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
            
            train_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.RandomCrop(self.img_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
            
            test_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        
        else:
            # MNIST normalization values
            mean = [0.1307]
            std = [0.3081]
            
            train_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
            
            test_transforms = transforms.Compose([
                transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        
        return train_transforms, test_transforms

    def prepare_data(self):
        """
        Downloads the dataset if not already present.
        """
        if self.dataset_name == "cifar10":
            datasets.CIFAR10(self.data_dir, train=True, download=True)
            datasets.CIFAR10(self.data_dir, train=False, download=True)
        else:
            datasets.MNIST(self.data_dir, train=True, download=True)
            datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        """
        Sets up train, validation, and test datasets.
        """
        train_transforms, test_transforms = self._get_transforms()
        
        if stage == "fit" or stage is None:
            if self.dataset_name == "cifar10":
                full_dataset = datasets.CIFAR10(
                    self.data_dir, train=True, transform=train_transforms
                )
            else:
                full_dataset = datasets.MNIST(
                    self.data_dir, train=True, transform=train_transforms
                )
            
            val_length = int(len(full_dataset) * self.val_split)
            train_length = len(full_dataset) - val_length
            
            self.train_dataset, self.val_dataset = random_split(
                full_dataset,
                [train_length, val_length],
                generator=torch.Generator().manual_seed(self.random_seed)
            )
            
            self.val_dataset.dataset.transform = test_transforms
        
        if stage == "test" or stage is None:
            if self.dataset_name == "cifar10":
                self.test_dataset = datasets.CIFAR10(
                    self.data_dir, train=False, transform=test_transforms
                )
            else:
                self.test_dataset = datasets.MNIST(
                    self.data_dir, train=False, transform=test_transforms
                )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

In [5]:
import random
class SiameseDataset(Dataset):
    """
    Creates pairs of images for Siamese network training.
    Generates both positive pairs (same class) and negative pairs (different classes).
    """
    def __init__(self, dataset, images_per_class: int = 10, n_pairs_per_class: int = 100):
        self.dataset = dataset
        self.images_per_class = images_per_class
        self.n_pairs_per_class = n_pairs_per_class
        
        all_class_indices: Dict[int, List[int]] = defaultdict(list)
        for idx, (_, label) in enumerate(dataset):
            all_class_indices[label].append(idx)
        
        self.class_indices: Dict[int, List[int]] = {}
        for label, indices in all_class_indices.items():
            if len(indices) >= images_per_class:
                self.class_indices[label] = random.sample(indices, images_per_class)
        
        self.pairs = self._generate_pairs()
    
    def _generate_pairs(self):
        pairs = []
        for label in self.class_indices:
            indices = self.class_indices[label]
            
            for i in range(len(indices)):
                for j in range(i + 1, len(indices)):
                    pairs.append((indices[i], indices[j], 1))
        
        n_neg_pairs = len(pairs) 
        for _ in range(n_neg_pairs):
            label1, label2 = random.sample(list(self.class_indices.keys()), 2)
            idx1 = random.choice(self.class_indices[label1])
            idx2 = random.choice(self.class_indices[label2])
            pairs.append((idx1, idx2, 0)) 
        
        random.shuffle(pairs)
        return pairs
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        idx1, idx2, label = self.pairs[idx]
        img1, class1 = self.dataset[idx1]
        img2, class2 = self.dataset[idx2]
        return img1, img2, torch.tensor(label, dtype=torch.float32), class1, class2

class SiameseDataModule(L.LightningDataModule):
    """
    Data module for Siamese network training with CIFAR10 or MNIST datasets.
    """
    def __init__(
        self,
        data_dir: str = "data",
        dataset_name: str = "cifar10",
        batch_size: int = 32,
        num_workers: int = 4,
        images_per_class: int = 10,
        n_pairs_per_class: int = 100,
        val_split: float = 0.2,
        img_size: int = 32,
        random_seed: int = 42,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.dataset_name = dataset_name.lower()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.images_per_class = images_per_class
        self.n_pairs_per_class = n_pairs_per_class
        self.val_split = val_split
        self.img_size = img_size
        self.random_seed = random_seed
        
        if self.dataset_name not in ["cifar10", "mnist"]:
            raise ValueError("dataset_name must be either 'cifar10' or 'mnist'")
        
        self.num_classes = 10
        self.channels = 3 if self.dataset_name == "cifar10" else 1
    
    def _get_transforms(self):
        if self.dataset_name == "cifar10":
            mean = [0.4914, 0.4822, 0.4465]
            std = [0.2470, 0.2435, 0.2616]
        else:
            mean = [0.1307]
            std = [0.3081]
        
        train_transforms = transforms.Compose([
            transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        eval_transforms = transforms.Compose([
            transforms.Resize(self.img_size, interpolation=InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        
        return train_transforms, eval_transforms

    def prepare_data(self):
        if self.dataset_name == "cifar10":
            datasets.CIFAR10(self.data_dir, train=True, download=True)
            datasets.CIFAR10(self.data_dir, train=False, download=True)
        else:
            datasets.MNIST(self.data_dir, train=True, download=True)
            datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        train_transforms, eval_transforms = self._get_transforms()
        
        if stage == "fit" or stage is None:
            if self.dataset_name == "cifar10":
                full_dataset = datasets.CIFAR10(
                    self.data_dir, train=True, transform=train_transforms
                )
            else:
                full_dataset = datasets.MNIST(
                    self.data_dir, train=True, transform=train_transforms
                )
            
            dataset_size = len(full_dataset)
            val_size = int(dataset_size * self.val_split)
            train_size = dataset_size - val_size
            
            train_subset, val_subset = torch.utils.data.random_split(
                full_dataset,
                [train_size, val_size],
                generator=torch.Generator().manual_seed(self.random_seed)
            )
            
            self.train_dataset = SiameseDataset(
                train_subset,
                images_per_class=self.images_per_class,
                n_pairs_per_class=self.n_pairs_per_class
            )
            
            self.val_dataset = SiameseDataset(
                val_subset,
                images_per_class=self.images_per_class,
                n_pairs_per_class=self.n_pairs_per_class // 2 
            )
        
        if stage == "test" or stage is None:
            if self.dataset_name == "cifar10":
                self.test_dataset = datasets.CIFAR10(
                    self.data_dir, train=False, transform=eval_transforms
                )
            else:
                self.test_dataset = datasets.MNIST(
                    self.data_dir, train=False, transform=eval_transforms
                )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

In [6]:
class KANModule(L.LightningModule):
    """
    PyTorch Lightning module for KAN Networks with integrated W&B logging
    """
    def __init__(
        self,
        model_name: str = "kan_basic",  # kan_basic, kan_plus, kan_deep
        num_classes: int = 10,
        learning_rate: float = 1e-3,
        weight_decay: float = 1e-5,
        channels: int = 3,  # 3 for CIFAR10, 1 for MNIST
        img_size: int = 32,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.5,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = self._create_model()
        
        self.criterion = nn.CrossEntropyLoss()
        
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        
        self.val_predictions = []
        self.val_targets = []

    def _create_model(self) -> nn.Module:
        """
        Create the specified KAN architecture
        """
        input_dim = self.hparams.channels * self.hparams.img_size * self.hparams.img_size
        
        if self.hparams.model_name == "kan_basic":
            return KANBasic(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.num_classes,
                dropout=self.hparams.dropout
            )
        elif self.hparams.model_name == "kan_with_CNN":
            return KANwithCNN(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.num_classes,
                num_layers=self.hparams.num_layers,
                dropout=self.hparams.dropout
            )
        elif self.hparams.model_name == "kkan":
            return KKan(
                input_dim=input_dim,
                hidden_dim=self.hparams.hidden_dim,
                num_classes=self.hparams.num_classes,
                num_layers=self.hparams.num_layers,
                dropout=self.hparams.dropout
            )
        else:
            raise ValueError(f"Unknown model name: {self.hparams.model_name}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x = x.view(x.size(0), -1)
        
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        self.train_acc(preds, y)
        
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc, prog_bar=True)
        
        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x = x.view(x.size(0), -1)
        
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        self.val_predictions.extend(preds.cpu().numpy())
        self.val_targets.extend(y.cpu().numpy())
        
        self.val_acc(preds, y)
        
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_acc, prog_bar=True)
        
        return loss

    def on_validation_epoch_end(self):
        """
        Create and log visualizations to W&B at the end of validation
        """
        y_pred = np.array(self.val_predictions)
        y_true = np.array(self.val_targets)
        
        cm = confusion_matrix(y_true, y_pred)
        
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title(f'Confusion Matrix - Epoch {self.current_epoch}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        
        wandb.log({
            "confusion_matrix": wandb.Image(plt),
            "epoch": self.current_epoch
        })
        
        plt.figure(figsize=(10, 6))
        plt.hist(y_pred, bins=self.hparams.num_classes, alpha=0.5, label='Predictions')
        plt.hist(y_true, bins=self.hparams.num_classes, alpha=0.5, label='Ground Truth')
        plt.title(f'Prediction Distribution - Epoch {self.current_epoch}')
        plt.xlabel('Class')
        plt.ylabel('Count')
        plt.legend()
        
        wandb.log({
            "prediction_distribution": wandb.Image(plt),
            "epoch": self.current_epoch
        })
        
        self.val_predictions = []
        self.val_targets = []
        
        plt.close('all')

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x = x.view(x.size(0), -1)
        
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        self.test_acc(preds, y)
        
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc)
        
        return loss

# KAN Model Architectures
class SampleCNN(nn.Module):
    def __init__(self, in_channels: int, hidden_dim: int, num_classes: int, dropout: float = 0.1):
        super().__init__()
        self.model = nn.Sequential(
            # simple CNN architecture
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class KANBasic(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float = 0.1):
        super().__init__()
        self.model = nn.Sequential(
            EfficientKAN([input_dim, input_dim//2, input_dim//4, 64, num_classes]),
            nn.Dropout(dropout)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class KANwithCNN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, num_layers: int = 3, dropout: float = 0.1):
        super().__init__()
        self.img_size = int(np.sqrt(input_dim))
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
        
        self.conv_output_dim = 64 * (self.img_size // 4) * (self.img_size // 4)
        
        self.classifier = nn.Sequential(
            # nn.Linear(self.conv_output_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, num_classes)
            nn.Dropout(dropout),
            EfficientKAN([self.conv_output_dim, self.conv_output_dim//2, self.conv_output_dim//4, 64, num_classes]),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 1, self.img_size, self.img_size)
        x = self.features(x)
        x = self.classifier(x)
        return x

class KKan(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, num_layers: int = 3, dropout: float = 0.1):
        super().__init__()
        self.img_size = int(np.sqrt(input_dim))  
        self.conv1 = KANConv(
            in_channels=1,  # Changed to 1 for grayscale
            out_channels=6,
            kernel_size=(3,3),
        )

        self.conv2 = KANConv(
            in_channels=6,
            out_channels=12,
            kernel_size=(3,3),
        )

        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.flat = nn.Flatten()
        
        conv_output_size = ((self.img_size - 2) // 2 - 2) // 2
        self.linear1 = nn.Linear(12 * conv_output_size * conv_output_size, num_classes)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.view(x.size(0), 1, self.img_size, self.img_size)
        
        x = self.conv1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.pool1(x)
        
        x = self.flat(x)
        x = self.dropout(x)
        x = self.linear1(x)
        return x
    #     self.img_size = int(np.sqrt(input_dim // 3))
        
    #     self.conv1 = KANConv(
    #         in_channels=3,  # Changed from 1 to 3 for RGB
    #         out_channels=5,
    #         kernel_size=(3,3),
    #     )

    #     self.conv2 = KANConv(
    #         in_channels=5,
    #         out_channels=5,
    #         kernel_size=(3,3),
    #     )

    #     self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
    #     self.flat = nn.Flatten()
        
    #     # Calculate the size after convolutions and pooling
    #     conv_output_size = ((self.img_size // 2) // 2)  # After two pooling layers
    #     self.linear1 = nn.Linear(180, num_classes)

    # def forward(self, x):
    #     # Reshape the flattened input back to image format
    #     x = x.view(x.size(0), 3, self.img_size, self.img_size)
        
    #     x = self.conv1(x)
    #     x = self.pool1(x)
        
    #     x = self.conv2(x)
    #     x = self.pool1(x)
        
    #     x = self.flat(x)
    #     x = self.linear1(x)
    #     return x

In [7]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        """
        Args:
            distance: euclidean distance between pairs
            label: 1 for similar pairs, 0 for dissimilar pairs
        """
        loss_contrastive = torch.mean(
            label * torch.pow(distance, 2) +  # pull similar pairs together
            (1 - label) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)  # push dissimilar pairs apart
        )
        return loss_contrastive

class SiameseNetwork(nn.Module):
    """
    Siamese network implementation with shared encoder.
    """
    def __init__(self, input_dim: int, hidden_dim: int, embedding_dim: int = 128, channels: int = 1):
        super().__init__()
        self.input_dim = input_dim
        self.channels = channels
        
        # Calculate the flattened input dimension
        self.flat_dim = input_dim * channels
        
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(self.flat_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, embedding_dim)
        )
        
    def forward_one(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        return self.encoder(x)
    
    def forward(self, x1, x2):
        # Get embeddings for both inputs
        out1 = self.forward_one(x1)
        out2 = self.forward_one(x2)
        return out1, out2

class SiameseModule(L.LightningModule):
    """
    Lightning module for training Siamese network with contrastive loss.
    """
    def __init__(
        self,
        img_size: int = 32,
        channels: int = 1,
        hidden_dim: int = 512,
        embedding_dim: int = 128,
        learning_rate: float = 1e-3,
        margin: float = 2.0
    ):
        super().__init__()
        self.save_hyperparameters()
        
        input_dim = img_size * img_size
        
        self.model = SiameseNetwork(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            embedding_dim=embedding_dim,
            channels=channels
        )
        
        # Contrastive loss
        self.criterion = ContrastiveLoss(margin=margin)
        
        # Metrics
        self.train_accuracy = torchmetrics.Accuracy(task='binary')
        self.val_accuracy = torchmetrics.Accuracy(task='binary')
        self.test_accuracy = torchmetrics.Accuracy(task='binary')
        self.test_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.test_data = None
        self.test_embeddings = None
        self.test_labels = None
        
        self.class_prototypes = {}
    
    def forward(self, x1, x2):
        return self.model(x1, x2)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
            verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }
    
    def compute_distance_and_prediction(self, emb1, emb2):
        """Compute euclidean distance and binary prediction"""
        distance = F.pairwise_distance(emb1, emb2)
        # Use a threshold on the distance for prediction
        # If distance is small, images are similar (label=1)
        predictions = (distance < self.hparams.margin/2).float()
        return distance, predictions
    
    def training_step(self, batch, batch_idx):
        img1, img2, label, _, _ = batch
        emb1, emb2 = self(img1, img2)
        distance, predictions = self.compute_distance_and_prediction(emb1, emb2)
        loss = self.criterion(distance, label)
        
        self.train_accuracy(predictions, label)
        
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_accuracy, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        img1, img2, label, _, _ = batch
        emb1, emb2 = self(img1, img2)
        distance, predictions = self.compute_distance_and_prediction(emb1, emb2)
        loss = self.criterion(distance, label)
        
        self.val_accuracy(predictions, label)
        
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        """
        Store batch data for processing in test_epoch_end.
        We need all test data to perform query-based evaluation.
        """
        images, labels = batch
        
        embeddings = self.model.forward_one(images)
        
        if self.test_data is None:
            self.test_data = images
            self.test_embeddings = embeddings
            self.test_labels = labels
        else:
            self.test_data = torch.cat([self.test_data, images])
            self.test_embeddings = torch.cat([self.test_embeddings, embeddings])
            self.test_labels = torch.cat([self.test_labels, labels])
    
    def on_test_epoch_start(self):
        """Reset stored test data at the start of test epoch"""
        self.test_data = None
        self.test_embeddings = None
        self.test_labels = None
        
    def on_test_epoch_end(self):
        """
        Perform query-based evaluation at the end of test epoch
        """
        class_indices = defaultdict(list)
        for idx, label in enumerate(self.test_labels):
            class_indices[label.item()].append(idx)
        
        all_predictions = []
        all_true_labels = []
        per_class_accuracy = defaultdict(list)
        
        # For each class, select a random query image and compare with all test images
        for query_class in class_indices.keys():
            query_idx = random.choice(class_indices[query_class])
            query_embedding = self.test_embeddings[query_idx]
            
            distances = F.pairwise_distance(
                query_embedding.unsqueeze(0).repeat(len(self.test_embeddings), 1),
                self.test_embeddings
            )
            
            class_distances = defaultdict(list)
            for idx, dist in enumerate(distances):
                class_distances[self.test_labels[idx].item()].append(dist.item())
            
            avg_class_distances = {
                cls: np.mean(dists) for cls, dists in class_distances.items()
            }
            
            predicted_class = min(avg_class_distances.items(), key=lambda x: x[1])[0]
            
            all_predictions.append(predicted_class)
            all_true_labels.append(query_class)
            
            is_correct = predicted_class == query_class
            per_class_accuracy[query_class].append(is_correct)
            
            self.log(f"test_query_class_{query_class}_predicted", predicted_class)
            self.log(f"test_query_class_{query_class}_distance", avg_class_distances[predicted_class])
        
        correct_predictions = sum(p == t for p, t in zip(all_predictions, all_true_labels))
        overall_accuracy = correct_predictions / len(all_predictions)
        
        class_accuracies = {
            cls: sum(results) / len(results) 
            for cls, results in per_class_accuracy.items()
        }
        
        self.log("test_accuracy", overall_accuracy)
        for cls, acc in class_accuracies.items():
            self.log(f"test_class_{cls}_accuracy", acc)
        
        cm = confusion_matrix(all_true_labels, all_predictions)
        
        print("\nTest Results:")
        print(f"Overall Accuracy: {overall_accuracy:.4f}")
        print("\nPer-class Accuracy:")
        for cls, acc in class_accuracies.items():
            print(f"Class {cls}: {acc:.4f}")
        
        print("\nConfusion Matrix:")
        print(cm)
        
        if hasattr(self.logger, 'experiment') and hasattr(self.logger.experiment, 'log'):
            try:
                fig, ax = plt.subplots(figsize=(10, 8))
                sns.heatmap(cm, annot=True, fmt='d', ax=ax)
                plt.title("Test Confusion Matrix")
                plt.xlabel("Predicted Class")
                plt.ylabel("True Class")
                
                self.logger.experiment.log({
                    "test_confusion_matrix": wandb.Image(fig),
                    "test_epoch": self.current_epoch
                })
                plt.close()
            except:
                pass
    
    def predict_class(self, query_image):
        """Predict class for a query image using class prototypes"""
        query_embedding = self.model.forward_one(query_image.unsqueeze(0))
        
        distances = {}
        for class_idx, prototype in self.class_prototypes.items():
            distance = F.pairwise_distance(
                query_embedding,
                prototype.unsqueeze(0)
            )
            distances[class_idx] = distance.item()
        
        return min(distances.items(), key=lambda x: x[1])[0]
    
def predict_query_similarity(self, query_image, reference_images):
    """
    Utility method to compute similarities between a query image and reference images
    """
    query_embedding = self.model.forward_one(query_image.unsqueeze(0))
    reference_embeddings = self.model.forward_one(reference_images)
    
    distances = F.pairwise_distance(
        query_embedding.repeat(len(reference_embeddings), 1),
        reference_embeddings
    )
    
    return distances

In [8]:
class SiameseKANNetwork(nn.Module):
    """
    Siamese network implementation that can use different KAN architectures as encoders.
    """
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        embedding_dim: int = 128,
        channels: int = 1,
        architecture: str = "kan_basic"  # kan_basic, kan_with_cnn, kkan
    ):
        super().__init__()
        self.input_dim = input_dim
        self.channels = channels
        self.architecture = architecture
        self.img_size = int(np.sqrt(input_dim))
        
        if architecture == "kan_basic":
            self.encoder = KANBasicEncoder(
                input_dim=input_dim * channels,
                hidden_dim=hidden_dim,
                embedding_dim=embedding_dim
            )
        elif architecture == "kan_with_cnn":
            self.encoder = KANwithCNNEncoder(
                img_size=self.img_size,
                channels=channels,
                hidden_dim=hidden_dim,
                embedding_dim=embedding_dim
            )
        elif architecture == "kkan":
            self.encoder = KKANEncoder(
                img_size=self.img_size,
                channels=channels,
                embedding_dim=embedding_dim
            )
        else:
            raise ValueError(f"Unknown architecture: {architecture}")
    
    def forward_one(self, x):
        return self.encoder(x)
    
    def forward(self, x1, x2):
        out1 = self.forward_one(x1)
        out2 = self.forward_one(x2)
        return out1, out2

class KANBasicEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, embedding_dim: int):
        super().__init__()
        self.model = nn.Sequential(
            EfficientKAN([input_dim, hidden_dim, hidden_dim//2, hidden_dim//4, embedding_dim]),
            nn.BatchNorm1d(embedding_dim),
            nn.Dropout(0.3)
        )
    
    def forward(self, x):
        return self.model(x.view(x.size(0), -1))

class KANwithCNNEncoder(nn.Module):
    def __init__(self, img_size: int, channels: int, hidden_dim: int, embedding_dim: int):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
        
        conv_output_dim = 64 * (img_size // 4) * (img_size // 4)
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            EfficientKAN([conv_output_dim, hidden_dim, hidden_dim//2, hidden_dim//4, embedding_dim]),
            nn.BatchNorm1d(embedding_dim)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class KKANEncoder(nn.Module):
    def __init__(self, img_size: int, channels: int, embedding_dim: int):
        super().__init__()
        
        self.conv1 = KANConv(
            in_channels=channels,
            out_channels=6,
            kernel_size=(3,3),
        )

        self.conv2 = KANConv(
            in_channels=6,
            out_channels=12,
            kernel_size=(3,3),
        )

        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.flat = nn.Flatten()
        
        conv_output_size = ((img_size - 2) // 2 - 2) // 2
        self.final = nn.Sequential(
            nn.Linear(12 * conv_output_size * conv_output_size, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.Dropout(0.3)
        )
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.flat(x)
        return self.final(x)

class SiameseKANModule(L.LightningModule):
    """
    Lightning module for Siamese network with KAN architectures
    """
    def __init__(
        self,
        img_size: int = 32,
        channels: int = 1,
        hidden_dim: int = 512,
        embedding_dim: int = 128,
        learning_rate: float = 1e-3,
        margin: float = 2.0,
        architecture: str = "kan_basic"
    ):
        super().__init__()
        self.save_hyperparameters()
        
        input_dim = img_size * img_size
        
        self.model = SiameseKANNetwork(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            embedding_dim=embedding_dim,
            channels=channels,
            architecture=architecture
        )
        
        self.criterion = ContrastiveLoss(margin=margin)
        self.train_accuracy = torchmetrics.Accuracy(task='binary')
        self.val_accuracy = torchmetrics.Accuracy(task='binary')
        self.test_accuracy = torchmetrics.Accuracy(task='binary')
        
        self.test_data = None
        self.test_embeddings = None
        self.test_labels = None
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=5,
            verbose=True
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }
    
    def forward(self, x1, x2):
        return self.model(x1, x2)
    
    def compute_distance_and_prediction(self, emb1, emb2):
        distance = F.pairwise_distance(emb1, emb2)
        predictions = (distance < self.hparams.margin/2).float()
        return distance, predictions
    
    def training_step(self, batch, batch_idx):
        img1, img2, label, _, _ = batch
        emb1, emb2 = self(img1, img2)
        distance, predictions = self.compute_distance_and_prediction(emb1, emb2)
        loss = self.criterion(distance, label)
        
        self.train_accuracy(predictions, label)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_accuracy, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        img1, img2, label, _, _ = batch
        emb1, emb2 = self(img1, img2)
        distance, predictions = self.compute_distance_and_prediction(emb1, emb2)
        loss = self.criterion(distance, label)
        
        self.val_accuracy(predictions, label)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)
        
        return loss
    
        
    def test_step(self, batch, batch_idx):
        """
        Store batch data for processing in test_epoch_end.
        We need all test data to perform query-based evaluation.
        """
        images, labels = batch
        
        embeddings = self.model.forward_one(images)
        
        if self.test_data is None:
            self.test_data = images
            self.test_embeddings = embeddings
            self.test_labels = labels
        else:
            self.test_data = torch.cat([self.test_data, images])
            self.test_embeddings = torch.cat([self.test_embeddings, embeddings])
            self.test_labels = torch.cat([self.test_labels, labels])
    
    def on_test_epoch_start(self):
        """Reset stored test data at the start of test epoch"""
        self.test_data = None
        self.test_embeddings = None
        self.test_labels = None
        
    def on_test_epoch_end(self):
        """
        Perform query-based evaluation at the end of test epoch
        """
        class_indices = defaultdict(list)
        for idx, label in enumerate(self.test_labels):
            class_indices[label.item()].append(idx)
        
        all_predictions = []
        all_true_labels = []
        per_class_accuracy = defaultdict(list)
        
        for query_class in class_indices.keys():
            query_idx = random.choice(class_indices[query_class])
            query_embedding = self.test_embeddings[query_idx]
            
            distances = F.pairwise_distance(
                query_embedding.unsqueeze(0).repeat(len(self.test_embeddings), 1),
                self.test_embeddings
            )
            
            class_distances = defaultdict(list)
            for idx, dist in enumerate(distances):
                class_distances[self.test_labels[idx].item()].append(dist.item())
            
            avg_class_distances = {
                cls: np.mean(dists) for cls, dists in class_distances.items()
            }
            
            predicted_class = min(avg_class_distances.items(), key=lambda x: x[1])[0]
            
            all_predictions.append(predicted_class)
            all_true_labels.append(query_class)
            
            is_correct = predicted_class == query_class
            per_class_accuracy[query_class].append(is_correct)
            
            self.log(f"test_query_class_{query_class}_predicted", predicted_class)
            self.log(f"test_query_class_{query_class}_distance", avg_class_distances[predicted_class])
        
        correct_predictions = sum(p == t for p, t in zip(all_predictions, all_true_labels))
        overall_accuracy = correct_predictions / len(all_predictions)
        
        class_accuracies = {
            cls: sum(results) / len(results) 
            for cls, results in per_class_accuracy.items()
        }
        
        self.log("test_accuracy", overall_accuracy)
        for cls, acc in class_accuracies.items():
            self.log(f"test_class_{cls}_accuracy", acc)
        
        cm = confusion_matrix(all_true_labels, all_predictions)
        
        print("\nTest Results:")
        print(f"Overall Accuracy: {overall_accuracy:.4f}")
        print("\nPer-class Accuracy:")
        for cls, acc in class_accuracies.items():
            print(f"Class {cls}: {acc:.4f}")
        
        print("\nConfusion Matrix:")
        print(cm)
        
        if hasattr(self.logger, 'experiment') and hasattr(self.logger.experiment, 'log'):
            try:
                fig, ax = plt.subplots(figsize=(10, 8))
                sns.heatmap(cm, annot=True, fmt='d', ax=ax)
                plt.title("Test Confusion Matrix")
                plt.xlabel("Predicted Class")
                plt.ylabel("True Class")
                
                self.logger.experiment.log({
                    "test_confusion_matrix": wandb.Image(fig),
                    "test_epoch": self.current_epoch
                })
                plt.close()
            except:
                pass
    
    def predict_class(self, query_image):
        """Predict class for a query image using class prototypes"""
        query_embedding = self.model.forward_one(query_image.unsqueeze(0))
        
        distances = {}
        for class_idx, prototype in self.class_prototypes.items():
            distance = F.pairwise_distance(
                query_embedding,
                prototype.unsqueeze(0)
            )
            distances[class_idx] = distance.item()
        
        return min(distances.items(), key=lambda x: x[1])[0]
   

In [9]:
def run_experiments(base_names, architectures):
    results = {}
    
    data_module = SiameseDataModule(
        dataset_name=base_names[1],
        images_per_class=int(base_names[3]),
        n_pairs_per_class=100,
        val_split=0.2,
        batch_size=16,
        img_size=28 if base_names[1] == "mnist" else 32
    )
    
    channels = 1 if data_module.dataset_name == "mnist" else 3
    img_size = 28 if data_module.dataset_name == "mnist" else 32
    
    for arch in architectures:
        run_names = base_names.copy()
        run_names[2] = arch 
        run_name = ".".join(run_names)
        
        wandb_logger = WandbLogger(
            project="Few-KAN",
            name=run_name,
            group=f"{base_names[1]}_{base_names[-1]}"  
        )
        
        trainer = L.Trainer(
            max_epochs=50,
            callbacks=[
                L.pytorch.callbacks.EarlyStopping(
                    monitor='val_loss',
                    patience=10
                ),
                L.pytorch.callbacks.ModelCheckpoint(
                    monitor='val_loss',
                    filename=f'{run_name}-{{epoch:02d}}-{{val_loss:.2f}}',
                    save_top_k=1
                )
            ],
            logger=wandb_logger
        )
        
        model = SiameseKANModule(
            img_size=img_size,
            channels=channels,
            architecture=arch
        )
        
        print(f"\nTraining {run_name}")
        trainer.fit(model, data_module)
        results[arch] = trainer.test(model, data_module)
        
        wandb.finish()
    
    return results

names = ["siam", "cifar10", "v1", "15"]
architectures = ["kan_basic", "kan_with_cnn", "kkan"]

results = run_experiments(names, architectures)

print("\nResults Summary:")
for arch, result in results.items():
    print(f"\n{arch}:")
    for metric, value in result[0].items():
        print(f"{metric}: {value:.4f}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Training siam.cifar10.kan_basic.15
Files already downloaded and verified
Files already downloaded and verified


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33miarata[0m ([33mhdu-dk[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | model          | SiameseKANNetwork | 17.5 M | train
1 | criterion      | ContrastiveLoss   | 0      | train
2 | train_accuracy | BinaryAccuracy    | 0      | train
3 | val_accuracy   | BinaryAccuracy    | 0      | train
4 | test_accuracy  | BinaryAccuracy    | 0      | train
-------------------------------------------------------------
17.5 M    Trainable params
0         Non-trainable params
17.5 M    Total params
70.125    Total estimated model params size (MB)
19        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]


Test Results:
Overall Accuracy: 0.2000

Per-class Accuracy:
Class 3: 0.0000
Class 8: 1.0000
Class 0: 0.0000
Class 6: 1.0000
Class 1: 0.0000
Class 9: 0.0000
Class 5: 0.0000
Class 7: 0.0000
Class 4: 0.0000
Class 2: 0.0000

Confusion Matrix:
[[0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 1 0 0 0]]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Test metric                 DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_accuracy            0.20000000298023224
   test_class_0_accuracy                0.0
   test_class_1_accuracy                0.0
   test_class_2_accuracy                0.0
   test_class_3_accuracy                0.0
   test_cl

0,1
epoch,▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
test_accuracy,▁
test_class_0_accuracy,▁
test_class_1_accuracy,▁
test_class_2_accuracy,▁
test_class_3_accuracy,▁
test_class_4_accuracy,▁
test_class_5_accuracy,▁
test_class_6_accuracy,▁
test_class_7_accuracy,▁

0,1
epoch,50.0
test_accuracy,0.2
test_class_0_accuracy,0.0
test_class_1_accuracy,0.0
test_class_2_accuracy,0.0
test_class_3_accuracy,0.0
test_class_4_accuracy,0.0
test_class_5_accuracy,0.0
test_class_6_accuracy,1.0
test_class_7_accuracy,0.0


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Training siam.cifar10.kan_with_cnn.15
Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | model          | SiameseKANNetwork | 22.8 M | train
1 | criterion      | ContrastiveLoss   | 0      | train
2 | train_accuracy | BinaryAccuracy    | 0      | train
3 | val_accuracy   | BinaryAccuracy    | 0      | train
4 | test_accuracy  | BinaryAccuracy    | 0      | train
-------------------------------------------------------------
22.8 M    Trainable params
0         Non-trainable params
22.8 M    Total params
91.174    Total estimated model params size (MB)
27        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]


Test Results:
Overall Accuracy: 0.5000

Per-class Accuracy:
Class 3: 0.0000
Class 8: 1.0000
Class 0: 1.0000
Class 6: 1.0000
Class 1: 0.0000
Class 9: 0.0000
Class 5: 1.0000
Class 7: 1.0000
Class 4: 0.0000
Class 2: 0.0000

Confusion Matrix:
[[1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 0 1 0]
 [0 1 0 0 0 0 0 0 0 0]]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Test metric                 DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_accuracy                    0.5
   test_class_0_accuracy                1.0
   test_class_1_accuracy                0.0
   test_class_2_accuracy                0.0
   test_class_3_accuracy                0.0
   test_class_4_ac

0,1
epoch,▁▁▁▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇█████
test_accuracy,▁
test_class_0_accuracy,▁
test_class_1_accuracy,▁
test_class_2_accuracy,▁
test_class_3_accuracy,▁
test_class_4_accuracy,▁
test_class_5_accuracy,▁
test_class_6_accuracy,▁
test_class_7_accuracy,▁

0,1
epoch,33.0
test_accuracy,0.5
test_class_0_accuracy,1.0
test_class_1_accuracy,0.0
test_class_2_accuracy,0.0
test_class_3_accuracy,0.0
test_class_4_accuracy,0.0
test_class_5_accuracy,1.0
test_class_6_accuracy,1.0
test_class_7_accuracy,1.0


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Training siam.cifar10.kkan.15
Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type              | Params | Mode 
-------------------------------------------------------------
0 | model          | SiameseKANNetwork | 63.8 K | train
1 | criterion      | ContrastiveLoss   | 0      | train
2 | train_accuracy | BinaryAccuracy    | 0      | train
3 | val_accuracy   | BinaryAccuracy    | 0      | train
4 | test_accuracy  | BinaryAccuracy    | 0      | train
-------------------------------------------------------------
63.8 K    Trainable params
0         Non-trainable params
63.8 K    Total params
0.255     Total estimated model params size (MB)
286       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]


Test Results:
Overall Accuracy: 0.2000

Per-class Accuracy:
Class 3: 0.0000
Class 8: 0.0000
Class 0: 0.0000
Class 6: 0.0000
Class 1: 1.0000
Class 9: 1.0000
Class 5: 0.0000
Class 7: 0.0000
Class 4: 0.0000
Class 2: 0.0000

Confusion Matrix:
[[0 1 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1]]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Test metric                 DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_accuracy            0.20000000298023224
   test_class_0_accuracy                0.0
   test_class_1_accuracy                1.0
   test_class_2_accuracy                0.0
   test_class_3_accuracy                0.0
   test_cl

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇█████
test_accuracy,▁
test_class_0_accuracy,▁
test_class_1_accuracy,▁
test_class_2_accuracy,▁
test_class_3_accuracy,▁
test_class_4_accuracy,▁
test_class_5_accuracy,▁
test_class_6_accuracy,▁
test_class_7_accuracy,▁

0,1
epoch,31.0
test_accuracy,0.2
test_class_0_accuracy,0.0
test_class_1_accuracy,1.0
test_class_2_accuracy,0.0
test_class_3_accuracy,0.0
test_class_4_accuracy,0.0
test_class_5_accuracy,0.0
test_class_6_accuracy,0.0
test_class_7_accuracy,0.0



Results Summary:

kan_basic:
test_query_class_3_predicted: 8.0000
test_query_class_3_distance: 0.5359
test_query_class_8_predicted: 8.0000
test_query_class_8_distance: 0.5135
test_query_class_0_predicted: 8.0000
test_query_class_0_distance: 0.5316
test_query_class_6_predicted: 6.0000
test_query_class_6_distance: 0.4598
test_query_class_1_predicted: 9.0000
test_query_class_1_distance: 0.5308
test_query_class_9_predicted: 6.0000
test_query_class_9_distance: 0.4495
test_query_class_5_predicted: 6.0000
test_query_class_5_distance: 0.5282
test_query_class_7_predicted: 1.0000
test_query_class_7_distance: 0.6287
test_query_class_4_predicted: 6.0000
test_query_class_4_distance: 0.5357
test_query_class_2_predicted: 1.0000
test_query_class_2_distance: 0.6865
test_accuracy: 0.2000
test_class_3_accuracy: 0.0000
test_class_8_accuracy: 1.0000
test_class_0_accuracy: 0.0000
test_class_6_accuracy: 1.0000
test_class_1_accuracy: 0.0000
test_class_9_accuracy: 0.0000
test_class_5_accuracy: 0.0000
test_cla

In [10]:
# names = ["siam", "mnist", "kan_basic", "linear_kan_embedding", "v1", "shot_5"]
# wandb_logger = WandbLogger(project="Few-KAN", name=".".join(names))

In [11]:
# model = SiameseKAN(model_name=names[2], 
#     channels=1,  # 1 for MNIST, 3 for CIFAR10
#     img_size=32,
#     hidden_dim=128,
#     embedding_dim=64,
#     margin=1.0,
#     learning_rate=1e-3
# )

# data_module = SiameseDataModule(
#     dataset_name=names[1],  # or "mnist"
#     images_per_class=10,  # Can be 5, 10, or 15
#     n_pairs_per_class=100,
#     val_split=0.2,
#     batch_size=16,
#     img_size=28 if names[1] == "mnist" else 32
# )
# trainer = L.Trainer(
#     max_epochs=50,
#     callbacks=[
#         L.pytorch.callbacks.EarlyStopping(monitor='val_loss', patience=10),
#         L.pytorch.callbacks.ModelCheckpoint(monitor='val_loss')
#     ]
#     ,
#      logger=wandb_logger
# )


# model = SiameseModule(
#     img_size=img_size,
#     channels=channels,
#     hidden_dim=512,
#     embedding_dim=128,
#     margin=2.0  # Contrastive loss margin
# )

# # Train the model
# trainer.fit(model, data_module)

# # Build class prototypes
# trainer.test(model, data_module)

# trainer = L.Trainer(max_epochs=100, logger=wandb_logger)
# trainer.fit(model, data_module)