In [31]:
from __future__ import annotations
import os, time, math, json, random
from pathlib import Path
from typing import Sequence, Tuple, Union, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms.v2 import (
    Compose, Resize, RandomHorizontalFlip, ToImage, ToDtype, Normalize
)
from torchvision.utils import make_grid
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

In [52]:
SEED               = 42
N_EPOCHS           = 20
BATCH_SIZE         = 128
IMAGE_SIZE         = (64, 64)            # Mini‑ImageNet default → 84×84
LEARNING_RATE      = 3e-4
WEIGHT_DECAY       = 1e-4
RUN_DIR            = Path("runs/miniimagenet_cnn_vs_kan").resolve()
RUN_DIR.mkdir(parents=True, exist_ok=True)

random.seed(SEED)
torch.manual_seed(SEED)

device = (
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("mps")  if torch.backends.mps.is_available() else
    torch.device("cpu")
)
print(f"➡️  Using device: {device}")

➡️  Using device: mps


In [None]:
def _force_rgb(img: torch.Tensor) -> torch.Tensor:
    """Ensure 3‑channel RGB.

    • If RGBA ➜ drop alpha.
    • If grayscale ➜ replicate to 3 channels.
    """
    if img.shape[0] == 4:
        print("RGBA")# RGBA → RGB
        return img[:3]
    if img.shape[0] == 1:       # Gray → RGB by tiling
        return img.repeat(3, 1, 1)
    return img  # already RGB

def get_dataloaders(batch_size:int=BATCH_SIZE,
                    image_size:Tuple[int,int]=IMAGE_SIZE,
                    dataset_name:str="zh-plus/tiny-imagenet"):
    """
    Return train/val/test DataLoaders for Mini‑ImageNet.

    dataset_name:
        • 'timm/mini-imagenet' – 60 000 images, 200 classes
        • 'zh-plus/tiny-imagenet' – smaller, good for quick prototyping
    """
    ds = load_dataset(dataset_name)

    train_ds = ds["train"].train_test_split(
        test_size=0.1, seed=SEED, stratify_by_column="label")
    val_ds   = train_ds["test"]
    train_ds = train_ds["train"]
    test_ds  = ds["validation"]

    # Normalisation stats from ImageNet
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]

    tfms = Compose([
        ToImage(),                       # PIL → (C,H,W) uint8 tensor
        Resize(image_size),
        RandomHorizontalFlip(),
        ToDtype(torch.float32, scale=True),
        Normalize(mean=mean, std=std),
    ])

    def add_tfms(example):
        example["image"] = tfms(example["image"])
        return example

    for split in (train_ds, val_ds, test_ds):
        split.set_transform(add_tfms)

    loader_cfg = dict(
        batch_size=batch_size,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    train_loader = DataLoader(train_ds, shuffle=True,  **loader_cfg)
    val_loader   = DataLoader(val_ds,   shuffle=False, **loader_cfg)
    test_loader  = DataLoader(test_ds,  shuffle=False, **loader_cfg)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = get_dataloaders()
print(f"➡️  Dataset ready: {len(train_loader.dataset)} train samples, "
      f"{len(val_loader.dataset)} val, {len(test_loader.dataset)} test | "
      f"200 classes")


➡️  Dataset ready: 45000 train samples, 5000 val, 10000 test | 200 classes


In [47]:
class BaselineCNN(nn.Module):
    """A lightweight CNN with two conv blocks + MLP head."""
    def __init__(self, input_shape=(4, *IMAGE_SIZE),
                 conv_channels=(64, 128), fc_dims=(512,)):
        super().__init__()
        C_in, _, _ = input_shape
        c1, c2 = conv_channels

        self.features = nn.Sequential(
            nn.Conv2d(C_in, c1, 3, padding=1, bias=False), nn.BatchNorm2d(c1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(c1, c2, 3, padding=1, bias=False), nn.BatchNorm2d(c2), nn.ReLU(inplace=True), nn.MaxPool2d(2),
        )
        with torch.no_grad():
            flat = self.features(torch.zeros(1, *input_shape)).view(1, -1).shape[1]
        mlp: List[nn.Module] = []
        in_dim = flat
        for h in fc_dims:
            mlp += [nn.Linear(in_dim, h), nn.ReLU(inplace=True)]
            in_dim = h
        mlp.append(nn.Linear(in_dim, 200))
        self.classifier = nn.Sequential(*mlp)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

In [49]:
from classes.BSplineActivation import BSplineActivation
class KANCNN(nn.Module):
    """CNN backbone identical to baseline but MLP replaced by KAN."""
    def __init__(self, input_shape=(3, *IMAGE_SIZE),
                 conv_channels=(64, 128), kan_inner=128, kan_outer=200,
                 spline_cp=7, spline_deg=2):
        super().__init__()
        C_in, _, _ = input_shape
        c1, c2 = conv_channels

        self.features = nn.Sequential(
            nn.Conv2d(C_in, c1, 3, padding=1, bias=False), nn.BatchNorm2d(c1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(c1, c2, 3, padding=1, bias=False), nn.BatchNorm2d(c2), nn.ReLU(inplace=True), nn.MaxPool2d(2),
        )
        with torch.no_grad():
            flat = self.features(torch.zeros(1, *input_shape)).view(1, -1).shape[1]
        self.inner = nn.Linear(flat, kan_inner)
        self.inner_act = BSplineActivation(num_control_points=spline_cp, degree=spline_deg, range_min=-1, range_max=1)
        self.outer = nn.Linear(kan_inner, kan_outer)
        self.outer_act = BSplineActivation(num_control_points=spline_cp, degree=spline_deg, range_min=-1, range_max=1)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.inner_act(self.inner(x))
        x = self.outer_act(self.outer(x))
        return x

In [44]:
def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    return (logits.argmax(dim=1) == targets).float().mean()


def train_one_epoch(model: nn.Module, loader: DataLoader, criterion, optimizer, epoch: int) -> Dict[str, float]:
    model.train()
    loss_sum = acc_sum = 0.0
    for batch in tqdm(loader, desc=f"Train {epoch:02d}", leave=False):
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        acc = accuracy_from_logits(logits, y)
        loss_sum += loss.item() * x.size(0)
        acc_sum  += acc.item()  * x.size(0)

    n = len(loader.dataset)
    return {"loss": loss_sum / n, "acc": acc_sum / n}


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion) -> Dict[str, float]:
    model.eval()
    loss_sum = acc_sum = 0.0
    for batch in loader:
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        acc = accuracy_from_logits(logits, y)
        loss_sum += loss.item() * x.size(0)
        acc_sum  += acc.item() * x.size(0)
    n = len(loader.dataset)
    return {"loss": loss_sum / n, "acc": acc_sum / n}


def run_training(model:nn.Module, name:str,
                 train_loader:DataLoader, val_loader:DataLoader) -> Dict[str, List[float]]:
    """Full training loop for one model."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

    for epoch in range(1, N_EPOCHS+1):
        tic = time.time()
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
        val_metrics   = evaluate(model, val_loader, criterion)

        history["train_loss"].append(train_metrics["loss"])
        history["train_acc"].append(train_metrics["acc"])
        history["val_loss"].append(val_metrics["loss"])
        history["val_acc"].append(val_metrics["acc"])

        print(f"Epoch {epoch:2d}/{N_EPOCHS} • "
              f"train acc {train_metrics['acc']*100:5.2f}% | "
              f"val acc {val_metrics['acc']*100:5.2f}% | "
              f"Δt {time.time()-tic:4.1f}s")

    torch.save(model.state_dict(), RUN_DIR/f"{name}.pt")
    with open(RUN_DIR/f"{name}_history.json", "w") as f:
        json.dump(history, f)
    return history


In [45]:
def plot_metrics(df:pd.DataFrame):
    """Plot accuracy + loss curves for both models."""
    sns.set_theme(style="whitegrid", font_scale=1.2)

    # Accuracy
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(df["epoch"], df["baseline_train_acc"], label="Baseline train")
    ax.plot(df["epoch"], df["baseline_val_acc"],   label="Baseline val")
    ax.plot(df["epoch"], df["kan_train_acc"],      label="KAN train", linestyle="--")
    ax.plot(df["epoch"], df["kan_val_acc"],        label="KAN val",   linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy (%)")
    ax.set_title("Mini‑ImageNet • Accuracy vs. Epoch")
    ax.legend()
    fig.tight_layout()
    plt.savefig(RUN_DIR/"accuracy_curves.png", dpi=200)

    # Loss
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(df["epoch"], df["baseline_train_loss"], label="Baseline train")
    ax.plot(df["epoch"], df["baseline_val_loss"],   label="Baseline val")
    ax.plot(df["epoch"], df["kan_train_loss"],      label="KAN train", linestyle="--")
    ax.plot(df["epoch"], df["kan_val_loss"],        label="KAN val",   linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Cross‑entropy loss")
    ax.set_title("Mini‑ImageNet • Loss vs. Epoch")
    ax.legend()
    fig.tight_layout()
    plt.savefig(RUN_DIR/"loss_curves.png", dpi=200)
    plt.show()

In [51]:
def main():
    baseline = BaselineCNN()
    kan      = KANCNN()

    print("📚 Training baseline CNN …")
    hist_base = run_training(baseline, "baseline", train_loader, val_loader)

    print("\n🌀 Training KAN-CNN …")
    hist_kan  = run_training(kan,      "KAN",      train_loader, val_loader)

    # convert to DataFrame for plotting
    df = pd.DataFrame({
        "epoch": range(1, N_EPOCHS+1),
        "baseline_train_acc": [x*100 for x in hist_base["train_acc"]],
        "baseline_val_acc"  : [x*100 for x in hist_base["val_acc"]],
        "baseline_train_loss": hist_base["train_loss"],
        "baseline_val_loss"  : hist_base["val_loss"],
        "kan_train_acc":      [x*100 for x in hist_kan["train_acc"]],
        "kan_val_acc":        [x*100 for x in hist_kan["val_acc"]],
        "kan_train_loss":     hist_kan["train_loss"],
        "kan_val_loss":       hist_kan["val_loss"],
    })
    plot_metrics(df)

    # Final test evaluation
    criterion = nn.CrossEntropyLoss()
    test_base = evaluate(baseline, test_loader, criterion)
    test_kan  = evaluate(kan,      test_loader, criterion)

    print(f"\n✅ Test accuracy: "
          f"Baseline {test_base['acc']*100:5.2f}% | "
          f"KAN {test_kan['acc']*100:5.2f}%")

if __name__ == "__main__":
    main()

📚 Training baseline CNN …


                                                 

RuntimeError: Given groups=1, weight of size [64, 4, 3, 3], expected input[128, 3, 84, 84] to have 4 channels, but got 3 channels instead