Backbone scale ablation (PVTv2 family)

PVTv2-B0

PVTv2-B1

PVTv2-B3

PVTv2-B4

(optional) PVTv2-B5

# **PVTv2-B0**

In [2]:
import os, time, math, random, sys, subprocess, hashlib
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

# ============================================================
# TRUE FL + GA-FELCM + PVTv2-B0 (FUSION) — 6 Clients (3+3)
# Preprocessing ON + Augmentation ON + Fusion ON
# CHANGES REQUESTED:
#  - backbone -> PVTv2-B0
#  - rounds -> 12
#  - output -> only VAL/TEST metrics
# Saves:
#  - checkpoint
#  - one CSV with only VAL/TEST metrics
# ============================================================


def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,  # requested
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTDIR, "FL_GAFELCM_PVTv2B0_FUSION_checkpoint.pth")
CSV_PATH = os.path.join(OUTDIR, "VAL_TEST_METRICS_ONLY.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "notumor" in s or "no tumor" in s or "no_tumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(os.path.basename)
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


print("Downloading datasets via kagglehub...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not locate required dataset roots")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if not idxs:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        counts[np.argmax(props)] += len(idxs) - counts.sum()
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start : start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(idxs, test_size=tune_frac, stratify=yk, random_state=SEED)

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED)

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_test_splits = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([transforms.Resize((CFG["img_size"], CFG["img_size"])), transforms.ToTensor()])
if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        return x, int(row["y"]), row["path"], self.source_id, self.client_id


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(ds, batch_size=bs, shuffle=(shuffle and sampler is None), sampler=sampler,
                      num_workers=CFG["num_workers"], pin_memory=(DEVICE.type == "cuda"), drop_last=False,
                      persistent_workers=(CFG["num_workers"] > 0))


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    return WeightedRandomSampler(torch.DoubleTensor(class_weights[ys]), num_samples=len(ys), replacement=True)


client_loaders = []
for ds_name, _, gid, tr_idx, tune_idx, val_idx in client_splits:
    src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(src, tune_idx if tune_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(src, val_idx if val_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, _, gid, test_idx in client_test_splits:
    src = test1 if ds_name == "ds1" else test2
    source_id = 0 if ds_name == "ds1" else 1
    client_test_loaders.append((ds_name, gid, make_loader(src, test_idx, CFG["batch_size"], EVAL_TFMS, source_id=source_id, client_id=gid)))


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)
        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))
        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise
        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)
        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()
        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        x2 = x1 + self.alpha * torch.tanh(self.beta * (lap / (blur + eps)))
        if self.sharpen > 0:
            outs = []
            for c in range(x2.shape[1]):
                xc = x2[:, c:c+1]
                xs = F.conv2d(F.pad(xc, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(xc * (1 - self.sharpen) + xs * self.sharpen)
            x2 = torch.cat(outs, dim=1)
        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        return ((x2 - mn) / (mx - mn + eps)).clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(nn.Conv2d(c, out_dim, 1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU()) for c in in_channels
        ])
        self.fuse = nn.Sequential(nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU())
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(nn.Linear(dim, max(8, dim // 4)), nn.ReLU(inplace=True), nn.Linear(max(8, dim // 4), dim), nn.Sigmoid())
        self.refine = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


BACKBONE_NAME = "pvt_v2_b0"  # requested


class PVTv2B0_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model(BACKBONE_NAME, pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)
        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim), nn.Dropout(head_dropout), nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(), nn.Dropout(head_dropout * 0.5), nn.Linear(max(64, out_dim // 2), num_classes)
        )
        self.theta_mlp = nn.Sequential(nn.Linear(7, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim))
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)
        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)
        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)
        f0, f1 = self.fuser(feats0), self.fuser(feats1)
        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1
        t0, t1, t_mid = self.tuner(f0), self.tuner(f1), self.tuner(f_mid)
        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views
        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (bb_params if n.startswith("backbone.") else head_params).append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma, preproc_module.alpha, preproc_module.beta, preproc_module.tau,
            float(preproc_module.blur_k) / 7.0, preproc_module.sharpen, preproc_module.denoise
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    model.eval()
    preproc_module.eval()
    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)
        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)
        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if not all_y:
        return {k: np.nan for k in ["loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro", "log_loss"]}

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)
    out = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
    }
    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        pass
    return out


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    losses = []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()
        losses.append(float(loss.item()))

    return float(np.mean(losses))


def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)
    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())
    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    y = y.long()
    classes = torch.unique(y)
    sep = 0.0
    if len(classes) >= 2:
        centroids, vars_, sizes = [], [], []
        for c in classes:
            e = emb[y == c]
            if e.size(0) < 2:
                continue
            mu = e.mean(dim=0)
            centroids.append(mu)
            vars_.append((e - mu).pow(2).sum(dim=1).mean().item())
            sizes.append(e.size(0))
        if len(centroids) >= 2:
            centroids = torch.stack(centroids, dim=0)
            gm = centroids.mean(dim=0)
            between = sum(n * (c - gm).pow(2).sum().item() for c, n in zip(centroids, sizes))
            within = float(np.mean(vars_)) if vars_ else 1e-6
            sep = float(between / (within + 1e-6))
    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None, [], 0.0
    pop = elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)]
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())
    bx, by = bx[: CFG["batch_size"]].contiguous(), by[: CFG["batch_size"]].contiguous()
    for _ in range(CFG["ga_gens"]):
        scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop
    scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
    return scored[0][1], [th for _, th in scored[: CFG["ga_elites"]]], float(scored[0][0])


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


def weighted_aggregate(mets):
    total = sum(w for _, _, w in mets)
    if total == 0:
        return {}
    keys = mets[0][1].keys()
    return {k: float(np.average([m[1].get(k, np.nan) for m in mets], weights=[m[2] for m in mets])) for k in keys}


print("Initializing global model...")
global_model = PVTv2B0_MultiScale(NUM_CLASSES, pretrained=True, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
set_trainable_for_round(global_model, 1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values + \
         train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
criterion = nn.CrossEntropyLoss(weight=torch.tensor(w, device=DEVICE), label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc, best_round_saved = -1.0, None
best_model_state = None

print(f"Training TRUE FL for {CFG['rounds']} rounds with backbone={BACKBONE_NAME} ...")
for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []
    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
        else:
            best_theta = None
            pre_k = nn.Identity().to(DEVICE)

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = PVTv2B0_MultiScale(NUM_CLASSES, pretrained=False, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)
        set_trainable_for_round(local_model, rnd)
        opt = make_optimizer(local_model)
        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        sched = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=sched, scaler=scaler)

        met = evaluate_full(local_model, val_loader, pre_k)
        local_rows.append({"val_acc": met.get("acc", np.nan), "val_size": len(val_loader.dataset)})
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))

    wsum = sum(local_weights)
    fedavg_update(global_model, local_models, [w / wsum for w in local_weights], [n for n, p in local_models[0].named_parameters() if p.requires_grad])

    lv = pd.DataFrame(local_rows)
    round_acc = float(np.average(lv["val_acc"], weights=lv["val_size"])) if lv["val_size"].sum() > 0 else np.nan
    if np.isfinite(round_acc) and round_acc > best_global_acc:
        best_global_acc = round_acc
        best_round_saved = rnd
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

print(f"Done training. Best federated VAL accuracy={best_global_acc:.4f} at round {best_round_saved}")

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# Re-run GA with the restored best global model so final VAL/TEST uses GA-FELCM (not Identity preproc)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

best_theta_by_client = {}
for k in range(CFG["clients_total"]):
    _, tune_loader, _ = client_loaders[k]
    ds_name = "ds1" if k < n_per_ds else "ds2"
    elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2
    th, _, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
    best_theta_by_client[k] = th


def preproc_for_client(client_id: int):
    th = best_theta_by_client.get(client_id)
    return theta_to_module(th).to(DEVICE) if th is not None else nn.Identity().to(DEVICE)


val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = preproc_for_client(k)
    met = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((k, met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = preproc_for_client(gid)
        met = evaluate_full(global_model, t_loader, pre)
        mets.append((gid, met, len(t_loader.dataset)))
    return weighted_aggregate(mets) if mets else {}


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([
    (0, test_ds1, len(test1)),
    (1, test_ds2, len(test2)),
])


def compact_metrics(m):
    keep = ["acc", "precision_macro", "recall_macro", "f1_macro", "log_loss", "loss_ce", "auc_roc_macro_ovr"]
    return {k: float(m[k]) for k in keep if k in m}


val_test_df = pd.DataFrame([
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **compact_metrics(val_best)},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **compact_metrics(test_ds1)},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **compact_metrics(test_ds2)},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **compact_metrics(global_test)},
])

print("\nFINAL OUTPUT (VAL/TEST metrics only):")
print(val_test_df)

checkpoint = {
    "state_dict": {k: v.detach().cpu() for k, v in global_model.state_dict().items()},
    "config": CFG,
    "seed": SEED,
    "device_used": str(DEVICE),
    "dataset1_raw_root": DS1_ROOT,
    "dataset2_root": DS2_ROOT,
    "labels": labels,
    "label2id": label2id,
    "id2label": id2label,
    "num_classes": NUM_CLASSES,
    "backbone_name": BACKBONE_NAME,
    "best_round_saved": best_round_saved,
    "best_val_acc": best_global_acc,
    "final_val_federated": val_best,
    "final_test_ds1": test_ds1,
    "final_test_ds2": test_ds2,
    "final_test_global_weighted": global_test,
}

torch.save(checkpoint, MODEL_PATH)
val_test_df.to_csv(CSV_PATH, index=False)
print(f"✅ Saved checkpoint: {MODEL_PATH}")
print(f"✅ Saved VAL/TEST-only CSV: {CSV_PATH}")

Downloading datasets via kagglehub...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Initializing global model...


model.safetensors:   0%|          | 0.00/14.7M [00:00<?, ?B/s]

Training TRUE FL for 12 rounds with backbone=pvt_v2_b0 ...
Done training. Best federated VAL accuracy=0.9700 at round 10

FINAL OUTPUT (VAL/TEST metrics only):
                           setting split           dataset       acc  \
0  Enhanced FELCM (Best θ ds1/ds2)   VAL  ds1+ds2 weighted  0.909953   
1      Enhanced FELCM (Best θ ds1)  TEST               ds1  0.898230   
2      Enhanced FELCM (Best θ ds2)  TEST               ds2  0.923223   
3          Enhanced FELCM (Best θ)  TEST   global weighted  0.918813   

   precision_macro  recall_macro  f1_macro  log_loss   loss_ce  \
0         0.761724      0.938551  0.759451  0.277318  0.270269   
1         0.902717      0.898933  0.896712  0.380071  0.379986   
2         0.918617      0.917311  0.916350  0.271831  0.275057   
3         0.915812      0.914069  0.912885  0.290927  0.293569   

   auc_roc_macro_ovr  
0                NaN  
1           0.981358  
2           0.985628  
3           0.984875  
✅ Saved checkpoint: /content/outp

# **PVTv2-B1**



In [3]:
import os, time, math, random, sys, subprocess, hashlib
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

# ============================================================
# TRUE FL + GA-FELCM + PVTv2-B1 (FUSION) — 6 Clients (3+3)
# Preprocessing ON + Augmentation ON + Fusion ON
# CHANGES REQUESTED:
#  - backbone -> PVTv2-B1
#  - rounds -> 12
#  - output -> only VAL/TEST metrics
# Saves:
#  - checkpoint
#  - one CSV with only VAL/TEST metrics
# ============================================================


def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,  # requested
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTDIR, "FL_GAFELCM_PVTv2B1_FUSION_checkpoint.pth")
CSV_PATH = os.path.join(OUTDIR, "VAL_TEST_METRICS_ONLY.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "notumor" in s or "no tumor" in s or "no_tumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(os.path.basename)
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


print("Downloading datasets via kagglehub...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not locate required dataset roots")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if not idxs:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        counts[np.argmax(props)] += len(idxs) - counts.sum()
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start : start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(idxs, test_size=tune_frac, stratify=yk, random_state=SEED)

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED)

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_test_splits = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([transforms.Resize((CFG["img_size"], CFG["img_size"])), transforms.ToTensor()])
if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        return x, int(row["y"]), row["path"], self.source_id, self.client_id


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(ds, batch_size=bs, shuffle=(shuffle and sampler is None), sampler=sampler,
                      num_workers=CFG["num_workers"], pin_memory=(DEVICE.type == "cuda"), drop_last=False,
                      persistent_workers=(CFG["num_workers"] > 0))


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    return WeightedRandomSampler(torch.DoubleTensor(class_weights[ys]), num_samples=len(ys), replacement=True)


client_loaders = []
for ds_name, _, gid, tr_idx, tune_idx, val_idx in client_splits:
    src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(src, tune_idx if tune_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(src, val_idx if val_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, _, gid, test_idx in client_test_splits:
    src = test1 if ds_name == "ds1" else test2
    source_id = 0 if ds_name == "ds1" else 1
    client_test_loaders.append((ds_name, gid, make_loader(src, test_idx, CFG["batch_size"], EVAL_TFMS, source_id=source_id, client_id=gid)))


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)
        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))
        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise
        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)
        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()
        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        x2 = x1 + self.alpha * torch.tanh(self.beta * (lap / (blur + eps)))
        if self.sharpen > 0:
            outs = []
            for c in range(x2.shape[1]):
                xc = x2[:, c:c+1]
                xs = F.conv2d(F.pad(xc, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(xc * (1 - self.sharpen) + xs * self.sharpen)
            x2 = torch.cat(outs, dim=1)
        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        return ((x2 - mn) / (mx - mn + eps)).clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(nn.Conv2d(c, out_dim, 1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU()) for c in in_channels
        ])
        self.fuse = nn.Sequential(nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU())
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(nn.Linear(dim, max(8, dim // 4)), nn.ReLU(inplace=True), nn.Linear(max(8, dim // 4), dim), nn.Sigmoid())
        self.refine = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


BACKBONE_NAME = "pvt_v2_b1"  # requested


class PVTv2B1_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model(BACKBONE_NAME, pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)
        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim), nn.Dropout(head_dropout), nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(), nn.Dropout(head_dropout * 0.5), nn.Linear(max(64, out_dim // 2), num_classes)
        )
        self.theta_mlp = nn.Sequential(nn.Linear(7, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim))
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)
        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)
        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)
        f0, f1 = self.fuser(feats0), self.fuser(feats1)
        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1
        t0, t1, t_mid = self.tuner(f0), self.tuner(f1), self.tuner(f_mid)
        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views
        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (bb_params if n.startswith("backbone.") else head_params).append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma, preproc_module.alpha, preproc_module.beta, preproc_module.tau,
            float(preproc_module.blur_k) / 7.0, preproc_module.sharpen, preproc_module.denoise
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    model.eval()
    preproc_module.eval()
    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)
        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)
        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if not all_y:
        return {k: np.nan for k in ["loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro", "log_loss"]}

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)
    out = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
    }
    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        pass
    return out


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    losses = []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()
        losses.append(float(loss.item()))

    return float(np.mean(losses))


def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)
    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())
    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    y = y.long()
    classes = torch.unique(y)
    sep = 0.0
    if len(classes) >= 2:
        centroids, vars_, sizes = [], [], []
        for c in classes:
            e = emb[y == c]
            if e.size(0) < 2:
                continue
            mu = e.mean(dim=0)
            centroids.append(mu)
            vars_.append((e - mu).pow(2).sum(dim=1).mean().item())
            sizes.append(e.size(0))
        if len(centroids) >= 2:
            centroids = torch.stack(centroids, dim=0)
            gm = centroids.mean(dim=0)
            between = sum(n * (c - gm).pow(2).sum().item() for c, n in zip(centroids, sizes))
            within = float(np.mean(vars_)) if vars_ else 1e-6
            sep = float(between / (within + 1e-6))
    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None, [], 0.0
    pop = elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)]
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())
    bx, by = bx[: CFG["batch_size"]].contiguous(), by[: CFG["batch_size"]].contiguous()
    for _ in range(CFG["ga_gens"]):
        scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop
    scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
    return scored[0][1], [th for _, th in scored[: CFG["ga_elites"]]], float(scored[0][0])


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


def weighted_aggregate(mets):
    total = sum(w for _, _, w in mets)
    if total == 0:
        return {}
    keys = mets[0][1].keys()
    return {k: float(np.average([m[1].get(k, np.nan) for m in mets], weights=[m[2] for m in mets])) for k in keys}


print("Initializing global model...")
global_model = PVTv2B1_MultiScale(NUM_CLASSES, pretrained=True, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
set_trainable_for_round(global_model, 1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values + \
         train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
criterion = nn.CrossEntropyLoss(weight=torch.tensor(w, device=DEVICE), label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc, best_round_saved = -1.0, None
best_model_state = None

print(f"Training TRUE FL for {CFG['rounds']} rounds with backbone={BACKBONE_NAME} ...")
for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []
    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
        else:
            best_theta = None
            pre_k = nn.Identity().to(DEVICE)

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = PVTv2B1_MultiScale(NUM_CLASSES, pretrained=False, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)
        set_trainable_for_round(local_model, rnd)
        opt = make_optimizer(local_model)
        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        sched = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=sched, scaler=scaler)

        met = evaluate_full(local_model, val_loader, pre_k)
        local_rows.append({"val_acc": met.get("acc", np.nan), "val_size": len(val_loader.dataset)})
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))

    wsum = sum(local_weights)
    fedavg_update(global_model, local_models, [w / wsum for w in local_weights], [n for n, p in local_models[0].named_parameters() if p.requires_grad])

    lv = pd.DataFrame(local_rows)
    round_acc = float(np.average(lv["val_acc"], weights=lv["val_size"])) if lv["val_size"].sum() > 0 else np.nan
    if np.isfinite(round_acc) and round_acc > best_global_acc:
        best_global_acc = round_acc
        best_round_saved = rnd
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

print(f"Done training. Best federated VAL accuracy={best_global_acc:.4f} at round {best_round_saved}")

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# Re-run GA with the restored best global model so final VAL/TEST uses GA-FELCM (not Identity preproc)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

best_theta_by_client = {}
for k in range(CFG["clients_total"]):
    _, tune_loader, _ = client_loaders[k]
    ds_name = "ds1" if k < n_per_ds else "ds2"
    elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2
    th, _, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
    best_theta_by_client[k] = th


def preproc_for_client(client_id: int):
    th = best_theta_by_client.get(client_id)
    return theta_to_module(th).to(DEVICE) if th is not None else nn.Identity().to(DEVICE)


val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = preproc_for_client(k)
    met = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((k, met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = preproc_for_client(gid)
        met = evaluate_full(global_model, t_loader, pre)
        mets.append((gid, met, len(t_loader.dataset)))
    return weighted_aggregate(mets) if mets else {}


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([
    (0, test_ds1, len(test1)),
    (1, test_ds2, len(test2)),
])


def compact_metrics(m):
    keep = ["acc", "precision_macro", "recall_macro", "f1_macro", "log_loss", "loss_ce", "auc_roc_macro_ovr"]
    return {k: float(m[k]) for k in keep if k in m}


val_test_df = pd.DataFrame([
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **compact_metrics(val_best)},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **compact_metrics(test_ds1)},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **compact_metrics(test_ds2)},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **compact_metrics(global_test)},
])

print("\nFINAL OUTPUT (VAL/TEST metrics only):")
print(val_test_df)

checkpoint = {
    "state_dict": {k: v.detach().cpu() for k, v in global_model.state_dict().items()},
    "config": CFG,
    "seed": SEED,
    "device_used": str(DEVICE),
    "dataset1_raw_root": DS1_ROOT,
    "dataset2_root": DS2_ROOT,
    "labels": labels,
    "label2id": label2id,
    "id2label": id2label,
    "num_classes": NUM_CLASSES,
    "backbone_name": BACKBONE_NAME,
    "best_round_saved": best_round_saved,
    "best_val_acc": best_global_acc,
    "final_val_federated": val_best,
    "final_test_ds1": test_ds1,
    "final_test_ds2": test_ds2,
    "final_test_global_weighted": global_test,
}

torch.save(checkpoint, MODEL_PATH)
val_test_df.to_csv(CSV_PATH, index=False)
print(f"✅ Saved checkpoint: {MODEL_PATH}")
print(f"✅ Saved VAL/TEST-only CSV: {CSV_PATH}")

Downloading datasets via kagglehub...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Initializing global model...


model.safetensors:   0%|          | 0.00/56.1M [00:00<?, ?B/s]

Training TRUE FL for 12 rounds with backbone=pvt_v2_b1 ...
Done training. Best federated VAL accuracy=0.9779 at round 10

FINAL OUTPUT (VAL/TEST metrics only):
                           setting split           dataset       acc  \
0  Enhanced FELCM (Best θ ds1/ds2)   VAL  ds1+ds2 weighted  0.960506   
1      Enhanced FELCM (Best θ ds1)  TEST               ds1  0.942478   
2      Enhanced FELCM (Best θ ds2)  TEST               ds2  0.936493   
3          Enhanced FELCM (Best θ)  TEST   global weighted  0.937549   

   precision_macro  recall_macro  f1_macro  log_loss   loss_ce  \
0         0.826924      0.956931  0.853176  0.168795  0.163554   
1         0.944453      0.944455  0.942591  0.222971  0.215454   
2         0.934290      0.936575  0.931809  0.225085  0.225703   
3         0.936083      0.937965  0.933711  0.224712  0.223895   

   auc_roc_macro_ovr  
0                NaN  
1           0.985592  
2           0.993379  
3           0.992005  
✅ Saved checkpoint: /content/outp

# **PVTv2-B3**



In [4]:
import os, time, math, random, sys, subprocess, hashlib
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

# ============================================================
# TRUE FL + GA-FELCM + PVTv2-B3 (FUSION) — 6 Clients (3+3)
# Preprocessing ON + Augmentation ON + Fusion ON
# CHANGES REQUESTED:
#  - backbone -> PVTv2-B3
#  - rounds -> 12
#  - output -> only VAL/TEST metrics
# Saves:
#  - checkpoint
#  - one CSV with only VAL/TEST metrics
# ============================================================


def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,  # requested
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTDIR, "FL_GAFELCM_PVTv2B3_FUSION_checkpoint.pth")
CSV_PATH = os.path.join(OUTDIR, "VAL_TEST_METRICS_ONLY.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "notumor" in s or "no tumor" in s or "no_tumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(os.path.basename)
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


print("Downloading datasets via kagglehub...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not locate required dataset roots")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if not idxs:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        counts[np.argmax(props)] += len(idxs) - counts.sum()
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start : start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(idxs, test_size=tune_frac, stratify=yk, random_state=SEED)

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED)

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_test_splits = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([transforms.Resize((CFG["img_size"], CFG["img_size"])), transforms.ToTensor()])
if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        return x, int(row["y"]), row["path"], self.source_id, self.client_id


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(ds, batch_size=bs, shuffle=(shuffle and sampler is None), sampler=sampler,
                      num_workers=CFG["num_workers"], pin_memory=(DEVICE.type == "cuda"), drop_last=False,
                      persistent_workers=(CFG["num_workers"] > 0))


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    return WeightedRandomSampler(torch.DoubleTensor(class_weights[ys]), num_samples=len(ys), replacement=True)


client_loaders = []
for ds_name, _, gid, tr_idx, tune_idx, val_idx in client_splits:
    src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(src, tune_idx if tune_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(src, val_idx if val_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, _, gid, test_idx in client_test_splits:
    src = test1 if ds_name == "ds1" else test2
    source_id = 0 if ds_name == "ds1" else 1
    client_test_loaders.append((ds_name, gid, make_loader(src, test_idx, CFG["batch_size"], EVAL_TFMS, source_id=source_id, client_id=gid)))


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)
        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))
        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise
        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)
        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()
        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        x2 = x1 + self.alpha * torch.tanh(self.beta * (lap / (blur + eps)))
        if self.sharpen > 0:
            outs = []
            for c in range(x2.shape[1]):
                xc = x2[:, c:c+1]
                xs = F.conv2d(F.pad(xc, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(xc * (1 - self.sharpen) + xs * self.sharpen)
            x2 = torch.cat(outs, dim=1)
        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        return ((x2 - mn) / (mx - mn + eps)).clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(nn.Conv2d(c, out_dim, 1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU()) for c in in_channels
        ])
        self.fuse = nn.Sequential(nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU())
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(nn.Linear(dim, max(8, dim // 4)), nn.ReLU(inplace=True), nn.Linear(max(8, dim // 4), dim), nn.Sigmoid())
        self.refine = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


BACKBONE_NAME = "pvt_v2_b3"  # requested


class PVTv2B3_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model(BACKBONE_NAME, pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)
        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim), nn.Dropout(head_dropout), nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(), nn.Dropout(head_dropout * 0.5), nn.Linear(max(64, out_dim // 2), num_classes)
        )
        self.theta_mlp = nn.Sequential(nn.Linear(7, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim))
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)
        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)
        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)
        f0, f1 = self.fuser(feats0), self.fuser(feats1)
        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1
        t0, t1, t_mid = self.tuner(f0), self.tuner(f1), self.tuner(f_mid)
        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views
        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (bb_params if n.startswith("backbone.") else head_params).append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma, preproc_module.alpha, preproc_module.beta, preproc_module.tau,
            float(preproc_module.blur_k) / 7.0, preproc_module.sharpen, preproc_module.denoise
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    model.eval()
    preproc_module.eval()
    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)
        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)
        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if not all_y:
        return {k: np.nan for k in ["loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro", "log_loss"]}

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)
    out = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
    }
    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        pass
    return out


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    losses = []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()
        losses.append(float(loss.item()))

    return float(np.mean(losses))


def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)
    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())
    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    y = y.long()
    classes = torch.unique(y)
    sep = 0.0
    if len(classes) >= 2:
        centroids, vars_, sizes = [], [], []
        for c in classes:
            e = emb[y == c]
            if e.size(0) < 2:
                continue
            mu = e.mean(dim=0)
            centroids.append(mu)
            vars_.append((e - mu).pow(2).sum(dim=1).mean().item())
            sizes.append(e.size(0))
        if len(centroids) >= 2:
            centroids = torch.stack(centroids, dim=0)
            gm = centroids.mean(dim=0)
            between = sum(n * (c - gm).pow(2).sum().item() for c, n in zip(centroids, sizes))
            within = float(np.mean(vars_)) if vars_ else 1e-6
            sep = float(between / (within + 1e-6))
    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None, [], 0.0
    pop = elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)]
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())
    bx, by = bx[: CFG["batch_size"]].contiguous(), by[: CFG["batch_size"]].contiguous()
    for _ in range(CFG["ga_gens"]):
        scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop
    scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
    return scored[0][1], [th for _, th in scored[: CFG["ga_elites"]]], float(scored[0][0])


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


def weighted_aggregate(mets):
    total = sum(w for _, _, w in mets)
    if total == 0:
        return {}
    keys = mets[0][1].keys()
    return {k: float(np.average([m[1].get(k, np.nan) for m in mets], weights=[m[2] for m in mets])) for k in keys}


print("Initializing global model...")
global_model = PVTv2B3_MultiScale(NUM_CLASSES, pretrained=True, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
set_trainable_for_round(global_model, 1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values + \
         train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
criterion = nn.CrossEntropyLoss(weight=torch.tensor(w, device=DEVICE), label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc, best_round_saved = -1.0, None
best_model_state = None

print(f"Training TRUE FL for {CFG['rounds']} rounds with backbone={BACKBONE_NAME} ...")
for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []
    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
        else:
            best_theta = None
            pre_k = nn.Identity().to(DEVICE)

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = PVTv2B3_MultiScale(NUM_CLASSES, pretrained=False, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)
        set_trainable_for_round(local_model, rnd)
        opt = make_optimizer(local_model)
        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        sched = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=sched, scaler=scaler)

        met = evaluate_full(local_model, val_loader, pre_k)
        local_rows.append({"val_acc": met.get("acc", np.nan), "val_size": len(val_loader.dataset)})
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))

    wsum = sum(local_weights)
    fedavg_update(global_model, local_models, [w / wsum for w in local_weights], [n for n, p in local_models[0].named_parameters() if p.requires_grad])

    lv = pd.DataFrame(local_rows)
    round_acc = float(np.average(lv["val_acc"], weights=lv["val_size"])) if lv["val_size"].sum() > 0 else np.nan
    if np.isfinite(round_acc) and round_acc > best_global_acc:
        best_global_acc = round_acc
        best_round_saved = rnd
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

print(f"Done training. Best federated VAL accuracy={best_global_acc:.4f} at round {best_round_saved}")

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# Re-run GA with the restored best global model so final VAL/TEST uses GA-FELCM (not Identity preproc)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

best_theta_by_client = {}
for k in range(CFG["clients_total"]):
    _, tune_loader, _ = client_loaders[k]
    ds_name = "ds1" if k < n_per_ds else "ds2"
    elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2
    th, _, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
    best_theta_by_client[k] = th


def preproc_for_client(client_id: int):
    th = best_theta_by_client.get(client_id)
    return theta_to_module(th).to(DEVICE) if th is not None else nn.Identity().to(DEVICE)


val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = preproc_for_client(k)
    met = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((k, met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = preproc_for_client(gid)
        met = evaluate_full(global_model, t_loader, pre)
        mets.append((gid, met, len(t_loader.dataset)))
    return weighted_aggregate(mets) if mets else {}


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([
    (0, test_ds1, len(test1)),
    (1, test_ds2, len(test2)),
])


def compact_metrics(m):
    keep = ["acc", "precision_macro", "recall_macro", "f1_macro", "log_loss", "loss_ce", "auc_roc_macro_ovr"]
    return {k: float(m[k]) for k in keep if k in m}


val_test_df = pd.DataFrame([
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **compact_metrics(val_best)},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **compact_metrics(test_ds1)},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **compact_metrics(test_ds2)},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **compact_metrics(global_test)},
])

print("\nFINAL OUTPUT (VAL/TEST metrics only):")
print(val_test_df)

checkpoint = {
    "state_dict": {k: v.detach().cpu() for k, v in global_model.state_dict().items()},
    "config": CFG,
    "seed": SEED,
    "device_used": str(DEVICE),
    "dataset1_raw_root": DS1_ROOT,
    "dataset2_root": DS2_ROOT,
    "labels": labels,
    "label2id": label2id,
    "id2label": id2label,
    "num_classes": NUM_CLASSES,
    "backbone_name": BACKBONE_NAME,
    "best_round_saved": best_round_saved,
    "best_val_acc": best_global_acc,
    "final_val_federated": val_best,
    "final_test_ds1": test_ds1,
    "final_test_ds2": test_ds2,
    "final_test_global_weighted": global_test,
}

torch.save(checkpoint, MODEL_PATH)
val_test_df.to_csv(CSV_PATH, index=False)
print(f"✅ Saved checkpoint: {MODEL_PATH}")
print(f"✅ Saved VAL/TEST-only CSV: {CSV_PATH}")

Downloading datasets via kagglehub...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Initializing global model...




model.safetensors:   0%|          | 0.00/181M [00:00<?, ?B/s]

Training TRUE FL for 12 rounds with backbone=pvt_v2_b3 ...




Done training. Best federated VAL accuracy=0.9842 at round 12

FINAL OUTPUT (VAL/TEST metrics only):
                           setting split           dataset       acc  \
0  Enhanced FELCM (Best θ ds1/ds2)   VAL  ds1+ds2 weighted  0.946288   
1      Enhanced FELCM (Best θ ds1)  TEST               ds1  0.938053   
2      Enhanced FELCM (Best θ ds2)  TEST               ds2  0.949763   
3          Enhanced FELCM (Best θ)  TEST   global weighted  0.947697   

   precision_macro  recall_macro  f1_macro  log_loss   loss_ce  \
0         0.765838      0.943700  0.777604  0.199099  0.194695   
1         0.943312      0.933781  0.935748  0.272568  0.270744   
2         0.948912      0.945429  0.945884  0.184047  0.185038   
3         0.947924      0.943374  0.944096  0.199664  0.200159   

   auc_roc_macro_ovr  
0                NaN  
1           0.983110  
2           0.992767  
3           0.991063  
✅ Saved checkpoint: /content/outputs/FL_GAFELCM_PVTv2B3_FUSION_checkpoint.pth
✅ Saved VAL/TE

# **PVTv2-B4**



In [5]:
import os, time, math, random, sys, subprocess, hashlib
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

# ============================================================
# TRUE FL + GA-FELCM + PVTv2-B4 (FUSION) — 6 Clients (3+3)
# Preprocessing ON + Augmentation ON + Fusion ON
# CHANGES REQUESTED:
#  - backbone -> PVTv2-B4
#  - rounds -> 12
#  - output -> only VAL/TEST metrics
# Saves:
#  - checkpoint
#  - one CSV with only VAL/TEST metrics
# ============================================================


def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,  # requested
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTDIR, "FL_GAFELCM_PVTv2B4_FUSION_checkpoint.pth")
CSV_PATH = os.path.join(OUTDIR, "VAL_TEST_METRICS_ONLY.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "notumor" in s or "no tumor" in s or "no_tumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(os.path.basename)
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


print("Downloading datasets via kagglehub...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not locate required dataset roots")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if not idxs:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        counts[np.argmax(props)] += len(idxs) - counts.sum()
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start : start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(idxs, test_size=tune_frac, stratify=yk, random_state=SEED)

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED)

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_test_splits = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([transforms.Resize((CFG["img_size"], CFG["img_size"])), transforms.ToTensor()])
if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        return x, int(row["y"]), row["path"], self.source_id, self.client_id


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(ds, batch_size=bs, shuffle=(shuffle and sampler is None), sampler=sampler,
                      num_workers=CFG["num_workers"], pin_memory=(DEVICE.type == "cuda"), drop_last=False,
                      persistent_workers=(CFG["num_workers"] > 0))


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    return WeightedRandomSampler(torch.DoubleTensor(class_weights[ys]), num_samples=len(ys), replacement=True)


client_loaders = []
for ds_name, _, gid, tr_idx, tune_idx, val_idx in client_splits:
    src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(src, tune_idx if tune_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(src, val_idx if val_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, _, gid, test_idx in client_test_splits:
    src = test1 if ds_name == "ds1" else test2
    source_id = 0 if ds_name == "ds1" else 1
    client_test_loaders.append((ds_name, gid, make_loader(src, test_idx, CFG["batch_size"], EVAL_TFMS, source_id=source_id, client_id=gid)))


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)
        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))
        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise
        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)
        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()
        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        x2 = x1 + self.alpha * torch.tanh(self.beta * (lap / (blur + eps)))
        if self.sharpen > 0:
            outs = []
            for c in range(x2.shape[1]):
                xc = x2[:, c:c+1]
                xs = F.conv2d(F.pad(xc, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(xc * (1 - self.sharpen) + xs * self.sharpen)
            x2 = torch.cat(outs, dim=1)
        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        return ((x2 - mn) / (mx - mn + eps)).clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(nn.Conv2d(c, out_dim, 1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU()) for c in in_channels
        ])
        self.fuse = nn.Sequential(nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU())
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(nn.Linear(dim, max(8, dim // 4)), nn.ReLU(inplace=True), nn.Linear(max(8, dim // 4), dim), nn.Sigmoid())
        self.refine = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


BACKBONE_NAME = "pvt_v2_b4"  # requested


class PVTv2B4_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model(BACKBONE_NAME, pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)
        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim), nn.Dropout(head_dropout), nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(), nn.Dropout(head_dropout * 0.5), nn.Linear(max(64, out_dim // 2), num_classes)
        )
        self.theta_mlp = nn.Sequential(nn.Linear(7, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim))
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)
        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)
        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)
        f0, f1 = self.fuser(feats0), self.fuser(feats1)
        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1
        t0, t1, t_mid = self.tuner(f0), self.tuner(f1), self.tuner(f_mid)
        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views
        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (bb_params if n.startswith("backbone.") else head_params).append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma, preproc_module.alpha, preproc_module.beta, preproc_module.tau,
            float(preproc_module.blur_k) / 7.0, preproc_module.sharpen, preproc_module.denoise
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    model.eval()
    preproc_module.eval()
    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)
        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)
        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if not all_y:
        return {k: np.nan for k in ["loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro", "log_loss"]}

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)
    out = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
    }
    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        pass
    return out


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    losses = []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()
        losses.append(float(loss.item()))

    return float(np.mean(losses))


def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)
    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())
    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    y = y.long()
    classes = torch.unique(y)
    sep = 0.0
    if len(classes) >= 2:
        centroids, vars_, sizes = [], [], []
        for c in classes:
            e = emb[y == c]
            if e.size(0) < 2:
                continue
            mu = e.mean(dim=0)
            centroids.append(mu)
            vars_.append((e - mu).pow(2).sum(dim=1).mean().item())
            sizes.append(e.size(0))
        if len(centroids) >= 2:
            centroids = torch.stack(centroids, dim=0)
            gm = centroids.mean(dim=0)
            between = sum(n * (c - gm).pow(2).sum().item() for c, n in zip(centroids, sizes))
            within = float(np.mean(vars_)) if vars_ else 1e-6
            sep = float(between / (within + 1e-6))
    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None, [], 0.0
    pop = elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)]
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())
    bx, by = bx[: CFG["batch_size"]].contiguous(), by[: CFG["batch_size"]].contiguous()
    for _ in range(CFG["ga_gens"]):
        scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop
    scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
    return scored[0][1], [th for _, th in scored[: CFG["ga_elites"]]], float(scored[0][0])


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


def weighted_aggregate(mets):
    total = sum(w for _, _, w in mets)
    if total == 0:
        return {}
    keys = mets[0][1].keys()
    return {k: float(np.average([m[1].get(k, np.nan) for m in mets], weights=[m[2] for m in mets])) for k in keys}


print("Initializing global model...")
global_model = PVTv2B4_MultiScale(NUM_CLASSES, pretrained=True, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
set_trainable_for_round(global_model, 1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values + \
         train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
criterion = nn.CrossEntropyLoss(weight=torch.tensor(w, device=DEVICE), label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc, best_round_saved = -1.0, None
best_model_state = None

print(f"Training TRUE FL for {CFG['rounds']} rounds with backbone={BACKBONE_NAME} ...")
for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []
    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
        else:
            best_theta = None
            pre_k = nn.Identity().to(DEVICE)

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = PVTv2B4_MultiScale(NUM_CLASSES, pretrained=False, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)
        set_trainable_for_round(local_model, rnd)
        opt = make_optimizer(local_model)
        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        sched = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=sched, scaler=scaler)

        met = evaluate_full(local_model, val_loader, pre_k)
        local_rows.append({"val_acc": met.get("acc", np.nan), "val_size": len(val_loader.dataset)})
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))

    wsum = sum(local_weights)
    fedavg_update(global_model, local_models, [w / wsum for w in local_weights], [n for n, p in local_models[0].named_parameters() if p.requires_grad])

    lv = pd.DataFrame(local_rows)
    round_acc = float(np.average(lv["val_acc"], weights=lv["val_size"])) if lv["val_size"].sum() > 0 else np.nan
    if np.isfinite(round_acc) and round_acc > best_global_acc:
        best_global_acc = round_acc
        best_round_saved = rnd
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

print(f"Done training. Best federated VAL accuracy={best_global_acc:.4f} at round {best_round_saved}")

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# Re-run GA with the restored best global model so final VAL/TEST uses GA-FELCM (not Identity preproc)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

best_theta_by_client = {}
for k in range(CFG["clients_total"]):
    _, tune_loader, _ = client_loaders[k]
    ds_name = "ds1" if k < n_per_ds else "ds2"
    elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2
    th, _, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
    best_theta_by_client[k] = th


def preproc_for_client(client_id: int):
    th = best_theta_by_client.get(client_id)
    return theta_to_module(th).to(DEVICE) if th is not None else nn.Identity().to(DEVICE)


val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = preproc_for_client(k)
    met = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((k, met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = preproc_for_client(gid)
        met = evaluate_full(global_model, t_loader, pre)
        mets.append((gid, met, len(t_loader.dataset)))
    return weighted_aggregate(mets) if mets else {}


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([
    (0, test_ds1, len(test1)),
    (1, test_ds2, len(test2)),
])


def compact_metrics(m):
    keep = ["acc", "precision_macro", "recall_macro", "f1_macro", "log_loss", "loss_ce", "auc_roc_macro_ovr"]
    return {k: float(m[k]) for k in keep if k in m}


val_test_df = pd.DataFrame([
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **compact_metrics(val_best)},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **compact_metrics(test_ds1)},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **compact_metrics(test_ds2)},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **compact_metrics(global_test)},
])

print("\nFINAL OUTPUT (VAL/TEST metrics only):")
print(val_test_df)

checkpoint = {
    "state_dict": {k: v.detach().cpu() for k, v in global_model.state_dict().items()},
    "config": CFG,
    "seed": SEED,
    "device_used": str(DEVICE),
    "dataset1_raw_root": DS1_ROOT,
    "dataset2_root": DS2_ROOT,
    "labels": labels,
    "label2id": label2id,
    "id2label": id2label,
    "num_classes": NUM_CLASSES,
    "backbone_name": BACKBONE_NAME,
    "best_round_saved": best_round_saved,
    "best_val_acc": best_global_acc,
    "final_val_federated": val_best,
    "final_test_ds1": test_ds1,
    "final_test_ds2": test_ds2,
    "final_test_global_weighted": global_test,
}

torch.save(checkpoint, MODEL_PATH)
val_test_df.to_csv(CSV_PATH, index=False)
print(f"✅ Saved checkpoint: {MODEL_PATH}")
print(f"✅ Saved VAL/TEST-only CSV: {CSV_PATH}")

Downloading datasets via kagglehub...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.
Initializing global model...


model.safetensors:   0%|          | 0.00/250M [00:00<?, ?B/s]

Training TRUE FL for 12 rounds with backbone=pvt_v2_b4 ...
Done training. Best federated VAL accuracy=0.9858 at round 12

FINAL OUTPUT (VAL/TEST metrics only):
                           setting split           dataset       acc  \
0  Enhanced FELCM (Best θ ds1/ds2)   VAL  ds1+ds2 weighted  0.939968   
1      Enhanced FELCM (Best θ ds1)  TEST               ds1  0.933628   
2      Enhanced FELCM (Best θ ds2)  TEST               ds2  0.945972   
3          Enhanced FELCM (Best θ)  TEST   global weighted  0.943794   

   precision_macro  recall_macro  f1_macro  log_loss   loss_ce  \
0         0.769559      0.872017  0.775791  0.202708  0.202017   
1         0.938777      0.934354  0.934505  0.262773  0.263721   
2         0.943232      0.941978  0.942226  0.191857  0.192667   
3         0.942446      0.940633  0.940864  0.204368  0.205203   

   auc_roc_macro_ovr  
0                NaN  
1           0.988137  
2           0.993236  
3           0.992336  
✅ Saved checkpoint: /content/outp

# **PVTv2-B5**

In [1]:
import os, time, math, random, sys, subprocess, hashlib
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, roc_auc_score
)

# ============================================================
# TRUE FL + GA-FELCM + PVTv2-B5 (FUSION) — 6 Clients (3+3)
# Preprocessing ON + Augmentation ON + Fusion ON
# CHANGES REQUESTED:
#  - backbone -> PVTv2-B5
#  - rounds -> 12
#  - output -> only VAL/TEST metrics
# Saves:
#  - checkpoint
#  - one CSV with only VAL/TEST metrics
# ============================================================


def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,  # requested
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTDIR, "FL_GAFELCM_PVTv2B5_FUSION_checkpoint.pth")
CSV_PATH = os.path.join(OUTDIR, "VAL_TEST_METRICS_ONLY.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "notumor" in s or "no tumor" in s or "no_tumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(os.path.basename)
    return dfm


def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


print("Downloading datasets via kagglehub...")
ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}

DS1_ROOT = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
DS2_ROOT = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)

if DS1_ROOT is None or DS2_ROOT is None:
    raise RuntimeError("Could not locate required dataset roots")

df1 = enforce_labels(build_df_from_root(DS1_ROOT, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw"))
df2 = enforce_labels(build_df_from_root(DS2_ROOT, ["glioma", "meningioma", "notumor", "pituitary"], "ds2"))


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


train1, val1, test1 = split_dataset(df1)
train2, val2, test2 = split_dataset(df2)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if not idxs:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        counts[np.argmax(props)] += len(idxs) - counts.sum()
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start : start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(idxs, test_size=tune_frac, stratify=yk, random_state=SEED)

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED)

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


n_per_ds = CFG["clients_per_dataset"]
client_indices_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
client_indices_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

client_splits = []
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds1", k, k, tr, tune, va))
for k in range(n_per_ds):
    tr, tune, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
    client_splits.append(("ds2", k, n_per_ds + k, tr, tune, va))

client_test_splits = []
for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([transforms.Resize((CFG["img_size"], CFG["img_size"])), transforms.ToTensor()])
if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        return x, int(row["y"]), row["path"], self.source_id, self.client_id


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(ds, batch_size=bs, shuffle=(shuffle and sampler is None), sampler=sampler,
                      num_workers=CFG["num_workers"], pin_memory=(DEVICE.type == "cuda"), drop_last=False,
                      persistent_workers=(CFG["num_workers"] > 0))


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    return WeightedRandomSampler(torch.DoubleTensor(class_weights[ys]), num_samples=len(ys), replacement=True)


client_loaders = []
for ds_name, _, gid, tr_idx, tune_idx, val_idx in client_splits:
    src = train1 if ds_name == "ds1" else train2
    source_id = 0 if ds_name == "ds1" else 1
    sampler = make_weighted_sampler(src, tr_idx, NUM_CLASSES)
    tr_loader = make_loader(src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
    tune_loader = make_loader(src, tune_idx if tune_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)
    val_loader = make_loader(src, val_idx if val_idx else tr_idx[:1], CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)
    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for ds_name, _, gid, test_idx in client_test_splits:
    src = test1 if ds_name == "ds1" else test2
    source_id = 0 if ds_name == "ds1" else 1
    client_test_loaders.append((ds_name, gid, make_loader(src, test_idx, CFG["batch_size"], EVAL_TFMS, source_id=source_id, client_id=gid)))


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)
        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))
        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise
        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)
        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()
        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        x2 = x1 + self.alpha * torch.tanh(self.beta * (lap / (blur + eps)))
        if self.sharpen > 0:
            outs = []
            for c in range(x2.shape[1]):
                xc = x2[:, c:c+1]
                xs = F.conv2d(F.pad(xc, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(xc * (1 - self.sharpen) + xs * self.sharpen)
            x2 = torch.cat(outs, dim=1)
        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        return ((x2 - mn) / (mx - mn + eps)).clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(nn.Conv2d(c, out_dim, 1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU()) for c in in_channels
        ])
        self.fuse = nn.Sequential(nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU())
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(nn.Linear(dim, max(8, dim // 4)), nn.ReLU(inplace=True), nn.Linear(max(8, dim // 4), dim), nn.Sigmoid())
        self.refine = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


BACKBONE_NAME = "pvt_v2_b5"  # requested


class PVTv2B5_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model(BACKBONE_NAME, pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)
        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim), nn.Dropout(head_dropout), nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(), nn.Dropout(head_dropout * 0.5), nn.Linear(max(64, out_dim // 2), num_classes)
        )
        self.theta_mlp = nn.Sequential(nn.Linear(7, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim))
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)
        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)
        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)
        f0, f1 = self.fuser(feats0), self.fuser(feats1)
        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1
        t0, t1, t_mid = self.tuner(f0), self.tuner(f1), self.tuner(f_mid)
        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views
        return self.classifier(t_final)


def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (bb_params if n.startswith("backbone.") else head_params).append(p)
    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma, preproc_module.alpha, preproc_module.beta, preproc_module.tau,
            float(preproc_module.blur_k) / 7.0, preproc_module.sharpen, preproc_module.denoise
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    model.eval()
    preproc_module.eval()
    all_y, all_p, all_loss = [], [], []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)
        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)
        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if not all_y:
        return {k: np.nan for k in ["loss_ce", "acc", "precision_macro", "recall_macro", "f1_macro", "log_loss"]}

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)
    out = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
    }
    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        pass
    return out


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    losses = []
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()
        losses.append(float(loss.item()))

    return float(np.mean(losses))


def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)
    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())
    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    y = y.long()
    classes = torch.unique(y)
    sep = 0.0
    if len(classes) >= 2:
        centroids, vars_, sizes = [], [], []
        for c in classes:
            e = emb[y == c]
            if e.size(0) < 2:
                continue
            mu = e.mean(dim=0)
            centroids.append(mu)
            vars_.append((e - mu).pow(2).sum(dim=1).mean().item())
            sizes.append(e.size(0))
        if len(centroids) >= 2:
            centroids = torch.stack(centroids, dim=0)
            gm = centroids.mean(dim=0)
            between = sum(n * (c - gm).pow(2).sum().item() for c, n in zip(centroids, sizes))
            within = float(np.mean(vars_)) if vars_ else 1e-6
            sep = float(between / (within + 1e-6))
    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None, [], 0.0
    pop = elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)]
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())
    bx, by = bx[: CFG["batch_size"]].contiguous(), by[: CFG["batch_size"]].contiguous()
    for _ in range(CFG["ga_gens"]):
        scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop
    scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
    return scored[0][1], [th for _, th in scored[: CFG["ga_elites"]]], float(scored[0][0])


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


def weighted_aggregate(mets):
    total = sum(w for _, _, w in mets)
    if total == 0:
        return {}
    keys = mets[0][1].keys()
    return {k: float(np.average([m[1].get(k, np.nan) for m in mets], weights=[m[2] for m in mets])) for k in keys}


print("Initializing global model...")
global_model = PVTv2B5_MultiScale(NUM_CLASSES, pretrained=True, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
set_trainable_for_round(global_model, 1)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values + \
         train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
criterion = nn.CrossEntropyLoss(weight=torch.tensor(w, device=DEVICE), label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

elite_pool_ds1, elite_pool_ds2 = [], []
best_global_acc, best_round_saved = -1.0, None
best_model_state = None

print(f"Training TRUE FL for {CFG['rounds']} rounds with backbone={BACKBONE_NAME} ...")
for rnd in range(1, CFG["rounds"] + 1):
    local_models, local_weights, local_rows = [], [], []
    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = "ds1" if k < n_per_ds else "ds2"
        elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2

        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
        else:
            best_theta = None
            pre_k = nn.Identity().to(DEVICE)

        if ds_name == "ds1":
            elite_pool_ds1 = elite_pool
        else:
            elite_pool_ds2 = elite_pool

        local_model = PVTv2B5_MultiScale(NUM_CLASSES, pretrained=False, head_dropout=CFG["head_dropout"], cond_dim=CFG["cond_dim"], num_clients=CFG["clients_total"]).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)
        set_trainable_for_round(local_model, rnd)
        opt = make_optimizer(local_model)
        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        sched = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(local_model, tr_loader, opt, pre_k, criterion, global_model=global_model, scheduler=sched, scaler=scaler)

        met = evaluate_full(local_model, val_loader, pre_k)
        local_rows.append({"val_acc": met.get("acc", np.nan), "val_size": len(val_loader.dataset)})
        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))

    wsum = sum(local_weights)
    fedavg_update(global_model, local_models, [w / wsum for w in local_weights], [n for n, p in local_models[0].named_parameters() if p.requires_grad])

    lv = pd.DataFrame(local_rows)
    round_acc = float(np.average(lv["val_acc"], weights=lv["val_size"])) if lv["val_size"].sum() > 0 else np.nan
    if np.isfinite(round_acc) and round_acc > best_global_acc:
        best_global_acc = round_acc
        best_round_saved = rnd
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}

print(f"Done training. Best federated VAL accuracy={best_global_acc:.4f} at round {best_round_saved}")

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

# Re-run GA with the restored best global model so final VAL/TEST uses GA-FELCM (not Identity preproc)
backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

best_theta_by_client = {}
for k in range(CFG["clients_total"]):
    _, tune_loader, _ = client_loaders[k]
    ds_name = "ds1" if k < n_per_ds else "ds2"
    elite_pool = elite_pool_ds1 if ds_name == "ds1" else elite_pool_ds2
    th, _, _ = run_ga_for_client(backbone_frozen, tune_loader, elite_pool)
    best_theta_by_client[k] = th


def preproc_for_client(client_id: int):
    th = best_theta_by_client.get(client_id)
    return theta_to_module(th).to(DEVICE) if th is not None else nn.Identity().to(DEVICE)


val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    pre = preproc_for_client(k)
    met = evaluate_full(global_model, val_loader, pre)
    val_metrics_clients.append((k, met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)


def eval_test_per_dataset(ds_name):
    mets = []
    for ds, gid, t_loader in client_test_loaders:
        if ds != ds_name:
            continue
        pre = preproc_for_client(gid)
        met = evaluate_full(global_model, t_loader, pre)
        mets.append((gid, met, len(t_loader.dataset)))
    return weighted_aggregate(mets) if mets else {}


test_ds1 = eval_test_per_dataset("ds1")
test_ds2 = eval_test_per_dataset("ds2")
global_test = weighted_aggregate([
    (0, test_ds1, len(test1)),
    (1, test_ds2, len(test2)),
])


def compact_metrics(m):
    keep = ["acc", "precision_macro", "recall_macro", "f1_macro", "log_loss", "loss_ce", "auc_roc_macro_ovr"]
    return {k: float(m[k]) for k in keep if k in m}


val_test_df = pd.DataFrame([
    {"setting": "Enhanced FELCM (Best θ ds1/ds2)", "split": "VAL", "dataset": "ds1+ds2 weighted", **compact_metrics(val_best)},
    {"setting": "Enhanced FELCM (Best θ ds1)", "split": "TEST", "dataset": "ds1", **compact_metrics(test_ds1)},
    {"setting": "Enhanced FELCM (Best θ ds2)", "split": "TEST", "dataset": "ds2", **compact_metrics(test_ds2)},
    {"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **compact_metrics(global_test)},
])

print("\nFINAL OUTPUT (VAL/TEST metrics only):")
print(val_test_df)

checkpoint = {
    "state_dict": {k: v.detach().cpu() for k, v in global_model.state_dict().items()},
    "config": CFG,
    "seed": SEED,
    "device_used": str(DEVICE),
    "dataset1_raw_root": DS1_ROOT,
    "dataset2_root": DS2_ROOT,
    "labels": labels,
    "label2id": label2id,
    "id2label": id2label,
    "num_classes": NUM_CLASSES,
    "backbone_name": BACKBONE_NAME,
    "best_round_saved": best_round_saved,
    "best_val_acc": best_global_acc,
    "final_val_federated": val_best,
    "final_test_ds1": test_ds1,
    "final_test_ds2": test_ds2,
    "final_test_global_weighted": global_test,
}

torch.save(checkpoint, MODEL_PATH)
val_test_df.to_csv(CSV_PATH, index=False)
print(f"✅ Saved checkpoint: {MODEL_PATH}")
print(f"✅ Saved VAL/TEST-only CSV: {CSV_PATH}")

  self.setter(val)


Downloading datasets via kagglehub...
Downloading from https://www.kaggle.com/api/v1/datasets/download/yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection?dataset_version_number=1...


100%|██████████| 130M/130M [00:00<00:00, 187MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/orvile/pmram-bangladeshi-brain-cancer-mri-dataset?dataset_version_number=2...


100%|██████████| 161M/161M [00:01<00:00, 104MB/s]

Extracting files...





Initializing global model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/328M [00:00<?, ?B/s]

Training TRUE FL for 12 rounds with backbone=pvt_v2_b5 ...
Done training. Best federated VAL accuracy=0.9779 at round 10

FINAL OUTPUT (VAL/TEST metrics only):
                           setting split           dataset       acc  \
0  Enhanced FELCM (Best θ ds1/ds2)   VAL  ds1+ds2 weighted  0.908373   
1      Enhanced FELCM (Best θ ds1)  TEST               ds1  0.907080   
2      Enhanced FELCM (Best θ ds2)  TEST               ds2  0.939336   
3          Enhanced FELCM (Best θ)  TEST   global weighted  0.933646   

   precision_macro  recall_macro  f1_macro  log_loss   loss_ce  \
0         0.738625      0.926706  0.742622  0.277366  0.269406   
1         0.903363      0.908528  0.903041  0.288221  0.288536   
2         0.934881      0.933925  0.933910  0.217846  0.218045   
3         0.929320      0.929444  0.928464  0.230262  0.230481   

   auc_roc_macro_ovr  
0                NaN  
1           0.986038  
2           0.992015  
3           0.990961  
✅ Saved checkpoint: /content/outp