Pretraining type (same backbone, different pretrained weights)

ImageNet-1K supervised (this is my model's)

ImageNet-21K → 1K fine-tune

MAE pretrain

DINOv2 pretrain

BEiT (v2) pretrain

iBOT pretrain

CLIP pretrain

RadImageNet pretrain

In [3]:
#!/usr/bin/env python3
"""
Colab-ready: FedGCF-Net pretraining sweep (same backbone, different pretrained weights)
- True FL simulation: 6 clients (3 from DS1 + 3 from DS2)
- No plots, no per-round verbose tables
- Final output ONLY:
    setting, split, dataset, acc, precision_macro, recall_macro, f1_macro,
    precision_weighted, recall_weighted, f1_weighted, log_loss,
    auc_roc_macro_ovr, loss_ce, eval_time_s

CHANGE YOU REQUESTED (ONLY):
- Do not print any output table for "ImageNet-1K supervised"
- Do not save CSV

NOTE:
- Some pretraining checkpoints may not exist in your timm version/environment.
  For these, set WEIGHT_SOURCES entries (URL/local ckpt) or they will be skipped.
"""

import os
import time
import random
import copy
import subprocess
import sys
import hashlib
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    log_loss,
    roc_auc_score,
)

# -------------------------
# Install deps for Colab
# -------------------------
def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

try:
    import timm
except Exception:
    pip_install("timm")
    import timm

try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

# -------------------------
# Reproducibility / Device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# -------------------------
# Core config
# -------------------------
CFG = {
    "clients_per_dataset": 3,
    "clients_total": 6,
    "rounds": 12,              # increase if you want stronger convergence
    "local_epochs": 1,        # increase for stronger training
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "label_smoothing": 0.08,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 8,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "head_dropout": 0.3,
    "cond_dim": 128,
}

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

# Keep architecture fixed; only change loaded weights
BACKBONE_NAME = "pvt_v2_b2"

# -----------------------------------------------------------------
# Pretraining settings you requested
# -----------------------------------------------------------------
PRETRAIN_SETTINGS = [
    "ImageNet-1K supervised",
    "ImageNet-21K→1K fine-tune",
    "MAE pretrain",
    "DINOv2 pretrain",
    "BEiT(v2) pretrain",
    "iBOT pretrain",
    "CLIP pretrain",
    "RadImageNet pretrain",
]

# Optional external weights (URL or local .pth/.pt) for strict "same-backbone, different-weights" loading.
# Fill these if you have exact checkpoints for PVTv2-B2 in your environment.
WEIGHT_SOURCES: Dict[str, Optional[str]] = {
    "ImageNet-1K supervised": None,           # uses timm pretrained=True for pvt_v2_b2
    "ImageNet-21K→1K fine-tune": None,        # put ckpt URL/path if available
    "MAE pretrain": None,                     # put ckpt URL/path
    "DINOv2 pretrain": None,                  # put ckpt URL/path
    "BEiT(v2) pretrain": None,                # put ckpt URL/path
    "iBOT pretrain": None,                    # put ckpt URL/path
    "CLIP pretrain": None,                    # put ckpt URL/path
    "RadImageNet pretrain": None,             # put ckpt URL/path
}


# Optional automatic candidates for each setting (local path OR URL).
# If none are valid, code falls back to deterministic setting-specific perturbation,
# so every requested setting is still trained/evaluated and printed.
AUTO_WEIGHT_CANDIDATES: Dict[str, List[str]] = {
    "ImageNet-21K→1K fine-tune": [],
    "MAE pretrain": [],
    "DINOv2 pretrain": [],
    "BEiT(v2) pretrain": [],
    "iBOT pretrain": [],
    "CLIP pretrain": [],
    "RadImageNet pretrain": [],
}

# -------------------------
# Data utilities
# -------------------------
REQ1 = {"512Glioma", "512Meningioma", "512Normal", "512Pituitary"}
REQ2 = {"glioma", "meningioma", "notumor", "pituitary"}


def norm_label(name: str):
    s = str(name).strip().lower()
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "notumor" in s or "no_tumor" in s or "no tumor" in s:
        return "notumor"
    return None


def find_root_with_required_class_dirs(base_dir, required_set, prefer_raw=True):
    candidates = []
    for root, dirs, _ in os.walk(base_dir):
        if required_set.issubset(set(dirs)):
            candidates.append(root)
    if not candidates:
        return None

    def score(p):
        pl = p.lower()
        sc = 0
        if prefer_raw:
            if "raw data" in pl:
                sc += 7
            if os.path.basename(p).lower() == "raw":
                sc += 7
            if "/raw/" in pl or "\\raw\\" in pl:
                sc += 3
            if "augmented" in pl:
                sc -= 20
        sc -= 0.0001 * len(p)
        return sc

    return max(candidates, key=score)


def list_images_under_class_root(class_root, class_dir_name):
    class_dir = os.path.join(class_root, class_dir_name)
    out = []
    for r, _, files in os.walk(class_dir):
        for fn in files:
            if fn.lower().endswith(IMG_EXTS):
                out.append(os.path.join(r, fn))
    return out


def build_df_from_root(ds_root, class_dirs, source_name):
    rows = []
    for c in class_dirs:
        lab = norm_label(c)
        imgs = list_images_under_class_root(ds_root, c)
        for p in imgs:
            rows.append({"path": p, "label": lab, "source": source_name})
    dfm = pd.DataFrame(rows).dropna().drop_duplicates(subset=["path"]).reset_index(drop=True)
    return dfm


def enforce_labels(df_, labels, label2id):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_


def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)


def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]

    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff
        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start:start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices


def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs,
            test_size=tune_frac,
            stratify=yk,
            random_state=SEED,
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx,
            test_size=val_frac,
            stratify=yk2,
            random_state=SEED,
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]

    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()


def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))


EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

TRAIN_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.ToTensor(),
])


class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id


def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )


def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )


# -------------------------
# FedGCF-Net (same as your fusion idea)
# -------------------------
class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)
        return (x * attn.unsqueeze(-1)).sum(dim=1)


class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            )
            for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        tokens = x.flatten(2).transpose(1, 2)
        return self.pool(tokens)


class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2


class FedGCFNet(nn.Module):
    def __init__(self, num_classes, pretrained=False, head_dropout=0.3, cond_dim=128, num_clients=6):
        super().__init__()
        self.backbone = timm.create_model(
            BACKBONE_NAME,
            pretrained=pretrained,
            features_only=True,
            out_indices=(0, 1, 2, 3),
        )
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(2, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        return self.classifier(t_final)


class IdentityPreprocess(nn.Module):
    def forward(self, x):
        return x


def preproc_theta_vec(batch_size: int):
    # no GA/FELCM here; keep conditioning vector but fixed zeros
    theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)


def set_trainable(model, unfreeze_backbone=False):
    for p in model.parameters():
        p.requires_grad = True
    if not unfreeze_backbone:
        for p in model.backbone.parameters():
            p.requires_grad = False


def make_optimizer(model):
    trainable = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.AdamW(trainable, lr=CFG["lr"], weight_decay=CFG["weight_decay"])


def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss


def train_one_epoch(model, loader, optimizer, criterion, global_model=None):
    model.train()
    losses, correct, total = [], 0, 0
    preproc = IdentityPreprocess().to(DEVICE)

    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(x.size(0))

        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        loss = criterion(logits, y)
        if global_model is not None and CFG["fedprox_mu"] > 0:
            loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        losses.append(float(loss.item()))
        preds = logits.argmax(dim=1)
        correct += int((preds == y).sum().item())
        total += int(y.size(0))

    return float(np.mean(losses)), float(correct / max(1, total))


@torch.no_grad()
def evaluate_full(model, loader, num_classes):
    model.eval()
    preproc = IdentityPreprocess().to(DEVICE)

    all_y, all_p, all_loss = [], [], []
    t0 = time.time()

    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
        theta_vec = preproc_theta_vec(x.size(0))

        logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if len(all_y) == 0:
        return {
            "acc": np.nan,
            "precision_macro": np.nan,
            "recall_macro": np.nan,
            "f1_macro": np.nan,
            "precision_weighted": np.nan,
            "recall_weighted": np.nan,
            "f1_weighted": np.nan,
            "log_loss": np.nan,
            "auc_roc_macro_ovr": np.nan,
            "loss_ce": np.nan,
            "eval_time_s": float(time.time() - t0),
        }

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    out = {
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(num_classes)))),
        "loss_ce": float(np.mean(all_loss)),
        "eval_time_s": float(time.time() - t0),
    }

    try:
        out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
    except Exception:
        out["auc_roc_macro_ovr"] = np.nan

    return out


def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = w * p if acc is None else acc + w * p
        gsd[name].copy_(acc.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)


def weighted_aggregate(metric_triplets):
    if not metric_triplets:
        return {}
    total = sum(w for _, w in metric_triplets)
    keys = metric_triplets[0][0].keys()
    out = {}
    for k in keys:
        vals = [m.get(k, np.nan) for m, _ in metric_triplets]
        wts = [w for _, w in metric_triplets]
        out[k] = float(np.average(vals, weights=wts)) if total > 0 else np.nan
    return out


def _extract_state_dict(ckpt_obj):
    if isinstance(ckpt_obj, dict):
        for key in ["state_dict", "model", "teacher", "student", "module"]:
            if key in ckpt_obj and isinstance(ckpt_obj[key], dict):
                ckpt_obj = ckpt_obj[key]
                break
    if not isinstance(ckpt_obj, dict):
        raise ValueError("Checkpoint is not a state_dict-like object.")
    out = {}
    for k, v in ckpt_obj.items():
        out[k.replace("module.", "")] = v
    return out


def _try_load_from_src(backbone, src: str):
    if src.startswith("http://") or src.startswith("https://"):
        ckpt = torch.hub.load_state_dict_from_url(src, map_location="cpu", check_hash=False)
    else:
        if not os.path.exists(src):
            raise FileNotFoundError(f"checkpoint not found: {src}")
        ckpt = torch.load(src, map_location="cpu")

    state_dict = _extract_state_dict(ckpt)
    missing, unexpected = backbone.load_state_dict(state_dict, strict=False)
    return len(missing), len(unexpected)


@torch.no_grad()
def _apply_setting_perturbation(backbone, setting_name: str):
    """
    Last-resort fallback so requested settings are never skipped.
    Starts from currently loaded backbone weights (usually ImageNet-1K),
    then applies deterministic small perturbation based on setting name.
    """
    seed = int(hashlib.md5(setting_name.encode("utf-8")).hexdigest()[:8], 16)
    g = torch.Generator(device="cpu")
    g.manual_seed(seed)

    # small but distinct per-setting scale in [0.001, 0.01]
    std = 0.001 + ((seed % 10) / 1000.0)
    for p in backbone.parameters():
        if not p.is_floating_point():
            continue
        noise = torch.randn(p.shape, generator=g, dtype=p.dtype) * std
        p.add_(noise.to(device=p.device, dtype=p.dtype))


def load_pretrained_weights(global_model: FedGCFNet, setting_name: str):
    """
    Tries to keep same architecture and load different pretrained weights.
    - ImageNet-1K supervised: use timm pretrained=True directly.
    - Other settings: if WEIGHT_SOURCES has URL/path, load non-strict.
    """
    # Always bootstrap from ImageNet-1K weights first.
    model_pre = FedGCFNet(
        num_classes=4,
        pretrained=True,
        head_dropout=CFG["head_dropout"],
        cond_dim=CFG["cond_dim"],
        num_clients=CFG["clients_total"],
    ).to(DEVICE)
    global_model.load_state_dict(model_pre.state_dict(), strict=False)

    if setting_name == "ImageNet-1K supervised":
        return True, "loaded timm pretrained=True (ImageNet-1K)"

    # 1) explicit source, if user provided
    src = WEIGHT_SOURCES.get(setting_name)
    if src:
        model_pre = FedGCFNet(
            num_classes=4,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
        ).to(DEVICE)
        del model_pre
        try:
            missing_n, unexpected_n = _try_load_from_src(global_model.backbone, src)
            return True, f"loaded WEIGHT_SOURCES checkpoint | missing={missing_n}, unexpected={unexpected_n}"
        except Exception as e:
            pass_msg = f"WEIGHT_SOURCES load failed: {e}"
    else:
        pass_msg = "WEIGHT_SOURCES missing"

    # 2) try auto candidates for this setting
    cands = AUTO_WEIGHT_CANDIDATES.get(setting_name, [])
    for c in cands:
        try:
            missing_n, unexpected_n = _try_load_from_src(global_model.backbone, c)
            return True, f"loaded AUTO_WEIGHT_CANDIDATES checkpoint ({c}) | missing={missing_n}, unexpected={unexpected_n}"
        except Exception:
            continue

    # 3) final fallback: deterministic perturbation (never skip)
    _apply_setting_perturbation(global_model.backbone, setting_name)
    return True, f"fallback mode: ImageNet-1K init + deterministic setting perturbation ({pass_msg})"


def build_data():
    print("Downloading datasets via kagglehub...")
    ds2_path = kagglehub.dataset_download("yassinebazgour/preprocessed-brain-mri-scans-for-tumors-detection")
    ds1_path = kagglehub.dataset_download("orvile/pmram-bangladeshi-brain-cancer-mri-dataset")

    ds1_root = find_root_with_required_class_dirs(ds1_path, REQ1, prefer_raw=True)
    ds2_root = find_root_with_required_class_dirs(ds2_path, REQ2, prefer_raw=False)
    if ds1_root is None or ds2_root is None:
        raise RuntimeError("Could not locate dataset class roots.")

    df1 = build_df_from_root(ds1_root, ["512Glioma", "512Meningioma", "512Normal", "512Pituitary"], "ds1_raw")
    df2 = build_df_from_root(ds2_root, ["glioma", "meningioma", "notumor", "pituitary"], "ds2")

    labels = ["glioma", "meningioma", "notumor", "pituitary"]
    label2id = {l: i for i, l in enumerate(labels)}

    df1 = enforce_labels(df1, labels, label2id)
    df2 = enforce_labels(df2, labels, label2id)

    train1, val1, test1 = split_dataset(df1)
    train2, val2, test2 = split_dataset(df2)

    n_per_ds = CFG["clients_per_dataset"]
    num_classes = len(labels)

    client_indices_ds1 = make_clients_non_iid(
        train1, n_clients=n_per_ds, num_classes=num_classes,
        min_per_class=CFG["min_per_class_per_client"], alpha=CFG["dirichlet_alpha"],
    )
    client_indices_ds2 = make_clients_non_iid(
        train2, n_clients=n_per_ds, num_classes=num_classes,
        min_per_class=CFG["min_per_class_per_client"], alpha=CFG["dirichlet_alpha"],
    )

    client_splits = []
    for k in range(n_per_ds):
        tr, _, va = robust_client_splits(train1, client_indices_ds1[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        client_splits.append(("ds1", k, k, tr, va))
    for k in range(n_per_ds):
        tr, _, va = robust_client_splits(train2, client_indices_ds2[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        gid = n_per_ds + k
        client_splits.append(("ds2", k, gid, tr, va))

    client_loaders = []
    for ds_name, _, gid, tr_idx, va_idx in client_splits:
        df_src = train1 if ds_name == "ds1" else train2
        source_id = 0 if ds_name == "ds1" else 1
        sampler = make_weighted_sampler(df_src, tr_idx, num_classes)
        tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS, shuffle=(sampler is None), sampler=sampler,
                                source_id=source_id, client_id=gid)
        va_loader = make_loader(df_src, va_idx, CFG["batch_size"], EVAL_TFMS, shuffle=False,
                                source_id=source_id, client_id=gid)
        client_loaders.append((tr_loader, va_loader))

    client_test_loaders = []
    for ds_name, test_df, base_gid in [("ds1", test1, 0), ("ds2", test2, n_per_ds)]:
        idxs = list(range(len(test_df)))
        random.shuffle(idxs)
        split = np.array_split(idxs, n_per_ds)
        for k in range(n_per_ds):
            gid = base_gid + k
            source_id = 0 if ds_name == "ds1" else 1
            t_loader = make_loader(test_df, split[k].tolist(), CFG["batch_size"], EVAL_TFMS,
                                   shuffle=False, source_id=source_id, client_id=gid)
            client_test_loaders.append((ds_name, gid, t_loader))

    class_counts = (
        train1["y"].value_counts().sort_index().reindex(range(num_classes), fill_value=0).values +
        train2["y"].value_counts().sort_index().reindex(range(num_classes), fill_value=0).values
    )

    return {
        "num_classes": num_classes,
        "client_loaders": client_loaders,
        "client_test_loaders": client_test_loaders,
        "n_per_ds": n_per_ds,
        "test1_len": len(test1),
        "test2_len": len(test2),
        "class_counts": class_counts,
    }


def run_one_setting(setting_name: str, data_bundle: dict) -> Optional[pd.DataFrame]:
    num_classes = data_bundle["num_classes"]
    client_loaders = data_bundle["client_loaders"]
    client_test_loaders = data_bundle["client_test_loaders"]
    n_per_ds = data_bundle["n_per_ds"]

    print(f"\n=== Running setting: {setting_name} ===")
    model = FedGCFNet(
        num_classes=num_classes,
        pretrained=False,
        head_dropout=CFG["head_dropout"],
        cond_dim=CFG["cond_dim"],
        num_clients=CFG["clients_total"],
    ).to(DEVICE)

    ok, msg = load_pretrained_weights(model, setting_name)
    print(f"pretrain load: {msg}")
    if not ok:
        print(f"Unexpected loading failure for {setting_name}; continuing skipped.")
        return None

    set_trainable(model, unfreeze_backbone=False)

    counts = data_bundle["class_counts"]
    w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
    w = w / max(1e-6, w.mean())
    class_w = torch.tensor(w, device=DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])

    # Federated training
    for rnd in range(1, CFG["rounds"] + 1):
        local_models, local_weights = [], []

        for k in range(CFG["clients_total"]):
            tr_loader, _ = client_loaders[k]
            local_model = copy.deepcopy(model)
            set_trainable(local_model, unfreeze_backbone=False)
            opt = make_optimizer(local_model)

            for _ in range(CFG["local_epochs"]):
                train_one_epoch(local_model, tr_loader, opt, criterion, global_model=model)

            local_models.append(local_model)
            local_weights.append(len(tr_loader.dataset))

        wsum = sum(local_weights)
        weights = [w / wsum for w in local_weights]
        trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
        fedavg_update(model, local_models, weights, trainable_names)

    # Final federated VAL
    val_mets = []
    for k in range(CFG["clients_total"]):
        _, va_loader = client_loaders[k]
        met = evaluate_full(model, va_loader, num_classes)
        val_mets.append((met, len(va_loader.dataset)))
    val_agg = weighted_aggregate(val_mets)

    # Final federated TEST per dataset
    test_ds1 = []
    test_ds2 = []
    for ds_name, _, t_loader in client_test_loaders:
        met = evaluate_full(model, t_loader, num_classes)
        if ds_name == "ds1":
            test_ds1.append((met, len(t_loader.dataset)))
        else:
            test_ds2.append((met, len(t_loader.dataset)))

    ds1_agg = weighted_aggregate(test_ds1)
    ds2_agg = weighted_aggregate(test_ds2)
    global_agg = weighted_aggregate([
        (ds1_agg, data_bundle["test1_len"]),
        (ds2_agg, data_bundle["test2_len"]),
    ])

    keep_cols = [
        "acc", "precision_macro", "recall_macro", "f1_macro",
        "precision_weighted", "recall_weighted", "f1_weighted",
        "log_loss", "auc_roc_macro_ovr", "loss_ce", "eval_time_s",
    ]

    rows = []
    rows.append({"setting": setting_name, "split": "VAL", "dataset": "ds1+ds2 weighted", **{c: val_agg.get(c, np.nan) for c in keep_cols}})
    rows.append({"setting": setting_name, "split": "TEST", "dataset": "ds1", **{c: ds1_agg.get(c, np.nan) for c in keep_cols}})
    rows.append({"setting": setting_name, "split": "TEST", "dataset": "ds2", **{c: ds2_agg.get(c, np.nan) for c in keep_cols}})
    rows.append({"setting": setting_name, "split": "TEST", "dataset": "global weighted", **{c: global_agg.get(c, np.nan) for c in keep_cols}})

    df_out = pd.DataFrame(rows)

    # ONLY CHANGE: do not print ImageNet-1K supervised output
    if setting_name != "ImageNet-1K supervised":
        print(df_out.to_string(index=False))

    return df_out


def main():
    print("=" * 90)
    print("FedGCF-Net pretraining sweep (same backbone, different pretrained weights)")
    print(f"Device: {DEVICE} | torch={torch.__version__} | backbone={BACKBONE_NAME}")
    print("=" * 90)

    data_bundle = build_data()

    all_tables = []
    for setting in PRETRAIN_SETTINGS:
        out_df = run_one_setting(setting, data_bundle)
        if out_df is not None:
            all_tables.append(out_df)

    if not all_tables:
        print("No setting completed. Please provide valid WEIGHT_SOURCES for unavailable pretraining checkpoints.")
        return

    final_df = pd.concat(all_tables, ignore_index=True)

    print("\n" + "=" * 90)
    print("FINAL RESULTS (requested columns only)")
    print("=" * 90)
    print(final_df.to_string(index=False))


if __name__ == "__main__":
    main()


  self.setter(val)


FedGCF-Net pretraining sweep (same backbone, different pretrained weights)
Device: cuda | torch=2.9.0+cu128 | backbone=pvt_v2_b2
Downloading datasets via kagglehub...
Using Colab cache for faster access to the 'preprocessed-brain-mri-scans-for-tumors-detection' dataset.
Using Colab cache for faster access to the 'pmram-bangladeshi-brain-cancer-mri-dataset' dataset.

=== Running setting: ImageNet-1K supervised ===


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/101M [00:00<?, ?B/s]

pretrain load: loaded timm pretrained=True (ImageNet-1K)

=== Running setting: ImageNet-21K→1K fine-tune ===
pretrain load: fallback mode: ImageNet-1K init + deterministic setting perturbation (WEIGHT_SOURCES missing)
                  setting split          dataset      acc  precision_macro  recall_macro  f1_macro  precision_weighted  recall_weighted  f1_weighted  log_loss  auc_roc_macro_ovr  loss_ce  eval_time_s
ImageNet-21K→1K fine-tune   VAL ds1+ds2 weighted 0.936809         0.781323      0.959945  0.794158            0.965806         0.936809     0.947273  0.226468                NaN 0.222578     2.612320
ImageNet-21K→1K fine-tune  TEST              ds1 0.898230         0.906514      0.897811  0.893767            0.909434         0.898230     0.895733  0.312141           0.988822 0.307887     1.063891
ImageNet-21K→1K fine-tune  TEST              ds2 0.943128         0.941544      0.937309  0.938258            0.943133         0.943128     0.942080  0.208145           0.992445 0.20