In [3]:
from __future__ import annotations
import os, time, math, json, random
from pathlib import Path
from typing import Sequence, Tuple, Union, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms.v2 import (
    Compose, Resize, RandomHorizontalFlip, ToImage,
    ToDtype, Normalize, Lambda                    # ⬅ new
)
from torchvision.utils import make_grid
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm

In [4]:
# -----------------------  constants  ---------------------------------
SEED               = 42
N_EPOCHS           = 10
BATCH_SIZE         = 128
IMAGE_SIZE         = (64, 64)            # Tiny‑ImageNet native size
LEARNING_RATE      = 3e-4
WEIGHT_DECAY       = 1e-5
RUN_DIR            = Path("runs/tinyimagenet_cnn_vs_kan").resolve()
RUN_DIR.mkdir(parents=True, exist_ok=True)

random.seed(SEED)
torch.manual_seed(SEED)

device = (
    torch.device("cuda") if torch.cuda.is_available() else
    torch.device("mps")  if torch.backends.mps.is_available() else
    torch.device("cpu")
)
print(f"➡️  Using device: {device}")


➡️  Using device: mps


In [41]:
# -----------------------  utils  -------------------------------------
def _force_rgb(img: torch.Tensor) -> torch.Tensor:
    """Ensure 3‑channel RGB.

    • If RGBA ➜ drop alpha.
    • If grayscale ➜ replicate to 3 channels.
    """
    if img.shape[0] == 4:        # RGBA → RGB
        return img[:3]
    if img.shape[0] == 1:        # Gray → RGB by tiling
        return img.repeat(3, 1, 1)
    return img                   # already RGB

def get_dataloaders(batch_size:int=BATCH_SIZE,
                    image_size:Tuple[int,int]=IMAGE_SIZE,
                    dataset_name:str="zh-plus/tiny-imagenet"):
    """
    Return train/val/test DataLoaders for Tiny‑ImageNet.

    Splits:
        • train  – 100 000 imgs
        • valid  – 10 000 imgs (used here as held‑out **test**)
    We further split 10 % of *train* into an internal validation set.
    """
    ds = load_dataset(dataset_name)

    # carve INTERNAL val from train
    split = ds["train"].train_test_split(
        test_size=0.1, seed=SEED, stratify_by_column="label")
    train_ds = split["train"]
    val_ds   = split["test"]
    test_ds  = ds["valid"]                 # official validation → test

    # ImageNet normalisation stats
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]

    tfms = Compose([
        ToImage(),                       # PIL → (C,H,W) uint8 tensor
        Lambda(_force_rgb),              # ⬅ squash grayscale → RGB
        RandomHorizontalFlip(),
        ToDtype(torch.float32, scale=True),
        Normalize(mean=mean, std=std),
    ])

    def add_tfms(example):
        example["image"] = tfms(example["image"])
        return example

    for split_ds in (train_ds, val_ds, test_ds):
        split_ds.set_transform(add_tfms)

    loader_cfg = dict(
        batch_size=batch_size,
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
    )

    train_loader = DataLoader(train_ds, shuffle=True,  **loader_cfg)
    val_loader   = DataLoader(val_ds,   shuffle=False, **loader_cfg)
    test_loader  = DataLoader(test_ds,  shuffle=False, **loader_cfg)
    return train_loader, val_loader, test_loader

In [67]:
import torch
import torch.nn as nn

class BaselineCNN(nn.Module):
    """
    Tiny‑ImageNet‑sized baseline CNN
    • Input : (B, 3, 64, 64)
    • Output: (B, num_classes)
    """
    def __init__(
        self,
        input_shape: tuple[int, int, int] = (3, 64, 64),
        num_classes: int = 200,
        dropout: float = 0.1,
    ):
        super().__init__()
        C_in, _, _ = input_shape

        # ─────────── feature extractor ───────────
        self.features = nn.Sequential(
            # 64×64 → 64×64
            nn.Conv2d(C_in, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.SiLU(inplace=True),

            # 64×64 → 32×32
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.SiLU(inplace=True),
            nn.MaxPool2d(2),

            # 32×32 → 16×16
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.SiLU(inplace=True),
            nn.MaxPool2d(2),

            # 16×16 → 8×8
            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.SiLU(inplace=True),
            nn.MaxPool2d(2),

            # 8×8 → 4×4
            nn.Conv2d(512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.SiLU(inplace=True),
            nn.MaxPool2d(2),

            # 4×4 → 1×1 (global pooling)
            nn.AdaptiveMaxPool2d(1),
        )

        # ─────────── classifier ───────────
        self.classifier = nn.Sequential(
            nn.Flatten(),               # (B, 512, 1, 1) → (B, 512)
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),

            nn.Linear(256, num_classes),
        )

        # weight‑init
        self.apply(self._init_weights)

    # ────────────────────────────────────────────
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.classifier(x)
        return x

    @staticmethod
    def _init_weights(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)


In [68]:
batch   = next(iter(train_loader))
images  = batch["image"]           
labels  = batch["label"]          

print("images :", images.shape, images.dtype, images.min().item(), images.max().item())
print("labels :", labels[:20], labels.dtype)
print("# unique classes in batch:", len(torch.unique(labels)))

# sanity checks you should see
#   • dtype: torch.float32
#   • min/max roughly -2 ↔ +2      (ImageNet mean/std normalisation)
#   • label dtype: torch.int64
#   • at least a handful of different label values


images : torch.Size([128, 3, 64, 64]) torch.float32 -2.1179039478302 255.0
labels : tensor([183,  22, 183,  72,  61,  96, 151, 120, 195, 164,   5, 143, 116, 146,
        135, 197, 170, 132,  48,  16]) torch.int64
# unique classes in batch: 93


In [69]:
def train(model, loader, loss_fn, optim, device=device):
    model.train()
    tot_loss = tot_correct = 0
    for batch in loader:
        x, y = batch["image"].to(device), batch["label"].to(device)
        optim.zero_grad(set_to_none=True)
        logits = model(x)
        loss   = loss_fn(logits, y)
        loss.backward()
        optim.step()
        tot_loss    += loss.item() * x.size(0)
        tot_correct += (logits.argmax(1) == y).sum().item()
    n = len(loader.dataset)
    return tot_loss/n, tot_correct/n          # avg_loss, accuracy


@torch.no_grad()
def evaluate(model, loader, loss_fn, device=device):
    model.eval()
    tot_loss = tot_correct = 0
    for batch in loader:
        x, y = batch["image"].to(device), batch["label"].to(device)
        logits = model(x)
        loss   = loss_fn(logits, y)
        tot_loss    += loss.item() * x.size(0)
        tot_correct += (logits.argmax(1) == y).sum().item()
    n = len(loader.dataset)
    return tot_loss/n, tot_correct/n


In [70]:
def train_eval_cnn(lr, dropout, weight_decay,
                   train_loader=train_loader, val_loader=val_loader,
                   n_epochs=30):
    lr, dropout, weight_decay = float(lr), float(dropout), float(weight_decay)
    model     = BaselineCNN(dropout=dropout).to(device)
    optim     = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn   = nn.CrossEntropyLoss()

    print(f"[trial] lr={lr:.1e}  dr={dropout:.2f}  wd={weight_decay:.1e}")
    for epoch in range(1, n_epochs+1):
        tr_loss, tr_acc = train(model, train_loader, loss_fn, optim)
        va_loss, va_acc = evaluate(model, val_loader, loss_fn)
        print(f"Epoch {epoch:02d} • "
              f"train {tr_acc*100:5.2f}% / val {va_acc*100:5.2f}% | "
              f"loss {tr_loss:.3f}/{va_loss:.3f}")
    return va_acc                                # return best/final as you wish


In [71]:
pbounds = {"lr": (1e-4, 3e-3), "dropout": (0.0, 0.4), "weight_decay": (1e-6, 1e-3)}
bo = BayesianOptimization(
    f=lambda lr, dropout, weight_decay: train_eval_cnn(lr, dropout, weight_decay),
    pbounds=pbounds, random_state=42, verbose=2
)
bo.maximize(init_points=5, n_iter=25)
print("best →", bo.max)


|   iter    |  target   |  dropout  |    lr     | weight... |
-------------------------------------------------------------
[trial] lr=2.9e-03  dr=0.15  wd=7.3e-04
Epoch 01 • train  0.62% / val  1.16% | loss 5.291/5.170
Epoch 02 • train  1.90% / val  0.97% | loss 5.052/5.273
Epoch 03 • train  3.36% / val  1.30% | loss 4.853/5.284
Epoch 04 • train  5.39% / val  2.52% | loss 4.634/5.183
Epoch 05 • train  7.31% / val  1.08% | loss 4.453/6.082
Epoch 06 • train  9.32% / val  1.40% | loss 4.289/5.979
Epoch 07 • train 10.85% / val  1.09% | loss 4.152/5.744
Epoch 08 • train 12.23% / val  1.99% | loss 4.043/5.698
Epoch 09 • train 13.26% / val  2.37% | loss 3.964/5.612
Epoch 10 • train 14.04% / val  1.05% | loss 3.909/5.982
Epoch 11 • train 14.88% / val  1.22% | loss 3.858/5.950
Epoch 12 • train 15.25% / val  1.70% | loss 3.832/5.719
Epoch 13 • train 15.63% / val  1.45% | loss 3.800/5.850
Epoch 14 • train 16.23% / val  2.54% | loss 3.774/5.634
Epoch 15 • train 16.58% / val  3.88% | loss 3.754/5.

KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load("best_model.pt"))
test_loss, test_acc = evaluate(model, test_loader, loss_fn, device)
print(f"Test Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")


In [63]:
# ── 0. deps ──────────────────────────────────────────────────────────────
# pip install bayesian-optimization if you haven’t already
from bayes_opt import BayesianOptimization
import torch, time, json
from pathlib import Path
from classes.BSplineActivation import BSplineActivation
# ── 1. make KANCNN fully hyper‑param friendly ────────────────────────────
class KANCNN(nn.Module):
    """
    Same backbone as before but every spline & layer size is configurable,
    so BayesOpt can mess with them.
    """
    def __init__(self,
                 input_shape=(3, *IMAGE_SIZE),
                 conv_channels=(64, 128),
                 kan_1=512,
                 kan_2=256,
                 kan_3=200,
                 spline_cp=7,
                 spline_deg=2,
                 range_min=-3.0,
                 range_max=50.0):
        super().__init__()
        C_in, _, _ = input_shape
        c1, c2 = conv_channels

        self.features = nn.Sequential(
            nn.Conv2d(C_in, c1, kernel_size=3, padding=1),
            nn.BatchNorm2d(c1), nn.ReLU(inplace=True), nn.MaxPool2d(2),

            nn.Conv2d(c1, c2, kernel_size=3, padding=1),
            nn.BatchNorm2d(c2), nn.ReLU(inplace=True), nn.MaxPool2d(2)
        )

        with torch.no_grad():
            flat = self.features(torch.zeros(1, *input_shape)).flatten(1).size(1)

        self.kan1      = nn.Linear(flat, kan_1)
        self.kan1_act  = BSplineActivation(
            num_control_points=spline_cp,
            degree=spline_deg,
            range_min=range_min,
            range_max=range_max
        )
        self.kan2      = nn.Linear(kan_1, kan_2)
        self.kan2_act  = BSplineActivation(
            num_control_points=spline_cp,
            degree=spline_deg,
            range_min=range_min,
            range_max=range_max
        )
        self.kan3      = nn.Linear(kan_2, kan_3)
        self.kan3_act  = BSplineActivation(
            num_control_points=spline_cp,
            degree=spline_deg,
            range_min=range_min,
            range_max=range_max
        )

    def forward(self, x):
        x = torch.flatten(self.features(x), 1)
        x = self.kan1(x)
        x = self.kan1_act(x)
        x = self.kan2(x)
        x = self.kan2_act(x)
        x = self.kan3(x)
        x = self.kan3_act(x)
        
        return x


In [52]:
# ---------------  training / eval helpers  ---------------------------
def accuracy_from_logits(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    return (logits.argmax(dim=1) == targets).float().mean()

def train_one_epoch(model: nn.Module, loader: DataLoader, criterion, optimizer, epoch: int) -> Dict[str, float]:
    model.train()
    loss_sum = acc_sum = 0.0
    for batch in tqdm(loader, desc=f"Train {epoch:02d}", leave=False):
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        acc = accuracy_from_logits(logits, y)
        loss_sum += loss.item() * x.size(0)
        acc_sum  += acc.item()  * x.size(0)

    n = len(loader.dataset)
    return {"loss": loss_sum / n, "acc": acc_sum / n}

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion) -> Dict[str, float]:
    model.eval()
    loss_sum = acc_sum = 0.0
    for batch in loader:
        x = batch["image"].to(device, non_blocking=True)
        y = batch["label"].to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        acc = accuracy_from_logits(logits, y)
        loss_sum += loss.item() * x.size(0)
        acc_sum  += acc.item() * x.size(0)
    n = len(loader.dataset)
    return {"loss": loss_sum / n, "acc": acc_sum / n}

def run_training(model:nn.Module, name:str,
                 train_loader:DataLoader, val_loader:DataLoader) -> Dict[str, List[float]]:
    """Full training loop for one model."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}

    for epoch in range(1, N_EPOCHS+1):
        tic = time.time()
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
        val_metrics   = evaluate(model, val_loader, criterion)

        history["train_loss"].append(train_metrics["loss"])
        history["train_acc"].append(train_metrics["acc"])
        history["val_loss"].append(val_metrics["loss"])
        history["val_acc"].append(val_metrics["acc"])

        print(f"Epoch {epoch:2d}/{N_EPOCHS} • "
              f"train acc {train_metrics['acc']*100:5.2f}% | "
              f"val acc {val_metrics['acc']*100:5.2f}% | "
              f"Δt {time.time()-tic:4.1f}s")

    torch.save(model.state_dict(), RUN_DIR/f"{name}.pt")
    with open(RUN_DIR/f"{name}_history.json", "w") as f:
        json.dump(history, f)
    return history

### Plot

In [53]:
def plot_metrics(df:pd.DataFrame):
    """Plot accuracy + loss curves for both models."""
    sns.set_theme(style="whitegrid", font_scale=1.2)

    # Accuracy
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(df["epoch"], df["baseline_train_acc"], label="Baseline train")
    ax.plot(df["epoch"], df["baseline_val_acc"],   label="Baseline val")
    ax.plot(df["epoch"], df["kan_train_acc"],      label="KAN train", linestyle="--")
    ax.plot(df["epoch"], df["kan_val_acc"],        label="KAN val",   linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy (%)")
    ax.set_title("Tiny‑ImageNet • Accuracy vs. Epoch")
    ax.legend()
    fig.tight_layout()
    plt.savefig(RUN_DIR/"accuracy_curves.png", dpi=200)

    # Loss
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(df["epoch"], df["baseline_train_loss"], label="Baseline train")
    ax.plot(df["epoch"], df["baseline_val_loss"],   label="Baseline val")
    ax.plot(df["epoch"], df["kan_train_loss"],      label="KAN train", linestyle="--")
    ax.plot(df["epoch"], df["kan_val_loss"],        label="KAN val",   linestyle="--")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Cross‑entropy loss")
    ax.set_title("Tiny‑ImageNet • Loss vs. Epoch")
    ax.legend()
    fig.tight_layout()
    plt.savefig(RUN_DIR/"loss_curves.png", dpi=200)
    plt.show()

In [None]:
def count_parameters(model: nn.Module) -> int:
    """
    Count the number of trainable parameters in the model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Final test evaluation
    

train_loader, val_loader, test_loader = get_dataloaders()

baseline = BaselineCNN()
kan      = KANCNN()

print(f'KAN parameters: {count_parameters(kan)}')
print(f'Baseline parameters: {count_parameters(baseline)}')

In [55]:
#print("📚 Training baseline CNN …")
#hist_base, baseline_model = run_training(baseline, "baseline", train_loader, val_loader)

In [56]:
#print("\n🌀 Training KAN‑CNN …")
#hist_kan  = run_training(kan,      "KAN",      train_loader, val_loader)

In [None]:
'''
criterion = nn.CrossEntropyLoss()
#test_base = evaluate(baseline, test_loader, criterion)
test_kan  = evaluate(kan,      test_loader, criterion)


print(f"\n✅ Test accuracy: "
        #f"Baseline {test_base['acc']*100:5.2f}% | "
        f"KAN {test_kan['acc']*100:5.2f}%")
'''

### Optimize KAN

In [61]:
pbounds = {
    # ints → we pass floats in but will round later
    "epochs":             (35, 55),
    "kan_1":          (64, 512),   # width of first KAN layer
    "kan_2":          (128, 512),   # second KAN layer
    "kan_3":            (200, 200),
    "spline_cp":          (6, 10),      # control points
    "spline_deg":         (2, 5),      # deg ≤ cp‑1 guard enforced later
    "range_min":          (-5.0, -0.5),
    "range_max":          (5.0, 70.0),
    "lr":                 (1e-4, 5e-3)
}

import time
from tqdm.auto import tqdm

def train_kan(model, train_loader, val_loader, epochs, lr):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optim     = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)

    best_val = 0.0
    for ep in range(1, epochs + 1):
        tic = time.time()
        # ─ train ─
        model.train()
        loss_sum = acc_sum = 0.0
        for batch in train_loader:
            x, y = batch["image"].to(device), batch["label"].to(device)
            optim.zero_grad(set_to_none=True)
            logits = model(x)
            loss   = criterion(logits, y)
            loss.backward()
            optim.step()

            acc = (logits.argmax(1) == y).float().mean().item()
            loss_sum += loss.item() * x.size(0)
            acc_sum  += acc * x.size(0)

        # ─ eval ─
        model.eval()
        loss_sum_val = acc_sum_val = 0.0
        with torch.no_grad():
            for batch in val_loader:
                x, y = batch["image"].to(device), batch["label"].to(device)
                logits = model(x)
                loss   = criterion(logits, y)
                acc    = (logits.argmax(1) == y).float().mean().item()
                loss_sum_val += loss.item() * x.size(0)
                acc_sum_val  += acc * x.size(0)

        train_loss = loss_sum     / len(train_loader.dataset)
        train_acc  = acc_sum      / len(train_loader.dataset)
        val_loss   = loss_sum_val / len(val_loader.dataset)
        val_acc    = acc_sum_val  / len(val_loader.dataset)
        elapsed    = time.time() - tic

        # ← print exactly like you had it
        print(f"Epoch [{ep}/{epochs}], "
              f"Loss: {train_loss:.4f}, "
              f"Test Acc: {val_acc*100:5.2f}%, "
              f"Time: {elapsed:5.2f} seconds")

        best_val = max(best_val, val_acc)

    return best_val


def optimize_kan(epochs,
                 kan_inner,
                 kan_outer,
                 spline_cp,
                 spline_deg,
                 range_min,
                 range_max,
                 lr):

    # ─ cast + sanity ─
    epochs      = int(round(epochs))
    kan_inner   = int(round(kan_inner))
    kan_outer   = int(round(kan_outer))
    spline_cp   = int(round(spline_cp))
    spline_deg  = int(round(spline_deg))

    # keep B‑spline well‑formed
    spline_deg  = max(2, min(spline_deg, spline_cp - 1))
    lr          = float(lr)

    model = KANCNN(
        kan_inner=kan_inner,
        kan_outer=kan_outer,
        spline_cp=spline_cp,
        spline_deg=spline_deg,
        range_min=range_min,
        range_max=range_max
    )

    val_acc = train_kan(model, train_loader, val_loader, epochs, lr)

    # BayesOpt maximizes the returned value
    return val_acc



In [None]:
optimizer = BayesianOptimization(
    f=optimize_kan,
    pbounds=pbounds,
    random_state=38,
    verbose=2
)

# 8 random warm‑ups + 10 guided iterations
optimizer.maximize(init_points=6, n_iter=25)

print("🚀 best combo so far →", optimizer.max)