Backbone family ladder (mid-size tier)

ResNet-50

DenseNet-121

EfficientNet-B0 (or B3)

ConvNeXt-T

Swin-T

ViT-B/16 (or DeiT-B)

# **DeiT-B**

In [None]:
import os, time, math, random, sys, subprocess, hashlib
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, confusion_matrix, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms
except Exception:
    pip_install("torchvision")
    from torchvision import transforms

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s: return "glioma"
    if "meningioma" in s: return "meningioma"
    if "pituitary" in s: return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s: return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl: sc += 7
            if os.path.basename(p).lower() == "raw": sc += 7
            if "/raw/" in pl or "\\raw\\" in pl: sc += 3
            if "augmented" in pl: sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class DeiTBBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.deit = timm.create_model("deit_base_patch16_224", pretrained=pretrained)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.deit.patch_embed(x)

        cls = self.deit.cls_token.expand(b, -1, -1)
        if getattr(self.deit, "dist_token", None) is not None:
            dist = self.deit.dist_token.expand(b, -1, -1)
            x = torch.cat((cls, dist, x), dim=1)
            token_offset = 2
        else:
            x = torch.cat((cls, x), dim=1)
            token_offset = 1

        x = x + self.deit.pos_embed
        x = self.deit.pos_drop(x)

        gh = h // self.deit.patch_embed.patch_size[0]
        gw = w // self.deit.patch_embed.patch_size[1]

        feats = []
        for i, blk in enumerate(self.deit.blocks):
            x = blk(x)
            if i in {2, 5, 8, 11}:
                tok = x[:, token_offset:, :]
                f = tok.transpose(1, 2).reshape(b, tok.shape[-1], gh, gw)
                feats.append(f)
        return feats


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class DeiTB_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = DeiTBBackbone(pretrained=pretrained)
        in_channels = [768, 768, 768, 768]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        # Signature-compatible with the original pipeline API.
        # Current DeiT model does not return internal gate tensors.
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    total = sum(w for _, w in mets)
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = DeiTB_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = DeiTB_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Downloading from https://www.kaggle.com/api/v1/datasets/download/yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection?dataset_version_number=1...


100%|██████████| 130M/130M [00:06<00:00, 20.3MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/orvile/pmram-bangladeshi-brain-cancer-mri-dataset?dataset_version_number=2...


100%|██████████| 161M/161M [00:08<00:00, 19.6MB/s]

Extracting files...



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]




Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.930490         0.755383      0.941059  0.775142            0.957063         0.930490     0.939449  0.259749                NaN 0.251492     5.450670
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.867257         0.874717      0.866663  0.861529            0.878006         0.867257     0.863929  0.416327           0.978740 0.410909     2.527019
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.921327         0.917508      0.918532  0.915784            0.921755         0.921327     0.919280  0.270578           0.986869 0.266299     9.511256
        Enhanced FELCM (Best θ)  TEST  global weighted 0.911788         0.909959      0.909381  0.906212            0.914036         0.911788     0.90951

## **DenseNet-121**

In [None]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class DenseNet121Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
        dnet = models.densenet121(weights=weights)
        self.features = dnet.features

    def forward(self, x):
        feats = []
        x = self.features.conv0(x)
        x = self.features.norm0(x)
        x = self.features.relu0(x)
        x = self.features.pool0(x)

        x = self.features.denseblock1(x)
        feats.append(x)
        x = self.features.transition1(x)

        x = self.features.denseblock2(x)
        feats.append(x)
        x = self.features.transition2(x)

        x = self.features.denseblock3(x)
        feats.append(x)
        x = self.features.transition3(x)

        x = self.features.denseblock4(x)
        x = self.features.norm5(x)
        feats.append(x)
        return feats


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class DenseNet121_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = DenseNet121Backbone(pretrained=pretrained)
        in_channels = [256, 512, 1024, 1024]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = DenseNet121_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = DenseNet121_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


100%|██████████| 30.8M/30.8M [00:00<00:00, 192MB/s]



Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.622433         0.716489      0.672254  0.508683            0.898396         0.622433     0.643419  1.256511                NaN 1.271830     3.113227
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.584071         0.835465      0.580947  0.552589            0.838480         0.584071     0.560216  1.256331           0.896351 1.243431     2.325316
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.590521         0.820432      0.584873  0.538577            0.835699         0.590521     0.553285  1.318358           0.879787 1.320917     6.863452
        Enhanced FELCM (Best θ)  TEST  global weighted 0.589383         0.823084      0.584180  0.541049            0.836189         0.589383     0.55450

# **ResNet-50**

In [None]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class ResNet50Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        net = models.resnet50(weights=weights)
        self.conv1 = net.conv1
        self.bn1 = net.bn1
        self.relu = net.relu
        self.maxpool = net.maxpool
        self.layer1 = net.layer1
        self.layer2 = net.layer2
        self.layer3 = net.layer3
        self.layer4 = net.layer4

    def forward(self, x):
        feats = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        feats.append(x)
        x = self.layer2(x)
        feats.append(x)
        x = self.layer3(x)
        feats.append(x)
        x = self.layer4(x)
        feats.append(x)
        return feats


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class ResNet50_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = ResNet50Backbone(pretrained=pretrained)
        in_channels = [256, 512, 1024, 2048]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = ResNet50_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = ResNet50_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Downloading from https://www.kaggle.com/api/v1/datasets/download/orvile/pmram-bangladeshi-brain-cancer-mri-dataset?dataset_version_number=2...


100%|██████████| 161M/161M [00:09<00:00, 18.4MB/s]

Extracting files...





Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 148MB/s]



Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.739336         0.701299      0.705339  0.612545            0.909203         0.739336     0.775984  0.796141                NaN 0.788748     2.829313
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.672566         0.785990      0.669170  0.646711            0.788236         0.672566     0.652030  1.046664           0.900620 1.051847     2.016313
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.734597         0.825398      0.725817  0.712547            0.831364         0.734597     0.723580  0.842489           0.934326 0.843561     6.288494
        Enhanced FELCM (Best θ)  TEST  global weighted 0.723653         0.818446      0.715823  0.700932            0.823755         0.723653     0.71095

# **EfficientNet-B0**

In [None]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class EfficientNetB0Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        net = models.efficientnet_b0(weights=weights)
        self.features = net.features

    def forward(self, x):
        feats = []
        for i, block in enumerate(self.features):
            x = block(x)
            if i in {2, 3, 5, 8}:
                feats.append(x)
        return feats


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class EfficientNetB0_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = EfficientNetB0Backbone(pretrained=pretrained)
        in_channels = [24, 40, 112, 1280]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = EfficientNetB0_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = EfficientNetB0_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 136MB/s]



Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.725118         0.687829      0.780957  0.624834            0.876695         0.725118     0.749311  0.778805                NaN 0.771438     1.958183
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.703540         0.817952      0.697915  0.685557            0.825652         0.703540     0.694738  0.905015           0.933505 0.894178     1.528389
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.745024         0.819638      0.736473  0.720858            0.825738         0.745024     0.731872  0.777580           0.944821 0.773610     4.190761
        Enhanced FELCM (Best θ)  TEST  global weighted 0.737705         0.819340      0.729670  0.714630            0.825723         0.737705     0.72532

# **EfficientNet-B3**

In [None]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class EfficientNetB3Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.EfficientNet_B3_Weights.IMAGENET1K_V1 if pretrained else None
        net = models.efficientnet_b3(weights=weights)
        self.features = net.features

    def forward(self, x):
        feats = []
        for i, block in enumerate(self.features):
            x = block(x)
            if i in {2, 3, 5, 8}:
                feats.append(x)
        return feats


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class EfficientNetB3_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = EfficientNetB3Backbone(pretrained=pretrained)
        in_channels = [32, 48, 136, 1536]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = EfficientNetB3_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = EfficientNetB3_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth


100%|██████████| 47.2M/47.2M [00:00<00:00, 74.7MB/s]



Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.878357         0.720307      0.787387  0.704651            0.933898         0.878357     0.902432  0.384337                NaN 0.373188     2.422866
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.823009         0.840227      0.828367  0.822599            0.846353         0.823009     0.823649  0.510442           0.962610 0.506676     1.961017
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.890047         0.888522      0.887609  0.885975            0.896895         0.890047     0.891488  0.357523           0.975532 0.355066     5.279348
        Enhanced FELCM (Best θ)  TEST  global weighted 0.878220         0.880001      0.877157  0.874794            0.887978         0.878220     0.87952

# **ConvNeXt-T**

In [None]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class ConvNeXtTinyBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretrained else None
        net = models.convnext_tiny(weights=weights)
        self.features = net.features

    def forward(self, x):
        feats = []
        # features layout (torchvision convnext_tiny):
        # 0: stem, 1: stage1, 2: downsample, 3: stage2, 4: downsample,
        # 5: stage3, 6: downsample, 7: stage4
        x = self.features[0](x)
        x = self.features[1](x)
        feats.append(x)

        x = self.features[2](x)
        x = self.features[3](x)
        feats.append(x)

        x = self.features[4](x)
        x = self.features[5](x)
        feats.append(x)

        x = self.features[6](x)
        x = self.features[7](x)
        feats.append(x)
        return feats


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class ConvNeXtTiny_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = ConvNeXtTinyBackbone(pretrained=pretrained)
        in_channels = [96, 192, 384, 768]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = ConvNeXtTiny_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = ConvNeXtTiny_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to /root/.cache/torch/hub/checkpoints/convnext_tiny-983f1562.pth


100%|██████████| 109M/109M [00:00<00:00, 129MB/s]



Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.944708         0.798204      0.968297  0.817921            0.965771         0.944708     0.952090  0.218650                NaN 0.216986     3.179074
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.933628         0.933729      0.931392  0.930280            0.935778         0.933628     0.932495  0.227647           0.986707 0.226494     2.807986
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.931754         0.930454      0.926140  0.926722            0.932187         0.931754     0.930538  0.252154           0.990907 0.254234     8.024479
        Enhanced FELCM (Best θ)  TEST  global weighted 0.932084         0.931032      0.927067  0.927349            0.932820         0.932084     0.93088

# **Swin-T**

In [None]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models

try:
    import timm
except Exception:
    pip_install("timm")
    import timm


# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class SwinTinyBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.swin = timm.create_model(
            "swin_tiny_patch4_window7_224",
            pretrained=pretrained,
            features_only=True,
            out_indices=(0, 1, 2, 3),
        )

    @staticmethod
    def _to_nchw(feat):
        # timm Swin features are commonly NHWC (B, H, W, C) while CNN fusers expect NCHW.
        if feat.ndim != 4:
            return feat
        # If already NCHW, keep as-is.
        if feat.shape[1] in {96, 192, 384, 768} and feat.shape[-1] not in {96, 192, 384, 768}:
            return feat
        # Convert NHWC -> NCHW.
        if feat.shape[-1] in {96, 192, 384, 768}:
            return feat.permute(0, 3, 1, 2).contiguous()
        # Fallback: pick the smallest spatial dims as H/W and move channel dim to position 1.
        b, d1, d2, d3 = feat.shape
        spatial = sorted([(1, d1), (2, d2), (3, d3)], key=lambda x: x[1])[:2]
        sp_idx = {i for i, _ in spatial}
        ch_idx = [i for i in (1, 2, 3) if i not in sp_idx][0]
        order = [0, ch_idx] + sorted(sp_idx)
        return feat.permute(*order).contiguous()

    def forward(self, x):
        feats = self.swin(x)
        return [self._to_nchw(f) for f in feats]


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class SwinTiny_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = SwinTinyBackbone(pretrained=pretrained)
        in_channels = [96, 192, 384, 768]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = SwinTiny_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = SwinTiny_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]


Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.939968         0.762435      0.867149  0.774771            0.966980         0.939968     0.949871  0.234568                NaN 0.230608     3.481177
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.920354         0.922945      0.919402  0.919116            0.922960         0.920354     0.919661  0.286956           0.982562 0.286032     2.254558
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.925118         0.923275      0.918953  0.918404            0.925842         0.925118     0.923018  0.274266           0.981405 0.275690     7.025684
        Enhanced FELCM (Best θ)  TEST  global weighted 0.924278         0.923217      0.919032  0.918530            0.925334         0.924278     0.92242

# **ViT-B/16**

In [1]:
import os, time, math, random, sys, subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score,
)

# -------------------------
# Install deps (Colab)
# -------------------------
def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


def _safe_import_torch():
    """
    Robust torch import for notebook runtimes.
    Fixes common 'partially initialized module torch has no attribute nn'
    by clearing stale torch modules from sys.modules before import.
    """
    for k in list(sys.modules.keys()):
        if k == "torch" or k.startswith("torch."):
            sys.modules.pop(k, None)

    import torch  # noqa: F401
    import torch.nn as nn  # noqa: F401
    import torch.nn.functional as F  # noqa: F401
    from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler  # noqa: F401
    return torch, nn, F, Dataset, DataLoader, WeightedRandomSampler


try:
    torch, nn, F, Dataset, DataLoader, WeightedRandomSampler = _safe_import_torch()
except Exception as e:
    raise RuntimeError(
        "Torch import failed. In notebook runtimes, restart the runtime/kernel and run this script again. "
        "Also ensure there is no local file/folder named 'torch'. Original error: " + str(e)
    )

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

try:
    from torchvision import transforms, models
except Exception:
    pip_install("torchvision")
    from torchvision import transforms, models


# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"


IDENTITY_PRE = nn.Identity().to(DEVICE)


class ViTB16Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.ViT_B_16_Weights.IMAGENET1K_V1 if pretrained else None
        self.vit = models.vit_b_16(weights=weights)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.vit._process_input(x)
        n = x.shape[0]

        cls = self.vit.class_token.expand(n, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = self.vit.encoder(x)

        tok = x[:, 1:, :]
        gh = h // self.vit.patch_size
        gw = w // self.vit.patch_size
        feat = tok.transpose(1, 2).reshape(b, tok.shape[-1], gh, gw)

        # Keep the same 4-scale interface expected by the fuser.
        f1 = F.interpolate(feat, scale_factor=4.0, mode="bilinear", align_corners=False)
        f2 = F.interpolate(feat, scale_factor=2.0, mode="bilinear", align_corners=False)
        f3 = feat
        f4 = F.avg_pool2d(feat, kernel_size=2, stride=2)
        return [f1, f2, f3, f4]


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class ViTB16_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = ViTB16Backbone(pretrained=pretrained)
        in_channels = [768, 768, 768, 768]
        out_dim = 512

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []
    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost


def _safe_first_batch(dl):
    try:
        bx, by, *_ = next(iter(dl))
        return bx, by
    except Exception:
        return None, None


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = mutate(crossover(p1, p2), p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out


@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        met = {k: np.nan for k in [
            "loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted", "log_loss", "eval_time_s"
        ]}
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    return met, y_true, p_pred


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None, grad_clip=1.0):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader):
    if not pool:
        return None
    best, best_acc = None, -1
    for th in pool[:10]:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best


def weighted_aggregate(mets):
    if not mets:
        return {}
    keys = mets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m[0].get(k, np.nan) for m in mets]
        ws = [m[1] for m in mets]
        out[k] = float(np.average(vals, weights=ws))
    return out


# ========== DATA ==========
print("Downloading datasets...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not detect dataset roots.")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))

train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)

n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_loaders = []
for ds_name, local_id, gid, tr_idx, tune_idx, val_idx in client_splits:
    df_src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        source_id = 0 if ds_name == "ds1" else 1
        t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=base_gid + k)
        client_test_loaders.append((ds_name, base_gid + k, t_loader))

# ========== MODEL ==========
global_model = ViTB16_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
counts = counts1 + counts2
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

# ========== TRAIN (silent per-round) ==========
elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc = -1.0
best_model_state = None
best_theta_ds1 = None
best_theta_ds2 = None

for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool, use_separability=True)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            pre_k = IDENTITY_PRE

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = ViTB16_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=scheduler, scaler=scaler, grad_clip=CFG["grad_clip"])

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k)
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))
        local_rows.append((met_loc, len(val_loader.dataset)))

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    if CFG["use_preprocessing"] and elite_pool_ds1:
        best_theta_ds1 = pick_best_theta_from_pool(global_model, elite_pool_ds1, client_loaders[0][2])
    if CFG["use_preprocessing"] and elite_pool_ds2:
        best_theta_ds2 = pick_best_theta_from_pool(global_model, elite_pool_ds2, client_loaders[n_per_ds][2])

    global_metrics = weighted_aggregate(local_rows)
    if np.isfinite(global_metrics.get("acc", np.nan)) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# ========== FINAL METRICS ONLY ==========
pre_best_ds1 = theta_to_module(best_theta_ds1).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds1 is not None) else IDENTITY_PRE
pre_best_ds2 = theta_to_module(best_theta_ds2).to(DEVICE) if (CFG["use_preprocessing"] and best_theta_ds2 is not None) else IDENTITY_PRE

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = pre_best_ds1 if k < n_per_ds else pre_best_ds2
    met, _, _ = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = pre_best_ds1 if ds == "ds1" else pre_best_ds2
        met, _, _ = evaluate_full(global_model, t_loader, pre)
        mets.append((met, len(t_loader.dataset)))
    return weighted_aggregate(mets)


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([(test_ds1, len(test1)), (test_ds2, len(test2))])

columns = [
    "setting", "split", "dataset",
    "acc", "precision_macro", "recall_macro", "f1_macro",
    "precision_weighted", "recall_weighted", "f1_weighted",
    "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
]

rows = [
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **val_best},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **test_ds1},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **test_ds2},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **global_test},
]

final_df = pd.DataFrame(rows)
for c in columns:
    if c not in final_df.columns:
        final_df[c] = np.nan
final_df = final_df[columns]

print("\nFinal output metrics:")
print(final_df.to_string(index=False))

  self.setter(val)


Downloading datasets...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:01<00:00, 199MB/s]



Final output metrics:
                        setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Enhanced FELCM (Best θ ds1/ds2)   VAL ds1+ds2 weighted 0.927330         0.758157      0.954026  0.774016            0.960339         0.927330     0.938602  0.276077                NaN 0.273979     7.235911
    Enhanced FELCM (Best θ ds1)  TEST              ds1 0.915929         0.925593      0.916968  0.915256            0.926273         0.915929     0.915114  0.339266           0.986697 0.343778     3.821958
    Enhanced FELCM (Best θ ds2)  TEST              ds2 0.937441         0.937293      0.932836  0.933625            0.938009         0.937441     0.936365  0.247798           0.988816 0.248896    13.663844
        Enhanced FELCM (Best θ)  TEST  global weighted 0.933646         0.935229      0.930037  0.930384            0.935939         0.933646     0.93261