**“Centralized GCF-Net”** means: you train **the same FedGCF-Net architecture/process**, but **NOT federated**.

### What you do in centralized GCF-Net

* **Merge the training data** (DS1 + DS2 train splits) into **one single dataset**
* Train **one model on one machine** (standard training)
* Validate/test on DS1 and DS2 test splits (same as you do now)

### What you do NOT do

* No clients
* No FedAvg aggregation
* No FedProx term (unless you keep it, but usually you turn it off)
* No per-client conditioning (you can set `client_id=0` for all, or remove client embedding)

### Why it’s useful

* It’s the **upper bound** when data sharing is allowed.
* It shows: “If we could centralize data


In [1]:
import os
import time
import math
import random
import sys
import subprocess
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    log_loss,
    roc_auc_score,
)


def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])


try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

CFG = {
    # centralized runtime defaults (faster than FL while keeping architecture/process)
    "epochs": 6,
    "early_stop_patience": 2,
    "batch_size": 32 if torch.cuda.is_available() else 8,
    "img_size": 224 if torch.cuda.is_available() else 128,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,

    # mimic prior split style
    "global_val_frac": 0.15,
    "test_frac": 0.15,

    # preprocessing/GA
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 6,
    "ga_gens": 3,
    "ga_elites": 2,
    "ga_sample_size": 256,

    # model/training speed
    "backbone_name": "pvt_v2_b2",
    "head_dropout": 0.3,
    "cond_dim": 128,
    "freeze_backbone": True,
    "unfreeze_after_epoch": 3,
    "unfreeze_tail_frac": 0.15,
    "unfreeze_lr_mult": 0.10,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}
LABELS = ["glioma", "meningioma", "notumor", "pituitary"]
LABEL2ID = {l: i for i, l in enumerate(LABELS)}
NUM_CLASSES = len(LABELS)


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no_tumor" in s or "no tumor" in s or "notumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(path):
        p = path.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in p:
                sc += 7
            if os.path.basename(path).lower() == "raw":
                sc += 7
            if "/raw/" in p or "\\raw\\" in p:
                sc += 3
            if "augmented" in p:
                sc -= 20
        sc -= 0.0001 * len(path)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        for path in list_images_under_class_root(ds_root, c):
            rows.append({"path": path, "label": lab, "source": source_name})
    df = pd.DataFrame(rows).dropna().drop_duplicates(subset=["path"]).reset_index(drop=True)
    df["label"] = df["label"].astype(str).str.strip().str.lower()
    df = df[df["label"].isin(set(LABELS))].reset_index(drop=True)
    df["y"] = df["label"].map(LABEL2ID).astype(int)
    return df


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

TRAIN_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.ToTensor(),
])


class MRIDataset(Dataset):
    def __init__(self, frame, tfms=None):
        self.df = frame.reset_index(drop=True)
        self.tfms = tfms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        source_id = 0 if row["source"] == "ds1_raw" else 1
        return x, y, source_id


def make_weighted_sampler(frame, num_classes):
    ys = frame["y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(torch.DoubleTensor(sample_weights), len(sample_weights), replacement=True)


def make_loader(frame, bs, tfms, shuffle=False, sampler=None):
    ds = MRIDataset(frame, tfms=tfms)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))
        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        b, c, _, _ = x.shape

        if self.denoise > 0:
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), 3, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = ((x - mu) / sd).clamp(-self.tau, self.tau)
        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap).abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(lap, (pad, pad, pad, pad), mode="reflect"), k, 1)
        c_map = lap / (blur + eps)
        x2 = x1 + self.alpha * torch.tanh(self.beta * c_map)

        if self.sharpen > 0:
            out = []
            for i in range(c):
                x_c = x2[:, i : i + 1]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                out.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(out, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        return ((x2 - mn) / (mx - mn + eps)).clamp(0, 1)


def theta_to_module(theta):
    return EnhancedFELCM(*theta)


def random_theta():
    return (
        random.uniform(0.7, 1.4),
        random.uniform(0.15, 0.55),
        random.uniform(3.0, 9.0),
        random.uniform(1.8, 3.2),
        random.choice([3, 5, 7]),
        random.uniform(0.0, 0.25),
        random.uniform(0.0, 0.2),
    )


def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)


def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))


class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        return self.pool(x.flatten(2).transpose(1, 2))


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class PVTv2B2_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128):
        super().__init__()
        self.backbone = timm.create_model(
            CFG["backbone_name"],
            pretrained=pretrained,
            features_only=True,
            out_indices=(0, 1, 2, 3),
        )
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        # centralized: source + theta conditioning only (no client conditioning)
        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id):
        return self.cond_norm(self.theta_mlp(theta_vec) + self.source_emb(source_id))

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id):
        cond = self._cond_vec(theta_vec, source_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views
        return self.classifier(t_final)


def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor(
            [
                preproc_module.gamma,
                preproc_module.alpha,
                preproc_module.beta,
                preproc_module.tau,
                float(preproc_module.blur_k) / 7.0,
                preproc_module.sharpen,
                preproc_module.denoise,
            ],
            device=DEVICE,
            dtype=torch.float32,
        )
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


def set_trainable_for_epoch(model, epoch):
    if not CFG["freeze_backbone"]:
        for p in model.backbone.parameters():
            p.requires_grad = True
        return

    for p in model.backbone.parameters():
        p.requires_grad = False

    if epoch >= CFG["unfreeze_after_epoch"]:
        params = list(model.backbone.parameters())
        tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
        for p in params[-tail_n:]:
            p.requires_grad = True


def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)

    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])


def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids, within_vars, sizes = [], [], []
    for c in classes:
        e = emb[y == c]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        centroids.append(mu)
        within_vars.append((e - mu).pow(2).sum(dim=1).mean().item())
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    gmean = centroids.mean(dim=0)
    between = sum(n * (c - gmean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))


@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    emb = backbone_frozen((x_p - IMAGENET_MEAN) / IMAGENET_STD)
    if isinstance(emb, (list, tuple)):
        emb = emb[-1].mean(dim=(2, 3))
    sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = 0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) + 0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn
    return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost


def run_ga(backbone_frozen, dl_for_eval):
    try:
        bx, by, _ = next(iter(dl_for_eval))
    except Exception:
        return None

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    pop = [random_theta() for _ in range(CFG["ga_pop"])]
    for _ in range(CFG["ga_gens"]):
        scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]
        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            new_pop.append(mutate(crossover(p1, p2), p=0.75))
        pop = new_pop

    scored = sorted([(ga_fitness(th, backbone_frozen, bx, by), th) for th in pop], key=lambda x: x[0], reverse=True)
    return scored[0][1]


def train_one_epoch(model, loader, optimizer, preproc_module, criterion, scheduler=None, scaler=None):
    model.train()
    preproc_module.eval()
    for x, y, source_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            logits = model(x_raw_n, x_fel_n, preproc_theta_vec(preproc_module, x.size(0)), source_id)
            loss = criterion(logits, y)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
            optimizer.step()

        if scheduler is not None:
            scheduler.step()


@torch.no_grad()
def evaluate(model, loader, preproc_module):
    model.eval()
    preproc_module.eval()
    t0 = time.time()

    all_y, all_p, all_loss = [], [], []
    for x, y, source_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        logits = model(x_raw_n, x_fel_n, preproc_theta_vec(preproc_module, x.size(0)), source_id)
        probs = torch.softmax(logits, dim=1)

        all_loss.append(float(F.cross_entropy(logits, y).item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    return {
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "auc_roc_macro_ovr": float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro")),
        "loss_ce": float(np.mean(all_loss)),
        "eval_time_s": float(time.time() - t0),
    }


def weighted_merge(metrics_a, n_a, metrics_b, n_b):
    out = {}
    total = n_a + n_b
    for k in metrics_a.keys():
        out[k] = float((metrics_a[k] * n_a + metrics_b[k] * n_b) / total)
    return out


def main():
    ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
    ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

    ds1_root = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
    ds2_root = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)
    if ds1_root is None or ds2_root is None:
        raise RuntimeError("Failed to locate dataset roots.")

    df1 = build_df_from_root(ds1_root, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw")
    df2 = build_df_from_root(ds2_root, ["glioma", "meningioma", "notumor", "pituitary"], "ds2")

    train1, val1, test1 = split_dataset(df1)
    train2, val2, test2 = split_dataset(df2)

    train_all = pd.concat([train1, train2], axis=0).reset_index(drop=True)
    val_all = pd.concat([val1, val2], axis=0).reset_index(drop=True)

    sampler = make_weighted_sampler(train_all, NUM_CLASSES)
    train_loader = make_loader(train_all, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler)
    val_loader = make_loader(val_all, CFG["batch_size"], EVAL_TFMS, shuffle=False)
    test1_loader = make_loader(test1, CFG["batch_size"], EVAL_TFMS, shuffle=False)
    test2_loader = make_loader(test2, CFG["batch_size"], EVAL_TFMS, shuffle=False)

    ga_frame = train_all.sample(min(CFG["ga_sample_size"], len(train_all)), random_state=SEED).reset_index(drop=True)
    ga_loader = make_loader(ga_frame, CFG["batch_size"], EVAL_TFMS, shuffle=False)

    model = PVTv2B2_MultiScale(
        num_classes=NUM_CLASSES,
        pretrained=True,
        head_dropout=CFG["head_dropout"],
        cond_dim=CFG["cond_dim"],
    ).to(DEVICE)

    # class-balanced loss
    counts = train_all["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
    w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
    w = w / max(1e-6, w.mean())
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(w, device=DEVICE), label_smoothing=CFG["label_smoothing"])

    preproc = nn.Identity().to(DEVICE)
    if CFG["use_preprocessing"]:
        if CFG["use_ga"]:
            backbone_frozen = model.backbone.eval()
            for p in backbone_frozen.parameters():
                p.requires_grad = False
            best_theta = run_ga(backbone_frozen, ga_loader)
            preproc = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else nn.Identity().to(DEVICE)
        else:
            preproc = EnhancedFELCM().to(DEVICE)

    scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

    best_state = None
    best_val_f1 = -1.0
    stale_epochs = 0

    for epoch in range(1, CFG["epochs"] + 1):
        set_trainable_for_epoch(model, epoch)
        optimizer = make_optimizer(model)
        total_steps = max(1, len(train_loader))
        warmup_steps = max(1, int(len(train_loader) * CFG["warmup_epochs"]))
        scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

        train_one_epoch(model, train_loader, optimizer, preproc, criterion, scheduler=scheduler, scaler=scaler)

        val_metrics = evaluate(model, val_loader, preproc)
        if val_metrics["f1_macro"] > best_val_f1:
            best_val_f1 = val_metrics["f1_macro"]
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            stale_epochs = 0
        else:
            stale_epochs += 1
            if stale_epochs >= CFG["early_stop_patience"]:
                break

    if best_state is not None:
        model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()})

    test_ds1 = evaluate(model, test1_loader, preproc)
    test_ds2 = evaluate(model, test2_loader, preproc)
    test_global = weighted_merge(test_ds1, len(test1), test_ds2, len(test2))

    cols = [
        "setting",
        "split",
        "dataset",
        "acc",
        "precision_macro",
        "recall_macro",
        "f1_macro",
        "precision_weighted",
        "recall_weighted",
        "f1_weighted",
        "log_loss",
        "auc_roc_macro_ovr",
        "loss_ce",
        "eval_time_s",
    ]
    out = pd.DataFrame([
        {"setting": "Centralized GCF-Net", "split": "TEST", "dataset": "ds1", **test_ds1},
        {"setting": "Centralized GCF-Net", "split": "TEST", "dataset": "ds2", **test_ds2},
        {"setting": "Centralized GCF-Net", "split": "TEST", "dataset": "global_weighted", **test_global},
    ])[cols]

    print(out.to_string(index=False))


if __name__ == "__main__":
    main()


  self.setter(val)


Downloading from https://www.kaggle.com/api/v1/datasets/download/yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection?dataset_version_number=1...


100%|██████████| 130M/130M [00:06<00:00, 19.9MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/orvile/pmram-bangladeshi-brain-cancer-mri-dataset?dataset_version_number=2...


100%|██████████| 161M/161M [00:09<00:00, 17.1MB/s]


Extracting files...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/101M [00:00<?, ?B/s]

            setting split         dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
Centralized GCF-Net  TEST             ds1 0.973451         0.973441      0.973360  0.973207            0.973908         0.973451     0.973488  0.172691           0.993911 0.164041     6.403807
Centralized GCF-Net  TEST             ds2 0.964929         0.963347      0.963957  0.963530            0.965235         0.964929     0.964966  0.175786           0.995404 0.175774    18.914482
Centralized GCF-Net  TEST global_weighted 0.966432         0.965128      0.965616  0.965238            0.966766         0.966432     0.966469  0.175240           0.995140 0.173704    16.707290
