In [145]:
from __future__ import annotations
import os, time, math, json, random
from pathlib import Path
from typing import Sequence, Tuple, Union, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms.v2 import (
    Compose, Resize, RandomHorizontalFlip, ToImage,
    ToDtype, Normalize, Lambda                    # ⬅ new
)
from torchvision.utils import make_grid
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

In [146]:
# -----------------------  constants  ---------------------------------
SEED               = 42
N_EPOCHS           = 44
BATCH_SIZE         = 64
IMAGE_SIZE         = (64, 64)            # Tiny‑ImageNet native size
LEARNING_RATE      = 0.002271
WEIGHT_DECAY       = 1e-5
RUN_DIR            = Path("runs/tinyimagenet_cnn_vs_kan").resolve()
RUN_DIR.mkdir(parents=True, exist_ok=True)

random.seed(SEED)
torch.manual_seed(SEED)

device = (
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("mps")  if torch.backends.mps.is_available() else
    torch.device("cpu")
)
print(f"➡️  Using device: {device}")


➡️  Using device: mps


In [147]:
# -----------------------  utils  -------------------------------------
def _force_rgb(img: torch.Tensor) -> torch.Tensor:
    """Ensure 3‑channel RGB.

    • If RGBA ➜ drop alpha.
    • If grayscale ➜ replicate to 3 channels.
    """
    if img.shape[0] == 4:        # RGBA → RGB
        return img[:3]
    if img.shape[0] == 1:        # Gray → RGB by tiling
        return img.repeat(3, 1, 1)
    return img                   # already RGB

def get_dataloaders(batch_size:int=BATCH_SIZE,
                    image_size:Tuple[int,int]=IMAGE_SIZE,
                    dataset_name:str="zh-plus/tiny-imagenet"):
    """
    Return train/val/test DataLoaders for Tiny‑ImageNet.

    Splits:
        • train  – 100 000 imgs
        • valid  – 10 000 imgs (used here as held‑out **test**)
    We further split 10 % of *train* into an internal validation set.
    """
    ds = load_dataset(dataset_name)

    # carve INTERNAL val from train
    split = ds["train"].train_test_split(
        test_size=0.1, seed=SEED, stratify_by_column="label")
    train_ds = split["train"]
    val_ds   = split["test"]
    test_ds  = ds["valid"]                 # official validation → test

    # ImageNet normalisation stats
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]

    tfms = Compose([
        ToImage(),                       # PIL → (C,H,W) uint8 tensor
        Lambda(_force_rgb),              # ⬅ squash grayscale → RGB
        RandomHorizontalFlip(),
        ToDtype(torch.float32, scale=True),
        Normalize(mean=mean, std=std),
    ])

    def add_tfms(example):
        example["image"] = tfms(example["image"])
        return example

    for split_ds in (train_ds, val_ds, test_ds):
        split_ds.set_transform(add_tfms)

    loader_cfg = dict(
        batch_size=batch_size,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    train_loader = DataLoader(train_ds, shuffle=True,  **loader_cfg)
    val_loader   = DataLoader(val_ds,   shuffle=False, **loader_cfg)
    test_loader  = DataLoader(test_ds,  shuffle=False, **loader_cfg)
    return train_loader, val_loader, test_loader

In [148]:
# ---------------  models  --------------------------------------------
class BaselineCNN(nn.Module):
    """A lightweight CNN with two conv blocks + MLP head."""
    def __init__(self, input_shape=(3, *IMAGE_SIZE),   # ⬅ 3‑channel
                 conv_channels=(64, 128, 256)):
        super().__init__()
        C_in, _, _ = input_shape
        c1, c2, c3 = conv_channels
        dropout = 0.1
        self.features = nn.Sequential(
        # Block 1
        nn.Conv2d(C_in, c1, 3, padding=1),
        nn.BatchNorm2d(64),
        nn.SiLU(inplace=True),
        
        # Block 2
        nn.Conv2d(c1, c2, 3, padding=1),
        nn.BatchNorm2d(128),
        nn.SiLU(inplace=True),
        nn.MaxPool2d(2),  # Now 32x32
        
        # Block 3
        nn.Conv2d(c2, c3, 3, padding=1),
        nn.BatchNorm2d(256),
        nn.SiLU(inplace=True),
        nn.MaxPool2d(2)
        )
        with torch.no_grad():
            flat_feats = self.features(torch.zeros(1, *input_shape)).view(1, -1).size(1)
                
        self.flatten = nn.Flatten()

        
        self.ff = nn.Sequential(
            nn.Linear(flat_feats, 200),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 200)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.flatten(x)
        x = self.ff(x)
        return x

In [149]:
# ── 0. deps ──────────────────────────────────────────────────────────────
# pip install bayesian-optimization if you haven’t already
from bayes_opt import BayesianOptimization
import torch, time, json
from pathlib import Path
from classes.BSplineActivation import BSplineActivation
# ── 1. make KANCNN fully hyper‑param friendly ────────────────────────────
class KANCNN(nn.Module):
    """
    7×7 stem + four two‑conv blocks (64‑128‑256‑512) → GAP → 3‑layer KAN head.
    Every width / spline knob is exposed for BayesianOptimization.
    """
    def __init__(self,
                 input_shape=(3, *IMAGE_SIZE),
                 stem_out=64,          # first 7×7 conv channels
                 ch2=128, ch3=256, ch4=512,   # block widths
                 kan_1=512, kan_2=256, kan_3=200,
                 spline_cp=8, spline_deg=3,
                 range_min=-5.0, range_max=5.0):
        super().__init__()
        C_in, *_ = input_shape

        # ── 1. Convolutional backbone ──────────────────────────────────
        self.features = nn.Sequential(
            # Stem: 64×64 → 32×32
            nn.Conv2d(C_in, stem_out, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(stem_out),
            nn.ReLU(inplace=True),
            # 32×32 → 16×16
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

            # Block 1 (64 channels, no down‑sample) 16×16 → 16×16
            nn.Conv2d(stem_out, stem_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(stem_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(stem_out, stem_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(stem_out),
            nn.ReLU(inplace=True),

            # Block 2 (128 channels, stride‑2) 16×16 → 8×8
            nn.Conv2d(stem_out, ch2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ch2),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch2, ch2, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(ch2),
            nn.ReLU(inplace=True),

            # Block 3 (256 channels, stride‑2) 8×8 → 4×4
            nn.Conv2d(ch2, ch3, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ch3),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3, ch3, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(ch3),
            nn.ReLU(inplace=True),

            # Block 4 (512 channels, stride‑2) 4×4 → 2×2
            nn.Conv2d(ch3, ch4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ch4),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch4, ch4, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(ch4),
            nn.ReLU(inplace=True),

            # Global average pool 2×2 → 1×1
            nn.AdaptiveAvgPool2d(1)
        )

        flat_dim = ch4  # after GAP you always have 512 (or ch4)

        # ── 2. KAN head with B‑spline activations ───────────────────────
        self.kan1 = nn.Linear(flat_dim, kan_1)
        self.kan1_act = BSplineActivation(
            num_control_points=spline_cp,
            degree=spline_deg,
            range_min=range_min,
            range_max=range_max
        )
        self.kan2 = nn.Linear(kan_1, kan_2)
        self.kan2_act = BSplineActivation(
            num_control_points=spline_cp,
            degree=spline_deg,
            range_min=range_min,
            range_max=range_max
        )
        self.kan3 = nn.Linear(kan_2, kan_3)
        self.kan3_act = BSplineActivation(
            num_control_points=spline_cp,
            degree=spline_deg,
            range_min=range_min,
            range_max=range_max
        )

    # ── 3. Forward pass ────────────────────────────────────────────────
    def forward(self, x):
        x = self.features(x)          # B × 512 × 1 × 1
        x = torch.flatten(x, 1)       # B × 512
        x = self.kan1_act(self.kan1(x))
        x = self.kan2_act(self.kan2(x))
        x = self.kan3_act(self.kan3(x))
        return x


In [150]:
# ---------------  training / eval helpers  ---------------------------
def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    return (logits.argmax(dim=1) == targets).float().mean()

def train_one_epoch(model: nn.Module, loader: DataLoader, criterion, optimizer, epoch: int) -> Dict[str, float]:
    model.train()
    loss_sum = acc_sum = 0.0
    for batch in tqdm(loader, desc=f"Train {epoch:02d}", leave=False):
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        acc = accuracy_from_logits(logits, y)
        loss_sum += loss.item() * x.size(0)
        acc_sum  += acc.item()  * x.size(0)

    n = len(loader.dataset)
    return {"loss": loss_sum / n, "acc": acc_sum / n}

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion) -> Dict[str, float]:
    model.eval()
    loss_sum = acc_sum = 0.0
    for batch in loader:
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        acc = accuracy_from_logits(logits, y)
        loss_sum += loss.item() * x.size(0)
        acc_sum  += acc.item() * x.size(0)
    n = len(loader.dataset)
    return {"loss": loss_sum / n, "acc": acc_sum / n}

def run_training(model:nn.Module, name:str,
                 train_loader:DataLoader, val_loader:DataLoader) -> Dict[str, List[float]]:
    """Full training loop for one model."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

    for epoch in range(1, N_EPOCHS+1):
        tic = time.time()
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
        val_metrics   = evaluate(model, val_loader, criterion)

        history["train_loss"].append(train_metrics["loss"])
        history["train_acc"].append(train_metrics["acc"])
        history["val_loss"].append(val_metrics["loss"])
        history["val_acc"].append(val_metrics["acc"])

        print(f"Epoch {epoch:2d}/{N_EPOCHS} • "
              f"train acc {train_metrics['acc']*100:5.2f}% | "
              f"val acc {val_metrics['acc']*100:5.2f}% | "
              f"Δt {time.time()-tic:4.1f}s")

    torch.save(model.state_dict(), RUN_DIR/f"{name}.pt")
    with open(RUN_DIR/f"{name}_history.json", "w") as f:
        json.dump(history, f)
    return history

### Plot

In [151]:
def plot_metrics(df:pd.DataFrame):
    """Plot accuracy + loss curves for both models."""
    sns.set_theme(style="whitegrid", font_scale=1.2)

    # Accuracy
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(df["epoch"], df["baseline_train_acc"], label="Baseline train")
    ax.plot(df["epoch"], df["baseline_val_acc"],   label="Baseline val")
    ax.plot(df["epoch"], df["kan_train_acc"],      label="KAN train", linestyle="--")
    ax.plot(df["epoch"], df["kan_val_acc"],        label="KAN val",   linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy (%)")
    ax.set_title("Tiny‑ImageNet • Accuracy vs. Epoch")
    ax.legend()
    fig.tight_layout()
    plt.savefig(RUN_DIR/"accuracy_curves.png", dpi=200)

    # Loss
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(df["epoch"], df["baseline_train_loss"], label="Baseline train")
    ax.plot(df["epoch"], df["baseline_val_loss"],   label="Baseline val")
    ax.plot(df["epoch"], df["kan_train_loss"],      label="KAN train", linestyle="--")
    ax.plot(df["epoch"], df["kan_val_loss"],        label="KAN val",   linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Cross‑entropy loss")
    ax.set_title("Tiny‑ImageNet • Loss vs. Epoch")
    ax.legend()
    fig.tight_layout()
    plt.savefig(RUN_DIR/"loss_curves.png", dpi=200)
    plt.show()

In [None]:
def count_parameters(model: nn.Module) -> int:
    """
    Count the number of trainable parameters in the model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Final test evaluation
    

train_loader, val_loader, test_loader = get_dataloaders()

baseline = BaselineCNN()
kan      = KANCNN()

print(f'KAN parameters: {count_parameters(kan)}')
print(f'Baseline parameters: {count_parameters(baseline)}')

In [89]:
#print("📚 Training baseline CNN …")
#hist_base, baseline_model = run_training(baseline, "baseline", train_loader, val_loader)

In [None]:
print("\n🌀 Training KAN‑CNN …")
hist_kan  = run_training(kan,      "KAN",      train_loader, val_loader)

In [None]:

criterion = nn.CrossEntropyLoss()
#test_base = evaluate(baseline, test_loader, criterion)
test_kan  = evaluate(kan,      test_loader, criterion)


print(f"\n✅ Test accuracy: "
        #f"Baseline {test_base['acc']*100:5.2f}% | "
        f"KAN {test_kan['acc']*100:5.2f}%")


### Optimize KAN

In [92]:
pbounds = {
    # ints → we pass floats in but will round later
    "epochs":             (25, 55),
    "kan_1":          (200, 400),   # width of first KAN layer
    "kan_2":          (128, 512),   # second KAN layer
    "kan_3":            (200, 200),
    "spline_cp":          (6, 10),      # control points
    "spline_deg":         (2, 5),      # deg ≤ cp‑1 guard enforced later
    "range_min":          (-5.0, -0.5),
    "range_max":          (30.0, 70.0),
    "lr":                 (1e-4, 1e-2)
}

import time
from tqdm.auto import tqdm

def train_kan(model, train_loader, val_loader, epochs, lr):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optim     = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)

    best_val = 0.0
    for ep in range(1, epochs + 1):
        tic = time.time()
        # ─ train ─
        model.train()
        loss_sum = acc_sum = 0.0
        for batch in train_loader:
            x, y = batch["image"].to(device), batch["label"].to(device)
            optim.zero_grad(set_to_none=True)
            logits = model(x)
            loss   = criterion(logits, y)
            loss.backward()
            optim.step()

            acc = (logits.argmax(1) == y).float().mean().item()
            loss_sum += loss.item() * x.size(0)
            acc_sum  += acc * x.size(0)

        # ─ eval ─
        model.eval()
        loss_sum_val = acc_sum_val = 0.0
        with torch.no_grad():
            for batch in val_loader:
                x, y = batch["image"].to(device), batch["label"].to(device)
                logits = model(x)
                loss   = criterion(logits, y)
                acc    = (logits.argmax(1) == y).float().mean().item()
                loss_sum_val += loss.item() * x.size(0)
                acc_sum_val  += acc * x.size(0)

        train_loss = loss_sum     / len(train_loader.dataset)
        train_acc  = acc_sum      / len(train_loader.dataset)
        val_loss   = loss_sum_val / len(val_loader.dataset)
        val_acc    = acc_sum_val  / len(val_loader.dataset)
        elapsed    = time.time() - tic

        # ← print exactly like you had it
        print(f"Epoch [{ep}/{epochs}], "
              f"Loss: {train_loss:.4f}, "
              f"Test Acc: {val_acc*100:5.2f}%, "
              f"Time: {elapsed:5.2f} seconds")

        best_val = max(best_val, val_acc)

    return best_val


def optimize_kan(epochs,
                 kan_1,
                 kan_2,
                 kan_3,
                 spline_cp,
                 spline_deg,
                 range_min,
                 range_max,
                 lr):

    # ─ cast + sanity ─
    epochs      = int(round(epochs))
    kan_1   = int(round(kan_1))
    kan_2   = int(round(kan_2))
    kan_3   = int(round(kan_3))
    spline_cp   = int(round(spline_cp))
    spline_deg  = int(round(spline_deg))

    # keep B‑spline well‑formed
    spline_deg  = max(2, min(spline_deg, spline_cp - 1))
    lr          = float(lr)

    model = KANCNN(
        kan_1=kan_1,
        kan_2=kan_2,
        kan_3=kan_3,
        spline_cp=spline_cp,
        spline_deg=spline_deg,
        range_min=range_min,
        range_max=range_max
    )

    val_acc = train_kan(model, train_loader, val_loader, epochs, lr)

    # BayesOpt maximizes the returned value
    return val_acc



In [None]:
optimizer = BayesianOptimization(
    f=optimize_kan,
    pbounds=pbounds,
    random_state=38,
    verbose=2
)

optimizer.maximize(init_points=6, n_iter=10)

print("🚀 best combo so far →", optimizer.max)