In [None]:
# ======================================================================================
# KAGGLE NOTEBOOK: TRUE FL + GA-FELCM + PVTv2-B2 (FUSION) — 12 Clients (3x4 datasets)
# DS1/DS2/DS3: kagglehub download
# DS4: Kaggle dataset input path (YOU UPLOADED) -> /kaggle/input/datasets/mdzubayerahmadshibly/ds4mine
# Outputs: /kaggle/working/outputs/
# ======================================================================================

import os, time, math, random, sys, subprocess, hashlib
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    log_loss, confusion_matrix, roc_auc_score,
    roc_curve, precision_recall_curve
)

def pip_install(pkg):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

# timm
try:
    import timm
except Exception:
    pip_install("timm")
    import timm

# kagglehub
try:
    import kagglehub
except Exception:
    pip_install("kagglehub")
    import kagglehub

from torchvision import transforms

try:
    from IPython.display import display
except Exception:
    display = print

# -------------------------
# Repro / device
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

print("=" * 92)
print("KAGGLE: TRUE FL + GA-FELCM + PVTv2-B2 (FUSION) — 12 Clients (3x4 datasets) | AUG=ON")
print("=" * 92)
print(f"DEVICE: {DEVICE} | torch={torch.__version__}")
print("=" * 92)

CFG = {
    "clients_per_dataset": 3,
    "clients_total": 12,
    "rounds": 12,
    "local_epochs": 2,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "warmup_epochs": 1,
    "label_smoothing": 0.08,
    "grad_clip": 1.0,
    "fedprox_mu": 0.01,
    "img_size": 224 if torch.cuda.is_available() else 160,
    "batch_size": 20 if torch.cuda.is_available() else 10,
    "num_workers": 2 if torch.cuda.is_available() else 0,
    "global_val_frac": 0.15,
    "test_frac": 0.15,
    "client_val_frac": 0.12,
    "client_tune_frac": 0.12,
    "min_per_class_per_client": 5,
    "dirichlet_alpha": 0.35,
    "use_preprocessing": True,
    "use_ga": True,
    "ga_pop": 10,
    "ga_gens": 5,
    "ga_elites": 3,
    "elite_pool_max": 18,
    "use_augmentation": True,
    "cond_dim": 128,
    "head_dropout": 0.3,
    "unfreeze_after_round": 3,
    "unfreeze_lr_mult": 0.10,
    "unfreeze_tail_frac": 0.17,
    "quick_hash_subset_per_split": 300,
    "preproc_val_sample_n": 500,
    "before_after_n": 12,
}

OUTDIR = "/kaggle/working/outputs"
os.makedirs(OUTDIR, exist_ok=True)
MODEL_PATH = os.path.join(OUTDIR, "FL_GAFELCM_PVTv2B2_FUSION_checkpoint.pth")
CSV_PATH = os.path.join(OUTDIR, "ALL_OUTPUTS_AND_METRICS.csv")

IMG_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp")

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1, 3, 1, 1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1, 3, 1, 1)

ALL_ROWS = []

def add_table_to_csv(df, table_name):
    df2 = df.copy()
    df2.insert(0, "table_name", table_name)
    for _, row in df2.iterrows():
        ALL_ROWS.append(row.to_dict())

def print_table(df, title):
    print("\n" + "-" * 92)
    print(title)
    print("-" * 92)
    display(df)

# -------------------------
# Helper: resolve nested roots
# -------------------------
def pick_probable_root(base_dir: str) -> str:
    """
    Some downloads wrap data inside a single folder. This tries to descend
    while there is exactly one subfolder and no images at current level.
    """
    base_dir = str(base_dir)
    cur = base_dir
    for _ in range(6):
        # any images anywhere under cur?
        any_img = any(str(p).lower().endswith(IMG_EXTS) for p in Path(cur).rglob("*"))
        if any_img:
            return cur
        subs = [p for p in Path(cur).iterdir() if p.is_dir()]
        if len(subs) == 1:
            cur = str(subs[0])
            continue
        return cur
    return cur

# ======================================================================================
# STEP 0: DOWNLOAD DS1/DS2/DS3 (kagglehub) + DS4 (Kaggle input path)
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 0: DATASETS (DS1/2/3 via kagglehub) + DS4 from Kaggle input path")
print("=" * 92)

# NOTE: Kagglehub needs Internet ON in Kaggle notebook settings for downloads.
ds1_path = kagglehub.dataset_download("alamshihab075/brain-tumor-mri-dataset-for-deep-learning")
ds2_path = kagglehub.dataset_download("zehrakucuker/brain-tumor-mri-images-classification-dataset")
ds3_path = kagglehub.dataset_download("chubskuy/brain-tumor-image")

# Your uploaded Kaggle dataset path:
ds4_path = "/kaggle/input/datasets/mdzubayerahmadshibly/ds4mine"

ds1_path = pick_probable_root(ds1_path)
ds2_path = pick_probable_root(ds2_path)
ds3_path = pick_probable_root(ds3_path)
ds4_path = pick_probable_root(ds4_path)

print("✅ DS1:", ds1_path)
print("✅ DS2:", ds2_path)
print("✅ DS3:", ds3_path)
print("✅ DS4:", ds4_path)

CFG["ds1_base"] = ds1_path
CFG["ds2_base"] = ds2_path
CFG["ds3_base"] = ds3_path
CFG["ds4_base"] = ds4_path

DATASET_NAMES = ["ds1", "ds2", "ds3", "ds4"]
DATASET_BASES = {
    "ds1": CFG["ds1_base"],
    "ds2": CFG["ds2_base"],
    "ds3": CFG["ds3_base"],
    "ds4": CFG["ds4_base"],
}

# ======================================================================================
# STEP 1: DISCOVER + MERGE DATASET IMAGES BY CLASS
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 1: DISCOVER + MERGE DATASET IMAGES BY CLASS")
print("=" * 92)

def norm_label(name: str):
    s = str(name).strip().lower().replace("_", " ").replace("-", " ")
    if "glioma" in s:
        return "glioma"
    if "meningioma" in s:
        return "meningioma"
    if "pituitary" in s:
        return "pituitary"
    if "normal" in s or "no tumor" in s or "notumor" in s or "no tumour" in s:
        return "notumor"
    return None

def infer_label_from_path(path):
    parts = Path(path).parts
    for part in reversed(parts):
        lab = norm_label(part)
        if lab is not None:
            return lab
    return None

def build_df_from_dataset_tree(base_dir, source_name):
    rows = []
    for root, _, files in os.walk(base_dir):
        for fn in files:
            if not fn.lower().endswith(IMG_EXTS):
                continue
            p = os.path.join(root, fn)
            lab = infer_label_from_path(p)
            if lab is None:
                continue
            rows.append({"path": p, "label": lab, "source": source_name})

    if not rows:
        raise RuntimeError(
            f"No class-labeled images found under {base_dir}. Expected path segments containing "
            f"glioma/meningioma/pituitary/normal(notumor)."
        )

    dfm = pd.DataFrame(rows).dropna().reset_index(drop=True)
    dfm["path"] = dfm["path"].astype(str)
    dfm["label"] = dfm["label"].astype(str)
    dfm["source"] = dfm["source"].astype(str)
    dfm = dfm.drop_duplicates(subset=["path"]).reset_index(drop=True)
    dfm["filename"] = dfm["path"].apply(lambda x: os.path.basename(x))

    counts = dfm["label"].value_counts().reindex(["glioma", "meningioma", "notumor", "pituitary"], fill_value=0)
    print(f"{source_name}: total images = {len(dfm)} | " +
          ", ".join([f"{k}:{int(v)}" for k, v in counts.items()]))
    return dfm

all_dfs = {}
for ds_name in DATASET_NAMES:
    all_dfs[ds_name] = build_df_from_dataset_tree(DATASET_BASES[ds_name], ds_name)

df1, df2, df3, df4 = all_dfs["ds1"], all_dfs["ds2"], all_dfs["ds3"], all_dfs["ds4"]

labels = ["glioma", "meningioma", "notumor", "pituitary"]
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}
NUM_CLASSES = len(labels)

def enforce_labels(df_):
    df_ = df_.copy()
    df_["label"] = df_["label"].astype(str).str.strip().str.lower()
    df_ = df_[df_["label"].isin(set(labels))].reset_index(drop=True)
    df_["y"] = df_["label"].map(label2id).astype(int)
    return df_

df1 = enforce_labels(df1)
df2 = enforce_labels(df2)
df3 = enforce_labels(df3)
df4 = enforce_labels(df4)
all_dfs = {"ds1": df1, "ds2": df2, "ds3": df3, "ds4": df4}

# ======================================================================================
# STEP 2: TRAIN/VAL/TEST SPLIT
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 2: TRAIN/VAL/TEST SPLIT (PER DATASET)")
print("=" * 92)

def split_dataset(df_):
    train_df, temp_df = train_test_split(
        df_,
        test_size=(CFG["global_val_frac"] + CFG["test_frac"]),
        stratify=df_["y"],
        random_state=SEED,
    )
    val_rel = CFG["global_val_frac"] / (CFG["global_val_frac"] + CFG["test_frac"])
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_rel),
        stratify=temp_df["y"],
        random_state=SEED,
    )
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)

train_frames, val_frames, test_frames = {}, {}, {}
for ds_name in DATASET_NAMES:
    tr, va, te = split_dataset(all_dfs[ds_name])
    train_frames[ds_name] = tr
    val_frames[ds_name] = va
    test_frames[ds_name] = te
    print(f"{ds_name.upper()} TRAIN={len(tr)} | VAL={len(va)} | TEST={len(te)}")

# ======================================================================================
# STEP 2.5: LEAKAGE CHECKS
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 2.5: SANITY / LEAKAGE CHECKS (PER DATASET)")
print("=" * 92)

def split_overlap_checks(train_df, val_df, test_df):
    tr = set(train_df["path"].tolist())
    va = set(val_df["path"].tolist())
    te = set(test_df["path"].tolist())
    checks = {
        "path_overlap_train_val": len(tr.intersection(va)),
        "path_overlap_train_test": len(tr.intersection(te)),
        "path_overlap_val_test": len(va.intersection(te)),
        "unique_paths_train": len(tr),
        "unique_paths_val": len(va),
        "unique_paths_test": len(te),
    }
    trf = set(train_df["filename"].tolist())
    vaf = set(val_df["filename"].tolist())
    tef = set(test_df["filename"].tolist())
    checks.update(
        {
            "filename_overlap_train_val": len(trf.intersection(vaf)),
            "filename_overlap_train_test": len(trf.intersection(tef)),
            "filename_overlap_val_test": len(vaf.intersection(tef)),
        }
    )
    return checks

def md5_file(path, max_bytes=2_000_000):
    h = hashlib.md5()
    try:
        with open(path, "rb") as f:
            h.update(f.read(max_bytes))
        return h.hexdigest()
    except Exception:
        return None

def quick_hash_subset(frame, n=300):
    n = min(n, len(frame))
    if n <= 0:
        return set()
    idx = np.random.choice(len(frame), size=n, replace=False)
    hashes = []
    for i in idx:
        hv = md5_file(frame.iloc[i]["path"])
        if hv is not None:
            hashes.append(hv)
    return set(hashes)

def leakage_report(name, tr, va, te):
    over = split_overlap_checks(tr, va, te)
    leak_df = pd.DataFrame([over])

    n_hash = int(CFG["quick_hash_subset_per_split"])
    trh = quick_hash_subset(tr, n_hash)
    vah = quick_hash_subset(va, n_hash)
    teh = quick_hash_subset(te, n_hash)

    hash_over = {
        "subset_hash_train_val": len(trh.intersection(vah)),
        "subset_hash_train_test": len(trh.intersection(teh)),
        "subset_hash_val_test": len(vah.intersection(teh)),
        "subset_hash_n_train": len(trh),
        "subset_hash_n_val": len(vah),
        "subset_hash_n_test": len(teh),
    }
    leak_df = pd.concat([leak_df, pd.DataFrame([hash_over])], axis=1)
    print_table(leak_df, f"Leakage / Sanity Summary — {name}")
    add_table_to_csv(leak_df, f"leakage_sanity_{name}")

for ds_name in DATASET_NAMES:
    leakage_report(ds_name, train_frames[ds_name], val_frames[ds_name], test_frames[ds_name])

# ======================================================================================
# STEP 3: NON-IID CLIENT PARTITIONING
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 3: NON-IID CLIENT PARTITIONING (3 clients per dataset => 12 total)")
print("=" * 92)

def make_clients_non_iid(train_df, n_clients, num_classes, min_per_class=5, alpha=0.35):
    y = train_df["y"].values
    idx_by_class = {c: np.where(y == c)[0].tolist() for c in range(num_classes)}
    for c in idx_by_class:
        random.shuffle(idx_by_class[c])

    client_indices = [[] for _ in range(n_clients)]

    # guarantee some per-class
    for c in range(num_classes):
        idxs = idx_by_class[c]
        feasible = min(min_per_class, max(1, len(idxs) // n_clients))
        for k in range(n_clients):
            take = idxs[:feasible]
            idxs = idxs[feasible:]
            client_indices[k].extend(take)
        idx_by_class[c] = idxs

    # dirichlet remainder
    for c in range(num_classes):
        idxs = idx_by_class[c]
        if len(idxs) == 0:
            continue
        props = np.random.dirichlet([alpha] * n_clients)
        counts = (props * len(idxs)).astype(int)
        diff = len(idxs) - counts.sum()
        counts[np.argmax(props)] += diff

        start = 0
        for k in range(n_clients):
            client_indices[k].extend(idxs[start: start + counts[k]])
            start += counts[k]

    for k in range(n_clients):
        random.shuffle(client_indices[k])
    return client_indices

def robust_client_splits(train_df, indices, val_frac, tune_frac):
    idxs = np.array(indices, dtype=int)
    if len(idxs) < 3:
        return idxs.tolist(), idxs.tolist(), idxs.tolist()

    yk = train_df.loc[idxs, "y"].values
    if len(np.unique(yk)) < 2 or len(idxs) < 20:
        n_tune = max(1, int(round(len(idxs) * tune_frac)))
        n_tune = min(n_tune, max(1, len(idxs) - 2))
        tune_idx = idxs[:n_tune]
        rem_idx = idxs[n_tune:]
    else:
        rem_idx, tune_idx = train_test_split(
            idxs, test_size=tune_frac, stratify=yk, random_state=SEED
        )

    if len(rem_idx) < 2:
        return rem_idx.tolist(), tune_idx.tolist(), rem_idx.tolist()

    yk2 = train_df.loc[rem_idx, "y"].values
    if len(np.unique(yk2)) < 2 or len(rem_idx) < 12:
        n_val = max(1, int(round(len(rem_idx) * val_frac)))
        n_val = min(n_val, max(1, len(rem_idx) - 1))
        val_idx = rem_idx[:n_val]
        train_idx = rem_idx[n_val:]
    else:
        train_idx, val_idx = train_test_split(
            rem_idx, test_size=val_frac, stratify=yk2, random_state=SEED
        )

    if len(train_idx) == 0:
        train_idx = val_idx[:]
    if len(val_idx) == 0:
        val_idx = train_idx[:1]
    return train_idx.tolist(), tune_idx.tolist(), val_idx.tolist()

n_per_ds = CFG["clients_per_dataset"]
client_splits = []
client_test_splits = []
base_gid = 0

for ds_name in DATASET_NAMES:
    train_df = train_frames[ds_name]
    test_df  = test_frames[ds_name]

    client_indices = make_clients_non_iid(
        train_df,
        n_clients=n_per_ds,
        num_classes=NUM_CLASSES,
        min_per_class=CFG["min_per_class_per_client"],
        alpha=CFG["dirichlet_alpha"],
    )

    for k in range(n_per_ds):
        tr, tune, va = robust_client_splits(train_df, client_indices[k], CFG["client_val_frac"], CFG["client_tune_frac"])
        gid = base_gid + k
        client_splits.append((ds_name, k, gid, tr, tune, va))
        print(f"{ds_name.upper()} Client {k} (gid {gid}): train={len(tr)} tune={len(tune)} val={len(va)}")

    idxs = list(range(len(test_df)))
    random.shuffle(idxs)
    split = np.array_split(idxs, n_per_ds)
    for k in range(n_per_ds):
        client_test_splits.append((ds_name, k, base_gid + k, split[k].tolist()))

    base_gid += n_per_ds

gid_to_ds = {gid: ds_name for (ds_name, _, gid, _, _, _) in client_splits}

def client_distribution_table():
    dist_rows = []
    for (ds_name, local_id, gid, tr_idx, tune_idx, val_idx) in client_splits:
        df_src = train_frames[ds_name]
        counts = df_src.loc[tr_idx, "label"].value_counts().reindex(labels, fill_value=0)
        row = {
            "client": f"client_{gid}",
            "dataset": ds_name,
            "total_train": len(tr_idx),
            "total_tune": len(tune_idx),
            "total_val": len(val_idx),
        }
        row.update({lab: int(counts[lab]) for lab in labels})
        dist_rows.append(row)
    return pd.DataFrame(dist_rows)

dist_df = client_distribution_table()
print_table(dist_df, "Client class distribution (Non-IID, per dataset)")
add_table_to_csv(dist_df, "client_distribution")

# ======================================================================================
# STEP 4: DATA LOADERS
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 4: DATA LOADERS (AUG ON) + IMAGENET NORM")
print("=" * 92)

def load_rgb(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return Image.new("RGB", (CFG["img_size"], CFG["img_size"]), (128, 128, 128))

EVAL_TFMS = transforms.Compose([
    transforms.Resize((CFG["img_size"], CFG["img_size"])),
    transforms.ToTensor(),
])

if CFG["use_augmentation"]:
    TRAIN_TFMS = transforms.Compose([
        transforms.Resize((CFG["img_size"], CFG["img_size"])),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
    ])
else:
    TRAIN_TFMS = EVAL_TFMS

class MRIDataset(Dataset):
    def __init__(self, frame, indices=None, tfms=None, source_id=0, client_id=0):
        self.df = frame
        self.indices = indices if indices is not None else list(range(len(frame)))
        self.tfms = tfms
        self.source_id = int(source_id)
        self.client_id = int(client_id)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        j = self.indices[i]
        row = self.df.iloc[j]
        img = load_rgb(row["path"])
        x = self.tfms(img) if self.tfms is not None else transforms.ToTensor()(img)
        y = int(row["y"])
        return x, y, row["path"], self.source_id, self.client_id

def make_weighted_sampler(frame, indices, num_classes):
    if len(indices) == 0:
        return None
    ys = frame.loc[indices, "y"].values
    class_counts = np.bincount(ys, minlength=num_classes)
    class_weights = 1.0 / np.clip(class_counts, 1, None)
    sample_weights = class_weights[ys]
    return WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True,
    )

def make_loader(frame, indices, bs, tfms, shuffle=False, sampler=None, source_id=0, client_id=0):
    ds = MRIDataset(frame, indices=indices, tfms=tfms, source_id=source_id, client_id=client_id)
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=(shuffle and sampler is None),
        sampler=sampler,
        num_workers=CFG["num_workers"],
        pin_memory=(DEVICE.type == "cuda"),
        drop_last=False,
        persistent_workers=(CFG["num_workers"] > 0),
    )

client_loaders = []
for (ds_name, local_id, gid, tr_idx, tune_idx, val_idx) in client_splits:
    df_src = train_frames[ds_name]
    source_id = DATASET_NAMES.index(ds_name)
    sampler = make_weighted_sampler(df_src, tr_idx, NUM_CLASSES)

    tr_loader = make_loader(df_src, tr_idx, CFG["batch_size"], TRAIN_TFMS,
                            shuffle=(sampler is None), sampler=sampler,
                            source_id=source_id, client_id=gid)

    tune_loader = make_loader(df_src, tune_idx if len(tune_idx) else tr_idx[:max(1, len(tr_idx))],
                              CFG["batch_size"], EVAL_TFMS, shuffle=True,
                              source_id=source_id, client_id=gid)

    val_loader = make_loader(df_src, val_idx if len(val_idx) else tr_idx[:max(1, min(len(tr_idx), CFG["batch_size"]))],
                             CFG["batch_size"], EVAL_TFMS, shuffle=False,
                             source_id=source_id, client_id=gid)

    client_loaders.append((tr_loader, tune_loader, val_loader))

client_test_loaders = []
for (ds_name, local_id, gid, test_idx) in client_test_splits:
    df_src = test_frames[ds_name]
    source_id = DATASET_NAMES.index(ds_name)
    t_loader = make_loader(df_src, test_idx, CFG["batch_size"], EVAL_TFMS,
                           shuffle=False, source_id=source_id, client_id=gid)
    client_test_loaders.append((ds_name, local_id, gid, t_loader))

print(f"Augmentation: {'ON ✅' if CFG['use_augmentation'] else 'OFF'}")
print(f"Preprocessing: {'ON ✅' if CFG['use_preprocessing'] else 'OFF'}")

# ======================================================================================
# STEP 5: GA-TUNED ENHANCED FELCM PREPROCESSOR
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 5: GA-TUNED ENHANCED FELCM PREPROCESSOR")
print("=" * 92)

THETA_FULLFORMS = {
    "gamma": "Power transform exponent (γ)",
    "alpha": "Local contrast weight (α)",
    "beta": "Contrast sharpness (β)",
    "tau": "Robust clipping threshold (τ)",
    "k": "Blur kernel size (k) for local contrast map",
    "sh": "Sharpen strength (sh)",
    "dn": "Denoise strength (dn)",
}

class EnhancedFELCM(nn.Module):
    def __init__(self, gamma=1.0, alpha=0.35, beta=6.0, tau=2.5, blur_k=7, sharpen=0.0, denoise=0.0):
        super().__init__()
        self.gamma = float(gamma)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.tau = float(tau)
        self.blur_k = int(blur_k)
        self.sharpen = float(sharpen)
        self.denoise = float(denoise)

        lap = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("lap", lap.view(1, 1, 3, 3))

        sharp = torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]], dtype=torch.float32)
        self.register_buffer("sharp_kernel", sharp.view(1, 1, 3, 3))

    def forward(self, x):
        eps = 1e-6
        B, C, H, W = x.shape

        if self.denoise > 0:
            k = 3
            x_blur = F.avg_pool2d(F.pad(x, (1, 1, 1, 1), mode="reflect"), k, 1)
            x = x * (1 - self.denoise) + x_blur * self.denoise

        mu = x.mean(dim=(2, 3), keepdim=True)
        sd = x.std(dim=(2, 3), keepdim=True).clamp_min(eps)
        x0 = (x - mu) / sd
        x0 = x0.clamp(-self.tau, self.tau)

        x1 = torch.sign(x0) * torch.pow(torch.abs(x0).clamp_min(eps), self.gamma)

        gray = x1.mean(dim=1, keepdim=True)
        lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), self.lap)
        mag = lap.abs()

        k = self.blur_k if self.blur_k % 2 == 1 else self.blur_k + 1
        pad = k // 2
        blur = F.avg_pool2d(F.pad(mag, (pad, pad, pad, pad), mode="reflect"), k, 1)
        C_map = mag / (blur + eps)

        x2 = x1 + self.alpha * torch.tanh(self.beta * C_map)

        if self.sharpen > 0:
            outs = []
            for c in range(C):
                x_c = x2[:, c: c + 1, :, :]
                x_sharp = F.conv2d(F.pad(x_c, (1, 1, 1, 1), mode="reflect"), self.sharp_kernel)
                outs.append(x_c * (1 - self.sharpen) + x_sharp * self.sharpen)
            x2 = torch.cat(outs, dim=1)

        mn = x2.amin(dim=(2, 3), keepdim=True)
        mx = x2.amax(dim=(2, 3), keepdim=True)
        x3 = (x2 - mn) / (mx - mn + eps)
        return x3.clamp(0, 1)

def theta_to_module(theta):
    return EnhancedFELCM(*theta)

def random_theta():
    gamma = random.uniform(0.7, 1.4)
    alpha = random.uniform(0.15, 0.55)
    beta = random.uniform(3.0, 9.0)
    tau = random.uniform(1.8, 3.2)
    blur_k = random.choice([3, 5, 7])
    sharpen = random.uniform(0.0, 0.25)
    denoise = random.uniform(0.0, 0.2)
    return (gamma, alpha, beta, tau, blur_k, sharpen, denoise)

def mutate(theta, p=0.8):
    if random.random() > p:
        return theta
    g, a, b, t, k, sh, dn = theta
    g = float(np.clip(g + np.random.normal(0, 0.06), 0.6, 1.5))
    a = float(np.clip(a + np.random.normal(0, 0.05), 0.08, 0.7))
    b = float(np.clip(b + np.random.normal(0, 0.5), 2.0, 11.0))
    t = float(np.clip(t + np.random.normal(0, 0.2), 1.5, 3.8))
    if random.random() < 0.3:
        k = random.choice([3, 5, 7])
    sh = float(np.clip(sh + np.random.normal(0, 0.04), 0.0, 0.35))
    dn = float(np.clip(dn + np.random.normal(0, 0.03), 0.0, 0.3))
    return (g, a, b, t, int(k), sh, dn)

def crossover(t1, t2):
    return tuple(random.choice([a, b]) for a, b in zip(t1, t2))

def theta_str(th):
    if th is None:
        return "None"
    g, a, b, t, k, sh, dn = th
    return f"(γ={g:.2f}, α={a:.2f}, β={b:.1f}, τ={t:.1f}, k={k}, sh={sh:.2f}, dn={dn:.2f})"

IDENTITY_PRE = nn.Identity().to(DEVICE)

# ======================================================================================
# STEP 6: MODEL (PVTv2-B2 + MULTI-SCALE FUSION)
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 6: MODEL (PVTv2-B2 + MULTI-SCALE FUSION)")
print("=" * 92)

BACKBONE_NAME = "pvt_v2_b2"

class TokenAttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.query = nn.Linear(dim, 1)

    def forward(self, x):  # x: [B, N, C]
        attn = torch.softmax(self.query(x).squeeze(-1), dim=1)  # [B, N]
        return (x * attn.unsqueeze(-1)).sum(dim=1)              # [B, C]

class MultiScaleFeatureFuser(nn.Module):
    def __init__(self, in_channels: List[int], out_dim: int):
        super().__init__()
        self.proj = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c, out_dim, kernel_size=1, bias=False),
                nn.GroupNorm(8, out_dim),
                nn.GELU(),
            ) for c in in_channels
        ])
        self.fuse = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(8, out_dim),
            nn.GELU(),
        )
        self.pool = TokenAttentionPooling(out_dim)

    def forward(self, feats):
        proj_feats = [p(f) for p, f in zip(self.proj, feats)]
        x = proj_feats[-1]
        for f in reversed(proj_feats[:-1]):
            x = F.interpolate(x, size=f.shape[-2:], mode="bilinear", align_corners=False)
            x = x + f
        x = self.fuse(x)
        B, C, H, W = x.shape
        tokens = x.flatten(2).transpose(1, 2)  # [B, HW, C]
        pooled = self.pool(tokens)
        return pooled

class EnhancedBrainTuner(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.se = nn.Sequential(
            nn.Linear(dim, max(8, dim // 4)),
            nn.ReLU(inplace=True),
            nn.Linear(max(8, dim // 4), dim),
            nn.Sigmoid(),
        )
        self.refine = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.gate = nn.Parameter(torch.ones(2) / 2)

    def forward(self, x):
        gate = F.softmax(self.gate, dim=0)
        out1 = x * self.se(x)
        out2 = x + 0.2 * self.refine(x)
        return gate[0] * out1 + gate[1] * out2

class PVTv2B2_MultiScale(nn.Module):
    def __init__(self, num_classes, pretrained=True, head_dropout=0.3, cond_dim=128, num_clients=6, num_sources=2):
        super().__init__()
        self.backbone = timm.create_model(
            BACKBONE_NAME,
            pretrained=pretrained,
            features_only=True,
            out_indices=(0, 1, 2, 3),
        )
        in_channels = self.backbone.feature_info.channels()
        out_dim = max(256, in_channels[-1] // 2)

        self.fuser = MultiScaleFeatureFuser(in_channels, out_dim)
        self.tuner = EnhancedBrainTuner(out_dim, dropout=0.1)

        self.classifier = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(head_dropout),
            nn.Linear(out_dim, max(64, out_dim // 2)),
            nn.GELU(),
            nn.Dropout(head_dropout * 0.5),
            nn.Linear(max(64, out_dim // 2), num_classes),
        )

        self.theta_mlp = nn.Sequential(
            nn.Linear(7, cond_dim),
            nn.GELU(),
            nn.Linear(cond_dim, cond_dim),
        )
        self.source_emb = nn.Embedding(num_sources, cond_dim)
        self.client_emb = nn.Embedding(num_clients, cond_dim)
        self.cond_norm = nn.LayerNorm(cond_dim)

        self.gate_early = nn.Linear(cond_dim, 3)
        self.gate_mid = nn.Linear(cond_dim, out_dim)
        self.gate_late = nn.Linear(cond_dim, out_dim)

    def _cond_vec(self, theta_vec, source_id, client_id):
        cond = self.theta_mlp(theta_vec)
        cond = cond + self.source_emb(source_id) + self.client_emb(client_id)
        return self.cond_norm(cond)

    def forward(self, x_raw_n, x_fel_n, theta_vec, source_id, client_id, return_gates=False):
        cond = self._cond_vec(theta_vec, source_id, client_id)

        g0 = torch.sigmoid(self.gate_early(cond)).view(-1, 3, 1, 1)
        x0 = (1 - g0) * x_raw_n + g0 * x_fel_n

        feats0 = self.backbone(x0)
        feats1 = self.backbone(x_fel_n)

        f0 = self.fuser(feats0)
        f1 = self.fuser(feats1)

        g1 = torch.sigmoid(self.gate_mid(cond))
        f_mid = (1 - g1) * f0 + g1 * f1

        t0 = self.tuner(f0)
        t1 = self.tuner(f1)
        t_mid = self.tuner(f_mid)

        t_views = 0.5 * (t0 + t1)
        g2 = torch.sigmoid(self.gate_late(cond))
        t_final = (1 - g2) * t_mid + g2 * t_views

        logits = self.classifier(t_final)

        if return_gates:
            return logits, {"g0": g0, "g1": g1, "g2": g2}
        return logits

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def set_trainable_for_round(model, rnd):
    for p in model.backbone.parameters():
        p.requires_grad = False
    for n, p in model.named_parameters():
        if not n.startswith("backbone."):
            p.requires_grad = True
    if rnd >= CFG["unfreeze_after_round"]:
        params = list(model.backbone.parameters())
        if len(params) > 0:
            tail_n = max(1, int(len(params) * CFG["unfreeze_tail_frac"]))
            for p in params[-tail_n:]:
                p.requires_grad = True

def make_optimizer(model):
    head_params, bb_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("backbone."):
            bb_params.append(p)
        else:
            head_params.append(p)

    groups = []
    if head_params:
        groups.append({"params": head_params, "lr": CFG["lr"]})
    if bb_params:
        groups.append({"params": bb_params, "lr": CFG["lr"] * CFG["unfreeze_lr_mult"]})
    return torch.optim.AdamW(groups, weight_decay=CFG["weight_decay"])

# ======================================================================================
# STEP 7: GA FITNESS (ROBUST)
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 7: GA FITNESS (ROBUST)")
print("=" * 92)

@torch.no_grad()
def enhanced_separability_score(emb, y):
    eps = 1e-6
    y = y.long()
    classes = torch.unique(y)
    if len(classes) < 2:
        return 0.0

    centroids = []
    within_vars = []
    sizes = []

    for c in classes:
        mask = y == c
        e = emb[mask]
        if e.size(0) < 2:
            continue
        mu = e.mean(dim=0)
        var = (e - mu).pow(2).sum(dim=1).mean().item()
        centroids.append(mu)
        within_vars.append(var)
        sizes.append(e.size(0))

    if len(centroids) < 2:
        return 0.0

    centroids = torch.stack(centroids, dim=0)
    global_mean = centroids.mean(dim=0)
    between = sum(n * (c - global_mean).pow(2).sum().item() for c, n in zip(centroids, sizes))
    within = float(np.mean(within_vars)) if within_vars else eps
    return float(between / (within + eps))

@torch.no_grad()
def ga_fitness(theta, backbone_frozen, batch_x, batch_y, use_separability=True):
    pre = theta_to_module(theta).to(DEVICE)
    x = batch_x.to(DEVICE)
    y = batch_y.to(DEVICE)

    x_p = pre(x)
    gray = x_p.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), pre.lap).abs()
    contrast = float(lap.mean().item())
    dyn_range = float((x_p.max() - x_p.min()).item())

    x_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

    sep = 0.0
    if use_separability:
        emb = backbone_frozen(x_n)
        if isinstance(emb, (list, tuple)):
            emb = emb[-1]
            emb = emb.mean(dim=(2, 3))
        sep = enhanced_separability_score(emb, y)

    g, a, b, t, k, sh, dn = theta
    cost = (0.03 * abs(g - 1.0) + 0.05 * a + 0.01 * (b / 10.0) +
            0.02 * abs(t - 2.5) + 0.02 * sh + 0.02 * dn)

    if use_separability:
        return 0.35 * contrast + 0.15 * dyn_range + 1.35 * sep - 0.5 * cost
    return 0.60 * contrast + 0.35 * dyn_range - 0.5 * cost

def _safe_first_batch(dl):
    try:
        it = iter(dl)
        bx, by, *_ = next(it)
        return bx, by
    except Exception:
        return None, None

def run_ga_for_client(backbone_frozen, dl_for_eval, elite_pool, use_separability=True):
    bx, by = _safe_first_batch(dl_for_eval)
    if bx is None:
        return None, [], 0.0

    pop = []
    if elite_pool:
        pop.extend(elite_pool[: min(len(elite_pool), CFG["ga_pop"] // 2)])
    while len(pop) < CFG["ga_pop"]:
        pop.append(random_theta())

    bx = bx[: CFG["batch_size"]].contiguous()
    by = by[: CFG["batch_size"]].contiguous()

    for _ in range(CFG["ga_gens"]):
        scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
        scored.sort(key=lambda x: x[0], reverse=True)
        elites = [th for _, th in scored[: CFG["ga_elites"]]]

        new_pop = elites[:]
        while len(new_pop) < CFG["ga_pop"]:
            p1, p2 = random.sample(elites + pop[: max(2, CFG["ga_pop"] // 2)], 2)
            child = crossover(p1, p2)
            child = mutate(child, p=0.75)
            new_pop.append(child)
        pop = new_pop

    scored = [(ga_fitness(th, backbone_frozen, bx, by, use_separability), th) for th in pop]
    scored.sort(key=lambda x: x[0], reverse=True)
    best_theta = scored[0][1]
    best_fit = float(scored[0][0])
    top = [th for _, th in scored[: CFG["ga_elites"]]]
    return best_theta, top, best_fit

# ======================================================================================
# STEP 8: TRAIN / EVAL UTILITIES
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 8: TRAIN / EVAL UTILITIES (FULL METRICS)")
print("=" * 92)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(step):
        if step < num_warmup_steps:
            return float(step) / float(max(1, num_warmup_steps))
        progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def preproc_theta_vec(preproc_module, batch_size):
    if hasattr(preproc_module, "gamma"):
        theta = torch.tensor(
            [
                preproc_module.gamma,
                preproc_module.alpha,
                preproc_module.beta,
                preproc_module.tau,
                float(preproc_module.blur_k) / 7.0,
                preproc_module.sharpen,
                preproc_module.denoise,
            ],
            device=DEVICE,
            dtype=torch.float32,
        )
    else:
        theta = torch.zeros(7, device=DEVICE, dtype=torch.float32)
    return theta.unsqueeze(0).repeat(batch_size, 1)

def gate_entropy(gate):
    eps = 1e-6
    p = gate.clamp(eps, 1 - eps)
    ent = -(p * torch.log2(p) + (1 - p) * torch.log2(1 - p))
    return ent

def summarize_gate_stats(gate_stats, num_classes):
    gate_metrics = {}
    all_gates = {"g0": [], "g1": [], "g2": []}
    all_labels = []
    for gates, y_cpu in gate_stats:
        for k in all_gates:
            all_gates[k].append(gates[k].detach().cpu())
        all_labels.append(y_cpu)

    labels_cat = torch.cat(all_labels, dim=0)
    for k in all_gates:
        g = torch.cat(all_gates[k], dim=0)
        ent = gate_entropy(g).mean(dim=list(range(1, g.ndim)))
        gate_metrics[f"{k}_mean"] = float(g.mean().item())
        gate_metrics[f"{k}_entropy_mean"] = float(ent.mean().item())

        for c in range(num_classes):
            mask = labels_cat == c
            if mask.any():
                gate_metrics[f"{k}_mean_c{c}"] = float(g[mask].mean().item())
                gate_metrics[f"{k}_entropy_c{c}"] = float(ent[mask].mean().item())
    return gate_metrics

@torch.no_grad()
def _auc_metrics(y_true, p_pred, num_classes):
    out = {}
    try:
        if num_classes == 2:
            out["auc_roc"] = float(roc_auc_score(y_true, p_pred[:, 1]))
        else:
            out["auc_roc_macro_ovr"] = float(roc_auc_score(y_true, p_pred, multi_class="ovr", average="macro"))
            for c in range(num_classes):
                yc = (y_true == c).astype(int)
                if yc.sum() > 0 and yc.sum() < len(yc):
                    out[f"auc_class_{c}"] = float(roc_auc_score(yc, p_pred[:, c]))
    except Exception:
        pass
    return out

@torch.no_grad()
def evaluate_full(model, loader, preproc_module, return_gates=False):
    t0 = time.time()
    model.eval()
    preproc_module.eval()

    all_y, all_p, all_loss = [], [], []
    gate_stats = []
    has_any = False

    for x, y, _, source_id, client_id in loader:
        has_any = True
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        x_p = preproc_module(x)
        x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
        x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD

        theta_vec = preproc_theta_vec(preproc_module, x.size(0))
        if return_gates:
            logits, gates = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id, return_gates=True)
            gate_stats.append((gates, y.detach().cpu()))
        else:
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)

        probs = torch.softmax(logits, dim=1)
        loss = F.cross_entropy(logits, y)

        all_loss.append(float(loss.item()))
        all_y.append(y.detach().cpu().numpy())
        all_p.append(probs.detach().cpu().numpy())

    if not has_any:
        met = {
            "loss_ce": np.nan,
            "acc": np.nan,
            "precision_macro": np.nan,
            "recall_macro": np.nan,
            "f1_macro": np.nan,
            "precision_weighted": np.nan,
            "recall_weighted": np.nan,
            "f1_weighted": np.nan,
            "log_loss": np.nan,
            "eval_time_s": float(time.time() - t0),
        }
        return met, np.array([]), np.array([])

    y_true = np.concatenate(all_y)
    p_pred = np.concatenate(all_p)
    y_hat = np.argmax(p_pred, axis=1)

    met = {
        "loss_ce": float(np.mean(all_loss)),
        "acc": float(accuracy_score(y_true, y_hat)),
        "precision_macro": float(precision_score(y_true, y_hat, average="macro", zero_division=0)),
        "recall_macro": float(recall_score(y_true, y_hat, average="macro", zero_division=0)),
        "f1_macro": float(f1_score(y_true, y_hat, average="macro", zero_division=0)),
        "precision_weighted": float(precision_score(y_true, y_hat, average="weighted", zero_division=0)),
        "recall_weighted": float(recall_score(y_true, y_hat, average="weighted", zero_division=0)),
        "f1_weighted": float(f1_score(y_true, y_hat, average="weighted", zero_division=0)),
        "log_loss": float(log_loss(y_true, p_pred, labels=list(range(NUM_CLASSES)))),
        "eval_time_s": float(time.time() - t0),
    }
    met.update(_auc_metrics(y_true, p_pred, NUM_CLASSES))
    if return_gates and gate_stats:
        met.update(summarize_gate_stats(gate_stats, NUM_CLASSES))
    return met, y_true, p_pred

def fedprox_term(local_model, global_model):
    loss = 0.0
    for p_local, p_global in zip(local_model.parameters(), global_model.parameters()):
        loss += ((p_local - p_global.detach()) ** 2).sum()
    return loss

def train_one_epoch(model, loader, optimizer, preproc_module, criterion, global_model=None, scheduler=None, scaler=None,
                    grad_clip=1.0):
    model.train()
    preproc_module.eval()
    losses, correct, total = [], 0, 0
    t0 = time.time()

    for x, y, _, source_id, client_id in loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)
        source_id = source_id.to(DEVICE, non_blocking=True)
        client_id = client_id.to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type=DEVICE.type, enabled=(scaler is not None)):
            x_p = preproc_module(x)
            x_raw_n = (x - IMAGENET_MEAN) / IMAGENET_STD
            x_fel_n = (x_p - IMAGENET_MEAN) / IMAGENET_STD
            theta_vec = preproc_theta_vec(preproc_module, x.size(0))
            logits = model(x_raw_n, x_fel_n, theta_vec, source_id, client_id)
            loss = criterion(logits, y)
            if global_model is not None and CFG["fedprox_mu"] > 0:
                loss = loss + 0.5 * CFG["fedprox_mu"] * fedprox_term(model, global_model)

        optimizer.zero_grad(set_to_none=True)
        if scaler is not None:
            scaler.scale(loss).backward()
            if grad_clip and grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if grad_clip and grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        losses.append(float(loss.item()))
        preds = logits.argmax(dim=1)
        correct += int((preds == y).sum().item())
        total += int(y.size(0))

    return float(np.mean(losses)), float(correct / max(1, total)), float(time.time() - t0)

def fedavg_update(global_model, local_models, weights, trainable_names):
    gsd = global_model.state_dict()
    new_sd = {}
    for name in trainable_names:
        acc = None
        for m, w in zip(local_models, weights):
            p = m.state_dict()[name].detach().float().cpu()
            acc = (w * p) if acc is None else (acc + w * p)
        new_sd[name] = acc
    for name, t in new_sd.items():
        gsd[name].copy_(t.to(gsd[name].device).type_as(gsd[name]))
    global_model.load_state_dict(gsd)

@torch.no_grad()
def pick_best_theta_from_pool(model, pool, val_loader, max_candidates=10):
    if not pool:
        return None, None
    cand = pool[:max_candidates]
    best, best_acc = None, -1
    for th in cand:
        pre = theta_to_module(th).to(DEVICE)
        met, _, _ = evaluate_full(model, val_loader, pre)
        if np.isfinite(met["acc"]) and met["acc"] > best_acc:
            best_acc = met["acc"]
            best = th
    return best, best_acc

# ======================================================================================
# STEP 9: INITIALIZE GLOBAL MODEL
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 9: INITIALIZING GLOBAL MODEL")
print("=" * 92)

global_model = PVTv2B2_MultiScale(
    num_classes=NUM_CLASSES,
    pretrained=True,
    head_dropout=CFG["head_dropout"],
    cond_dim=CFG["cond_dim"],
    num_clients=CFG["clients_total"],
    num_sources=len(DATASET_NAMES),
).to(DEVICE)

set_trainable_for_round(global_model, rnd=1)

total_params, tuned_params = count_params(global_model)
print(f"Backbone: {BACKBONE_NAME} | Total params: {total_params:,} | Trainable: {tuned_params:,} ({(tuned_params/total_params)*100:.2f}%)")

backbone_frozen = global_model.backbone.eval()
for p in backbone_frozen.parameters():
    p.requires_grad = False

counts = np.zeros(NUM_CLASSES, dtype=np.int64)
for ds_name in DATASET_NAMES:
    counts += train_frames[ds_name]["y"].value_counts().sort_index().reindex(range(NUM_CLASSES), fill_value=0).values
w = (counts.sum() / np.clip(counts, 1, None)).astype(np.float32)
w = w / max(1e-6, w.mean())
class_w = torch.tensor(w, device=DEVICE)

criterion = nn.CrossEntropyLoss(weight=class_w, label_smoothing=CFG["label_smoothing"])
scaler = torch.amp.GradScaler("cuda") if DEVICE.type == "cuda" else None

hp_rows = [{"hp_name": k, "hp_value": str(v)} for k, v in CFG.items()]
hp_rows += [
    {"hp_name": "GA_theta_ranges",
     "hp_value": "gamma∈[0.7,1.4], alpha∈[0.15,0.55], beta∈[3,9], tau∈[1.8,3.2], k∈{3,5,7}, sh∈[0,0.25], dn∈[0,0.2]"},
    {"hp_name": "theta_fullforms", "hp_value": str(THETA_FULLFORMS)},
    {"hp_name": "backbone_name", "hp_value": BACKBONE_NAME},
    {"hp_name": "norm", "hp_value": "ImageNet mean/std"},
]
hp_df = pd.DataFrame(hp_rows)
print_table(hp_df, "Hyperparameters / Search Space")
add_table_to_csv(hp_df, "hyperparameters")

# ======================================================================================
# STEP 10: FEDERATED TRAINING
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 10: FEDERATED TRAINING (NO CENTRAL VAL/TEST)")
print("=" * 92)

elite_pools = {ds: [] for ds in DATASET_NAMES}
history_global = []
history_local = []

best_global_acc = -1.0
best_model_state = None
best_thetas = {ds: None for ds in DATASET_NAMES}
best_round_saved = None

t_global_start = time.time()

for rnd in range(1, CFG["rounds"] + 1):
    round_t0 = time.time()
    local_models = []
    local_weights = []
    local_rows = []

    print(f"\n{'=' * 92}\nROUND {rnd}/{CFG['rounds']}\n{'=' * 92}")

    for k in range(CFG["clients_total"]):
        tr_loader, tune_loader, val_loader = client_loaders[k]
        ds_name = gid_to_ds[k]
        elite_pool = elite_pools[ds_name]

        ga_t0 = time.time()
        if CFG["use_preprocessing"] and CFG["use_ga"]:
            best_theta, top_thetas, best_fit = run_ga_for_client(
                backbone_frozen, tune_loader, elite_pool, use_separability=True
            )
            elite_pool.extend(top_thetas)
            elite_pool[:] = elite_pool[: CFG["elite_pool_max"]]
            pre_k = theta_to_module(best_theta).to(DEVICE) if best_theta is not None else IDENTITY_PRE
        else:
            best_theta, best_fit = None, 0.0
            pre_k = IDENTITY_PRE
        ga_time = float(time.time() - ga_t0)

        elite_pools[ds_name] = elite_pool

        local_model = PVTv2B2_MultiScale(
            num_classes=NUM_CLASSES,
            pretrained=False,
            head_dropout=CFG["head_dropout"],
            cond_dim=CFG["cond_dim"],
            num_clients=CFG["clients_total"],
            num_sources=len(DATASET_NAMES),
        ).to(DEVICE)
        local_model.load_state_dict(global_model.state_dict(), strict=True)

        set_trainable_for_round(local_model, rnd=rnd)
        opt = make_optimizer(local_model)

        total_steps = max(1, len(tr_loader) * CFG["local_epochs"])
        warmup_steps = max(1, len(tr_loader) * CFG["warmup_epochs"])
        scheduler = get_cosine_schedule_with_warmup(opt, warmup_steps, total_steps)

        tr_losses, tr_accs, tr_time = [], [], 0.0
        for _ in range(CFG["local_epochs"]):
            loss_ep, acc_ep, t_ep = train_one_epoch(
                local_model,
                tr_loader,
                opt,
                pre_k,
                criterion,
                global_model=global_model,
                scheduler=scheduler,
                scaler=scaler,
                grad_clip=CFG["grad_clip"],
            )
            tr_losses.append(loss_ep)
            tr_accs.append(acc_ep)
            tr_time += t_ep

        met_loc, _, _ = evaluate_full(local_model, val_loader, pre_k, return_gates=True)

        local_models.append(local_model)
        local_weights.append(len(tr_loader.dataset))

        if best_theta is not None:
            g, a, b, t, kk, sh, dn = best_theta
        else:
            g = a = b = t = kk = sh = dn = None

        row = {
            "round": rnd,
            "client": f"client_{k}",
            "dataset": ds_name,
            "ga_best_fit_score": float(best_fit),
            "ga_time_s": ga_time,
            "theta_str": theta_str(best_theta),
            "gamma_power": g,
            "alpha_contrast_weight": a,
            "beta_contrast_sharpness": b,
            "tau_clip": t,
            "k_blur_kernel_size": kk,
            "sh_sharpen_strength": sh,
            "dn_denoise_strength": dn,
            "train_loss": float(np.mean(tr_losses)),
            "train_acc": float(np.mean(tr_accs)),
            "train_time_s": float(tr_time),
            **{f"val_{k2}": v2 for k2, v2 in met_loc.items()},
        }
        local_rows.append(row)

        auc_val = row.get("val_auc_roc_macro_ovr", row.get("val_auc_roc", np.nan))
        print(
            f"Client {k} ({ds_name}) | train_acc={row['train_acc']:.4f} | "
            f"val_acc={row['val_acc']:.4f} | val_f1={row['val_f1_macro']:.4f} | "
            f"val_auc={auc_val:.4f} | val_logloss={row['val_log_loss']:.4f} | "
            f"GA_fit={row['ga_best_fit_score']:.3f} | ga_time={row['ga_time_s']:.1f}s | theta={row['theta_str']}"
        )

    wsum = sum(local_weights)
    weights = [w / wsum for w in local_weights]
    trainable_names = [n for n, p in local_models[0].named_parameters() if p.requires_grad]
    fedavg_update(global_model, local_models, weights, trainable_names)

    local_val_rows = pd.DataFrame(local_rows)
    local_val_rows["val_size"] = [len(client_loaders[i][2].dataset) for i in range(CFG["clients_total"])]
    total_val = local_val_rows["val_size"].sum()

    def weighted_avg(key):
        if total_val == 0:
            return np.nan
        return float(np.average(local_val_rows[key], weights=local_val_rows["val_size"]))

    global_metrics = {
        "acc": weighted_avg("val_acc"),
        "f1_macro": weighted_avg("val_f1_macro"),
        "precision_macro": weighted_avg("val_precision_macro"),
        "recall_macro": weighted_avg("val_recall_macro"),
        "log_loss": weighted_avg("val_log_loss"),
        "loss_ce": weighted_avg("val_loss_ce"),
        "eval_time_s": weighted_avg("val_eval_time_s"),
    }

    if CFG["use_preprocessing"]:
        for ds_name in DATASET_NAMES:
            if elite_pools[ds_name]:
                rep_gid = DATASET_NAMES.index(ds_name) * n_per_ds
                best_thetas[ds_name], _ = pick_best_theta_from_pool(
                    global_model, elite_pools[ds_name], client_loaders[rep_gid][2]
                )

    history_local.extend(local_rows)
    history_global.append({
        "round": rnd,
        "round_time_s": float(time.time() - round_t0),
        "global_thetas": str({ds: theta_str(best_thetas[ds]) for ds in DATASET_NAMES}),
        **{f"global_{k2}": v2 for k2, v2 in global_metrics.items()},
    })

    if np.isfinite(global_metrics["acc"]) and global_metrics["acc"] > best_global_acc:
        best_global_acc = float(global_metrics["acc"])
        best_model_state = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}
        best_round_saved = rnd

    print("\n" + "-" * 92)
    print(
        f"GLOBAL VAL (Round {rnd}) | acc={global_metrics['acc']:.4f} | f1={global_metrics['f1_macro']:.4f} | "
        f"logloss={global_metrics['log_loss']:.4f} | loss_ce={global_metrics['loss_ce']:.4f} | "
        f"round_time={history_global[-1]['round_time_s']:.1f}s | "
        f"thetas={str({ds: theta_str(best_thetas[ds]) for ds in DATASET_NAMES})}"
    )
    print(f"BEST SO FAR (by ACC) | best_val_acc={best_global_acc:.4f} at round={best_round_saved}")
    print("-" * 92)

if best_model_state is not None:
    global_model.load_state_dict({k: v.to(DEVICE) for k, v in best_model_state.items()})

t_total = float(time.time() - t_global_start)
print("\n" + "=" * 92)
print(f"TRAINING COMPLETE ✅ | total_time={t_total:.1f}s | best_val_acc={best_global_acc:.4f} | best_round={best_round_saved}")
print("=" * 92)

glob_df = pd.DataFrame(history_global)
loc_df  = pd.DataFrame(history_local)
print_table(glob_df, "GLOBAL per-round metrics")
print_table(loc_df.head(30), "LOCAL per-client per-round metrics (head)")
add_table_to_csv(glob_df, "global_round_metrics_full")
add_table_to_csv(loc_df, "client_round_metrics_full")

# ======================================================================================
# STEP 11: FINAL EVALUATION (FEDERATED VAL + TEST)
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 11: FINAL EVALUATION (FEDERATED VAL + TEST)")
print("=" * 92)

pre_by_ds = {
    ds_name: (theta_to_module(best_thetas[ds_name]).to(DEVICE)
              if (CFG["use_preprocessing"] and best_thetas[ds_name] is not None)
              else IDENTITY_PRE)
    for ds_name in DATASET_NAMES
}

def weighted_aggregate(mets):
    if not mets:
        return {}
    total = sum(w for _, _, w in mets)
    if total == 0:
        return {}
    keys = mets[0][1].keys()
    out = {}
    for k in keys:
        vals = [m[1].get(k, np.nan) for m in mets]
        weights = [m[2] for m in mets]
        out[k] = float(np.average(vals, weights=weights))
    return out

val_metrics_clients = []
for k in range(CFG["clients_total"]):
    _, _, val_loader = client_loaders[k]
    ds_name = gid_to_ds[k]
    met, _, _ = evaluate_full(global_model, val_loader, pre_by_ds[ds_name])
    val_metrics_clients.append((k, met, len(val_loader.dataset)))
val_best = weighted_aggregate(val_metrics_clients)

def eval_test_per_dataset(ds_name):
    mets = []
    for (ds, local_id, gid, t_loader) in client_test_loaders:
        if ds != ds_name:
            continue
        met, y_true, p_pred = evaluate_full(global_model, t_loader, pre_by_ds[ds], return_gates=True)
        mets.append((met, len(t_loader.dataset), y_true, p_pred, gid))
    if not mets:
        return {}, []
    agg = weighted_aggregate([(i, m[0], m[1]) for i, m in enumerate(mets)])
    return agg, mets

test_by_ds = {}
test_detail_by_ds = {}
for ds_name in DATASET_NAMES:
    tm, td = eval_test_per_dataset(ds_name)
    test_by_ds[ds_name] = tm
    test_detail_by_ds[ds_name] = td

ds1_mets = test_detail_by_ds["ds1"]

global_test = weighted_aggregate([
    (i, test_by_ds[ds_name], len(test_frames[ds_name])) for i, ds_name in enumerate(DATASET_NAMES)
])

def compact_metrics(m):
    keep = [
        "acc", "precision_macro", "recall_macro", "f1_macro",
        "precision_weighted", "recall_weighted", "f1_weighted",
        "log_loss"
    ]
    if "auc_roc_macro_ovr" in m:
        keep.append("auc_roc_macro_ovr")
    if "auc_roc" in m:
        keep.append("auc_roc")
    keep += ["loss_ce", "eval_time_s"]
    return {k: float(m[k]) for k in keep if k in m}

paper_rows = [
    {"setting": "Enhanced FELCM (Best θ per dataset)", "split": "VAL", "dataset": "all datasets weighted", **compact_metrics(val_best)}
]
for ds_name in DATASET_NAMES:
    paper_rows.append({
        "setting": f"Enhanced FELCM (Best θ {ds_name})",
        "split": "TEST",
        "dataset": ds_name,
        **compact_metrics(test_by_ds[ds_name])
    })
paper_rows.append({"setting": "Enhanced FELCM (Best θ)", "split": "TEST", "dataset": "global weighted", **compact_metrics(global_test)})

paper_df = pd.DataFrame(paper_rows)
print_table(paper_df, "VAL+TEST tables (federated, per-dataset + global)")
add_table_to_csv(paper_df, "paper_ready_metrics")

def pick_auc(m):
    if "auc_roc_macro_ovr" in m:
        return float(m["auc_roc_macro_ovr"])
    if "auc_roc" in m:
        return float(m["auc_roc"])
    return np.nan

explicit_rows = []
explicit_rows.append({
    "dataset": "all_datasets_val_weighted",
    "acc": float(val_best.get("acc", np.nan)),
    "pre": float(val_best.get("precision_macro", np.nan)),
    "rec": float(val_best.get("recall_macro", np.nan)),
    "f1": float(val_best.get("f1_macro", np.nan)),
    "logloss": float(val_best.get("log_loss", np.nan)),
    "auc_roc": pick_auc(val_best),
})
for ds_name in DATASET_NAMES:
    met = test_by_ds.get(ds_name, {})
    explicit_rows.append({
        "dataset": f"{ds_name}_test",
        "acc": float(met.get("acc", np.nan)),
        "pre": float(met.get("precision_macro", np.nan)),
        "rec": float(met.get("recall_macro", np.nan)),
        "f1": float(met.get("f1_macro", np.nan)),
        "logloss": float(met.get("log_loss", np.nan)),
        "auc_roc": pick_auc(met),
    })
explicit_rows.append({
    "dataset": "global_test_weighted",
    "acc": float(global_test.get("acc", np.nan)),
    "pre": float(global_test.get("precision_macro", np.nan)),
    "rec": float(global_test.get("recall_macro", np.nan)),
    "f1": float(global_test.get("f1_macro", np.nan)),
    "logloss": float(global_test.get("log_loss", np.nan)),
    "auc_roc": pick_auc(global_test),
})

explicit_metrics_df = pd.DataFrame(explicit_rows)
print_table(explicit_metrics_df, "Requested explicit metrics (acc, pre, rec, f1, logloss, auc_roc)")
add_table_to_csv(explicit_metrics_df, "requested_explicit_metrics")

print("\nPaper selection summary:")
print(f"- Best round (by federated VAL accuracy): round={best_round_saved} | best_val_acc={best_global_acc:.4f}")
for ds_name in DATASET_NAMES:
    print(f"- Best θ {ds_name}: {theta_str(best_thetas[ds_name])}")

# ======================================================================================
# STEP 12: PREPROCESSING VALIDATION (DS1 VAL SAMPLE)
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 12: PREPROCESSING VALIDATION (DS1 VAL SAMPLE)")
print("=" * 92)

@torch.no_grad()
def entropy_per_image(x01):
    gray = x01.mean(dim=1)
    B = gray.shape[0]
    ent = []
    for i in range(B):
        g = (gray[i].detach().cpu().numpy() * 255).astype(np.uint8)
        hist = np.bincount(g.flatten(), minlength=256).astype(np.float32)
        p = hist / np.clip(hist.sum(), 1, None)
        p = p[p > 0]
        ent.append(float(-(p * np.log2(p)).sum()))
    return np.array(ent)

@torch.no_grad()
def edge_energy(x01, lap_kernel):
    lap_kernel = lap_kernel.to(device=x01.device, dtype=x01.dtype)
    gray = x01.mean(dim=1, keepdim=True)
    lap = F.conv2d(F.pad(gray, (1, 1, 1, 1), mode="reflect"), lap_kernel).abs()
    return lap.mean(dim=(1, 2, 3)).detach().cpu().numpy()

@torch.no_grad()
def contrast_proxy(x01):
    gray = x01.mean(dim=1)
    return gray.std(dim=(1, 2)).detach().cpu().numpy()

@torch.no_grad()
def run_preproc_validation(frame, preproc, sample_n=500):
    n = min(sample_n, len(frame))
    if n <= 0:
        return pd.DataFrame(), pd.DataFrame(), None, None

    idx = np.random.choice(len(frame), size=n, replace=False)
    ds = MRIDataset(frame, indices=idx.tolist(), tfms=EVAL_TFMS, source_id=0, client_id=0)

    xs = []
    for i in range(len(ds)):
        x, _, *_ = ds[i]
        xs.append(x)
    x = torch.stack(xs).to(DEVICE)

    x_after = preproc(x).clamp(0, 1)
    lap_kernel = preproc.lap if hasattr(preproc, "lap") else EnhancedFELCM().to(DEVICE).lap

    ee_before = edge_energy(x, lap_kernel)
    ee_after  = edge_energy(x_after, lap_kernel)
    ent_before = entropy_per_image(x)
    ent_after  = entropy_per_image(x_after)
    con_before = contrast_proxy(x)
    con_after  = contrast_proxy(x_after)

    dfm = pd.DataFrame({
        "edge_energy_before": ee_before,
        "edge_energy_after": ee_after,
        "entropy_before": ent_before,
        "entropy_after": ent_after,
        "contrast_before": con_before,
        "contrast_after": con_after,
        "edge_gain_ratio": (ee_after / np.clip(ee_before, 1e-9, None)),
        "entropy_delta": (ent_after - ent_before),
        "contrast_delta": (con_after - con_before),
    })
    summary = dfm.agg(["mean", "std", "min", "max"]).T.reset_index().rename(columns={"index": "metric"})
    return dfm, summary, x, x_after

preproc_summary_df = pd.DataFrame()
if CFG["use_preprocessing"]:
    preproc_df, preproc_summary_df, _, _ = run_preproc_validation(val_frames["ds1"], pre_by_ds["ds1"], CFG["preproc_val_sample_n"])
    print_table(preproc_summary_df, "Preprocessing validation summary (DS1 VAL sample)")
    add_table_to_csv(preproc_summary_df, "preprocessing_validation_summary_ds1")

# ======================================================================================
# STEP 13: CURVES + CONFUSION + CALIBRATION TABLES (NO PLOTS)
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 13: COMPUTE CURVES + CONFUSION + CALIBRATION TABLES (NO PLOTS)")
print("=" * 92)

def multiclass_calibration_curve(y_true, p_pred, n_bins=12):
    conf = np.max(p_pred, axis=1)
    pred = np.argmax(p_pred, axis=1)
    acc = (pred == y_true).astype(np.float32)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins) - 1
    bin_ids = np.clip(bin_ids, 0, n_bins - 1)

    bin_conf, bin_acc, bin_count = [], [], []
    for b in range(n_bins):
        m = bin_ids == b
        if m.sum() == 0:
            bin_conf.append(np.nan)
            bin_acc.append(np.nan)
            bin_count.append(0)
        else:
            bin_conf.append(conf[m].mean())
            bin_acc.append(acc[m].mean())
            bin_count.append(int(m.sum()))
    return np.array(bin_conf), np.array(bin_acc), np.array(bin_count)

if len(ds1_mets) > 0:
    met0, n0, y_true_test, p_test_best, gid0 = ds1_mets[0]
    if len(y_true_test) > 0:
        roc_rows, pr_rows = [], []
        for c in range(NUM_CLASSES):
            yc = (y_true_test == c).astype(int)
            if yc.sum() == 0 or yc.sum() == len(yc):
                continue
            fpr, tpr, thr_roc = roc_curve(yc, p_test_best[:, c])
            for i in range(len(fpr)):
                roc_rows.append({"class": labels[c], "fpr": float(fpr[i]), "tpr": float(tpr[i]), "threshold": float(thr_roc[i])})

            prec, rec, thr_pr = precision_recall_curve(yc, p_test_best[:, c])
            for i in range(len(prec)):
                thr_val = float(thr_pr[i]) if i < len(thr_pr) else np.nan
                pr_rows.append({"class": labels[c], "precision": float(prec[i]), "recall": float(rec[i]), "threshold": thr_val})

        roc_df = pd.DataFrame(roc_rows)
        pr_df = pd.DataFrame(pr_rows)
        print_table(roc_df.head(20), "ROC curve points (DS1 TEST, first 20 rows)")
        print_table(pr_df.head(20), "PR curve points (DS1 TEST, first 20 rows)")
        add_table_to_csv(roc_df, "roc_curve_points_ds1")
        add_table_to_csv(pr_df, "pr_curve_points_ds1")

        y_hat_test = np.argmax(p_test_best, axis=1)
        cm_counts = confusion_matrix(y_true_test, y_hat_test, labels=list(range(NUM_CLASSES)))
        cm_norm = cm_counts / np.clip(cm_counts.sum(axis=1, keepdims=True), 1, None)

        cm_counts_df = pd.DataFrame(cm_counts, index=labels, columns=labels).reset_index().rename(columns={"index": "true"})
        cm_norm_df   = pd.DataFrame(cm_norm,   index=labels, columns=labels).reset_index().rename(columns={"index": "true"})
        print_table(cm_counts_df, "Confusion matrix counts (DS1 TEST)")
        print_table(cm_norm_df, "Confusion matrix row-normalized (DS1 TEST)")
        add_table_to_csv(cm_counts_df, "confusion_counts_ds1")
        add_table_to_csv(cm_norm_df, "confusion_norm_ds1")

        bin_conf, bin_acc, bin_n = multiclass_calibration_curve(y_true_test, p_test_best, n_bins=12)
        cal_df = pd.DataFrame({"bin_confidence": bin_conf, "bin_accuracy": bin_acc, "bin_count": bin_n})
        print_table(cal_df, "Calibration bins table (DS1)")
        add_table_to_csv(cal_df, "calibration_bins_ds1")

# ======================================================================================
# STEP 14: THETA EVOLUTION TABLE
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 14: THETA EVOLUTION TABLE")
print("=" * 92)

theta_cols = [
    "gamma_power",
    "alpha_contrast_weight",
    "beta_contrast_sharpness",
    "tau_clip",
    "k_blur_kernel_size",
    "sh_sharpen_strength",
    "dn_denoise_strength",
]
for c in theta_cols:
    if c in loc_df.columns:
        loc_df[c] = pd.to_numeric(loc_df[c], errors="coerce")

theta_evo = loc_df.groupby("round")[theta_cols].mean(numeric_only=True).reset_index()
print_table(theta_evo, "Mean best-θ parameters over rounds (clients averaged)")
add_table_to_csv(theta_evo, "theta_evolution_mean")

# ======================================================================================
# STEP 15: SAVE CHECKPOINT + ONE CSV
# ======================================================================================
print("\n" + "=" * 92)
print("STEP 15: SAVING ONLY TWO FILES (CHECKPOINT + ONE CSV)")
print("=" * 92)

checkpoint = {
    "state_dict": {k: v.detach().cpu() for k, v in global_model.state_dict().items()},
    "config": CFG,
    "seed": SEED,
    "device_used": str(DEVICE),
    "dataset_roots": DATASET_BASES,
    "labels": labels,
    "label2id": label2id,
    "id2label": id2label,
    "num_classes": NUM_CLASSES,
    "backbone_name": BACKBONE_NAME,
    "best_round_saved": best_round_saved,
    "best_val_acc": best_global_acc,
    "best_thetas": best_thetas,
    "best_theta_str_by_dataset": {ds: theta_str(best_thetas[ds]) for ds in DATASET_NAMES},
    "theta_fullforms": THETA_FULLFORMS,
    "client_splits": client_splits,
    "client_test_splits": client_test_splits,
    "history_global": glob_df.to_dict(orient="list"),
    "history_local": loc_df.to_dict(orient="list"),
    "final_val_federated": val_best,
    "final_test_by_dataset": test_by_ds,
    "final_test_global_weighted": global_test,
    "preprocessing_validation_summary_ds1": preproc_summary_df.to_dict(orient="list") if len(preproc_summary_df) else {},
    "total_training_time_s": t_total,
}

torch.save(checkpoint, MODEL_PATH)
print(f"✅ Saved checkpoint: {MODEL_PATH}")

all_df = pd.DataFrame(ALL_ROWS)
all_df.to_csv(CSV_PATH, index=False)
print(f"✅ Saved CSV (ALL outputs): {CSV_PATH}")

print("\nDONE ✅ (KAGGLE, TRUE FL SIMULATION, 12 clients (3x4 datasets), rounds=12, Preprocessing+GA, Augmentation, Fusion, PVTv2-B2, no plots)")


KAGGLE: TRUE FL + GA-FELCM + PVTv2-B2 (FUSION) — 12 Clients (3x4 datasets) | AUG=ON
DEVICE: cuda | torch=2.8.0+cu126

STEP 0: DATASETS (DS1/2/3 via kagglehub) + DS4 from Kaggle input path
✅ DS1: /kaggle/input/datasets/alamshihab075/brain-tumor-mri-dataset-for-deep-learning
✅ DS2: /kaggle/input/datasets/zehrakucuker/brain-tumor-mri-images-classification-dataset
✅ DS3: /kaggle/input/datasets/chubskuy/brain-tumor-image
✅ DS4: /kaggle/input/datasets/mdzubayerahmadshibly/ds4mine

STEP 1: DISCOVER + MERGE DATASET IMAGES BY CLASS
ds1: total images = 9257 | glioma:3293, meningioma:3593, notumor:811, pituitary:1560
ds2: total images = 11615 | glioma:3768, meningioma:3806, notumor:0, pituitary:4041
ds3: total images = 7023 | glioma:1621, meningioma:1645, notumor:2000, pituitary:1757
ds4: total images = 12064 | glioma:3773, meningioma:2729, notumor:2432, pituitary:3130

STEP 2: TRAIN/VAL/TEST SPLIT (PER DATASET)
DS1 TRAIN=6479 | VAL=1389 | TEST=1389
DS2 TRAIN=8130 | VAL=1742 | TEST=1743
DS3 TRAIN

Unnamed: 0,path_overlap_train_val,path_overlap_train_test,path_overlap_val_test,unique_paths_train,unique_paths_val,unique_paths_test,filename_overlap_train_val,filename_overlap_train_test,filename_overlap_val_test,subset_hash_train_val,subset_hash_train_test,subset_hash_val_test,subset_hash_n_train,subset_hash_n_val,subset_hash_n_test
0,0,0,0,6479,1389,1389,0,0,0,5,4,5,300,300,297



--------------------------------------------------------------------------------------------
Leakage / Sanity Summary — ds2
--------------------------------------------------------------------------------------------


Unnamed: 0,path_overlap_train_val,path_overlap_train_test,path_overlap_val_test,unique_paths_train,unique_paths_val,unique_paths_test,filename_overlap_train_val,filename_overlap_train_test,filename_overlap_val_test,subset_hash_train_val,subset_hash_train_test,subset_hash_val_test,subset_hash_n_train,subset_hash_n_val,subset_hash_n_test
0,0,0,0,8130,1742,1743,589,621,203,2,2,1,299,299,298



--------------------------------------------------------------------------------------------
Leakage / Sanity Summary — ds3
--------------------------------------------------------------------------------------------


Unnamed: 0,path_overlap_train_val,path_overlap_train_test,path_overlap_val_test,unique_paths_train,unique_paths_val,unique_paths_test,filename_overlap_train_val,filename_overlap_train_test,filename_overlap_val_test,subset_hash_train_val,subset_hash_train_test,subset_hash_val_test,subset_hash_n_train,subset_hash_n_val,subset_hash_n_test
0,0,0,0,4916,1053,1054,0,0,0,3,3,2,298,300,296



--------------------------------------------------------------------------------------------
Leakage / Sanity Summary — ds4
--------------------------------------------------------------------------------------------


Unnamed: 0,path_overlap_train_val,path_overlap_train_test,path_overlap_val_test,unique_paths_train,unique_paths_val,unique_paths_test,filename_overlap_train_val,filename_overlap_train_test,filename_overlap_val_test,subset_hash_train_val,subset_hash_train_test,subset_hash_val_test,subset_hash_n_train,subset_hash_n_val,subset_hash_n_test
0,0,0,0,8444,1810,1810,204,214,71,2,1,0,299,300,300



STEP 3: NON-IID CLIENT PARTITIONING (3 clients per dataset => 12 total)
DS1 Client 0 (gid 0): train=1129 tune=176 val=155
DS1 Client 1 (gid 1): train=2446 tune=380 val=334
DS1 Client 2 (gid 2): train=1438 tune=224 val=197
DS2 Client 0 (gid 3): train=548 tune=86 val=75
DS2 Client 1 (gid 4): train=2178 tune=338 val=297
DS2 Client 2 (gid 5): train=3568 tune=553 val=487
DS3 Client 0 (gid 6): train=781 tune=122 val=107
DS3 Client 1 (gid 7): train=1317 tune=205 val=180
DS3 Client 2 (gid 8): train=1706 tune=265 val=233
DS4 Client 0 (gid 9): train=2875 tune=446 val=393
DS4 Client 1 (gid 10): train=1672 tune=260 val=228
DS4 Client 2 (gid 11): train=1989 tune=309 val=272

--------------------------------------------------------------------------------------------
Client class distribution (Non-IID, per dataset)
--------------------------------------------------------------------------------------------


Unnamed: 0,client,dataset,total_train,total_tune,total_val,glioma,meningioma,notumor,pituitary
0,client_0,ds1,1129,176,155,37,4,421,667
1,client_1,ds1,2446,380,334,1738,526,14,168
2,client_2,ds1,1438,224,197,10,1415,3,10
3,client_3,ds2,548,86,75,158,5,0,385
4,client_4,ds2,2178,338,297,112,266,0,1800
5,client_5,ds2,3568,553,487,1772,1792,0,4
6,client_6,ds3,781,122,107,3,81,623,74
7,client_7,ds3,1317,205,180,59,593,209,456
8,client_8,ds3,1706,265,233,816,216,251,423
9,client_9,ds4,2875,446,393,263,1391,1199,22



STEP 4: DATA LOADERS (AUG ON) + IMAGENET NORM
Augmentation: ON ✅
Preprocessing: ON ✅

STEP 5: GA-TUNED ENHANCED FELCM PREPROCESSOR

STEP 6: MODEL (PVTv2-B2 + MULTI-SCALE FUSION)

STEP 7: GA FITNESS (ROBUST)

STEP 8: TRAIN / EVAL UTILITIES (FULL METRICS)

STEP 9: INITIALIZING GLOBAL MODEL


model.safetensors:   0%|          | 0.00/101M [00:00<?, ?B/s]

Backbone: pvt_v2_b2 | Total params: 25,990,026 | Trainable: 1,140,170 (4.39%)

--------------------------------------------------------------------------------------------
Hyperparameters / Search Space
--------------------------------------------------------------------------------------------


Unnamed: 0,hp_name,hp_value
0,clients_per_dataset,3
1,clients_total,12
2,rounds,12
3,local_epochs,2
4,lr,0.001
5,weight_decay,0.0005
6,warmup_epochs,1
7,label_smoothing,0.08
8,grad_clip,1.0
9,fedprox_mu,0.01



STEP 10: FEDERATED TRAINING (NO CENTRAL VAL/TEST)

ROUND 1/12
Client 0 (ds1) | train_acc=0.7879 | val_acc=0.9742 | val_f1=0.7372 | val_auc=0.9945 | val_logloss=0.1241 | GA_fit=9.638 | ga_time=12.1s | theta=(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, sh=0.00, dn=0.03)
Client 1 (ds1) | train_acc=0.8013 | val_acc=0.9431 | val_f1=0.9340 | val_auc=0.9889 | val_logloss=0.2335 | GA_fit=10.714 | ga_time=5.8s | theta=(γ=1.32, α=0.20, β=6.0, τ=3.1, k=5, sh=0.11, dn=0.28)
Client 2 (ds1) | train_acc=0.8383 | val_acc=0.9746 | val_f1=0.3307 | val_auc=0.8940 | val_logloss=0.1943 | GA_fit=0.198 | ga_time=5.9s | theta=(γ=0.78, α=0.46, β=3.5, τ=2.8, k=7, sh=0.28, dn=0.04)
Client 3 (ds2) | train_acc=0.6870 | val_acc=0.9067 | val_f1=0.5905 | val_auc=nan | val_logloss=0.3020 | GA_fit=6.868 | ga_time=6.0s | theta=(γ=0.64, α=0.52, β=5.4, τ=2.7, k=3, sh=0.29, dn=0.10)
Client 4 (ds2) | train_acc=0.7631 | val_acc=0.9428 | val_f1=0.8897 | val_auc=nan | val_logloss=0.2321 | GA_fit=0.209 | ga_time=6.0s | theta=(γ=0.67, α

Unnamed: 0,round,round_time_s,global_thetas,global_acc,global_f1_macro,global_precision_macro,global_recall_macro,global_log_loss,global_loss_ce,global_eval_time_s
0,1,806.945952,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.927654,0.773983,0.765457,0.803862,0.259516,0.257144,5.753748
1,2,723.766821,"{'ds1': '(γ=1.32, α=0.20, β=6.0, τ=3.1, k=5, s...",0.960784,0.851495,0.843712,0.879925,0.180107,0.178522,3.551848
2,3,781.09907,"{'ds1': '(γ=1.25, α=0.57, β=2.3, τ=2.2, k=5, s...",0.971602,0.868322,0.862058,0.87817,0.156187,0.15474,3.549849
3,4,763.970252,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.973293,0.898991,0.892093,0.915847,0.148511,0.14695,3.559128
4,5,762.685288,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.977011,0.914691,0.905577,0.933317,0.14249,0.141086,3.557618
5,6,762.572541,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.979716,0.907016,0.893641,0.932473,0.12749,0.126646,3.559277
6,7,761.740906,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.983435,0.933408,0.923432,0.949613,0.123804,0.123712,3.563437
7,8,760.979891,"{'ds1': '(γ=1.19, α=0.23, β=6.3, τ=3.1, k=7, s...",0.982082,0.884094,0.876849,0.900796,0.127243,0.126081,3.547035
8,9,761.152471,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.98073,0.920547,0.920137,0.923269,0.128611,0.127449,3.556252
9,10,760.474721,"{'ds1': '(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, s...",0.983773,0.938251,0.933468,0.94948,0.117907,0.11708,3.555869



--------------------------------------------------------------------------------------------
LOCAL per-client per-round metrics (head)
--------------------------------------------------------------------------------------------


Unnamed: 0,round,client,dataset,ga_best_fit_score,ga_time_s,theta_str,gamma_power,alpha_contrast_weight,beta_contrast_sharpness,tau_clip,...,val_g2_mean,val_g2_entropy_mean,val_g2_mean_c0,val_g2_entropy_c0,val_g2_mean_c1,val_g2_entropy_c1,val_g2_mean_c2,val_g2_entropy_c2,val_g2_mean_c3,val_g2_entropy_c3
0,1,client_0,ds1,9.638477,12.094798,"(γ=1.24, α=0.47, β=3.1, τ=2.7, k=7, sh=0.00, d...",1.244318,0.473141,3.055522,2.742403,...,0.499455,0.943852,0.499455,0.943852,0.499455,0.943852,0.499455,0.943852,0.499455,0.943852
1,1,client_1,ds1,10.714323,5.845026,"(γ=1.32, α=0.20, β=6.0, τ=3.1, k=5, sh=0.11, d...",1.320571,0.203079,6.02262,3.131826,...,0.497293,0.948486,0.497293,0.948486,0.497293,0.948486,0.497293,0.948487,0.497293,0.948486
2,1,client_2,ds1,0.198072,5.948863,"(γ=0.78, α=0.46, β=3.5, τ=2.8, k=7, sh=0.28, d...",0.77729,0.457708,3.50938,2.842538,...,0.49224,0.945733,0.49224,0.945732,0.49224,0.945733,0.49224,0.945732,0.49224,0.945732
3,1,client_3,ds2,6.867913,6.014931,"(γ=0.64, α=0.52, β=5.4, τ=2.7, k=3, sh=0.29, d...",0.641079,0.524832,5.440878,2.665325,...,0.495601,0.945834,0.495601,0.945834,0.495601,0.945834,,,0.495601,0.945834
4,1,client_4,ds2,0.209249,5.965567,"(γ=0.67, α=0.56, β=5.9, τ=2.5, k=3, sh=0.33, d...",0.66765,0.56285,5.855335,2.518214,...,0.488902,0.938683,0.488902,0.938683,0.488902,0.938683,,,0.488902,0.938683
5,1,client_5,ds2,4.163899,5.987546,"(γ=0.61, α=0.52, β=5.8, τ=2.8, k=7, sh=0.28, d...",0.614877,0.522026,5.792319,2.784551,...,0.4898,0.942876,0.4898,0.942876,0.4898,0.942876,,,0.4898,0.942876
6,1,client_6,ds3,13.143195,5.949632,"(γ=1.13, α=0.20, β=4.5, τ=3.3, k=3, sh=0.12, d...",1.133613,0.199126,4.511599,3.323043,...,0.485627,0.943933,0.485627,0.943933,0.485627,0.943933,0.485627,0.943933,0.485627,0.943933
7,1,client_7,ds3,9.825009,5.999177,"(γ=1.16, α=0.36, β=4.5, τ=2.4, k=3, sh=0.12, d...",1.16348,0.361975,4.511599,2.410383,...,0.490888,0.949444,0.490888,0.949444,0.490888,0.949444,0.490888,0.949444,0.490888,0.949444
8,1,client_8,ds3,12.263569,6.002362,"(γ=1.18, α=0.53, β=5.2, τ=3.0, k=5, sh=0.23, d...",1.178856,0.530371,5.195688,3.038469,...,0.486209,0.943719,0.486209,0.943719,0.486209,0.943719,0.486209,0.943719,0.486209,0.943719
9,1,client_9,ds4,6.640467,5.926919,"(γ=0.92, α=0.52, β=3.9, τ=2.2, k=3, sh=0.27, d...",0.921893,0.519938,3.864403,2.184408,...,0.476817,0.936757,0.476817,0.936757,0.476817,0.936757,0.476817,0.936757,0.476817,0.936757



STEP 11: FINAL EVALUATION (FEDERATED VAL + TEST)

--------------------------------------------------------------------------------------------
VAL+TEST tables (federated, per-dataset + global)
--------------------------------------------------------------------------------------------


Unnamed: 0,setting,split,dataset,acc,precision_macro,recall_macro,f1_macro,precision_weighted,recall_weighted,f1_weighted,log_loss,auc_roc_macro_ovr,loss_ce,eval_time_s
0,Enhanced FELCM (Best θ per dataset),VAL,all datasets weighted,0.983097,0.844038,0.924235,0.86374,0.988753,0.983097,0.985377,0.131958,,0.131386,3.54891
1,Enhanced FELCM (Best θ ds1),TEST,ds1,0.978402,0.981584,0.977547,0.979278,0.979055,0.978402,0.978391,0.136622,0.998678,0.135022,6.445197
2,Enhanced FELCM (Best θ ds2),TEST,ds2,0.987378,0.905945,0.905039,0.905469,0.988551,0.987378,0.987939,0.12167,,0.1206,7.165533
3,Enhanced FELCM (Best θ ds3),TEST,ds3,0.991461,0.991283,0.991519,0.99133,0.991659,0.991461,0.991491,0.088016,0.999698,0.087542,5.38899
4,Enhanced FELCM (Best θ ds4),TEST,ds4,0.961326,0.962196,0.961883,0.961765,0.961832,0.961326,0.9613,0.179371,0.996109,0.18337,8.016314
5,Enhanced FELCM (Best θ),TEST,global weighted,0.978152,0.955448,0.954197,0.954654,0.978832,0.978152,0.97831,0.136636,,0.137078,6.9432



--------------------------------------------------------------------------------------------
Requested explicit metrics (acc, pre, rec, f1, logloss, auc_roc)
--------------------------------------------------------------------------------------------


Unnamed: 0,dataset,acc,pre,rec,f1,logloss,auc_roc
0,all_datasets_val_weighted,0.983097,0.844038,0.924235,0.86374,0.131958,
1,ds1_test,0.978402,0.981584,0.977547,0.979278,0.136622,0.998678
2,ds2_test,0.987378,0.905945,0.905039,0.905469,0.12167,
3,ds3_test,0.991461,0.991283,0.991519,0.99133,0.088016,0.999698
4,ds4_test,0.961326,0.962196,0.961883,0.961765,0.179371,0.996109
5,global_test_weighted,0.978152,0.955448,0.954197,0.954654,0.136636,



Paper selection summary:
- Best round (by federated VAL accuracy): round=12 | best_val_acc=0.9885
- Best θ ds1: (γ=1.36, α=0.08, β=6.8, τ=3.1, k=5, sh=0.05, dn=0.20)
- Best θ ds2: (γ=0.68, α=0.44, β=4.5, τ=2.5, k=7, sh=0.27, dn=0.12)
- Best θ ds3: (γ=1.18, α=0.53, β=5.2, τ=3.0, k=5, sh=0.23, dn=0.04)
- Best θ ds4: (γ=0.97, α=0.49, β=2.0, τ=2.1, k=3, sh=0.34, dn=0.01)

STEP 12: PREPROCESSING VALIDATION (DS1 VAL SAMPLE)

--------------------------------------------------------------------------------------------
Preprocessing validation summary (DS1 VAL sample)
--------------------------------------------------------------------------------------------


Unnamed: 0,metric,mean,std,min,max
0,edge_energy_before,0.033557,0.011519,0.010264,0.11001
1,edge_energy_after,0.045955,0.010286,0.018803,0.102962
2,entropy_before,5.6202,0.832229,2.852003,7.564816
3,entropy_after,5.989759,0.673774,3.300354,7.341709
4,contrast_before,0.168932,0.034004,0.089726,0.356358
5,contrast_after,0.191185,0.016224,0.156567,0.316045
6,edge_gain_ratio,1.437564,0.296954,0.866943,3.060668
7,entropy_delta,0.369558,0.206458,-0.234723,0.920488
8,contrast_delta,0.022253,0.028139,-0.05547,0.108779



STEP 13: COMPUTE CURVES + CONFUSION + CALIBRATION TABLES (NO PLOTS)

--------------------------------------------------------------------------------------------
ROC curve points (DS1 TEST, first 20 rows)
--------------------------------------------------------------------------------------------


Unnamed: 0,class,fpr,tpr,threshold
0,glioma,0.0,0.0,inf
1,glioma,0.0,0.005435,0.912852
2,glioma,0.0,0.016304,0.911762
3,glioma,0.003584,0.016304,0.911719
4,glioma,0.003584,0.608696,0.908241
5,glioma,0.007168,0.608696,0.90824
6,glioma,0.007168,0.983696,0.903709
7,glioma,0.010753,0.983696,0.902528
8,glioma,0.010753,0.98913,0.900419
9,glioma,0.021505,0.98913,0.743146



--------------------------------------------------------------------------------------------
PR curve points (DS1 TEST, first 20 rows)
--------------------------------------------------------------------------------------------


Unnamed: 0,class,precision,recall,threshold
0,glioma,0.397408,1.0,0.006892
1,glioma,0.399132,1.0,0.007197
2,glioma,0.4,1.0,0.00721
3,glioma,0.400871,1.0,0.007232
4,glioma,0.401747,1.0,0.007319
5,glioma,0.402626,1.0,0.007361
6,glioma,0.403509,1.0,0.007382
7,glioma,0.404396,1.0,0.007513
8,glioma,0.405286,1.0,0.007594
9,glioma,0.406181,1.0,0.007615



--------------------------------------------------------------------------------------------
Confusion matrix counts (DS1 TEST)
--------------------------------------------------------------------------------------------


Unnamed: 0,true,glioma,meningioma,notumor,pituitary
0,glioma,183,1,0,0
1,meningioma,6,155,0,2
2,notumor,2,0,39,0
3,pituitary,0,0,1,74



--------------------------------------------------------------------------------------------
Confusion matrix row-normalized (DS1 TEST)
--------------------------------------------------------------------------------------------


Unnamed: 0,true,glioma,meningioma,notumor,pituitary
0,glioma,0.994565,0.005435,0.0,0.0
1,meningioma,0.03681,0.95092,0.0,0.01227
2,notumor,0.04878,0.0,0.95122,0.0
3,pituitary,0.0,0.0,0.013333,0.986667



--------------------------------------------------------------------------------------------
Calibration bins table (DS1)
--------------------------------------------------------------------------------------------


Unnamed: 0,bin_confidence,bin_accuracy,bin_count
0,,,0
1,,,0
2,,,0
3,,,0
4,,,0
5,0.493923,1.0,2
6,0.542857,0.0,3
7,0.62928,0.333333,3
8,0.721814,0.8,5
9,0.765013,0.666667,3



STEP 14: THETA EVOLUTION TABLE

--------------------------------------------------------------------------------------------
Mean best-θ parameters over rounds (clients averaged)
--------------------------------------------------------------------------------------------


Unnamed: 0,round,gamma_power,alpha_contrast_weight,beta_contrast_sharpness,tau_clip,k_blur_kernel_size,sh_sharpen_strength,dn_denoise_strength
0,1,0.95913,0.460884,4.353408,2.650644,4.333333,0.226634,0.080778
1,2,1.013108,0.396479,4.199809,2.772723,3.833333,0.223487,0.108809
2,3,0.985409,0.409887,4.608343,2.782583,4.0,0.196151,0.082068
3,4,0.97739,0.458337,4.74927,2.707201,3.833333,0.211109,0.069298
4,5,0.983716,0.367041,4.490318,2.877774,4.666667,0.166931,0.111062
5,6,1.019666,0.34766,4.876448,2.865048,4.0,0.139473,0.122805
6,7,0.993556,0.361964,4.992942,2.60764,3.833333,0.174861,0.096345
7,8,1.03247,0.325638,5.218503,2.772185,4.0,0.125115,0.09217
8,9,1.054699,0.282786,4.788884,2.925929,4.333333,0.142183,0.1255
9,10,1.067166,0.359066,4.611436,2.664418,4.166667,0.114938,0.09609



STEP 15: SAVING ONLY TWO FILES (CHECKPOINT + ONE CSV)
✅ Saved checkpoint: /kaggle/working/outputs/FL_GAFELCM_PVTv2B2_FUSION_checkpoint.pth
✅ Saved CSV (ALL outputs): /kaggle/working/outputs/ALL_OUTPUTS_AND_METRICS.csv

DONE ✅ (KAGGLE, TRUE FL SIMULATION, 12 clients (3x4 datasets), rounds=12, Preprocessing+GA, Augmentation, Fusion, PVTv2-B2, no plots)


In [None]:
import os
from pathlib import Path
from collections import Counter
from PIL import Image

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

DATASETS = {
    "ds1": CFG["ds1_base"],  # or put your paths directly
    "ds2": CFG["ds2_base"],
    "ds3": CFG["ds3_base"],
    "ds4": CFG["ds4_base"],
}

def iter_images(root):
    root = Path(root)
    for p in root.rglob("*"):
        if p.is_file() and p.suffix.lower() in IMG_EXTS:
            yield p

def image_size_counts(root, max_images=None):
    cnt = Counter()
    n_ok, n_bad = 0, 0
    for i, p in enumerate(iter_images(root)):
        if max_images is not None and i >= max_images:
            break
        try:
            with Image.open(p) as im:
                w, h = im.size
            cnt[f"{w}x{h}"] += 1
            n_ok += 1
        except Exception:
            n_bad += 1
    return cnt, n_ok, n_bad

# Change this if you want faster scan (e.g., max_images=2000)
MAX_IMAGES = None

for name, path in DATASETS.items():
    cnt, n_ok, n_bad = image_size_counts(path, max_images=MAX_IMAGES)
    print("\n" + "="*70)
    print(f"{name.upper()} | root: {path}")
    print(f"read_ok={n_ok} | read_failed={n_bad} | unique_sizes={len(cnt)}")
    print("Top image sizes:")
    for size, c in cnt.most_common(20):
        print(f"  {size:>10}  ->  {c}")



DS1 | root: /kaggle/input/datasets/alamshihab075/brain-tumor-mri-dataset-for-deep-learning
read_ok=9257 | read_failed=0 | unique_sizes=325
Top image sizes:
     512x512  ->  8214
     225x225  ->  119
     630x630  ->  66
     442x442  ->  42
     236x236  ->  33
     256x256  ->  25
     201x251  ->  23
     201x250  ->  22
     214x236  ->  18
     468x444  ->  15
     504x540  ->  13
     359x449  ->  13
     550x664  ->  12
     393x400  ->  12
     200x252  ->  11
     220x275  ->  11
     400x442  ->  11
     350x350  ->  10
     442x454  ->  10
     232x217  ->  9

DS2 | root: /kaggle/input/datasets/zehrakucuker/brain-tumor-mri-images-classification-dataset
read_ok=15605 | read_failed=0 | unique_sizes=387
Top image sizes:
     512x512  ->  11006
     225x225  ->  630
     630x630  ->  181
     236x236  ->  180
     201x251  ->  109
     442x442  ->  98
     228x221  ->  92
     300x168  ->  87
     232x217  ->  87
     150x198  ->  80
     428x417  ->  78
     200x252  ->  77
 