Fixed-gate ablation (to prove that “adaptive gating” matters)

Run the same training/eval but force:

g0=0.5, g1=0.5, g2=0.5 (constant)
Compare vs full model.
If full > fixed → your gains come from learned gating, not just having two paths.

In [1]:
import os
import sys
import time
import math
import random
import subprocess
from typing import List, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    log_loss,
    roc_auc_score,
)


def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# -------------------------
# Configuration
# -------------------------
CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
CSV_PATH = os.path.join(OUTDIR, "fixed_gates_val_test_metrics.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]

    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = (x - mu) / sd
        x0 = x0.clamp(-self.tau, self.tau)

        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []

    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) +
            0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None, []

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            ) for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class PVTv2B2_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()

        self.backbone = timm.create_model(
            "pvt_v2_b2",
            pretrained=pretrained,
            features_only=True,
            out_indices=(0, 1, 2, 3),
        )
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        # Constant gates only: g0 = g1 = g2 = 0.5
        x0 = 0.5 * x_raw_n + 0.5 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)
        f_mid = 0.5 * f0 + 0.5 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        t_final = 0.5 * t_mid + 0.5 * t_views
        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)

    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor(
            [
                preproc_module.gamma,
                preproc_module.alpha,
                preproc_module.beta,
                preproc_module.tau,
                float(preproc_module.blur_k) / 7.0,
                preproc_module.sharpen,
                preproc_module.denoise,
            ],
            device=DEVICE,
            dtype=torch.float32,
        )
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        out["auc_roc_macro_ovr"] = np.nan
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        return {
            "acc": np.nan,
            "precision_macro": np.nan,
            "recall_macro": np.nan,
            "f1_macro": np.nan,
            "precision_weighted": np.nan,
            "recall_weighted": np.nan,
            "f1_weighted": np.nan,
            "log_loss": np.nan,
            "auc_roc_macro_ovr": np.nan,
            "loss_ce": np.nan,
            "eval_time_s": float(time.time() - t0),
        }

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "loss_ce": float(np.mean(all_loss)),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()
        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader, max_candidates=10):
    if not pool:
        return None
    cand = pool[:max_candidates]
    best, best_acc = None, -1
    for th in cand:
        pre = theta_to_module(th).to(DEVICE)
        met = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(metric_triplets: List[Tuple[int, dict, int]]):
    if not metric_triplets:
        return {}
    total = sum(w for _, _, w in metric_triplets)
    keys = metric_triplets[0][1].keys()
    out = {}
    for k in keys:
        vals = [m[1].get(k, np.nan) for m in metric_triplets]
        wts = [m[2] for m in metric_triplets]
        out[k] = float(np.average(vals, weights=wts)) if total > 0 else np.nan
    return out


def prepare_data():
    ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
    ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

    req1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
    req2 = {"glioma", "meningioma", "notumor", "pituitary"}

    ds1_root = find_root_with_required_class_dirs(ds1_path, req1, prefer_raw=True)
    ds2_root = find_root_with_required_class_dirs(ds2_path, req2, prefer_raw=False)
    if ds1_root is None or ds2_root is None:
        raise RuntimeError("Dataset roots not found.")

    df1 = enforce_labels(build_df_from_root(ds1_root, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
    df2 = enforce_labels(build_df_from_root(ds2_root, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

    train1, val1, test1 = split_dataset(df1)
    train2, val2, test2 = split_dataset(df2)

    n_per_ds = CFG["clients_per_dataset"]
    client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
    client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

    client_splits = []
    for k in range(n_per_ds):
        tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        client_splits.append(("ds1", k, k, tr, tune, va))
    for k in range(n_per_ds):
        tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

    client_test_splits = []
    for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
        idxs = list(range(len(test_df)))
        random.shuffle(idxs)
        split = np.array_split(idxs, n_per_ds)
        for k in range(n_per_ds):
            client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))

    client_loaders = []
    for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
        df_src = train1 if ds_name == "ds1" else train2
        source_id = 0 if ds_name == "ds1" else 1

        sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
        tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
        tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
        val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:max(1, min(len(tr_idx), CFG["batch_size"]))], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
        client_loaders.append((tr_loader, tune_loader, val_loader))

    client_test_loaders = []
    for ds_name, local_id, gid, test_idx in client_test_splits:
        df_src = test1 if ds_name == "ds1" else test2
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(df_src, test_idx, CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
        client_test_loaders.append((ds_name, local_id, gid, t_loader))

    counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
    counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values

    return {
        "client_loaders": client_loaders,
        "client_test_loaders": client_test_loaders,
        "n_per_ds": n_per_ds,
        "counts": counts1 + counts2,
        "test1_len": len(test1),
        "test2_len": len(test2),
    }


def run_experiment(data_bundle, setting_name: str = "fixed_gates_0.5"):
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    client_loaders = data_bundle["client_loaders"]
    client_test_loaders = data_bundle["client_test_loaders"]
    n_per_ds = data_bundle["n_per_ds"]

    model = PVTv2B2_MultiScale(
        num_classes=NUM_CLASSES,
        pretrained=True,
        head_dropout=CFG["head_dropout"],
        cond_dim=CFG["cond_dim"],
        num_clients=CFG["clients_total"],
    ).to(DEVICE)

    set_trainable_for_round(model, rnd=1)

    backbone_frozen = model.backbone.eval()
    for p in backbone_frozen.parameters():
        p.requires_grad = False

    counts = data_bundle["counts"]
    w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
    w = w / max(1e-6, w.mean())
    class_w = torch.tensor(w, device=DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
    scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

    elite_pool_ds1, elite_pool_ds2 = [], []
    best_global_acc = -1.0
    best_model_state = None
    best_theta_ds1 = None
    best_theta_ds2 = None

    for rnd in range(1, CFG["rounds"] + 1):
        local_models = []
        local_weights = []
        local_val_rows = []

        for k in range(CFG["clients_total"]):
            tr_loader, tune_loader, val_loader = client_loaders[k]
            ds_name = "ds1" if k < n_per_ds else "ds2"
            elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

            if CFG["use_preprocessing"] and CFG["use_ga"]:
                best_theta, top_thetas = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
                elite_pool.extend(top_thetas)
                elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
                pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
            else:
                pre_k = nn.Identity().to(DEVICE)

            if ds_name == "ds1":
                elite_pool_ds1 = elite_pool
            else:
                elite_pool_ds2 = elite_pool

            local_model = PVTv2B2_MultiScale(
                num_classes=NUM_CLASSES,
                pretrained=False,
                head_dropout=CFG["head_dropout"],
                cond_dim=CFG["cond_dim"],
                num_clients=CFG["clients_total"],
            ).to(DEVICE)
            local_model.load_state_dict(model.state_dict(), strict=True)

            set_trainable_for_round(local_model, rnd=rnd)
            opt = make_optimizer(local_model)
            total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
            warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
            scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

            for _ in range(CFG["local_epochs"]):
                train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, model, scheduler, scaler)

            met_loc = evaluate_full(local_model, val_loader, pre_k)
            local_val_rows.append((k, met_loc, len(val_loader.dataset)))

            local_models.append(local_model)
            local_weights.append(len(tr_loader.dataset))

        wsum = sum(local_weights)
        weights = [w / wsum for w in local_weights]
        trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
        fedavg_update(model, local_models, weights, trainable_names)

        val_best = weighted_aggregate(local_val_rows)
        if np.isfinite(val_best["acc"]) and val_best["acc"] > best_global_acc:
            best_global_acc = float(val_best["acc"])
            best_model_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        if CFG["use_preprocessing"] and elite_pool_ds1:
            best_theta_ds1 = pick_best_theta_from_pool(model, elite_pool_ds1, client_loaders[0][2])
        if CFG["use_preprocessing"] and elite_pool_ds2:
            best_theta_ds2 = pick_best_theta_from_pool(model, elite_pool_ds2, client_loaders[n_per_ds][2])

    if best_model_state is not None:
        model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

    pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if best_theta_ds1 is not None else nn.Identity().to(DEVICE)
    pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if best_theta_ds2 is not None else nn.Identity().to(DEVICE)

    val_metrics_clients = []
    for k in range(CFG["clients_total"]):
        _, _, val_loader = client_loaders[k]
        pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
        met = evaluate_full(model, val_loader, pre)
        val_metrics_clients.append((k, met, len(val_loader.dataset)))
    val_best = weighted_aggregate(val_metrics_clients)

    def eval_test_per_dataset(ds_name):
        mets = []
        for ds, _, _, t_loader in client_test_loaders:
            if ds != ds_name:
                continue
            pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
            met = evaluate_full(model, t_loader, pre)
            mets.append((0, met, len(t_loader.dataset)))
        return weighted_aggregate(mets)

    test_ds1 = eval_test_per_dataset("ds1")
    test_ds2 = eval_test_per_dataset("ds2")
    global_test = weighted_aggregate([
        (0, test_ds1, data_bundle["test1_len"]),
        (1, test_ds2, data_bundle["test2_len"]),
    ])

    metric_cols = [
        "acc", "precision_macro", "recall_macro", "f1_macro", "precision_weighted",
        "recall_weighted", "f1_weighted", "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s"
    ]

    def row(split, dataset, m):
        r = {"setting": setting_name, "split": split, "dataset": dataset}
        for c in metric_cols:
            r[c] = float(m.get(c, np.nan)) if isinstance(m, dict) else np.nan
        return r

    return [
        row("VAL", "ds1+ds2 weighted", val_best),
        row("TEST", "ds1", test_ds1),
        row("TEST", "ds2", test_ds2),
        row("TEST", "global weighted", global_test),
    ]


def main():
    data_bundle = prepare_data()

    rows = []
    rows.extend(run_experiment(data_bundle, setting_name="fixed_gates_0.5"))

    out_df = pd.DataFrame(rows, columns=[
        "setting", "split", "dataset", "acc", "precision_macro", "recall_macro", "f1_macro",
        "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "auc_roc_macro_ovr",
        "loss_ce", "eval_time_s"
    ])

    out_df.to_csv(CSV_PATH, index=False)
    print(out_df)
    print(f"Saved: {CSV_PATH}")


if __name__ == "__main__":
    main()


  self.setter(val)


Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/101M [00:00<?, ?B/s]



           setting split           dataset       acc  precision_macro  \
0  fixed_gates_0.5   VAL  ds1+ds2 weighted  0.932070         0.778172   
1  fixed_gates_0.5  TEST               ds1  0.938053         0.941976   
2  fixed_gates_0.5  TEST               ds2  0.937441         0.936130   
3  fixed_gates_0.5  TEST   global weighted  0.937549         0.937162   

   recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  \
0      0.940230  0.789071            0.972431         0.932070     0.947113   
1      0.934708  0.934669            0.942502         0.938053     0.936798   
2      0.933919  0.933560            0.939898         0.937441     0.937307   
3      0.934058  0.933756            0.940358         0.937549     0.937217   

   log_loss  auc_roc_macro_ovr   loss_ce  eval_time_s  
0  0.238275                NaN  0.238878     2.869883  
1  0.221698           0.990906  0.219100     2.738322  
2  0.223606           0.989408  0.224826     7.678782  
3  0.223269   