In [1]:
%pip install --quiet pykan
%pip install --quiet git+https://github.com/ZiyaoLi/fast-kan.git

%pip install ipywidgets --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.1/78.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fastkan (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import argparse
import csv
import math
import os
import random
import time
from dataclasses import dataclass
from typing import Tuple, List, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset
from kan import KAN as PyKAN
from fastkan import FastKAN as FastKANNet
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as W
from IPython.display import display, clear_output
from torch.utils.data import TensorDataset
from IPython.display import display as _display
from google.colab import drive
from pathlib import Path
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.amp import GradScaler, autocast

In [3]:
def set_seed(seed: int):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)


def count_params(module: nn.Module) -> int:
  return sum(p.numel() for p in module.parameters() if p.requires_grad)


def accuracy_top1(logits: torch.Tensor, targets: torch.Tensor) -> float:
  preds = logits.argmax(dim=1)
  return (preds == targets).float().mean().item()


def nll_criterion(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
  return F.cross_entropy(logits, targets, reduction="mean")


def brier_score(logits: torch.Tensor, targets: torch.Tensor) -> float:
  probs = F.softmax(logits, dim=1)
  targets_oh = F.one_hot(targets, num_classes=probs.size(1)).float()
  return ((probs - targets_oh) ** 2).sum(dim=1).mean().item()


def ece_score(logits: torch.Tensor, targets: torch.Tensor, n_bins: int = 10) -> float:
  probs = torch.softmax(logits, dim=1)
  conf, preds = probs.max(dim=1)
  correct = preds.eq(targets).float()
  conf_sorted, idx = torch.sort(conf)
  correct_sorted = correct[idx]
  N = conf.numel()
  if N == 0:
    return 0.0
  edges = torch.linspace(0, N, steps=n_bins + 1, device=logits.device).round().long()

  ece = torch.zeros((), device=logits.device)
  for i in range(n_bins):
    lo, hi = edges[i].item(), edges[i + 1].item()
    if hi <= lo:
      continue
    conf_bin = conf_sorted[lo:hi].mean()
    acc_bin = correct_sorted[lo:hi].mean()
    weight = (hi - lo) / N
    ece += weight * (acc_bin - conf_bin).abs()
  return ece.item()




In [4]:
class WrappedDataset(Dataset):
  def __init__(self, base, transform):
    self.base = base
    self.transform = transform

  def __len__(self):
    return len(self.base)

  @property
  def targets(self):
    return self.base.targets

  def __getitem__(self, i):
    x, y = self.base[i]
    return self.transform(x), y


In [5]:
class TemperatureScaler(nn.Module):
  def __init__(self, initial_temp=1.5):
    super().__init__()
    self.temperature = nn.Parameter(torch.ones(1) * initial_temp)

  def forward(self, logits):
    return logits / self.temperature


def find_optimal_temp(logits, labels, device="cuda", max_iter=50, lr=0.01):
  scaler = TemperatureScaler().to(device)
  logits = logits.to(device)
  labels = labels.to(device)
  optimizer = torch.optim.LBFGS([scaler.temperature], lr=lr, max_iter=max_iter
  )
  def closure():
    optimizer.zero_grad()
    scaled_logits = scaler(logits)
    loss = F.cross_entropy(scaled_logits, labels)
    loss.backward()
    return loss


  optimizer.step(closure)
  return scaler.temperature.item()

In [6]:
class LabelSmoothingCE(nn.Module):
  def __init__(self, smoothing: float = 0.0):
    super().__init__()
    self.smoothing = smoothing

  def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    if self.smoothing <= 0.0:
      return F.cross_entropy(logits, targets)
    num_classes = logits.size(1)
    log_probs = F.log_softmax(logits, dim=1)
    with torch.no_grad():
      true_dist = torch.zeros_like(log_probs)
      true_dist.fill_(self.smoothing / (num_classes - 1))
      true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)
    return torch.mean(torch.sum(-true_dist * log_probs, dim=1))



In [7]:

class PyKANHead(nn.Module):
  def __init__(self,
        d_in: int,
        d_hid: int,
        num_classes: int,
        grid: int = 5,
        k: int = 3,
        base_fun: str = "silu",
    ):
    super().__init__()
    self.model = PyKAN(
        width=[d_in, d_hid, num_classes],
        grid=grid,
        k=k,
        base_fun=base_fun,
        symbolic_enabled=False,
    )
    if hasattr(self.model, "speed"):
      self.model.speed()

      self.expects_normalized_inputs = True

    def forward(self, z):
        return self.model(z)


class FastKANHead(nn.Module):
  def __init__(
        self,
        d_in: int,
        d_hid: int,
        num_classes: int,
        num_grids: int = 8,
        use_base_update: bool = True,
    ):
    super().__init__()
    self.model = FastKANNet(
            layers_hidden=[d_in, d_hid, num_classes],
            num_grids=num_grids,
            use_base_update=use_base_update,
        )

    self.expects_normalized_inputs = True

    def forward(self, z):
        return self.model(z)


In [8]:
class MLPHead(nn.Module):
    def __init__(self, d_in: int, d_hid: int, num_classes: int, dropout: float = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hid),
            nn.SiLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_hid, num_classes),
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        return self.net(z)




class FrozenResNet18(nn.Module):
    def __init__(self, weights=None):
        super().__init__()
        m = models.resnet18(weights=weights)
        self.features = nn.Sequential(*list(m.children())[:-1])
        for p in self.features.parameters():
            p.requires_grad_(False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.features(x)
        return torch.flatten(z, 1)


In [9]:
def tune_hidden_to_param_target(build_fn, d_in, num_classes, target_params, h0=512):

    best_h, best_diff = h0, float("inf")
    h = h0
    for step in [256, 128, 64, 32, 16, 8, 4, 2, 1]:
        improved = True
        while improved:
            improved = False
            for delta in (-step, step):
                h_try = max(4, h + delta)
                m = build_fn(h_try)
                p = count_params(m)
                diff = abs(p - target_params)
                if diff < best_diff:
                    best_h, best_diff, h = h_try, diff, h_try
                    improved = True
    return best_h


In [10]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


def build_transforms(dataset: str, train: bool) -> transforms.Compose:
    if dataset == "cifar10":
        if train:
            return transforms.Compose(
                [
                    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
                ]
            )
    elif dataset == "mnist":
        if train:
            return transforms.Compose(
                [
                    transforms.Resize(224),
                    transforms.RandomAffine(
                        degrees=10, translate=(0.05, 0.05), scale=(0.9, 1.1)
                    ),
                    transforms.ToTensor(),
                    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 1->3 channels
                    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
                ]
            )
    else:
        raise ValueError("Unknown dataset")


def load_dataset(dataset: str, data_dir: str):
    if dataset == "cifar10":
        train_ds = datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=None
        )
        test_ds = datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=None
        )
        num_classes = 10
    elif dataset == "mnist":
        train_ds = datasets.MNIST(
            root=data_dir, train=True, download=True, transform=None
        )
        test_ds = datasets.MNIST(
            root=data_dir, train=False, download=True, transform=None
        )
        num_classes = 10
    else:
        raise ValueError("Unknown dataset")
    return train_ds, test_ds, num_classes


def stratified_fraction_indices(
    targets: List[int], fraction: float, num_classes: int, seed: int
) -> List[int]:
    rng = np.random.RandomState(seed)
    targets = np.array(targets)
    idxs = []
    for c in range(num_classes):
        class_idx = np.where(targets == c)[0]
        rng.shuffle(class_idx)
        k = max(1, int(math.floor(len(class_idx) * fraction)))
        idxs.extend(class_idx[:k].tolist())
    rng.shuffle(idxs)
    return idxs


def make_loaders(
    dataset: str,
    data_dir: str,
    label_fraction: float,
    seed: int,
    batch_size: int,
    num_workers: int = 4,
):
    train_raw, test_raw, num_classes = load_dataset(dataset, data_dir)
    tr_tf = build_transforms(dataset, train=True)
    te_tf = build_transforms(dataset, train=False)

    targets = (
        train_raw.targets if hasattr(train_raw, "targets") else train_raw.train_labels
    )
    sub_idx = stratified_fraction_indices(targets, label_fraction, num_classes, seed)
    train_sub = Subset(WrappedDataset(train_raw, tr_tf), sub_idx)

    val_size = max(1, int(0.1 * len(train_sub)))
    train_size = len(train_sub) - val_size
    train_ds, val_ds = random_split(
        train_sub, [train_size, val_size], generator=torch.Generator().manual_seed(seed)
    )

    test_ds = WrappedDataset(test_raw, te_tf)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
    )
    return train_loader, val_loader, test_loader, num_classes


In [11]:
@torch.no_grad()
def estimate_feature_stats(backbone, loader, device, max_batches=32):
    m1, m2, n = 0.0, 0.0, 0
    for b, (x, _) in enumerate(loader):
        if b >= max_batches:
            break
        x = x.to(device)
        z = forward_features(backbone, x)
        m1 += z.sum(0)
        m2 += (z**2).sum(0)
        n += z.size(0)
    mu = m1 / n
    var = (m2 / n) - mu**2
    std = torch.sqrt(var.clamp_min(1e-6))
    return mu, std

In [12]:
def precompute_and_save_features(backbone, loader, device, save_path):
    """Extracts features and labels from a data loader and saves them to a file."""
    if os.path.exists(save_path):
        print(f"Features found at {save_path}, loading them.")
        return

    print(f"No pre-computed features found. Creating them at {save_path}...")
    backbone.eval()
    all_features = []
    all_labels = []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            features = backbone(x)
            all_features.append(torch.flatten(features, 1).cpu())
            all_labels.append(y.cpu())

    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save({"features": all_features, "labels": all_labels}, save_path)
    print(f"Saved pre-computed features to {save_path}")

In [22]:
@dataclass
class TrainConfig:
    dataset: str
    data_dir: str
    label_fraction: float
    seed: int
    device: str
    epochs: int
    patience: int
    batch_size: int
    lr: float
    weight_decay: float
    smoothing: float
    mlp_hidden: int
    kan_hidden: int

    kan_K: int
    match_params: bool
    results_dir: str
    tag: str


def build_backbone(device: str):
    try:
        weights = models.ResNet18_Weights.IMAGENET1K_V1
    except Exception:
        weights = None
    backbone = FrozenResNet18(weights=weights).to(device)
    backbone.eval()
    return backbone


def forward_features(backbone: nn.Module, x: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        return backbone(x)









def train_head(
    cfg: TrainConfig,
    model_head: nn.Module,
    loaders,
    label_smoothing: float,
    mu: torch.Tensor = None,
    std: torch.Tensor = None,
    feature_dropout_p: float = 0.0,
) -> Tuple[Dict[str, float], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Train a classification head with early stopping and AMP.

    Returns
    -------
    results : dict
        Final metrics (acc, nll, brier, ece, etc.)
    logits_val, y_val : torch.Tensor
        Validation logits and labels (best model, CPU)
    logits_test, y_test : torch.Tensor
        Test logits and labels (best model, CPU)
    """
    global _cancelled

    train_loader, val_loader, test_loader, num_classes = loaders
    device = torch.device(cfg.device)

    eps = 1e-6
    criterion = LabelSmoothingCE(smoothing=label_smoothing).to(device)
    opt = torch.optim.AdamW(
        model_head.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
    )
    scheduler = CosineAnnealingLR(opt, T_max=cfg.epochs, eta_min=1e-6)

    best_val_loss = float("inf")
    best_state = None
    best_epoch = -1
    no_improve = 0
    device_str = str(device)
    scaler = GradScaler(device_str, enabled=(device_str == "cuda"))

    # ------------------------------------------------------------------
    # Training loop
    # ------------------------------------------------------------------
    for epoch in range(cfg.epochs):
        if _cancelled:
            print("Cancelled by user.")
            break

        # ---------------- Train one epoch ----------------
        model_head.train()
        for z, y in train_loader:
            z, y = z.to(device, non_blocking=True), y.to(device, non_blocking=True)

            # Optional feature dropout
            if feature_dropout_p > 0.0:
                z = F.dropout(z, p=feature_dropout_p, training=True)

            # Normalise
            z_norm = (z - mu.to(device)) / (std.to(device) + eps)

            with autocast(device_str, enabled=(device_str == "cuda")):
              logits = model_head(z_norm)
              loss = criterion(logits, y)

            opt.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

        scheduler.step()  # <-- once per epoch

        # ---------------- Validation ----------------
        model_head.eval()
        total_val_loss, n_val = 0.0, 0
        with torch.no_grad():
            for z, y in val_loader:
                z, y = z.to(device, non_blocking=True), y.to(device, non_blocking=True)
                z_norm = (z - mu.to(device)) / (std.to(device) + eps)

                with autocast(device_type=str(device), enabled=(str(device) == "cuda")):
                    logits = model_head(z_norm)
                    total_val_loss += F.cross_entropy(logits, y, reduction="sum").item()
                n_val += y.size(0)

        current_val_loss = total_val_loss / n_val
        print(f"Epoch {epoch+1:03d}/{cfg.epochs}: val_loss={current_val_loss:.4f}")

        # ---------------- Early stopping ----------------
        if current_val_loss < best_val_loss - 1e-5:
            best_val_loss = current_val_loss
            best_state = {k: v.detach().cpu() for k, v in model_head.state_dict().items()}
            best_epoch = epoch + 1
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= cfg.patience:
                print(f"Early stopping at epoch {epoch+1}. Best epoch = {best_epoch}.")
                break

    # ---------------- Load best model ----------------
    if best_state is not None:
        model_head.load_state_dict(best_state)
    else:
        # Never improved – keep last weights
        pass

    # ---------------- Final evaluation ----------------
    model_head.eval()
    with torch.no_grad():
        # Validation set
        logits_val, y_val = [], []
        for z, y in val_loader:
            z, y = z.to(device, non_blocking=True), y.to(device, non_blocking=True)
            z_norm = (z - mu.to(device)) / (std.to(device) + eps)
            with autocast(device_type=str(device), enabled=(str(device) == "cuda")):
                logits_val.append(model_head(z_norm))
            y_val.append(y)
        logits_val = torch.cat(logits_val).cpu()
        y_val = torch.cat(y_val).cpu()

        # Test set
        logits_test, y_test = [], []
        for z, y in test_loader:
            z, y = z.to(device, non_blocking=True), y.to(device, non_blocking=True)
            z_norm = (z - mu.to(device)) / (std.to(device) + eps)
            with autocast(device_type=str(device), enabled=(str(device) == "cuda")):
                logits_test.append(model_head(z_norm))
            y_test.append(y)
        logits_test = torch.cat(logits_test).cpu()
        y_test = torch.cat(y_test).cpu()

    # ---------------- Metrics ----------------
    acc = accuracy_top1(logits_test, y_test)
    nll = nll_criterion(logits_test, y_test).item()
    brier = brier_score(logits_test, y_test)
    ece = ece_score(logits_test, y_test)

    results = {
        "acc": acc,
        "nll_uncalibrated": nll,
        "brier_uncalibrated": brier,
        "ece_uncalibrated": ece,
        "params": count_params(model_head),
        "epoch_best": best_epoch,
    }

    return results, logits_val, y_val, logits_test, y_test

In [14]:
@torch.no_grad()
def predict_logits(
    backbone, head, loader, device, mu, std, eps: float = 1e-6
):  # Added mu, std as required arguments
    head.eval()
    all_logits, all_y = [], []
    mu, std = mu.to(device), std.to(device)  # Move mu and std to device

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        z = forward_features(backbone, x)
        z = (z - mu) / (std + eps)  # Normalization using provided mu and std
        all_logits.append(head(z).cpu())
        all_y.append(y.cpu())
    return torch.cat(all_logits, 0), torch.cat(all_y, 0)


In [23]:
def run_experiment(
    dataset: str = "cifar10",
    label_fraction: float = 0.02,
    modes=("mlp", "pykan", "fastkan"),
    seeds=(42,),
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    epochs: int = 100,
    patience: int = 20,
    batch_size: int = 128,
    lr: float = 3e-3,
    weight_decay: float = 1e-4,
    smoothing: float = 0.1,
    mlp_hidden: int = 512,
    kan_hidden: int = 512,
    pykan_grid: int = 5,
    pykan_k: int = 3,
    fastkan_grids: int = 8,
    match_params: bool = True,
    ece_bins: int = 15,
):
    all_results = []
    all_logits_targets = {}
    device = torch.device(device)

    # Absolute path for cached features
    feature_root = Path("/content") / "data" / "features" / dataset
    feature_root.mkdir(parents=True, exist_ok=True)
    train_feat_path = feature_root / "train_features.pt"
    test_feat_path  = feature_root / "test_features.pt"

    # Candidate grids for parameter matching
    pykan_grid_candidates   = [3, 5, 8, 12]
    fastkan_grids_candidates = [3, 5, 8, 12, 16]

    for seed in seeds:
        if _cancelled:
            break
        set_seed(seed)
        print(f"\n=====  SEED {seed}  =====")

        # --------------------------------------------------------------
        # 1.  Pre-compute features once
        # --------------------------------------------------------------
        if not (train_feat_path.exists() and test_feat_path.exists()):
            backbone = build_backbone(device)
            train_raw_full, test_raw, num_classes = load_dataset(dataset, "./data")
            eval_tf = build_transforms(dataset, train=False)

            tmp_tr = DataLoader(WrappedDataset(train_raw_full, eval_tf),
                                batch_size=batch_size*2, shuffle=False,
                                num_workers=2, pin_memory=True)
            tmp_te = DataLoader(WrappedDataset(test_raw,  eval_tf),
                                batch_size=batch_size*2, shuffle=False,
                                num_workers=2, pin_memory=True)

            precompute_and_save_features(backbone, tmp_tr, device, str(train_feat_path))
            precompute_and_save_features(backbone, tmp_te, device, str(test_feat_path))
            del backbone, tmp_tr, tmp_te
            if str(device) == "cuda":
                torch.cuda.empty_cache()

        # --------------------------------------------------------------
        # 2.  Build datasets
        # --------------------------------------------------------------
        train_data = torch.load(train_feat_path)
        test_data  = torch.load(test_feat_path)

        full_feat = train_data["features"]
        full_lbls = train_data["labels"]
        test_feat = test_data["features"]
        test_lbls = test_data["labels"]
        num_classes = int(full_lbls.max().item()) + 1

        # Subset by label fraction
        sub_idx = stratified_fraction_indices(full_lbls, label_fraction,
                                              num_classes, seed)
        sub_feat = full_feat[sub_idx]
        sub_lbls = full_lbls[sub_idx]

        # Train / val split
        val_size = max(1, int(0.1 * len(sub_idx)))
        train_size = len(sub_idx) - val_size
        train_ds, val_ds = random_split(
            TensorDataset(sub_feat, sub_lbls),
            [train_size, val_size],
            generator=torch.Generator().manual_seed(seed)
        )

        # Data loaders
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                                  num_workers=2, pin_memory=True)
        val_loader   = DataLoader(val_ds,   batch_size=batch_size*2, shuffle=False,
                                  num_workers=2, pin_memory=True)
        test_loader  = DataLoader(TensorDataset(test_feat, test_lbls),
                                  batch_size=batch_size*2, shuffle=False,
                                  num_workers=2, pin_memory=True)

        # Normalisation stats over *full* training subset
        combined_idx = np.array(train_ds.indices)
        mu = sub_feat[combined_idx].mean(0)
        std = sub_feat[combined_idx].std(0)

        # --------------------------------------------------------------
        # 3.  Parameter-matching
        # --------------------------------------------------------------
        target_params = count_params(MLPHead(sub_feat.size(1), mlp_hidden, num_classes))
        h_py, chosen_py_grid = kan_hidden, pykan_grid
        h_fk, chosen_fk_grid = kan_hidden, fastkan_grids

        if match_params:
            # PyKAN
            best_py_combo, best_py_diff = None, float("inf")
            for g in pykan_grid_candidates:
                def build(h): return PyKANHead(sub_feat.size(1), h, num_classes,
                                               grid=g, k=pykan_k)
                h_try = tune_hidden_to_param_target(build, sub_feat.size(1),
                                                    num_classes, target_params,
                                                    h0=mlp_hidden)
                diff = abs(count_params(build(h_try)) - target_params)
                if diff < best_py_diff:
                    best_py_diff = diff
                    best_py_combo = (h_try, g)
            h_py, chosen_py_grid = best_py_combo
            print(f"[Param-match] PyKAN  hidden={h_py}  grid={chosen_py_grid}")

            # FastKAN
            best_fk_combo, best_fk_diff = None, float("inf")
            for g in fastkan_grids_candidates:
                def build(h): return FastKANHead(sub_feat.size(1), h, num_classes,
                                                 num_grids=g)
                h_try = tune_hidden_to_param_target(build, sub_feat.size(1),
                                                    num_classes, target_params,
                                                    h0=mlp_hidden)
                diff = abs(count_params(build(h_try)) - target_params)
                if diff < best_fk_diff:
                    best_fk_diff = diff
                    best_fk_combo = (h_try, g)
            h_fk, chosen_fk_grid = best_fk_combo
            print(f"[Param-match] FastKAN hidden={h_fk}  grids={chosen_fk_grid}")

        # --------------------------------------------------------------
        # 4.  Train heads
        # --------------------------------------------------------------
        cfg_obj = TrainConfig(
            dataset=dataset, data_dir="./data", label_fraction=label_fraction,
            seed=seed, device=str(device), epochs=epochs, patience=patience,
            batch_size=batch_size, lr=lr, weight_decay=weight_decay,
            smoothing=smoothing, mlp_hidden=mlp_hidden, kan_hidden=kan_hidden,
            kan_K=pykan_k, match_params=match_params, results_dir="./results", tag=""
        )

        loaders = (train_loader, val_loader, test_loader, num_classes)

        builders = {
            "mlp":    lambda: MLPHead(sub_feat.size(1), mlp_hidden, num_classes),
            "pykan":  lambda: PyKANHead(sub_feat.size(1), h_py, num_classes,
                                        grid=chosen_py_grid, k=pykan_k),
            "fastkan": lambda: FastKANHead(sub_feat.size(1), h_fk, num_classes,
                                           num_grids=chosen_fk_grid),
        }

        for name, builder in builders.items():
            if name not in modes:
                continue
            print(f"\n---- {name.upper()} head ----")
            head = builder().to(device)
            if str(device) == "cuda":
                head = torch.compile(head)

            base_res, logits_val, y_val, logits_test, y_test = train_head(
                cfg_obj, head, loaders, smoothing, mu=mu, std=std
            )

            # Temperature scaling
            temp = find_optimal_temp(logits_val, y_val, cfg_obj.device)
            scaled_logits = logits_test / temp

            nll_cal = nll_criterion(scaled_logits, y_test).item()
            bri_cal = brier_score(scaled_logits, y_test)
            ece_cal = ece_score(scaled_logits, y_test, n_bins=ece_bins)

            print(f"ECE {base_res['ece_uncalibrated']:.4f} -> {ece_cal:.4f}")

            res = {
                "model": name, "seed": seed, **base_res,
                "temperature": temp,
                "nll_calibrated": nll_cal,
                "brier_calibrated": bri_cal,
                "ece_calibrated": ece_cal,
                "grid": chosen_py_grid if name == "pykan" else (
                        chosen_fk_grid if name == "fastkan" else None),
            }
            all_results.append(res)

            all_logits_targets[f"{name}_seed_{seed}"] = (logits_test, y_test)
            all_logits_targets[f"{name}_cal_seed_{seed}"] = (scaled_logits, y_test)

    # --------------------------------------------------------------
    # 5.  Result table
    # --------------------------------------------------------------
    df = pd.DataFrame(all_results)
    col_order = [
        "model", "seed", "params", "acc", "epoch_best", "temperature", "grid",
        "nll_uncalibrated", "nll_calibrated",
        "brier_uncalibrated", "brier_calibrated",
        "ece_uncalibrated", "ece_calibrated",
    ]
    df = df.reindex(columns=col_order)

    return df, all_logits_targets

In [16]:
def plot_reliability_diagrams(logits_dict, n_bins=15, title="Reliability Diagram"):
    plt.figure(figsize=(5, 5))
    xs = np.linspace(0, 1, 101)
    plt.plot(xs, xs, linestyle="--", label="perfect")
    for name, (logits, y) in logits_dict.items():
        bc, ba = reliability_curve(logits, y, n_bins=n_bins)
        plt.plot(bc, ba, marker="o", label=name)
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.tight_layout()
    plt.show()


In [17]:
def plot_reliability_diagrams(logits_dict, n_bins=15, title="Reliability Diagram"):
    plt.figure(figsize=(5,5))
    xs = np.linspace(0,1,101)
    plt.plot(xs, xs, linestyle="--", label="perfect")
    for name, (logits, y) in logits_dict.items():
        bc, ba =  reliability_curve(logits, y, n_bins=n_bins)
        plt.plot(bc, ba, marker="o", label=name)
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.tight_layout()
    plt.show()

In [18]:
def plot_metric_summary(df, metrics=("acc", "ece"), *, combine=True, show=False):
    """
    Compute mean/std per model and plot bars with error bars.
    Returns a Figure (or dict of Figures when combine=False).
    """
    grouped = df.groupby("model")[list(metrics)].agg(["mean", "std"])
    nseeds = int(df["seed"].nunique()) if "seed" in df.columns else None

    if combine:
        n = len(metrics)
        fig, axes = plt.subplots(1, n, figsize=(5 * n, 3), squeeze=False)
        axes = axes[0]
        for ax, m in zip(axes, metrics):
            means = grouped[(m, "mean")].astype(float)
            stds = grouped[(m, "std")].astype(float).fillna(0.0)
            ax.bar(means.index.astype(str), means.values, yerr=stds.values, capsize=5)
            title = (
                f"{m.upper()} (avg over {nseeds} seed{'s' if nseeds!=1 else ''})"
                if nseeds
                else m.upper()
            )
            ax.set_title(title)
            ax.set_ylabel(m)
            ax.grid(axis="y", ls="--", alpha=0.3)
            ax.set_axisbelow(True)
        fig.tight_layout()
        if show:
            from IPython.display import display as _display

            _display(fig)
        return fig
    else:
        figs = {}
        for m in metrics:
            fig, ax = plt.subplots(figsize=(5, 3))
            means = grouped[(m, "mean")].astype(float)
            stds = grouped[(m, "std")].astype(float).fillna(0.0)
            ax.bar(means.index.astype(str), means.values, yerr=stds.values, capsize=5)
            title = (
                f"{m.upper()} (avg over {nseeds} seed{'s' if nseeds!=1 else ''})"
                if nseeds
                else m.upper()
            )
            ax.set_title(title)
            ax.set_ylabel(m)
            ax.grid(axis="y", ls="--", alpha=0.3)
            ax.set_axisbelow(True)
            fig.tight_layout()
            if show:
                _display(fig)
            figs[m] = fig
        return figs

In [19]:
@torch.no_grad()
def reliability_curve(logits, targets, n_bins=15):
    probs = torch.softmax(logits, dim=1)
    conf, preds = probs.max(dim=1)
    correct = preds.eq(targets).float()
    conf_s, idx = torch.sort(conf)
    corr_s = correct[idx]
    N = len(conf_s)
    edges = torch.linspace(0, N, steps=n_bins + 1, device=conf.device).round().long()
    xs, ys = [], []
    for i in range(n_bins):
        lo, hi = edges[i].item(), edges[i + 1].item()
        if hi <= lo:
            continue
        xs.append(conf_s[lo:hi].mean().item())
        ys.append(corr_s[lo:hi].mean().item())
    return xs, ys

In [20]:
def plot_reliability(logits_dict, n_bins=15, title="Reliability", *, show=False):
    """
    Draw reliability curves for each model and return the Figure.
    Assumes reliability_curve(logits, y, n_bins) -> (bin_conf, bin_acc).
    """
    xs = np.linspace(0, 1, 101)
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.plot(xs, xs, "--", label="perfect")
    for name, (logits, y) in logits_dict.items():
        bx, by = reliability_curve(logits, y, n_bins)
        ax.plot(bx, by, marker="o", label=name)
    ax.set_xlabel("Confidence")
    ax.set_ylabel("Accuracy")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, ls="--", alpha=0.3)
    ax.set_axisbelow(True)
    fig.tight_layout()
    if show:
        _display(fig)
    return fig


In [21]:


drive.mount('/content/drive')



# ------------------------------------------------------------------
#  Widget definitions (unchanged)
# ------------------------------------------------------------------
_label_opts = [("0.5%", 0.005), ("1%", 0.01), ("2%", 0.02), ("5%", 0.05)]

dataset_dd   = W.Dropdown(options=["cifar10", "mnist"], value="cifar10",
                          description="Dataset")
labels_ms    = W.SelectMultiple(options=_label_opts,
                                value=tuple(v for _, v in _label_opts),
                                description="Label %")
modes_ms     = W.SelectMultiple(options=["mlp", "pykan", "fastkan"],
                                value=("mlp", "pykan", "fastkan"),
                                description="Models")
seeds_text   = W.Text(value="42", description="Seeds (comma-sep)")
epochs_int   = W.IntSlider(value=150, min=20, max=300, step=10, description="Epochs")
patience_int = W.IntSlider(value=20, min=5, max=60, step=5, description="Patience")
lr_float     = W.FloatLogSlider(value=3e-3, base=10, min=-5, max=-1, step=0.2,
                                description="LR")
wd_float     = W.FloatLogSlider(value=1e-4, base=10, min=-6, max=-2, step=0.2,
                                description="Weight Decay")
smooth_float = W.FloatSlider(value=0.1, min=0.0, max=0.2, step=0.01,
                             description="Label Smooth")
match_cb     = W.Checkbox(value=True, description="Param-match KAN")
ece_bins_int = W.IntSlider(value=15, min=5, max=40, step=1, description="ECE bins")
device_dd    = W.Dropdown(options=["auto", "cuda", "cpu"], value="auto",
                          description="Device")

drive_path_text = W.Text(
    value="MyDrive/Colab_KAN_Runs",
    description="Google Drive Path",
    style={'description_width': 'initial'},
    layout=W.Layout(width='50%')
)

run_btn = W.Button(description="Run", button_style="primary", icon="play")
stop_btn = W.Button(description="Stop", button_style="danger", icon="stop")

_cancelled = False
out = W.Output()

ui = W.VBox([
    W.HBox([dataset_dd, labels_ms, modes_ms]),
    W.HBox([seeds_text, epochs_int, patience_int]),
    W.HBox([lr_float, wd_float, smooth_float]),
    W.HBox([match_cb, ece_bins_int, device_dd]),
    drive_path_text,
    W.HBox([run_btn, stop_btn]),
    out,
])
display(ui)

# ------------------------------------------------------------------
#  Helper utilities
# ------------------------------------------------------------------
def _pick_device(val: str) -> str:
    return "cuda" if torch.cuda.is_available() else "cpu" if val == "auto" else val

# ------------------------------------------------------------------
#  Stop button handler
# ------------------------------------------------------------------
@stop_btn.on_click
def _stop(_b):
    global _cancelled
    _cancelled = True
    with out:
        print("\n[STOP] Cancellation requested – will finish current fraction.")

# ------------------------------------------------------------------
#  Run button handler
# ------------------------------------------------------------------
@run_btn.on_click
def _run(_b):
    global _cancelled
    _cancelled = False
    import traceback

    with out:
        clear_output(wait=True)

        # Normalised, absolute Drive path
        gdrive_base_path = Path("/content/drive") / Path(drive_path_text.value.strip("/"))
        gdrive_base_path.mkdir(parents=True, exist_ok=True)

        modes = list(modes_ms.value)
        try:
            seeds = [int(s.strip()) for s in seeds_text.value.split(",") if s.strip()]
        except ValueError:
            print("❌ Invalid seeds format. Use comma-separated integers (e.g., 42, 43).")
            return

        fractions = list(labels_ms.value)
        if not fractions:
            print("❌ Select at least one label fraction.")
            return

        all_results = []
        print(f"📁 Results will be saved to: {gdrive_base_path}")
        print(f"Running {modes} on {dataset_dd.value} with seeds: {seeds}")
        print(f"Fractions: {', '.join(f'{f*100:.1f}%' for f in fractions)}")

        for frac in fractions:
            if _cancelled:
                print("\n[STOP] Cancelled mid-run.")
                break

            print(f"\n===== Label fraction {frac*100:.1f}% =====")
            try:
                df, logits_targets_dict = run_experiment(
                    dataset=dataset_dd.value,
                    label_fraction=frac,
                    modes=modes,
                    seeds=seeds,
                    device=_pick_device(device_dd.value),
                    epochs=int(epochs_int.value),
                    patience=int(patience_int.value),
                    lr=float(lr_float.value),
                    weight_decay=float(wd_float.value),
                    smoothing=float(smooth_float.value),
                    match_params=bool(match_cb.value),
                    ece_bins=int(ece_bins_int.value),
                )
            except Exception as e:
                print("💥 Exception during run_experiment:")
                traceback.print_exc()
                continue

            df["label_fraction"] = frac
            all_results.append(df.copy())

            run_id = (
                f"{dataset_dd.value}_labels{int(frac*100)}_"
                f"{'match' if match_cb.value else 'nomatch'}_"
                f"epochs{epochs_int.value}_"
                f"seeds{'-'.join(map(str, seeds))}"
            )

            csv_path = gdrive_base_path / f"results_{run_id}.csv"
            df.to_csv(csv_path, index=False)
            print(f"✅ Saved CSV → {csv_path}")

            # Metric summary plot
            fig = plot_metric_summary(
                df, metrics=["acc", "ece"], combine=True, show=False
            )
            if fig:
                metrics_path = gdrive_base_path / f"metrics_{run_id}.png"
                fig.savefig(metrics_path, dpi=200, bbox_inches="tight")
                plt.close(fig)
                print(f"✅ Saved Metrics Plot → {metrics_path}")

            # Reliability diagram
            fig = plot_reliability(
                logits_targets_dict,
                n_bins=int(ece_bins_int.value),
                title=f"{dataset_dd.value.upper()} @ {frac*100:.1f}%",
                show=False,
            )
            if fig:
                reliability_path = gdrive_base_path / f"reliability_{run_id}.png"
                fig.savefig(reliability_path, dpi=200, bbox_inches="tight")
                plt.close(fig)
                print(f"✅ Saved Reliability Plot → {reliability_path}")

        if all_results:
            full = pd.concat(all_results, ignore_index=True)
            full_csv_name = (
                f"combined_{dataset_dd.value}_"
                f"{'match' if match_cb.value else 'nomatch'}_"
                f"epochs{epochs_int.value}_"
                f"seeds{'-'.join(map(str, seeds))}.csv"
            )
            full_csv_path = gdrive_base_path / full_csv_name
            full.to_csv(full_csv_path, index=False)
            print(f"\n✅ Saved combined CSV → {full_csv_path}")
            display(full)

        print("\nAll done.")

Mounted at /content/drive


VBox(children=(HBox(children=(Dropdown(description='Dataset', options=('cifar10', 'mnist'), value='cifar10'), …