In [5]:
# Cell 2B: CPU-only reinstall
!pip uninstall -y torch torchvision torchaudio
!pip install --no-cache-dir --index-url https://download.pytorch.org/whl/cpu \
    torch torchvision torchaudio


Found existing installation: torch 2.9.0+cpu
Uninstalling torch-2.9.0+cpu:
  Successfully uninstalled torch-2.9.0+cpu
Found existing installation: torchvision 0.24.0+cpu
Uninstalling torchvision-0.24.0+cpu:
  Successfully uninstalled torchvision-0.24.0+cpu
Found existing installation: torchaudio 2.9.0+cpu
Uninstalling torchaudio-2.9.0+cpu:
  Successfully uninstalled torchaudio-2.9.0+cpu
Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torch
  Downloading https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.24.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cpu/torchaudio-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.9 kB)
Downloading https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp311-cp311-manylinux_2_28_x86_64.whl 

In [3]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

In [1]:
import torch
print("torch:", torch.__version__, "from", getattr(torch, "__file__", "n/a"))
print("cuda available:", torch.cuda.is_available())


torch: 2.6.0+cu124 from /usr/local/lib/python3.11/dist-packages/torch/__init__.py
cuda available: True


In [1]:
# full_convnextv2_with_kan_and_runner.py
# Single-file: ConvNeXtV2 variants (baseline / SE / KAN / SE+KAN) + KAN + training runner for Kaggle dataset.
# Usage: paste into a notebook cell on Kaggle and run. DATA_ROOT defaults to /kaggle/input/brain-tumor-mri-dataset

import os, math, random, warnings
from copy import deepcopy
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader

# -------------------- Basic utilities & deterministic seed --------------------

def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # deterministic cudnn for reproducibility (may slow down)
    try:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass
    if torch.cuda.is_available():
        try:
            torch.cuda.manual_seed_all(seed)
        except Exception:
            pass

def safe_train_test_split(df, **kwargs):
    try:
        return train_test_split(df, **kwargs)
    except Exception:
        stratify = kwargs.pop('stratify', None)
        try:
            return train_test_split(df, stratify=None, **kwargs)
        except Exception:
            if len(df) == 0:
                return df.copy(), df.copy()
            n = len(df)
            test_size = kwargs.get('test_size', 0.15)
            if isinstance(test_size, float):
                n_test = max(1, int(math.ceil(n * test_size)))
            else:
                n_test = int(test_size)
            if n_test >= n:
                n_test = max(1, n//10)
            idx = list(range(n)); random.shuffle(idx)
            test_idx = idx[:n_test]; train_idx = idx[n_test:]
            df_train = df.iloc[train_idx].reset_index(drop=True)
            df_test  = df.iloc[test_idx].reset_index(drop=True)
            return df_train, df_test

def get_class_paths(path):
    if not os.path.isdir(path):
        return pd.DataFrame(columns=['path','label'])
    paths, labels = [], []
    for label in sorted(os.listdir(path)):
        label_path = os.path.join(path, label)
        if not os.path.isdir(label_path):
            continue
        for fname in sorted(os.listdir(label_path)):
            ext = os.path.splitext(fname)[1].lower()
            if ext in ('.jpg','.jpeg','.png','.webp'):
                paths.append(os.path.join(label_path, fname))
                labels.append(label)
    return pd.DataFrame({'path':paths, 'label':labels})

class MRIDataset(Dataset):
    def __init__(self, df, label2idx, transforms=None):
        self.df = df.reset_index(drop=True)
        self.transforms = transforms
        self.label2idx = label2idx
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert('RGB')
        if self.transforms: img = self.transforms(img)
        else: img = T.ToTensor()(img)
        lbl = self.label2idx[row['label']]
        return img, int(lbl)

# -------------------- ConvNeXt helper utilities (trunc_normal_, drop_path, LayerNorm, GRN) --------------------

def _trunc_normal_(tensor, mean, std, a, b):
    # truncated normal (copied from many reference impls)
    def norm_cdf(x):
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_.", stacklevel=2)
    l = norm_cdf((a - mean) / std); u = norm_cdf((b - mean) / std)
    tensor.uniform_(2 * l - 1, 2 * u - 1); tensor.erfinv_()
    tensor.mul_(std * math.sqrt(2.0)); tensor.add_(mean)
    tensor.clamp_(min=a, max=b)
    return tensor

def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)

def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super().__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class GRN(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x

# -------------------- KAN implementation --------------------

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features
        grid: torch.Tensor = self.grid
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:]
            )
        assert bases.size() == (x.size(0), self.in_features, self.grid_size + self.spline_order)
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)
        A = self.b_splines(x).transpose(0, 1)
        B = y.transpose(0, 1)
        solution = torch.linalg.lstsq(A, B).solution
        result = solution.permute(2, 0, 1)
        assert result.size() == (self.out_features, self.in_features, self.grid_size + self.spline_order)
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0)

    def forward(self, x: torch.Tensor):
        orig_shape = x.shape
        x2 = x.reshape(-1, self.in_features)
        base_output = F.linear(self.base_activation(x2), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x2).view(x2.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        output = output.reshape(*orig_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)
        splines = self.b_splines(x).permute(1, 0, 2)
        orig_coeff = self.scaled_spline_weight.permute(1, 2, 0)
        unreduced_spline_output = torch.bmm(splines, orig_coeff).permute(1, 0, 2)
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]
        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1) * uniform_step + x_sorted[0] - margin)
        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate([
            grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
            grid,
            grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
        ], dim=0)
        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / (regularization_loss_activation + 1e-12)
        regularization_loss_entropy = -torch.sum(p * (p + 1e-12).log())
        return (regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy)


class KAN(torch.nn.Module):
    def __init__(self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1,1]):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order
        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )
    def forward(self, x: torch.Tensor, update_grid=False):
        orig_shape = x.shape
        x_flat = x.reshape(-1, orig_shape[-1])
        out = x_flat
        for layer in self.layers:
            if update_grid:
                layer.update_grid(out)
            out = layer(out)
        out = out.reshape(*orig_shape[:-1], out.shape[-1])
        return out
    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers)

# -------------------- ConvNeXtV2 Blocks --------------------

class ConvNeXtBlockBaseline(nn.Module):
    def __init__(self, dim, drop_path=0.0):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.norm = LayerNorm(dim, eps=1e-6)  # channels_last
        self.pw1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.pw2 = nn.Linear(4 * dim, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    def forward(self, x):
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)   # (N,H,W,C)
        x = self.norm(x)
        x = self.pw1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pw2(x)
        x = x.permute(0, 3, 1, 2)   # back to (N,C,H,W)
        x = shortcut + self.drop_path(x)
        return x

class ConvNeXtBlockSE(nn.Module):
    def __init__(self, dim, drop_path=0.0, se_reduction=4):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pw1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.pw2 = nn.Linear(4 * dim, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        hidden = max(1, dim // se_reduction)
        self.se_fc1 = nn.Linear(dim, hidden)
        self.se_act = nn.GELU()
        self.se_fc2 = nn.Linear(hidden, dim)
        self.se_sigmoid = nn.Sigmoid()

    def forward(self, x):
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)   # (N,H,W,C)
        x = self.norm(x)
        x = self.pw1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pw2(x)

        se = x.mean(dim=(1, 2))              # (N, C)
        se = self.se_fc1(se)
        se = self.se_act(se)
        se = self.se_fc2(se)
        se = self.se_sigmoid(se).unsqueeze(1).unsqueeze(1)  # (N,1,1,C)
        x = x * se

        x = x.permute(0, 3, 1, 2)
        x = shortcut + self.drop_path(x)
        return x

class ConvNeXtBlockKAN(nn.Module):
    def __init__(self, dim, drop_path=0.0, kan_grid_size=5, kan_spline_order=3):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.norm = LayerNorm(dim, eps=1e-6)
        self.kan1 = KAN([dim, 4 * dim], grid_size=kan_grid_size, spline_order=kan_spline_order)
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.kan2 = KAN([4 * dim, dim], grid_size=kan_grid_size, spline_order=kan_spline_order)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    def forward(self, x):
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N,H,W,C)
        x = self.norm(x)
        N, H, W, C = x.shape
        x_flat = x.reshape(-1, C)
        x_flat = self.kan1(x_flat)
        x = x_flat.reshape(N, H, W, 4 * C)
        x = self.act(x)
        x = self.grn(x)
        x_flat2 = x.reshape(-1, 4 * C)
        x_flat2 = self.kan2(x_flat2)
        x = x_flat2.reshape(N, H, W, C)
        x = x.permute(0, 3, 1, 2)
        x = shortcut + self.drop_path(x)
        return x

class ConvNeXtBlockSEKAN(nn.Module):
    def __init__(self, dim, drop_path=0.0, se_reduction=4, kan_grid_size=5, kan_spline_order=3):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.norm = LayerNorm(dim, eps=1e-6)
        self.kan1 = KAN([dim, 4 * dim], grid_size=kan_grid_size, spline_order=kan_spline_order)
        self.act = nn.GELU()
        self.grn = GRN(4 * dim)
        self.kan2 = KAN([4 * dim, dim], grid_size=kan_grid_size, spline_order=kan_spline_order)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        hidden = max(1, dim // se_reduction)
        self.se_fc1 = nn.Linear(dim, hidden)
        self.se_act = nn.GELU()
        self.se_fc2 = nn.Linear(hidden, dim)
        self.se_sigmoid = nn.Sigmoid()

    def forward(self, x):
        shortcut = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N,H,W,C)
        x = self.norm(x)
        N, H, W, C = x.shape
        x_flat = x.reshape(-1, C)
        x_flat = self.kan1(x_flat)
        x = x_flat.reshape(N, H, W, 4 * C)
        x = self.act(x)
        x = self.grn(x)
        x_flat2 = x.reshape(-1, 4 * C)
        x_flat2 = self.kan2(x_flat2)
        x = x_flat2.reshape(N, H, W, C)

        se = x.mean(dim=(1, 2))              # (N, C)
        se = self.se_fc1(se)
        se = self.se_act(se)
        se = self.se_fc2(se)
        se = self.se_sigmoid(se).unsqueeze(1).unsqueeze(1)
        x = x * se

        x = x.permute(0, 3, 1, 2)
        x = shortcut + self.drop_path(x)
        return x

# -------------------- ConvNeXtV2 backbone builder --------------------

class ConvNeXtV2Backbone(nn.Module):
    def __init__(self, in_chans=3, num_classes=1000, block_cls=ConvNeXtBlockBaseline, depths=[3,3,9,3], dims=[96,192,384,768], drop_path_rate=0.0, head_init_scale=1.0):
        super().__init__()
        self.depths = depths
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            down = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(down)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            blocks = []
            for j in range(depths[i]):
                blocks.append(block_cls(dim=dims[i], drop_path=dp_rates[cur + j]))
            self.stages.append(nn.Sequential(*blocks))
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
        self.head = nn.Linear(dims[-1], num_classes)

        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=0.02)
            if getattr(m, "bias", None) is not None:
                nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        x = x.mean([-2, -1])   # global average pooling -> (N, C)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

# -------------------- Training & evaluation helpers --------------------

def compute_metrics_from_logits(logits, targets):
    preds = logits.argmax(dim=1)
    acc = (preds == targets).float().mean().item()
    prec = precision_score(targets.cpu().numpy(), preds.cpu().numpy(), average='macro', zero_division=0)
    rec  = recall_score(targets.cpu().numpy(), preds.cpu().numpy(), average='macro', zero_division=0)
    f1   = f1_score(targets.cpu().numpy(), preds.cpu().numpy(), average='macro', zero_division=0)
    return acc, prec, rec, f1, preds.cpu().numpy()

def train_one_epoch(model, loader, criterion, optimizer, device, kan_reg_weight=0.0):
    model.train()
    running_loss = 0.0
    all_logits = []; all_targets = []
    for imgs, labels in loader:
        imgs = imgs.to(device); labels = labels.to(device, dtype=torch.long)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        # optional: include KAN regularization if present
        if kan_reg_weight and hasattr(model, 'modules'):
            reg = 0.0
            for m in model.modules():
                if hasattr(m, 'regularization_loss'):
                    reg = reg + m.regularization_loss()
            if isinstance(reg, torch.Tensor):
                loss = loss + kan_reg_weight * reg
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        all_logits.append(logits.detach().cpu()); all_targets.append(labels.detach().cpu())
    epoch_loss = running_loss / max(1, len(loader.dataset))
    all_logits = torch.cat(all_logits) if len(all_logits)>0 else torch.tensor([])
    all_targets = torch.cat(all_targets) if len(all_targets)>0 else torch.tensor([])
    if all_logits.numel() == 0:
        return epoch_loss, 0., 0., 0., 0.
    acc, prec, rec, f1, _ = compute_metrics_from_logits(all_logits, all_targets)
    return epoch_loss, acc, prec, rec, f1

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_logits = []; all_targets = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device); labels = labels.to(device, dtype=torch.long)
            logits = model(imgs)
            loss = criterion(logits, labels)
            running_loss += loss.item() * imgs.size(0)
            all_logits.append(logits.cpu()); all_targets.append(labels.cpu())
    n = max(1, len(loader.dataset))
    val_loss = running_loss / n
    all_logits = torch.cat(all_logits) if len(all_logits)>0 else torch.tensor([])
    all_targets = torch.cat(all_targets) if len(all_targets)>0 else torch.tensor([])
    if all_logits.numel() == 0:
        return val_loss, 0., 0., 0., 0., all_logits, all_targets, np.array([])
    acc, prec, rec, f1, preds = compute_metrics_from_logits(all_logits, all_targets)
    return val_loss, acc, prec, rec, f1, all_logits, all_targets, preds

def run_experiment(name, block_cls, depths, dims, seed, dataloaders, device,
                   epochs=10, lr=2e-4, weight_decay=1e-4, drop_path_rate=0.1, kan_reg_weight=0.0, outdir='results'):
    ensure_dir(outdir)
    set_global_seed(seed)
    num_classes = dataloaders['meta']['num_classes']
    model = ConvNeXtV2Backbone(in_chans=3, num_classes=num_classes, block_cls=block_cls,
                               depths=depths, dims=dims, drop_path_rate=drop_path_rate).to(device)
    print(f"[{name}] params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=False)
    best_val_acc = -1.0
    history = {'epoch':[], 'train_loss':[], 'train_acc':[], 'train_prec':[], 'train_rec':[], 'train_f1':[],
               'val_loss':[], 'val_acc':[], 'val_prec':[], 'val_rec':[], 'val_f1':[]}
    for epoch in range(1, epochs+1):
        train_loss, train_acc, train_prec, train_rec, train_f1 = train_one_epoch(model, dataloaders['train'], criterion, optimizer, device, kan_reg_weight=kan_reg_weight)
        val_loss, val_acc, val_prec, val_rec, val_f1, _, _, _ = evaluate(model, dataloaders['val'], criterion, device)
        scheduler.step(val_loss)
        history['epoch'].append(epoch); history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
        history['train_prec'].append(train_prec); history['train_rec'].append(train_rec); history['train_f1'].append(train_f1)
        history['val_loss'].append(val_loss); history['val_acc'].append(val_acc); history['val_prec'].append(val_prec); history['val_rec'].append(val_rec); history['val_f1'].append(val_f1)
        print(f"[{name}] Epoch {epoch}/{epochs} train_loss {train_loss:.4f} val_loss {val_loss:.4f} train_acc {train_acc:.4f} val_acc {val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'val_acc': val_acc}, os.path.join(outdir, f"best_{name}.pth"))
            print(f"[{name}] saved best checkpoint (val_acc={val_acc:.4f})")
    # load best for test
    best_path = os.path.join(outdir, f"best_{name}.pth")
    if os.path.exists(best_path):
        ckpt = torch.load(best_path, map_location=device)
        model.load_state_dict(ckpt['state_dict'])
    test_loss, test_acc, test_prec, test_rec, test_f1, logits, targets, preds = evaluate(model, dataloaders['test'], criterion, device)
    # save artifacts
    hist_df = pd.DataFrame(history); hist_df.to_csv(os.path.join(outdir, f"history_{name}.csv"), index=False)
    # curves
    try:
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1); plt.plot(history['epoch'], history['train_loss'], label='train'); plt.plot(history['epoch'], history['val_loss'], label='val'); plt.legend(); plt.title(f"{name} Loss")
        plt.subplot(1,2,2); plt.plot(history['epoch'], history['train_acc'], label='train'); plt.plot(history['epoch'], history['val_acc'], label='val'); plt.legend(); plt.title(f"{name} Acc")
        plt.tight_layout(); plt.savefig(os.path.join(outdir, f"{name}_curves.png")); plt.close()
    except Exception:
        pass
    # confusion matrix & report
    if isinstance(preds, np.ndarray) and preds.size>0:
        cm = confusion_matrix(targets.numpy(), preds)
        try:
            plt.figure(figsize=(6,6)); plt.imshow(cm, interpolation='nearest'); plt.title(f'CM {name}'); plt.colorbar()
            ticks = np.arange(len(dataloaders['meta']['classes'])); plt.xticks(ticks, dataloaders['meta']['classes'], rotation=45); plt.yticks(ticks, dataloaders['meta']['classes']); plt.tight_layout()
            plt.savefig(os.path.join(outdir, f"{name}_confusion.png")); plt.close()
        except Exception:
            pass
        rpt = classification_report(targets.numpy(), preds, target_names=dataloaders['meta']['classes'], zero_division=0)
    else:
        rpt = "No predictions to compute report."
    with open(os.path.join(outdir, f"{name}_report.txt"), 'w') as f:
        f.write(f"Test loss: {test_loss:.4f}\nTest acc: {test_acc:.4f}\nPrec: {test_prec:.4f}\nRec: {test_rec:.4f}\nF1: {test_f1:.4f}\n\n")
        f.write("Classification report:\n"); f.write(str(rpt))
    print(f"[{name}] TEST -> loss {test_loss:.4f} acc {test_acc:.4f} prec {test_prec:.4f} rec {test_rec:.4f} f1 {test_f1:.4f}")
    return {'name': name, 'history': history, 'test_loss': test_loss, 'test_acc': test_acc, 'test_prec': test_prec, 'test_rec': test_rec, 'test_f1': test_f1}

# -------------------- Main experiment entrypoint --------------------

if __name__ == '__main__':
    # USER CONFIG - change as needed
    DATA_ROOT = '/kaggle/input/brain-tumor-mri-dataset'   # expects Training/ and Testing/ subfolders with class subfolders
    TRAIN_DIR = os.path.join(DATA_ROOT, 'Training')
    TEST_DIR  = os.path.join(DATA_ROOT, 'Testing')
    OUTDIR = 'results_compare_fixed'
    ensure_dir(OUTDIR)

    IMG_SIZE = 160    # set to 160 for very fast debugging, 299/384 for real runs
    BATCH_SIZE = 64
    EPOCHS = 100
    LR = 2e-4
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS = 4
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    GLOBAL_SEED = 42
    set_global_seed(GLOBAL_SEED)

    # load dataframes
    tr_df = get_class_paths(TRAIN_DIR)
    ts_df = get_class_paths(TEST_DIR) if os.path.isdir(TEST_DIR) else pd.DataFrame(columns=['path','label'])
    if len(ts_df) == 0:
        if len(tr_df) == 0:
            raise RuntimeError(f"No data found in {TRAIN_DIR} nor {TEST_DIR}. Please check your DATA_ROOT.")
        train_df_all, ts_df = safe_train_test_split(tr_df, test_size=0.15, stratify=tr_df['label'] if len(tr_df)>0 else None, random_state=GLOBAL_SEED)
        tr_df = train_df_all
    if len(ts_df) == 0:
        tr_df, ts_df = safe_train_test_split(tr_df, test_size=0.15, stratify=tr_df['label'], random_state=GLOBAL_SEED)
    valid_df, test_df = safe_train_test_split(ts_df, train_size=0.5, stratify=ts_df['label'] if len(ts_df)>0 else None, random_state=GLOBAL_SEED)

    combined = pd.concat([tr_df, valid_df, test_df], ignore_index=True)
    classes = sorted(combined['label'].unique().tolist())
    if len(classes) == 0:
        raise RuntimeError("No classes detected in the data. Please ensure dataset structure is /<class>/*.jpg.")
    label2idx = {c:i for i,c in enumerate(classes)}
    num_classes = len(classes)

    print("Classes:", classes)
    print(f"Train samples: {len(tr_df)} | Val samples: {len(valid_df)} | Test samples: {len(test_df)}")

    # transforms
    train_tfms = T.Compose([
        T.RandomResizedCrop(IMG_SIZE, scale=(0.8,1.0)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(15),
        T.ColorJitter(brightness=0.1, contrast=0.1),
        T.ToTensor(),
        T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    val_tfms = T.Compose([T.Resize((IMG_SIZE,IMG_SIZE)), T.ToTensor(), T.Normalize([0.5]*3,[0.5]*3)])

    # datasets & dataloaders
    train_dataset = MRIDataset(tr_df, label2idx, transforms=train_tfms)
    val_dataset   = MRIDataset(valid_df, label2idx, transforms=val_tfms)
    test_dataset  = MRIDataset(test_df, label2idx, transforms=val_tfms)

    set_global_seed(GLOBAL_SEED)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    dataloaders = {'train': train_loader, 'val': val_loader, 'test': test_loader, 'meta': {'num_classes': num_classes, 'classes': classes}}

    # model config
    depths = (2,2,6,2)           # smaller
    dims  = (48,96,192,384)      # much smaller channels
    dp = 0.1

    from functools import partial

    variants = [
         ('convnextv2_se_kan', partial(ConvNeXtBlockSEKAN, kan_grid_size=3, kan_spline_order=2)),
        ('convnextv2_kan', partial(ConvNeXtBlockKAN, kan_grid_size=3, kan_spline_order=2)),
       
        ('convnextv2_baseline', ConvNeXtBlockBaseline),
        ('convnextv2_se', ConvNeXtBlockSE),
        
    ]


    results = []
    for name, block_cls in variants:
        print("\n" + "="*80)
        print("Running:", name)
        set_global_seed(GLOBAL_SEED)
        res = run_experiment(name=name, block_cls=block_cls, depths=depths, dims=dims, seed=GLOBAL_SEED, dataloaders=dataloaders, device=DEVICE, epochs=EPOCHS, lr=LR, weight_decay=WEIGHT_DECAY, drop_path_rate=dp, kan_reg_weight=0.0, outdir=OUTDIR)
        results.append(res)

    # summary
    summary_rows = []
    for r in results:
        summary_rows.append({'model': r['name'], 'test_loss': r['test_loss'], 'test_acc': r['test_acc'], 'test_prec': r['test_prec'], 'test_rec': r['test_rec'], 'test_f1': r['test_f1']})
    summary_df = pd.DataFrame(summary_rows)
    summary_df.to_csv(os.path.join(OUTDIR, 'summary_results.csv'), index=False)
    print("\nSummary results:")
    print(summary_df.to_string(index=False))
    print("All outputs saved to:", OUTDIR)

# End of script


Classes: ['glioma', 'meningioma', 'notumor', 'pituitary']
Train samples: 5712 | Val samples: 655 | Test samples: 656

Running: convnextv2_se_kan
[convnextv2_se_kan] params: 30,989,596




[convnextv2_se_kan] Epoch 1/100 train_loss 1.2349 val_loss 1.2080 train_acc 0.4149 val_acc 0.4305
[convnextv2_se_kan] saved best checkpoint (val_acc=0.4305)
[convnextv2_se_kan] Epoch 2/100 train_loss 0.9547 val_loss 0.9393 train_acc 0.6063 val_acc 0.6351
[convnextv2_se_kan] saved best checkpoint (val_acc=0.6351)
[convnextv2_se_kan] Epoch 3/100 train_loss 0.7745 val_loss 0.7600 train_acc 0.6861 val_acc 0.7008
[convnextv2_se_kan] saved best checkpoint (val_acc=0.7008)
[convnextv2_se_kan] Epoch 4/100 train_loss 0.6484 val_loss 0.7079 train_acc 0.7295 val_acc 0.7191
[convnextv2_se_kan] saved best checkpoint (val_acc=0.7191)
[convnextv2_se_kan] Epoch 5/100 train_loss 0.5624 val_loss 0.6318 train_acc 0.7780 val_acc 0.7557
[convnextv2_se_kan] saved best checkpoint (val_acc=0.7557)
[convnextv2_se_kan] Epoch 6/100 train_loss 0.4977 val_loss 0.6109 train_acc 0.8004 val_acc 0.7603
[convnextv2_se_kan] saved best checkpoint (val_acc=0.7603)
[convnextv2_se_kan] Epoch 7/100 train_loss 0.4908 val_loss



[convnextv2_kan] Epoch 1/100 train_loss 1.2348 val_loss 1.1351 train_acc 0.4125 val_acc 0.5069
[convnextv2_kan] saved best checkpoint (val_acc=0.5069)
[convnextv2_kan] Epoch 2/100 train_loss 1.0210 val_loss 0.8837 train_acc 0.5730 val_acc 0.6336
[convnextv2_kan] saved best checkpoint (val_acc=0.6336)
[convnextv2_kan] Epoch 3/100 train_loss 0.7927 val_loss 0.8092 train_acc 0.6921 val_acc 0.6550
[convnextv2_kan] saved best checkpoint (val_acc=0.6550)
[convnextv2_kan] Epoch 4/100 train_loss 0.6660 val_loss 0.8141 train_acc 0.7274 val_acc 0.6947
[convnextv2_kan] saved best checkpoint (val_acc=0.6947)
[convnextv2_kan] Epoch 5/100 train_loss 0.5969 val_loss 0.7139 train_acc 0.7542 val_acc 0.7130
[convnextv2_kan] saved best checkpoint (val_acc=0.7130)
[convnextv2_kan] Epoch 6/100 train_loss 0.5679 val_loss 0.6255 train_acc 0.7647 val_acc 0.7374
[convnextv2_kan] saved best checkpoint (val_acc=0.7374)
[convnextv2_kan] Epoch 7/100 train_loss 0.5062 val_loss 0.6105 train_acc 0.7948 val_acc 0.7771



[convnextv2_baseline] Epoch 1/100 train_loss 1.2846 val_loss 1.0384 train_acc 0.3683 val_acc 0.5756
[convnextv2_baseline] saved best checkpoint (val_acc=0.5756)
[convnextv2_baseline] Epoch 2/100 train_loss 0.8700 val_loss 0.8399 train_acc 0.6488 val_acc 0.6733
[convnextv2_baseline] saved best checkpoint (val_acc=0.6733)
[convnextv2_baseline] Epoch 3/100 train_loss 0.7101 val_loss 0.7390 train_acc 0.7099 val_acc 0.7206
[convnextv2_baseline] saved best checkpoint (val_acc=0.7206)
[convnextv2_baseline] Epoch 4/100 train_loss 0.5790 val_loss 0.7141 train_acc 0.7679 val_acc 0.7374
[convnextv2_baseline] saved best checkpoint (val_acc=0.7374)
[convnextv2_baseline] Epoch 5/100 train_loss 0.5566 val_loss 0.5889 train_acc 0.7759 val_acc 0.7832
[convnextv2_baseline] saved best checkpoint (val_acc=0.7832)
[convnextv2_baseline] Epoch 6/100 train_loss 0.4924 val_loss 0.6177 train_acc 0.8072 val_acc 0.7893
[convnextv2_baseline] saved best checkpoint (val_acc=0.7893)
[convnextv2_baseline] Epoch 7/100 



[convnextv2_se] Epoch 1/100 train_loss 1.2593 val_loss 1.1481 train_acc 0.3955 val_acc 0.5069
[convnextv2_se] saved best checkpoint (val_acc=0.5069)
[convnextv2_se] Epoch 2/100 train_loss 0.9208 val_loss 0.8074 train_acc 0.6322 val_acc 0.6977
[convnextv2_se] saved best checkpoint (val_acc=0.6977)
[convnextv2_se] Epoch 3/100 train_loss 0.6960 val_loss 0.6841 train_acc 0.7215 val_acc 0.7298
[convnextv2_se] saved best checkpoint (val_acc=0.7298)
[convnextv2_se] Epoch 4/100 train_loss 0.6063 val_loss 0.6367 train_acc 0.7712 val_acc 0.7176
[convnextv2_se] Epoch 5/100 train_loss 0.5427 val_loss 0.6103 train_acc 0.7873 val_acc 0.7573
[convnextv2_se] saved best checkpoint (val_acc=0.7573)
[convnextv2_se] Epoch 6/100 train_loss 0.4861 val_loss 0.4858 train_acc 0.8118 val_acc 0.8183
[convnextv2_se] saved best checkpoint (val_acc=0.8183)
[convnextv2_se] Epoch 7/100 train_loss 0.4476 val_loss 0.5307 train_acc 0.8207 val_acc 0.8214
[convnextv2_se] saved best checkpoint (val_acc=0.8214)
[convnextv2_

# Side-by-side concise comparison — block level and full backbone

Notation: **N** = batch, **C** = channels at that stage, **H,W** = spatial. Shapes shown as `(N,C,H,W)` (channels-first) or `(N,H,W,C)` when channels-last. `pw1` = Linear(C→4C), `pw2` = Linear(4C→C). `KAN1/KAN2` replace `pw1/pw2` when present. `SE` = squeeze-excite gating applied to block output channels.

---

## 1) Block-level flow (rows = step; columns = variant)

| Step | Baseline                                                                                                          | +SE  | +KAN                                                             | +SE + KAN           |
| ---: | ----------------------------------------------------------------------------------------------------------------- | ---- | ---------------------------------------------------------------- | ------------------- |
|    0 | Input `x_in` `(N,C,H,W)`                                                                                          | same | same                                                             | same                |
|    1 | `dwconv 7×7 (groups=C)` → `(N,C,H,W)`                                                                             | same | same                                                             | same                |
|    2 | `permute` → `(N,H,W,C)`                                                                                           | same | same                                                             | same                |
|    3 | `LayerNorm (channels_last)` → `(N,H,W,C)`                                                                         | same | same                                                             | same                |
|    4 | **Projection expand**: `pw1` → `(N,H,W,4C)`                                                                       | same | **KAN1([C→4C])** on flattened `(N*H*W,C)` → reshape `(N,H,W,4C)` | **KAN1**            |
|    5 | `GELU` → `(N,H,W,4C)`                                                                                             | same | same                                                             | same                |
|    6 | `GRN` → `(N,H,W,4C)`                                                                                              | same | same                                                             | same                |
|    7 | **Projection shrink**: `pw2` → `(N,H,W,C)`                                                                        | same | **KAN2([4C→C])** on flattened `(N*H*W,4C)` → reshape `(N,H,W,C)` | **KAN2**            |
|    8 | *(SE variant only)* `SE`: `global avg (N,C)` → `FC→act→FC→sigmoid` → scale `(N,1,1,C)`; multiply with `(N,H,W,C)` | —    | **(if +SE+KAN)** SE applied here                                 | **SE applied here** |
|    9 | `permute` → `(N,C,H,W)`                                                                                           | same | same                                                             | same                |
|   10 | `residual + DropPath` → output `(N,C,H,W)`                                                                        | same | same                                                             | same                |

**Summary:** only differences are (a) whether expand/shrink are Linear vs KAN, and (b) whether SE gating is applied after the second proj and before residual.

---

## 2) Full-backbone flow — stage by stage (generic)

| Stage        | Operation                                                           | Output shape (generic) | Variant differences                                  |
| ------------ | ------------------------------------------------------------------- | ---------------------: | ---------------------------------------------------- |
| Input        | image                                                               |          `(N,3,H0,W0)` | —                                                    |
| Stem         | `Conv2d(3→C0, k=4, s=4)` + `LayerNorm(ch_first)`                    |     `(N,C0,H0/4,W0/4)` | —                                                    |
| Stage 0      | `depths[0]` × **block** at `C0`                                     |     `(N,C0,H0/4,W0/4)` | blocks use `pw` or `KAN` and optional SE per variant |
| Downsample 1 | `LayerNorm` + `Conv2d(C0→C1,k=2,s=2)`                               |     `(N,C1,H0/8,W0/8)` | —                                                    |
| Stage 1      | `depths[1]` × blocks at `C1`                                        |     `(N,C1,H0/8,W0/8)` | as above                                             |
| Downsample 2 | `LayerNorm` + `Conv2d(C1→C2,k=2,s=2)`                               |   `(N,C2,H0/16,W0/16)` | —                                                    |
| Stage 2      | `depths[2]` × blocks at `C2`                                        |   `(N,C2,H0/16,W0/16)` | as above                                             |
| Downsample 3 | `LayerNorm` + `Conv2d(C2→C3,k=2,s=2)`                               |   `(N,C3,H0/32,W0/32)` | —                                                    |
| Stage 3      | `depths[3]` × blocks at `C3`                                        |   `(N,C3,H0/32,W0/32)` | as above                                             |
| Head         | `global avg pool → (N,C3)` → `LayerNorm` → `Linear(C3→num_classes)` |      `(N,num_classes)` | —                                                    |

**Where MLP/KAN live:** inside *each* block at every stage at the expand/shrink positions. SE (if used) also inside *each* block after the shrink.

---

## 3) Example numeric trace (your config: `IMG=160`, `dims=(48,96,192,384)`, `depths=(2,2,6,2)`)

* Input `(N,3,160,160)`
* Stem → `(N,48,40,40)`
* Stage0 (2 blocks) operate at `(N,48,40,40)`; inside block expand → `(N,192,40,40)`, shrink → `(N,48,40,40)`
* Downsample → `(N,96,20,20)`
* Stage1 (2 blocks) operate at `(N,96,20,20)`; expand → `4C=384` etc.
* Downsample → `(N,192,10,10)`
* Stage2 (6 blocks) operate at `(N,192,10,10)`; expand → `768`
* Downsample → `(N,384,5,5)`
* Stage3 (2 blocks) operate at `(N,384,5,5)`; expand → `1536`
* Pool → `(N,384)` → head → `(N,num_classes)`

---

## 4) Quick practical notes (1-line each)

* **SE placement**: correct — after `pw2`/`KAN2` and before residual. That’s standard practice.
* **Why +SE can underperform vs baseline/KAN**: SE applies **global** channel gating which can (a) redundantly suppress KAN’s fine-grained channel structure, (b) introduce optimization conflicts or over-regularization on small datasets. KAN tends to increase expressivity; adding SE may not always help and can sometimes hurt when both interact poorly.
* **KAN cost**: higher memory/compute (it flattens spatial dims and uses spline bases) — expect slower training and larger GPU use.
* **If you want best of both**: try lighter SE (larger reduction, weaker gating), or apply SE only in later stages; experiment.


# Quick defs

* **N** = batch size (number of images in batch).
* **C** = number of channels (feature dimension) at that stage — e.g. 48, 96, 192, 384 in your config.
* **H, W** = spatial height and width. A tensor shape is written `(N, C, H, W)` (channels-first) or `(N, H, W, C)` (channels-last).

---

# Whole-backbone overview (data path)

1. **Input**: `(N, 3, H0, W0)` — e.g. `(N,3,160,160)`.
2. **Stem**: `Conv2d(3 -> C0, kernel=4, stride=4)` → spatial downsample by 4: `(N, C0, H0/4, W0/4)`.
   With your numbers: `(N, 48, 40, 40)`.
3. **Stage 0**: `depths[0]` blocks at `(N, C0, 40,40)`.
4. **Downsample** (between stages): `LayerNorm(channels_first) + Conv2d(Ci -> C{i+1}, kernel=2, stride=2)` → halves spatial dims, raises channels.

   * after 1st downsample: `(N, 96, 20,20)`
   * after 2nd: `(N,192,10,10)`
   * after 3rd: `(N,384,5,5)`
5. **Stages 1..3**: each stage runs `depths[i]` blocks at corresponding `(N, C_i, H_i, W_i)`.
6. **Head**: global average pool over H,W → `(N, C_last)` → `LayerNorm` → `Linear(C_last -> num_classes)` → logits.

---

# Single block — step-by-step (baseline ConvNeXtV2 block)

Start: `x_in (N, C, H, W)`

1. **Depthwise conv 7×7, groups=C**
   `x = dwconv(x_in)` → `(N, C, H, W)` (spatial same, channels unchanged).
2. **Permute to channels-last**
   `x = x.permute(0,2,3,1)` → `(N, H, W, C)`.
3. **LayerNorm (channels_last)** applied per-channel. Output `(N, H, W, C)`.
4. **pw1 — pointwise MLP**: `Linear(C → 4·C)` applied independently at each (H,W) location → `(N, H, W, 4C)`.

   * numeric example: if `C=48` then `4C=192`.
5. **Activation (GELU)** → same shape `(N,H,W,4C)`.
6. **GRN (Global Response Normalization)** — channel-wise modulation that rescales features using their L2 norms across spatial axes → `(N,H,W,4C)`.
7. **pw2 — pointwise MLP**: `Linear(4·C → C)` → `(N,H,W,C)`.
8. **(Optional) SE gating** — *if present* (see variant): compute `se = sigmoid( FC(FC(x_global)) )` with `x_global = x.mean(dim=(1,2))` giving `(N,C)`. Scale `x *= se.unsqueeze(1).unsqueeze(1)` → `(N,H,W,C)`.
9. **Permute back to channels-first** → `(N, C, H, W)`.
10. **DropPath (stochastic depth) + residual add**: `out = x_in + DropPath(x)` → `(N, C, H, W)`.

**Where the MLPs are in baseline:** the two `Linear` layers in steps 4 and 7 are the block MLP (inverted: expand then shrink).

---

# KAN vs MLP (where they appear)

* **Baseline / SE variants**: use standard `pw1`/`pw2` (torch `nn.Linear`) at steps 4 and 7.
* **KAN variants**: replace `pw1` and `pw2` by `KAN` modules:

  * After LayerNorm, **flatten spatial**: `(N, H, W, C)` → `x_flat (N*H*W, C)`.
  * `KAN([C → 4C])` processes `x_flat` → `(N*H*W, 4C)` → reshape to `(N,H,W,4C)`.
  * Activation + GRN as before.
  * Flatten again and `KAN([4C → C])` → `(N,H,W,C)` → permute back → residual add.
* **Effectively**: where baseline uses simple linear projection per position, KAN uses a learned spline-based transform (data-adaptive nonlinear mapping) in place of those Linear layers.

---

# Example numeric trace (your config: input 160, dims=(48,96,192,384), depths=(2,2,6,2))

* Input: `(N, 3, 160, 160)`
* Stem conv (stride 4): `(N, 48, 40, 40)`
* Stage0 (2 blocks): operate at `(N,48,40,40)`; inside block `pw1` expands → `(N,192,40,40)` then `pw2` compresses back to `(N,48,40,40)`.
* Downsample → `(N,96,20,20)`
* Stage1 (2 blocks): `(N,96,20,20)`; `pw1` → `4C=384`, `pw2` → `96`.
* Downsample → `(N,192,10,10)`
* Stage2 (6 blocks): `(N,192,10,10)`; `pw1` → `768`, `pw2` → `192`.
* Downsample → `(N,384,5,5)`
* Stage3 (2 blocks): `(N,384,5,5)`; `pw1` → `1536`, `pw2` → `384`.
* Global avg pool → `(N,384)` → head → `(N,num_classes)`.

---

# Variant differences (compact table)

| Variant   |                pw1/pw2 | SE location                                        | Key effect                                                                          |
| --------- | ---------------------: | -------------------------------------------------- | ----------------------------------------------------------------------------------- |
| Baseline  |               `Linear` | none                                               | standard ConvNeXtV2 inverted-MLP per position                                       |
| +SE       |               `Linear` | after `pw2` (block output) before permute/residual | channel gating (squeeze & excite) — rescales channels globally                      |
| +KAN      | `KAN` (both pw1 & pw2) | none                                               | replaces linear projections with adaptive spline-based transforms (more expressive) |
| +SE + KAN |             `KAN` + SE | SE after second KAN (same logical place)           | combines KAN expressivity with channel gating                                       |

---

# Short notes on behavior & why placements matter

* **C and per-block expansion (4·C)**: the inverted MLP expands channel dimension (`C → 4C`) so the block can mix channels nonlinearly per spatial position and then project back. KAN performs that same expand/contract but with spline-adaptive mappings.
* **SE placement**: standard practice is exactly where you used it—on the block *output* channels before adding the residual. It gates the block's produced channel features globally.
* **Why SE sometimes hurts with KAN**:

  * KAN learns *richer, data-conditioned* mappings. A strong global gating (SE) after KAN can **suppress** those nuanced outputs or create conflicting gradients, reducing the benefit.
  * SE is a global coarse gate — if KAN encodes fine discriminative features across channels, SE may mistakenly down-weight helpful channels (especially on small/imbalanced datasets).
* **Memory / compute**: KAN is heavier (it flattens spatial dims and runs more complex `KANLinear` ops). expect more params and higher GPU usage than simple Linear.

---

# Final short guidance

* **Where MLPs are**: inside each block — `pw1` (C→4C) and `pw2` (4C→C).
* **Where KAN replaces MLPs**: exactly those two spots, applied to flattened `(N*H*W, features)`.
* **SE belongs after pw2/KAN2 and before residual** (your implementation is correct). Whether SE helps depends on data and interaction with KAN—on your MRI set KAN alone gave the best lift; SE added gating that was redundant or harmful.
