**“Local-only (no aggregation)”** means:

### Meaning

Each client (hospital/site/dataset-part) **trains its own model only on its own data**, and **nothing is shared/combined**.

So:

* **No FedAvg** (no averaging of weights)
* **No global model**
* **No federation rounds**
* You end up with **6 separate models** (client_0 … client_5)

### Why it’s used in your paper

It answers:
“Is federated collaboration actually helping, or would each site do fine alone?”

Usually you show:

* **Average performance across clients**
* **Worst client performance**
* Maybe **DS1 average vs DS2 average**

### How to implement with your setup (conceptually)

For each client (k):

1. Initialize model weights
2. Train on that client’s train loader only
3. Evaluate on that client’s test loader
4. Save metrics
   Then report all 6 results (mean ± std)

### How to write it in one line

> **Local-only**: independent training at each client without any server aggregation or parameter sharing.

If you want, I can tell you the exact minimal changes in your loop (basically: remove `fedavg_update()` and stop copying `global_model` into locals each round).


In [1]:
import os
import sys
import time
import math
import random
import subprocess
import hashlib
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    log_loss,
    roc_auc_score,
)


def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

# -------------------------
# Reproducibility + Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

print("=" * 92)
print("LOCAL-ONLY (NO AGGREGATION) + GA-FELCM + PVTv2-B2 (FUSION) | TEST METRICS ONLY")
print("=" * 92)
print(f"DEVICE: {DEVICE} | torch={torch.__version__}")
print("=" * 92)

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "local_epochs": 4,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "head_dropout": 0.3,
    "cond_dim": 128,
}

OUTDIR = "/content/outputs"
os.makedirs(OUTDIR, exist_ok=True)
CSV_PATH = os.path.join(OUTDIR, "LOCAL_ONLY_TEST_METRICS.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

# -------------------------
# Dataset discovery
# -------------------------
REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
NUM_CLASSES = len(labels)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(os.path.basename)
    dfm["y"] = dfm["label"].map(label2id).astype(int)
    return dfm


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])
    client_indices = [[] for _ in range(n_clients)]

    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start:start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(idxs, test_size=tune_frac, stratify=yk, random_state=SEED)

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED)

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


# -------------------------
# Data loading
# -------------------------
def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])
TRAIN_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.ToTensor(),
])


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        x = self.tfms(load_rgb(row["path"]))
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(torch.DoubleTensor(sample_weights), len(sample_weights), replacement=True)


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


# -------------------------
# Preprocessor + model
# -------------------------
class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()
        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        c_map = lap / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * c_map)
        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(nn.Conv2d(c, out_dim, 1, bias=False), nn.GroupNorm(8, out_dim), nn.GELU())
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class PVTv2B2_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model("pvt_v2_b2", pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)
        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )
        self.theta_mlp = nn.Sequential(nn.Linear(7, cond_dim), nn.GELU(), nn.Linear(cond_dim, cond_dim))
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)
        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        return self.cond_norm(self.theta_mlp(theta_vec) + self.source_emb(source_id) + self.client_emb(client_id))

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)
        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n
        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)
        f0, f1 = self.fuser(feats0), self.fuser(feats1)
        g1 = torch.sigmoid(self.gate_mid(cond))
        f = (1 - g1) * f0 + g1 * f1
        return self.classifier(f)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor([
            preproc_module.gamma,
            preproc_module.alpha,
            preproc_module.beta,
            preproc_module.tau,
            float(preproc_module.blur_k) / 7.0,
            preproc_module.sharpen,
            preproc_module.denoise,
        ], device=DEVICE, dtype=torch.float32)
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


@torch.no_grad()
def enhanced_separability_score(emb, y):
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0
    centroids, within_vars, sizes = [], [], []
    for c in classes:
        e = emb[y == c]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))
    if len(centroids) < 2:
        return 0.0
    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else 1e-6
    return float(between / (within + 1e-6))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)
    x_p = pre(x)

    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
    emb = backbone_frozen(x_n)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    sep = enhanced_separability_score(emb, y)

    g, a, b, t, _, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga_for_client(backbone_frozen, dl_for_eval):
    try:
        bx, by, *_ = next(iter(dl_for_eval))
    except Exception:
        return None

    pop = [random_theta() for _ in range(CFG["ga_pop"])]
    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    return scored[0][1]


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


@torch.no_grad()
def evaluate_full(model, loader, preproc_module):
    t0 = time.time()
    model.eval()
    preproc_module.eval()
    all_y, all_p, all_loss = [], [], []

    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)

        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.cpu().numpy())
        all_p.append(probs.cpu().numpy())

    y_true = np.concatenate(all_y) if all_y else np.array([])
    p_pred = np.concatenate(all_p) if all_p else np.array([])
    if len(y_true) == 0:
        return {k: np.nan for k in [
            "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted",
            "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
        ]}

    y_hat = np.argmax(p_pred, axis=1)
    out = {
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "loss_ce": float(np.mean(all_loss)),
        "eval_time_s": float(time.time() - t0),
    }
    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        out["auc_roc_macro_ovr"] = np.nan
    return out


if __name__ == "__main__":
    print("Downloading datasets...")
    ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
    ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

    ds1_root = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
    ds2_root = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)
    if ds1_root is None or ds2_root is None:
        raise RuntimeError("Could not discover dataset roots")

    df1 = build_df_from_root(ds1_root, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw")
    df2 = build_df_from_root(ds2_root, ["glioma", "meningioma", "notumor", "pituitary"], "ds2")

    train1, _, test1 = split_dataset(df1)
    train2, _, test2 = split_dataset(df2)

    n_per_ds = CFG["clients_per_dataset"]
    client_idx_ds1 = make_clients_non_iid(train1, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])
    client_idx_ds2 = make_clients_non_iid(train2, n_per_ds, NUM_CLASSES, CFG["min_per_class_per_client"], CFG["dirichlet_alpha"])

    client_splits = []
    for k in range(n_per_ds):
        tr, tune, _ = robust_client_splits(train1, client_idx_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        client_splits.append(("ds1", k, k, tr, tune))
    for k in range(n_per_ds):
        tr, tune, _ = robust_client_splits(train2, client_idx_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        client_splits.append(("ds2", k, n_per_ds + k, tr, tune))

    client_test_splits = []
    for ds_name, test_df, gid_base in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
        idxs = list(range(len(test_df)))
        random.shuffle(idxs)
        split = np.array_split(idxs, n_per_ds)
        for k in range(n_per_ds):
            client_test_splits.append((ds_name, k, gid_base + k, split[k].tolist()))

    counts1 = train1["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
    counts2 = train2["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
    counts = counts1 + counts2
    w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
    w = w / max(1e-6, w.mean())
    class_w = torch.tensor(w, device=DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
    scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

    rows = []
    ds_accumulator: Dict[str, List[Tuple[dict, int]]] = {"ds1": [], "ds2": []}

    for ds_name, _, gid, tr_idx, tune_idx in client_splits:
        print(f"\nTraining local model for client_{gid} ({ds_name})...")
        df_train = train1 if ds_name == "ds1" else train2
        df_test = test1 if ds_name == "ds1" else test2
        source_id = 0 if ds_name == "ds1" else 1

        sampler = make_weighted_sampler(df_train, tr_idx, NUM_CLASSES)
        tr_loader = make_loader(df_train, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler, source_id=source_id, client_id=gid)
        tune_loader = make_loader(df_train, tune_idx if len(tune_idx) else tr_idx[: max(1, len(tr_idx))], CFG["batch_size"], EVAL_TFMS, shuffle=True, source_id=source_id, client_id=gid)

        test_idx = next(x[3] for x in client_test_splits if x[2] == gid)
        test_loader = make_loader(df_test, test_idx, CFG["batch_size"], EVAL_TFMS, shuffle=False, source_id=source_id, client_id=gid)

        model = PVTv2B2_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=True,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)

        for p in model.backbone.parameters():
            p.requires_grad = False

        backbone_frozen = model.backbone.eval()
        for p in backbone_frozen.parameters():
            p.requires_grad = False

        theta = run_ga_for_client(backbone_frozen, tune_loader) if (CFG["use_preprocessing"] and CFG["use_ga"]) else None
        pre = theta_to_module(theta).to(DEVICE) if theta is not None else nn.Identity().to(DEVICE)

        params = [p for p in model.parameters() if p.requires_grad]
        opt = torch.optim.AdamW(params, lr=CFG["lr"], weight_decay=CFG["weight_decay"])
        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        for _ in range(CFG["local_epochs"]):
            train_one_epoch(model, tr_loader, opt, pre, criterion, scheduler=scheduler, scaler=scaler)

        met = evaluate_full(model, test_loader, pre)
        row = {
            "setting": "Local-only (no aggregation)",
            "split": "TEST",
            "dataset": f"{ds_name}_client_{gid}",
            **{k: met.get(k, np.nan) for k in [
                "acc", "precision_macro", "recall_macro", "f1_macro",
                "precision_weighted", "recall_weighted", "f1_weighted",
                "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
            ]},
        }
        rows.append(row)
        ds_accumulator[ds_name].append((met, len(test_loader.dataset)))

    def weighted_avg(parts):
        if not parts:
            return {}
        total = sum(w for _, w in parts)
        keys = [
            "acc", "precision_macro", "recall_macro", "f1_macro",
            "precision_weighted", "recall_weighted", "f1_weighted",
            "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
        ]
        out = {}
        for k in keys:
            vals, ws = [], []
            for m, w in parts:
                vals.append(m.get(k, np.nan))
                ws.append(w)
            out[k] = float(np.average(vals, weights=ws))
        return out

    ds1_agg = weighted_avg(ds_accumulator["ds1"])
    ds2_agg = weighted_avg(ds_accumulator["ds2"])
    global_agg = weighted_avg(ds_accumulator["ds1"] + ds_accumulator["ds2"])

    for dname, m in [("ds1_mean", ds1_agg), ("ds2_mean", ds2_agg), ("global_weighted", global_agg)]:
        rows.append({
            "setting": "Local-only (no aggregation)",
            "split": "TEST",
            "dataset": dname,
            **{k: m.get(k, np.nan) for k in [
                "acc", "precision_macro", "recall_macro", "f1_macro",
                "precision_weighted", "recall_weighted", "f1_weighted",
                "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
            ]},
        })

    out_df = pd.DataFrame(rows, columns=[
        "setting", "split", "dataset", "acc", "precision_macro", "recall_macro", "f1_macro",
        "precision_weighted", "recall_weighted", "f1_weighted", "log_loss",
        "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
    ])

    print("\nFinal TEST metrics (requested format):")
    print(out_df.to_string(index=False))

    out_df.to_csv(CSV_PATH, index=False)
    print(f"\nSaved: {CSV_PATH}")


  _C._set_float32_matmul_precision(precision)


LOCAL-ONLY (NO AGGREGATION) + GA-FELCM + PVTv2-B2 (FUSION) | TEST METRICS ONLY
DEVICE: cpu | torch=2.9.0+cpu
Downloading datasets...
Downloading from https://www.kaggle.com/api/v1/datasets/download/yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection?dataset_version_number=1...


100%|██████████| 130M/130M [00:00<00:00, 173MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/orvile/pmram-bangladeshi-brain-cancer-mri-dataset?dataset_version_number=2...


100%|██████████| 161M/161M [00:01<00:00, 133MB/s]

Extracting files...






Training local model for client_0 (ds1)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/101M [00:00<?, ?B/s]


Training local model for client_1 (ds1)...

Training local model for client_2 (ds1)...

Training local model for client_3 (ds2)...

Training local model for client_4 (ds2)...





Training local model for client_5 (ds2)...

Final TEST metrics (requested format):
                    setting split         dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Local-only (no aggregation)  TEST    ds1_client_0 0.736842         0.793881      0.740198  0.678583            0.800831         0.736842     0.681907  0.849699           0.933505 0.864920    21.507455
Local-only (no aggregation)  TEST    ds1_client_1 0.706667         0.816667      0.735726  0.706815            0.817778         0.706667     0.683605  0.922017           0.941527 0.915738    21.081559
Local-only (no aggregation)  TEST    ds1_client_2 0.626667         0.689286      0.547269  0.466917            0.668762         0.626667     0.520929  1.090790           0.800643 1.145110    21.619763
Local-only (no aggregation)  TEST    ds2_client_3 0.548295         0.682418      0.538990  0.461983            0