In [None]:
# ============================================================
# Cell 1 ‚Äî Setup (paths + dataset staging)
#
# Run this cell first.
#
# What this cell does
#   1) Selects where training outputs are saved (Drive or VM)
#   2) Prompts for the main results folder name (default: Training_Result)
#   3) Retrieves the dataset ZIP into the VM (no link spam; progress bar only)
#   4) Extracts the ZIP with a progress bar
#
# Exposed variables
#   - RESULTS_ROOT : output root directory
#   - DATASET_DIR  : dataset folder inside VM
#   - DATA_ROOT    : str(DATASET_DIR) for backward compatibility
# ============================================================
# ============================================================
# Cell 1 ‚Äî Setup (paths + dataset staging)
# ============================================================

!pip install -q gdown

import os
import sys

# -------------------------
# Global experiment config
# -------------------------
GLOBAL_SEED = 42
IN_CHANNELS = 1
OUT_CLASSES = 3

# --- Determinism guard: this MUST run before any torch import ---
if "torch" in sys.modules:
    raise RuntimeError("torch was imported before setting CUBLAS_WORKSPACE_CONFIG. Restart runtime and run CELL 1 first.")

# Optional CPU-side stability (reduces nondeterminism from BLAS/OpenMP scheduling)
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(GLOBAL_SEED)


import sys
import io
from pathlib import Path
from zipfile import ZipFile
from contextlib import redirect_stdout, redirect_stderr

import random
import numpy as np

import gdown
from tqdm import tqdm  # classic tqdm for consistent console bars



# -------------------------
# CPU-side reproducibility
# (Torch-side determinism will be enabled later, after torch import)
# -------------------------
random.seed(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)

print(f"‚úÖ Cell 1 reproducibility pre-config set (GLOBAL_SEED={GLOBAL_SEED}).")
print("   - CUBLAS_WORKSPACE_CONFIG set for deterministic CUDA matmul (effective only if torch not imported yet).")




# -----------------------------
# 1) Results destination
# -----------------------------
print(
    "\nResults destination\n"
    "  1) Google Drive (persistent)\n"
    "  2) Colab VM only (temporary)\n"
)
choice = input("Select 1 or 2: ").strip()
if choice not in {"1", "2"}:
    raise ValueError("Invalid choice. Please rerun and select 1 or 2.")

SAVE_TO_DRIVE = (choice == "1")

if SAVE_TO_DRIVE:
    try:
        from google.colab import drive
        # keep mount output minimal (no extra chatter)
        with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
            drive.mount("/content/drive", force_remount=True)
    except Exception as e:
        raise RuntimeError(
            "Drive mount failed. This cell is intended for Google Colab.\n"
            f"Details: {e}"
        )


# -----------------------------
# 2) Main results folder name
# -----------------------------
DEFAULT_RESULTS_FOLDER = "Training_Result"
folder_in = input(f"Results folder name (default: {DEFAULT_RESULTS_FOLDER}): ").strip()

if not folder_in:
    RESULTS_FOLDER_NAME = DEFAULT_RESULTS_FOLDER
else:
    RESULTS_FOLDER_NAME = folder_in.replace("\\", "/").strip().strip("/")
    if not RESULTS_FOLDER_NAME:
        RESULTS_FOLDER_NAME = DEFAULT_RESULTS_FOLDER

if SAVE_TO_DRIVE:
    RESULTS_ROOT = Path("/content/drive/MyDrive") / RESULTS_FOLDER_NAME
else:
    RESULTS_ROOT = Path("/content") / RESULTS_FOLDER_NAME

RESULTS_ROOT.mkdir(parents=True, exist_ok=True)


# -----------------------------
# 3) Dataset staging in VM
# -----------------------------
VM_DATA_ROOT = Path("/content/data_zip")
VM_DATA_ROOT.mkdir(parents=True, exist_ok=True)

ZIP_NAME = "XCT_SegData_CFRP_FFRP_v1.zip"
LOCAL_ZIP_PATH = VM_DATA_ROOT / ZIP_NAME
DATASET_DIR = VM_DATA_ROOT / "XCT_SegData_CFRP_FFRP_v1"

GDRIVE_ZIP_LINKS = [
    "https://drive.google.com/file/d/195nIFRoL-tEST9WPPYGGMIlpdACsQXWQ/view?usp=sharing",
    "https://drive.google.com/file/d/1Z1szaBWpD6HYhub3EI9VJAoRPBMGuBQy/view?usp=sharing",
    "https://drive.google.com/file/d/1EYGAN0yg2V-2rNWSEdk0gm-WH4LkGJcS/view?usp=sharing",
    "https://drive.google.com/file/d/1IDxrn-XyTgSMj9hRGo5MRmYJxmtAWcxY/view?usp=sharing",
    "https://drive.google.com/file/d/1inh4XOJdYDQAcT1GigwqlXBpOsOSk78Z/view?usp=sharing",
]

DRIVE_ZIP_FALLBACKS = []
if SAVE_TO_DRIVE:
    DRIVE_ZIP_FALLBACKS = [Path("/content/drive/MyDrive") / ZIP_NAME]


# -----------------------------
# 4) Helpers (progress bars only)
# -----------------------------
class _StderrFilter:
    """
    Drops link spam and gdown info lines from stderr while preserving tqdm bars.
    """
    def __init__(self, real_stderr):
        self._real = real_stderr

    def write(self, s):
        if not s:
            return 0

        # Remove any links completely
        if ("http://" in s) or ("https://" in s):
            return 0

        # Remove common gdown chatter
        if ("From (original):" in s) or ("From (redirected):" in s) or ("To:" in s):
            return 0
        if s.strip() in {"Downloading...", "Downloading"}:
            return 0

        return self._real.write(s)

    def flush(self):
        return self._real.flush()

    def isatty(self):
        return getattr(self._real, "isatty", lambda: False)()

    def fileno(self):
        return getattr(self._real, "fileno", lambda: 2)()

    @property
    def encoding(self):
        return getattr(self._real, "encoding", "utf-8")


def _copy_file_with_progress(src: Path, dst: Path, chunk_size: int = 1024 * 1024) -> None:
    total = src.stat().st_size
    with open(src, "rb") as fsrc, open(dst, "wb") as fdst, tqdm(
        total=total,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
        desc="Copying",
        dynamic_ncols=True,
    ) as pbar:
        while True:
            buf = fsrc.read(chunk_size)
            if not buf:
                break
            fdst.write(buf)
            pbar.update(len(buf))


def _download_zip_quiet(urls, out_path: Path) -> bool:
    """
    Keeps the gdown progress bar but removes all link/metadata prints.
    """
    for url in urls:
        if out_path.exists():
            try:
                out_path.unlink()
            except Exception:
                pass

        real_stderr = sys.stderr
        sys.stderr = _StderrFilter(real_stderr)
        try:
            with redirect_stdout(io.StringIO()):
                gdown.download(url=url, output=str(out_path), quiet=False, fuzzy=True)
        except Exception:
            pass
        finally:
            sys.stderr = real_stderr

        if out_path.exists() and out_path.stat().st_size > 0:
            return True

    return False


def _try_drive_fallback(candidates, dst_zip: Path) -> bool:
    for c in candidates:
        if c.exists() and c.stat().st_size > 0:
            _copy_file_with_progress(c, dst_zip)
            return True
    return False


def _unzip_with_progress(zip_path: Path, dest_dir: Path) -> None:
    with ZipFile(zip_path, "r") as zf:
        members = zf.infolist()
        total_size = sum(m.file_size for m in members)
        with tqdm(
            total=total_size,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
            desc="Extracting",
            dynamic_ncols=True,
        ) as pbar:
            for m in members:
                zf.extract(m, path=dest_dir)
                pbar.update(m.file_size)


# -----------------------------
# 5) Ensure dataset is ready
# -----------------------------
if not (DATASET_DIR.exists() and any(DATASET_DIR.iterdir())):
    if not (LOCAL_ZIP_PATH.exists() and LOCAL_ZIP_PATH.stat().st_size > 0):
        ok = _download_zip_quiet(GDRIVE_ZIP_LINKS, LOCAL_ZIP_PATH)

        if not ok and SAVE_TO_DRIVE and DRIVE_ZIP_FALLBACKS:
            ok = _try_drive_fallback(DRIVE_ZIP_FALLBACKS, LOCAL_ZIP_PATH)

        if not ok:
            raise RuntimeError("Dataset ZIP could not be retrieved (all sources failed).")

    _unzip_with_progress(LOCAL_ZIP_PATH, VM_DATA_ROOT)

    if not (DATASET_DIR.exists() and any(DATASET_DIR.iterdir())):
        raise RuntimeError("Extraction completed but dataset folder is missing/empty.")

DATA_ROOT = str(DATASET_DIR)



# ============================================================
# Dataset presence + compact dataset tree (no extra output)
# ============================================================

from pathlib import Path

def _count_files(p: Path) -> int:
    try:
        return sum(1 for x in p.iterdir() if x.is_file())
    except Exception:
        return 0

def print_dataset_tree(dataset_dir: Path):
    dataset_dir = Path(dataset_dir)
    print("üìÇ Checking dataset folder presence...")
    print("Exists?", dataset_dir.exists())

    print("DATASET TREE")
    print(f"{dataset_dir.name}/")

    domains = [
        ("CFRP", "‚îú‚îÄ‚îÄ"),
        ("FFRP_Autoclave", "‚îú‚îÄ‚îÄ"),
        ("FFRP_Oven", "‚îî‚îÄ‚îÄ"),
    ]

    for domain, prefix in domains:
        d = dataset_dir / domain
        if not d.exists():
            continue

        msk = d / "MASK"
        xct = d / "XCT"
        msk_n = _count_files(msk)
        xct_n = _count_files(xct)

        print(f"{prefix} {domain}/")
        print("‚îÇ   ‚îú‚îÄ‚îÄ MASK/  (files: {})".format(msk_n))
        print("‚îÇ   ‚îî‚îÄ‚îÄ XCT/   (files: {})".format(xct_n))

print_dataset_tree(Path(DATASET_DIR))




In [None]:
# ============================================================
# CELL 2: Experiment selection + split + datasets/dataloaders + RUN_META
# ============================================================

import os
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

import tifffile as tiff

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2  # for PadIfNeeded border_mode


# -----------------------------
# STRICT determinism + TF32 OFF (paper-grade reproducibility)
# -----------------------------
DETERMINISM_DEBUG = True  # keep True while debugging reproducibility

# Disable TF32 / enforce IEEE FP32 (new API; avoids deprecation warnings)
try:
    torch.backends.cuda.matmul.fp32_precision = "ieee"
except Exception:
    pass
try:
    torch.backends.cudnn.conv.fp32_precision = "ieee"
except Exception:
    pass

# cuDNN determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


if DETERMINISM_DEBUG:
    # Do not swallow errors: fail fast if any nondeterministic op is used.
    torch.use_deterministic_algorithms(True)
    try:
        torch.set_deterministic_debug_mode("error")
    except Exception:
        pass

# OpenCV determinism: disable threading + OpenCL
try:
    cv2.setNumThreads(0)
    cv2.ocl.setUseOpenCL(False)
except Exception:
    pass

# Reduce CPU scheduling variability (mostly irrelevant when NUM_WORKERS=0, but safe)
try:
    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)
except Exception:
    pass


# -----------------------------
# Deterministic seed helpers
# -----------------------------
def reset_all_seeds(seed: int = 42):
    """Reset Python/NumPy/PyTorch RNG state for reproducible runs."""
    seed = int(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # CV2 RNG (albumentations uses cv2 internally sometimes)
    try:
        cv2.setRNGSeed(seed)
    except Exception:
        pass

    # cuDNN determinism (repeat for safety)
    try:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass


@contextmanager
def fixed_rng(seed: int):
    """
    Freeze RNG state for deterministic per-sample augmentation.
    IMPORTANT: do NOT touch CUDA RNG here (dataset __getitem__ must be CPU-only RNG).
    """
    seed = int(seed)

    py_state = random.getstate()
    np_state = np.random.get_state()

    random.seed(seed)
    np.random.seed(seed)

    # OpenCV RNG for deterministic geometric ops
    try:
        cv2.setRNGSeed(seed)
    except Exception:
        pass

    try:
        yield
    finally:
        random.setstate(py_state)
        np.random.set_state(np_state)

        # restore cv2 to base seed (optional but keeps global stable)
        try:
            cv2.setRNGSeed(int(GLOBAL_SEED))
        except Exception:
            pass


# -----------------------------
# Global config (safe defaults)
# -----------------------------
if "DATASET_DIR" not in globals():
    DATASET_DIR = Path("/content/data_zip/XCT_SegData_CFRP_FFRP_v1")
else:
    DATASET_DIR = Path(DATASET_DIR)

if "GLOBAL_SEED" not in globals():
    GLOBAL_SEED = 42

if "IN_CHANNELS" not in globals():
    IN_CHANNELS = 1
if "OUT_CLASSES" not in globals():
    OUT_CLASSES = 3

# Loader config (keep here; ablation toggles will be in a separate config cell)
TRAIN_FRAC  = 0.8
PATCH_SIZE  = 256
BATCH_SIZE  = 4
NUM_WORKERS = 0  # you said this will ALWAYS be 0

reset_all_seeds(int(GLOBAL_SEED))


# -----------------------------
# Domains and file pairing
# -----------------------------
DOMAIN_INFO = {
    "CFRP": {
        "xct_dir":  DATASET_DIR / "CFRP" / "XCT",
        "mask_dir": DATASET_DIR / "CFRP" / "MASK",
    },
    "FFRP_Autoclave": {
        "xct_dir":  DATASET_DIR / "FFRP_Autoclave" / "XCT",
        "mask_dir": DATASET_DIR / "FFRP_Autoclave" / "MASK",
    },
    "FFRP_Oven": {
        "xct_dir":  DATASET_DIR / "FFRP_Oven" / "XCT",
        "mask_dir": DATASET_DIR / "FFRP_Oven" / "MASK",
    },
}

def collect_pairs(xct_dir: Path, mask_dir: Path) -> List[Tuple[Path, Path]]:
    """
    Match XCT/MASK pairs by filename stem. Supports .tif/.tiff.
    Returns sorted list of (xct_path, mask_path).
    """
    xct_dir  = Path(xct_dir)
    mask_dir = Path(mask_dir)

    x_files = sorted([p for p in xct_dir.iterdir() if p.is_file() and p.suffix.lower() in {".tif", ".tiff"}])
    m_files = sorted([p for p in mask_dir.iterdir() if p.is_file() and p.suffix.lower() in {".tif", ".tiff"}])

    m_map = {p.stem: p for p in m_files}

    pairs = []
    for xp in x_files:
        mp = m_map.get(xp.stem, None)
        if mp is not None:
            pairs.append((xp, mp))
    return pairs

ALL_PAIRS_BY_DOMAIN: Dict[str, List[Tuple[Path, Path]]] = {}
for dom, info in DOMAIN_INFO.items():
    if not info["xct_dir"].exists() or not info["mask_dir"].exists():
        raise FileNotFoundError(f"Missing domain folders for: {dom}")
    ALL_PAIRS_BY_DOMAIN[dom] = collect_pairs(info["xct_dir"], info["mask_dir"])


# -----------------------------
# Experiment definitions (Core A4 placed right after A3)
# -----------------------------
EXPERIMENTS = {
    "1":  {"name": "A1_train_CFRP__val_CFRP",               "short": "A1", "train_domains": ["CFRP"],             "val_domains": ["CFRP"]},
    "2":  {"name": "A2_train_Autoclave__val_Autoclave",     "short": "A2", "train_domains": ["FFRP_Autoclave"],   "val_domains": ["FFRP_Autoclave"]},
    "3":  {"name": "A3_train_Oven__val_Oven",               "short": "A3", "train_domains": ["FFRP_Oven"],        "val_domains": ["FFRP_Oven"]},

    # Core experiment (A4): All domains ‚Üí All domains
    "4":  {"name": "Core_train_AllDomains__val_AllDomains", "short": "A4",
           "train_domains": ["CFRP", "FFRP_Autoclave", "FFRP_Oven"],
           "val_domains":   ["CFRP", "FFRP_Autoclave", "FFRP_Oven"]},

    "5":  {"name": "B1_train_CFRP__val_AllDomains",         "short": "B1", "train_domains": ["CFRP"],           "val_domains": ["CFRP", "FFRP_Autoclave", "FFRP_Oven"]},
    "6":  {"name": "B2_train_Autoclave__val_AllDomains",    "short": "B2", "train_domains": ["FFRP_Autoclave"], "val_domains": ["CFRP", "FFRP_Autoclave", "FFRP_Oven"]},
    "7":  {"name": "B3_train_Oven__val_AllDomains",         "short": "B3", "train_domains": ["FFRP_Oven"],      "val_domains": ["CFRP", "FFRP_Autoclave", "FFRP_Oven"]},

    "8":  {"name": "C1_train_CFRP+Oven__val_Autoclave",     "short": "C1", "train_domains": ["CFRP", "FFRP_Oven"],              "val_domains": ["FFRP_Autoclave"]},
    "9":  {"name": "C2_train_CFRP+Autoclave__val_Oven",     "short": "C2", "train_domains": ["CFRP", "FFRP_Autoclave"],         "val_domains": ["FFRP_Oven"]},
    "10": {"name": "C3_train_Autoclave+Oven__val_CFRP",     "short": "C3", "train_domains": ["FFRP_Autoclave", "FFRP_Oven"],     "val_domains": ["CFRP"]},
}

def show_experiment_menu():
    print("\n================ EXPERIMENT MENU ================")
    print("Group A: Single-domain baselines (in-domain)")
    print("  1) A1 : CFRP ‚Üí CFRP")
    print("  2) A2 : FFRP_Autoclave ‚Üí FFRP_Autoclave")
    print("  3) A3 : FFRP_Oven ‚Üí FFRP_Oven")
    print("\nCore experiment (A4): All domains ‚Üí All domains")
    print("  4) A4 : {CFRP, FFRP_Autoclave, FFRP_Oven} ‚Üí {CFRP, FFRP_Autoclave, FFRP_Oven}\n")

    print("Group B: Single-domain train, eval on ALL domains")
    print("  5) B1 : CFRP ‚Üí {CFRP, FFRP_Autoclave, FFRP_Oven}")
    print("  6) B2 : FFRP_Autoclave ‚Üí {CFRP, FFRP_Autoclave, FFRP_Oven}")
    print("  7) B3 : FFRP_Oven ‚Üí {CFRP, FFRP_Autoclave, FFRP_Oven}\n")

    print("Group C: Multi-domain training (cross-domain)")
    print("  8)  C1 : {CFRP, FFRP_Oven} ‚Üí FFRP_Autoclave")
    print("  9)  C2 : {CFRP, FFRP_Autoclave} ‚Üí FFRP_Oven")
    print(" 10)  C3 : {FFRP_Autoclave, FFRP_Oven} ‚Üí CFRP")
    print("=================================================\n")

show_experiment_menu()
choice = input("Choose experiment (1‚Äì10): ").strip()
if choice not in EXPERIMENTS:
    raise ValueError(f"Invalid experiment choice: {choice}")

exp_cfg = EXPERIMENTS[choice]
EXPERIMENT_NAME  = exp_cfg["name"]
EXPERIMENT_SHORT = exp_cfg["short"]
TRAIN_DOMAINS    = list(exp_cfg["train_domains"])
VAL_DOMAINS      = list(exp_cfg["val_domains"])

# Run name convention stays unchanged
RUN_NAME = f"{EXPERIMENT_SHORT}_seed{int(GLOBAL_SEED)}"

print("\nSelected setup")
print(f"  Experiment : {EXPERIMENT_NAME}")
print(f"  Short      : {EXPERIMENT_SHORT}")
print(f"  Run name   : {RUN_NAME}")
print(f"  Train doms : {TRAIN_DOMAINS}")
print(f"  Val doms   : {VAL_DOMAINS}")


# -----------------------------
# Per-experiment split
# -----------------------------
def per_experiment_split(all_pairs_by_domain, train_domains, val_domains, seed: int, exp_choice: int):
    """
    Per-experiment splitting:
      - Train-only domain ‚Üí all slices in train
      - Val-only domain   ‚Üí ~20% slices in val
      - Train+Val domain  ‚Üí ~80% train / ~20% val
    """
    train_pairs_by_domain = {}
    val_pairs_by_domain   = {}

    all_used  = sorted(set(train_domains + val_domains))
    dom_order = list(sorted(DOMAIN_INFO.keys()))  # stable

    for dom in all_used:
        pairs = all_pairs_by_domain[dom]
        n_total = len(pairs)

        in_train = dom in train_domains
        in_val   = dom in val_domains

        train_idx, val_idx = [], []

        if in_train and in_val:
            indices = list(range(n_total))
            dom_idx = dom_order.index(dom)

            rng_dom = random.Random(int(seed) + 1000 * int(exp_choice) + 100 * (dom_idx + 1))
            rng_dom.shuffle(indices)

            n_train = int(round(n_total * TRAIN_FRAC))
            train_idx = indices[:n_train]
            val_idx   = indices[n_train:]

        elif in_train and not in_val:
            train_idx = list(range(n_total))

        elif in_val and not in_train:
            indices = list(range(n_total))
            dom_idx = dom_order.index(dom)

            rng_dom = random.Random(int(seed) + 1000 * int(exp_choice) + 100 * (dom_idx + 1))
            rng_dom.shuffle(indices)

            n_val = max(1, int(round(n_total * (1.0 - TRAIN_FRAC))))
            val_idx = indices[:n_val]

        if len(train_idx) > 0:
            train_pairs_by_domain[dom] = [pairs[j] for j in train_idx]
        if len(val_idx) > 0:
            val_pairs_by_domain[dom] = [pairs[j] for j in val_idx]

    return train_pairs_by_domain, val_pairs_by_domain


def balance_train_pairs(pairs_by_domain, seed: int):
    """Downsample all train domains to the minimum slice count among active train domains."""
    lengths = {d: len(v) for d, v in pairs_by_domain.items() if len(v) > 0}
    if not lengths:
        return pairs_by_domain

    min_count = min(lengths.values())
    rng = random.Random(int(seed) + 999)

    balanced = {}
    for d, pairs in pairs_by_domain.items():
        balanced[d] = rng.sample(pairs, min_count) if len(pairs) > min_count else list(pairs)
    return balanced


def balance_val_pairs(pairs_by_domain, seed: int, val_domains):
    """Downsample validation across active val domains to the minimum slice count."""
    lengths = {d: len(pairs_by_domain.get(d, [])) for d in val_domains if len(pairs_by_domain.get(d, [])) > 0}
    if len(lengths) <= 1:
        return pairs_by_domain

    min_count = min(lengths.values())
    rng = random.Random(int(seed) + 1234)

    balanced = {}
    for d in val_domains:
        pairs = pairs_by_domain.get(d, [])
        balanced[d] = rng.sample(pairs, min_count) if len(pairs) > min_count else list(pairs)
    return balanced


train_pairs_raw, val_pairs_raw = per_experiment_split(
    ALL_PAIRS_BY_DOMAIN, TRAIN_DOMAINS, VAL_DOMAINS,
    seed=int(GLOBAL_SEED),
    exp_choice=int(choice)
)

train_pairs_by_domain = balance_train_pairs(train_pairs_raw, int(GLOBAL_SEED))
val_pairs_by_domain   = balance_val_pairs(val_pairs_raw, int(GLOBAL_SEED), VAL_DOMAINS)

# Flatten train items
train_items = []
for d, pairs in train_pairs_by_domain.items():
    for x_path, m_path in pairs:
        train_items.append({"domain": d, "xct_path": x_path, "mask_path": m_path})

# Val items (kept domain-wise)
val_items_by_domain = {}
for d in VAL_DOMAINS:
    pairs = val_pairs_by_domain.get(d, [])
    val_items_by_domain[d] = [{"domain": d, "xct_path": x_path, "mask_path": m_path} for x_path, m_path in pairs]

def _count_domain(items, dom):
    return sum(1 for it in items if it["domain"] == dom)

val_total = sum(len(v) for v in val_items_by_domain.values())

print("\nFinal split (after balancing)")
print(f"  Train total: {len(train_items)}")
for d in TRAIN_DOMAINS:
    print(f"    - {d}: {_count_domain(train_items, d)}")
print(f"  Val total  : {val_total}")
for d in VAL_DOMAINS:
    print(f"    - {d}: {len(val_items_by_domain.get(d, []))}")


# -----------------------------
# Transforms (256√ó256 crop; no resize)
# -----------------------------
def get_transforms(is_train: bool):
    """
    No resize. Always:
      - PadIfNeeded (reflect) to ensure >= PATCH_SIZE
      - Train: RandomCrop + light augmentation
      - Val: CenterCrop
    """
    aug_list = [
        A.PadIfNeeded(
            min_height=PATCH_SIZE,
            min_width=PATCH_SIZE,
            border_mode=cv2.BORDER_REFLECT_101,
        )
    ]

    if is_train:
        aug_list += [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.20, contrast_limit=0.20, p=0.3),
            A.RandomCrop(height=PATCH_SIZE, width=PATCH_SIZE),
        ]
    else:
        aug_list += [A.CenterCrop(height=PATCH_SIZE, width=PATCH_SIZE)]

    aug_list.append(ToTensorV2())
    return A.Compose(aug_list)

TRAIN_TRANSFORM = get_transforms(is_train=True)
VAL_TRANSFORM   = get_transforms(is_train=False)


# -----------------------------
# Dataset (per-image min-max; epoch-aware deterministic aug per-sample)
# -----------------------------
def per_image_minmax(img: np.ndarray) -> np.ndarray:
    img = img.astype(np.float32)
    vmin = float(img.min())
    vmax = float(img.max())
    if vmax > vmin:
        return (img - vmin) / (vmax - vmin)
    return np.zeros_like(img, dtype=np.float32)


class SegmentationDataset(Dataset):
    """
    Returns: (image_tensor, mask_tensor, domain_str)

    Epoch-aware deterministic schedule:
      seed = base_seed + epoch*BIG + idx
    So:
      - same epoch => all models get same augmentation for same idx
      - different epoch => augmentation changes (crop/flip/etc)
    """
    def __init__(self, items, transform=None, base_seed: int = 0):
        self.items = list(items)
        self.transform = transform
        self.base_seed = int(base_seed)
        self.epoch = 0

    def set_epoch(self, epoch: int):
        self.epoch = int(epoch)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        info = self.items[int(idx)]
        x_path = info["xct_path"]
        m_path = info["mask_path"]
        domain = info["domain"]

        img  = tiff.imread(str(x_path)).astype(np.float32)
        mask = tiff.imread(str(m_path)).astype(np.uint8)

        img = per_image_minmax(img)

        # Albumentations expects HWC
        img_ch = img[..., None] if img.ndim == 2 else img

        if self.transform is not None:
            BIG = 1_000_000
            seed = self.base_seed + (int(self.epoch) * BIG) + int(idx)

            if hasattr(self.transform, "set_random_seed"):
                try:
                    self.transform.set_random_seed(int(seed))
                except Exception:
                    pass

            with fixed_rng(seed):
                aug = self.transform(image=img_ch, mask=mask)

            img_t  = aug["image"].float()
            mask_t = aug["mask"].long()
        else:
            img_t  = torch.from_numpy(img_ch.transpose(2, 0, 1)).float()
            mask_t = torch.from_numpy(mask).long()

        return img_t, mask_t, domain


def worker_init_fn(worker_id):
    # NUM_WORKERS is always 0 in your setup; keep minimal.
    worker_seed = int(GLOBAL_SEED) + int(worker_id)
    np.random.seed(worker_seed)
    random.seed(worker_seed)


# -----------------------------
# DataLoader factory (rebuild per model for identical schedule)
# -----------------------------
def make_dataloaders(base_seed: int = None):
    """
    Build new loaders with a deterministic generator.
    Call this at the start of EACH model training to keep the same shuffle schedule.
    """
    global TRAIN_SHUFFLE_BASE_SEED
    if base_seed is None:
        base_seed = int(GLOBAL_SEED)
    TRAIN_SHUFFLE_BASE_SEED = int(base_seed)

    tr_ds = SegmentationDataset(train_items, transform=TRAIN_TRANSFORM, base_seed=int(base_seed))

    val_flat = []
    for _, dom_list in val_items_by_domain.items():
        val_flat.extend(dom_list)
    va_ds = SegmentationDataset(val_flat, transform=VAL_TRANSFORM, base_seed=int(base_seed) + 12345)

    tr_gen = torch.Generator()
    tr_gen.manual_seed(int(base_seed))

    tr_loader = DataLoader(
        tr_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=tr_gen,
    )

    # IMPORTANT: attach the real generator so epoch reseeding is guaranteed across PyTorch versions
    tr_loader._shuffle_gen = tr_gen

    va_loader = DataLoader(
        va_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
    )

    # Per-domain eval loaders (center crop, no shuffle)
    tr_eval_by_dom = {}
    for d in TRAIN_DOMAINS:
        dom_items = [it for it in train_items if it["domain"] == d]
        if len(dom_items) == 0:
            continue
        ds_d = SegmentationDataset(dom_items, transform=VAL_TRANSFORM, base_seed=int(base_seed) + 777)
        tr_eval_by_dom[d] = DataLoader(ds_d, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    val_by_dom = {}
    for d in VAL_DOMAINS:
        dom_items = val_items_by_domain.get(d, [])
        if len(dom_items) == 0:
            continue
        ds_d = SegmentationDataset(dom_items, transform=VAL_TRANSFORM, base_seed=int(base_seed) + 888)
        val_by_dom[d] = DataLoader(ds_d, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    return tr_ds, va_ds, tr_loader, va_loader, tr_eval_by_dom, val_by_dom

# leave make_dataloaders(...) defined, but DO NOT call it here
print("\n‚úÖ Cell 2 ready: splits prepared (train_items / val_items_by_domain).")
print("   NOTE: dataloaders will be built in the training cell for each model.")


def set_training_epoch(epoch: int):
    """
    Called by training loop at the start of each epoch.

    Goal (your requirement):
      - SAME epoch: all models see SAME shuffle order + SAME augmentation
      - DIFF epoch: augmentation changes deterministically; shuffle stays the same deterministic order


    Mechanism:
      - train_dataset.set_epoch(epoch) controls epoch-wise augmentation
      - reseed the DataLoader generator with base_seed (fixed) for deterministic fixed shuffle across all epochs

    """
    train_dataset.set_epoch(int(epoch))

    gen = getattr(train_loader, "_shuffle_gen", None)
    if isinstance(gen, torch.Generator):
        gen.manual_seed(int(TRAIN_SHUFFLE_BASE_SEED))


# -----------------------------
# RUN_META (stored once in Table A later)
# -----------------------------
RUN_META = {
    "experiment": EXPERIMENT_NAME,
    "experiment_short": EXPERIMENT_SHORT,
    "run_name": RUN_NAME,
    "seed": int(GLOBAL_SEED),
    "train_domains": ",".join(TRAIN_DOMAINS),
    "val_domains": ",".join(VAL_DOMAINS),
    "train_total": len(train_items),
    "val_total": int(val_total),
}
for d in ["CFRP", "FFRP_Autoclave", "FFRP_Oven"]:
    RUN_META[f"n_train_{d}"] = _count_domain(train_items, d)
    RUN_META[f"n_val_{d}"]   = len(val_items_by_domain.get(d, []))


# -----------------------------
# Console table helpers (training loop will use these)
# -----------------------------
def print_epoch_table_header():
    print("")
    print("‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê")
    print("‚îÇ Epoch ‚îÇ Phase ‚îÇ    LR    ‚îÇ LossTot  ‚îÇ LossCls  ‚îÇ LossDice ‚îÇ HardDice  ‚îÇ  HardIoU  ‚îÇ Time(s)‚îÇ")
    print("‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§")

def format_epoch_row(epoch: int, phase: str, lr: float,
                     loss_total: float, loss_cls: float, loss_dice: float,
                     hard_dice: float, hard_iou: float, time_sec: float) -> str:
    phase = (phase or "")[:5]
    return (
        "‚îÇ {e:>5d} ‚îÇ {p:<5s} ‚îÇ {lr:>8.2e} ‚îÇ {lt:>8.5f} ‚îÇ {lc:>8.5f} ‚îÇ {ld:>8.5f} ‚îÇ {dc:>9.5f} ‚îÇ {io:>9.5f} ‚îÇ {ts:>6.2f} ‚îÇ"
        .format(e=epoch, p=phase, lr=lr, lt=loss_total, lc=loss_cls, ld=loss_dice, dc=hard_dice, io=hard_iou, ts=time_sec)
    )

def print_epoch_table_footer():
    print("‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò")


print("\n‚úÖ Cell 2 ready: splits/transforms defined (train_items / val_items_by_domain / RUN_META).")
print("   NOTE: dataloaders will be built in the training cell for each model.")
print(f"   RUN_NAME = {RUN_NAME}")


In [None]:
# ============================================================
# Visualisation (train + val) ‚Äî per-domain one row:
#   col1: raw XCT (full image, min‚Äìmax)
#   col2: raw MASK (full image)  [fixed colors]
#   col3: augmented XCT (PATCH_SIZE√óPATCH_SIZE)
#   col4: augmented MASK (PATCH_SIZE√óPATCH_SIZE) [fixed colors]
#
# OPTION B INCLUDED:
#   If train_loader is not defined yet (by design in Cell 2),
#   build a LOCAL viz_loader for a quick batch sanity check,
#   without touching training globals.
#
# Determinism safety:
# - NO global seeding here (no random.seed / np.random.seed / torch.manual_seed).
# - Local RNG objects only.
# - If fixed_rng(seed) exists, we use it only as a context manager that restores RNG.
# ============================================================

import random
import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

import torch
from torch.utils.data import DataLoader

# -----------------------------
# Fixed mask colors (label -> color)
# -----------------------------
MASK_COLOR_ORDER = ["green", "blue", "red"]  # index 0,1,2
MASK_CMAP = ListedColormap(MASK_COLOR_ORDER)
MASK_NORM = BoundaryNorm([-0.5, 0.5, 1.5, 2.5], MASK_CMAP.N)

# -----------------------------
# Local unpacker (NO dependency on Cell 4.5B)
# -----------------------------
def viz_unpack_batch(b):
    """
    Supports:
      - tuple/list: (x, y) or (x, y, domain)
      - dict: {"image":..., "mask":..., "domain":...}
    """
    if isinstance(b, dict):
        return b.get("image"), b.get("mask"), b.get("domain")
    if isinstance(b, (list, tuple)):
        if len(b) == 2:
            return b[0], b[1], None
        return b[0], b[1], b[2]
    raise ValueError("Unsupported batch format")

# -----------------------------
# Helpers: Tensor/NumPy -> 2D arrays for imshow
# -----------------------------
def _as_np_img(x):
    """Return a 2D numpy array for grayscale display."""
    if torch.is_tensor(x):
        arr = x.detach().cpu().numpy()
        if arr.ndim == 3 and arr.shape[0] == 1:      # (1,H,W)
            arr = arr[0]
        elif arr.ndim == 3 and arr.shape[-1] == 1:   # (H,W,1)
            arr = arr[..., 0]
        return arr.astype(np.float32)

    arr = np.asarray(x)
    if arr.ndim == 3 and arr.shape[-1] == 1:
        arr = arr[..., 0]
    if arr.ndim == 3 and arr.shape[0] == 1:
        arr = arr[0]
    return arr.astype(np.float32)

def _as_np_mask(m):
    """Return a 2D numpy array for mask display."""
    if torch.is_tensor(m):
        return m.detach().cpu().numpy().astype(np.int32)
    return np.asarray(m).astype(np.int32)

def _ensure_img_ch(img_norm):
    """Make image HxWx1 for Albumentations."""
    return img_norm[..., None] if img_norm.ndim == 2 else img_norm

def _safe_per_image_minmax(x, eps=1e-8):
    """Fallback if per_image_minmax is not defined."""
    x = x.astype(np.float32)
    mn = float(np.min(x))
    mx = float(np.max(x))
    if (mx - mn) < eps:
        return np.zeros_like(x, dtype=np.float32)
    return (x - mn) / (mx - mn)

def _apply_transform_safely(tf, img_ch, mask_raw, seed=None):
    """
    Apply Albumentations transform without permanently changing global RNG:
    - If fixed_rng(seed) exists: use it (restores RNG after).
    - Else: apply tf directly.
    """
    if seed is not None and ("fixed_rng" in globals()) and callable(globals().get("fixed_rng")):
        with fixed_rng(int(seed)):
            return tf(image=img_ch, mask=mask_raw)
    return tf(image=img_ch, mask=mask_raw)

# -----------------------------
# Main visualisation
# -----------------------------
def show_examples(train_items, val_items_by_domain, train_tf, val_tf):
    # ---------- TRAIN ROWS ----------
    unique_train_domains = []
    for it in train_items:
        d = it.get("domain", None)
        if d is None:
            continue
        if d not in unique_train_domains:
            unique_train_domains.append(d)

    if len(unique_train_domains) > 0:
        n_rows = len(unique_train_domains)
        fig, axes = plt.subplots(n_rows, 4, figsize=(16, 4 * n_rows))
        if n_rows == 1:
            axes = np.expand_dims(axes, 0)

        rng_vis = random.Random(int(GLOBAL_SEED) + 123)  # local RNG only

        for r, d in enumerate(unique_train_domains):
            dom_items = [it for it in train_items if it.get("domain", None) == d]
            if not dom_items:
                continue
            ex = rng_vis.choice(dom_items)

            img_raw  = tiff.imread(str(ex["xct_path"])).astype(np.float32)
            mask_raw = tiff.imread(str(ex["mask_path"])).astype(np.int32)

            if "per_image_minmax" in globals() and callable(globals().get("per_image_minmax")):
                img_norm = per_image_minmax(img_raw)
            else:
                img_norm = _safe_per_image_minmax(img_raw)

            img_vis = np.clip(img_norm, 0.0, 1.0)
            img_ch  = _ensure_img_ch(img_norm)

            seed = int(GLOBAL_SEED) + 777 + r  # deterministic for viz only
            aug = _apply_transform_safely(train_tf, img_ch, mask_raw, seed=seed)

            img_aug  = _as_np_img(aug["image"])
            mask_aug = _as_np_mask(aug["mask"])

            axes[r, 0].imshow(img_vis, cmap="gray", vmin=0, vmax=1)
            axes[r, 0].set_title(f"{d} - train XCT (raw)")
            axes[r, 0].axis("off")

            axes[r, 1].imshow(mask_raw, cmap=MASK_CMAP, norm=MASK_NORM)
            axes[r, 1].set_title(f"{d} - train MASK (raw)")
            axes[r, 1].axis("off")

            axes[r, 2].imshow(np.clip(img_aug, 0.0, 1.0), cmap="gray", vmin=0, vmax=1)
            axes[r, 2].set_title(f"{d} - train XCT (aug {PATCH_SIZE}√ó{PATCH_SIZE})")
            axes[r, 2].axis("off")

            axes[r, 3].imshow(mask_aug, cmap=MASK_CMAP, norm=MASK_NORM)
            axes[r, 3].set_title(f"{d} - train MASK (aug {PATCH_SIZE}√ó{PATCH_SIZE})")
            axes[r, 3].axis("off")

            u = np.unique(mask_raw)
            if np.any((u < 0) | (u > 2)):
                print(f"‚ö†Ô∏è  [{d}] raw mask has labels outside {{0,1,2}}: {u}")

        fig.suptitle("TRAIN: one row per domain (raw & augmented)", fontsize=14)
        fig.tight_layout()
        plt.show()

    # ---------- VAL ROWS ----------
    # Safe val domains source
    if "VAL_DOMAINS" in globals():
        val_domains = list(VAL_DOMAINS)
    else:
        val_domains = list(val_items_by_domain.keys())

    unique_val_domains = [d for d in val_domains if len(val_items_by_domain.get(d, [])) > 0]

    if len(unique_val_domains) > 0:
        n_rows = len(unique_val_domains)
        fig, axes = plt.subplots(n_rows, 4, figsize=(16, 4 * n_rows))
        if n_rows == 1:
            axes = np.expand_dims(axes, 0)

        rng_vis = random.Random(int(GLOBAL_SEED) + 999)  # local RNG only

        for r, d in enumerate(unique_val_domains):
            dom_items = val_items_by_domain.get(d, [])
            if not dom_items:
                continue
            ex = rng_vis.choice(dom_items)

            img_raw  = tiff.imread(str(ex["xct_path"])).astype(np.float32)
            mask_raw = tiff.imread(str(ex["mask_path"])).astype(np.int32)

            if "per_image_minmax" in globals() and callable(globals().get("per_image_minmax")):
                img_norm = per_image_minmax(img_raw)
            else:
                img_norm = _safe_per_image_minmax(img_raw)

            img_vis = np.clip(img_norm, 0.0, 1.0)
            img_ch  = _ensure_img_ch(img_norm)

            seed = int(GLOBAL_SEED) + 888 + r
            aug = _apply_transform_safely(val_tf, img_ch, mask_raw, seed=seed)

            img_aug  = _as_np_img(aug["image"])
            mask_aug = _as_np_mask(aug["mask"])

            axes[r, 0].imshow(img_vis, cmap="gray", vmin=0, vmax=1)
            axes[r, 0].set_title(f"{d} - val XCT (raw)")
            axes[r, 0].axis("off")

            axes[r, 1].imshow(mask_raw, cmap=MASK_CMAP, norm=MASK_NORM)
            axes[r, 1].set_title(f"{d} - val MASK (raw)")
            axes[r, 1].axis("off")

            axes[r, 2].imshow(np.clip(img_aug, 0.0, 1.0), cmap="gray", vmin=0, vmax=1)
            axes[r, 2].set_title(f"{d} - val XCT ({PATCH_SIZE}√ó{PATCH_SIZE})")
            axes[r, 2].axis("off")

            axes[r, 3].imshow(mask_aug, cmap=MASK_CMAP, norm=MASK_NORM)
            axes[r, 3].set_title(f"{d} - val MASK ({PATCH_SIZE}√ó{PATCH_SIZE})")
            axes[r, 3].axis("off")

            u = np.unique(mask_raw)
            if np.any((u < 0) | (u > 2)):
                print(f"‚ö†Ô∏è  [{d}] raw mask has labels outside {{0,1,2}}: {u}")

        fig.suptitle("VAL: one row per domain (raw & transformed)", fontsize=14)
        fig.tight_layout()
        plt.show()

# --------------------------
# Visual check now
# --------------------------
show_examples(train_items, val_items_by_domain, TRAIN_TRANSFORM, VAL_TRANSFORM)

# ============================================================
# OPTIONAL: Local viz loader sanity check (Option B)
# - Only if train_loader does not exist yet.
# - Does NOT overwrite training globals.
# ============================================================
if "train_loader" not in globals():
    viz_ds = SegmentationDataset(train_items, transform=TRAIN_TRANSFORM, base_seed=int(GLOBAL_SEED))
    viz_ds.set_epoch(0)

    viz_gen = torch.Generator()
    viz_gen.manual_seed(int(GLOBAL_SEED))

    viz_loader = DataLoader(
        viz_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        worker_init_fn=worker_init_fn,
        generator=viz_gen,
    )

    b = next(iter(viz_loader))
    xb, yb, db = viz_unpack_batch(b)
    dom_ex = list(sorted(set(list(db))))[:3] if db is not None else "None"
    print("üîé [viz_loader] Train batch shape:", xb.shape, yb.shape, "| domains example:", dom_ex)

else:
    b = next(iter(train_loader))
    xb, yb, db = viz_unpack_batch(b)
    dom_ex = list(sorted(set(list(db))))[:3] if db is not None else "None"
    print("üîé [train_loader] Train batch shape:", xb.shape, yb.shape, "| domains example:", dom_ex)


In [None]:
# ============================================================
# CELL 3: Result folder + paths + Excel logger (Table A + Table B)
# ============================================================
import time
from pathlib import Path
import pandas as pd
import warnings
from openpyxl import load_workbook

# --- Safety: required globals from previous cells ---
assert "RESULTS_ROOT"     in globals(), "RESULTS_ROOT is not defined. Run Cell 1 first."
assert "EXPERIMENT_NAME"  in globals(), "EXPERIMENT_NAME is not defined. Run Cell 2 first."
assert "GLOBAL_SEED"      in globals(), "GLOBAL_SEED is not defined. Run Cell 2 first."
assert "RUN_NAME"         in globals(), "RUN_NAME is not defined. Run Cell 2 first."

warnings.simplefilter("ignore", FutureWarning)

# ------------------------------------------------------------
# Root where ALL training results will be stored
#   RESULTS_ROOT / <RUN_NAME> /
# ------------------------------------------------------------
TRAIN_ROOT = Path(RESULTS_ROOT)
TRAIN_ROOT.mkdir(parents=True, exist_ok=True)



# For determinism debugging: avoid mixing old ckpts/logs with new runs
FORCE_FRESH_RUN_DIR = False  # final paper False

if FORCE_FRESH_RUN_DIR:
    run_tag = time.strftime("%Y%m%d_%H%M%S")
    RUN_DIRNAME = f"{RUN_NAME}__{run_tag}"
else:
    RUN_DIRNAME = RUN_NAME

RUN_ROOT = TRAIN_ROOT / RUN_DIRNAME
RUN_ROOT.mkdir(parents=True, exist_ok=True)

# ------------------------------------------------------------
# Model key ‚Üí prefix mapping (used for filenames)
# ------------------------------------------------------------
MODEL_KEY_TO_PREFIX = {
         "unet": "UNet",
    "unet3plus": "UNet3Plus",
    "sslite_v1": "SSLite_V1_EnhSemiFull",
    "sslite_v2": "SSLite_V2_4Branch",
    "sslite_v3": "SSLite_V3_Baseline",
}
SUPPORTED_MODEL_KEYS = list(MODEL_KEY_TO_PREFIX.keys())

# Optional aliases accepted (to prevent path/log saving crashes)
MODEL_KEY_ALIASES = {
    "unet3+": "unet3plus",
    "unet3plus": "unet3plus",
    "ssunet3pluslite_v1": "sslite_v1",
    "ssunet3pluslite-v1": "sslite_v1",
    "sslite_v1": "sslite_v1",
    "sslite_v2": "sslite_v2",
    "sslite_v3": "sslite_v3",
    "unet": "unet",
}

print("\n‚úÖ Supported models for saving:")
for k, p in MODEL_KEY_TO_PREFIX.items():
    print(f"   ‚Ä¢ key='{k}'  ‚Üí prefix='{p}'")

# ------------------------------------------------------------
# Paths for a given model
# ------------------------------------------------------------
def get_model_paths(model_name: str):
    """
    Return all important paths for a given model inside:
      RESULTS_ROOT / <RUN_NAME> /

    model_name (key, case-insensitive) ‚àà {
        "unet", "unet3plus", "sslite_v1", "sslite_v2", "sslite_v3"
    }
    Aliases are also accepted and mapped to canonical keys.
    """
    raw_key = str(model_name).lower()
    key = MODEL_KEY_ALIASES.get(raw_key, raw_key)

    if key not in MODEL_KEY_TO_PREFIX:
        raise ValueError(
            f"Unknown model_name='{model_name}'. Supported keys: {SUPPORTED_MODEL_KEYS}"
        )

    prefix = MODEL_KEY_TO_PREFIX[key]
    base = RUN_ROOT

    return {
        "base_dir":  base,
        "best_ckpt": base / f"{prefix}_best_model.pth",
        "last_ckpt": base / f"{prefix}_last_model.pth",
        "log_csv":   base / f"{prefix}_log.csv",
        "log_xlsx":  base / f"{prefix}_log.xlsx",
    }

# ------------------------------------------------------------
# Hard Dice/IoU columns (val-only in training cell, but always present here)
# NOTE: OUT_CLASSES assumed 3 for explicit columns as requested.
# ------------------------------------------------------------
HARD_KEYS = [
    "hard_dice_macro",
    "hard_iou_macro",
    "hard_dice_c0", "hard_dice_c1", "hard_dice_c2",
    "hard_iou_c0",  "hard_iou_c1",  "hard_iou_c2",
]

# ------------------------------------------------------------
# Table B: epoch-wise result columns
# Notes:
#  - Soft dice/IoU (dice_d0 / iou_d0) REMOVED as requested.
#  - hard_* columns are expected for VAL (training cell may leave train rows empty).
# ------------------------------------------------------------
TABLE_B_COLUMNS = [
    "phase",          # "train" / "val"
    "eval_domain",    # typically "ALL"
    "epoch",
    "model",          # model key (e.g. "sslite_v3")
    "lr",
    "time_sec",
    "loss_total",
    "loss_ce",
    "loss_dice",

    "dice_domain",    # string summary (optional)
    "iou_domain",     # string summary (optional)
    "dice_classes",   # string summary (optional)
    "iou_classes",    # string summary (optional)

] + HARD_KEYS + [

    # Run/meta toggles (filled from globals if row doesn't provide them)
    "ce_mode",        # "ce" / "focal"
    "abl_fusion",     # "native" / "unet_skip"
    "use_bn",         # bool (NEW: for BN ablation logging)
    "use_gn",         # bool
    "base_ch",        # int (e.g. 32/64)
    "ds_on",          # bool (deep supervision on/off)
]

# ------------------------------------------------------------
# Internal: build Table A row (RUN_META + model + LR schedule summary)
# ------------------------------------------------------------
def _build_tableA_row(model_name: str):
    """
    Build a single dict for Table A (one wide row).
    Uses RUN_META if available, plus model & LR schedule info.
    """
    rowA = {}

    # RUN_META: experiment, run_name, seed, domain counts, etc. (from Cell 2)
    if "RUN_META" in globals():
        rowA.update(RUN_META)

    # Ensure some basic fields exist
    rowA.setdefault("experiment", EXPERIMENT_NAME)
    rowA.setdefault("run_name",   RUN_NAME)          # logical name (e.g., C1_seed42)
    rowA.setdefault("seed",       int(GLOBAL_SEED))

    # physical folder trace
    rowA["run_dir"]      = str(RUN_ROOT.name)        # e.g., C1_seed42__20260112_103550
    rowA["run_dir_full"] = str(RUN_ROOT)             # full path
    rowA["results_root"] = str(TRAIN_ROOT)           # full path/root


    raw_key = str(model_name).lower()
    key = MODEL_KEY_ALIASES.get(raw_key, raw_key)

    rowA["model_key"]    = key
    rowA["model_prefix"] = MODEL_KEY_TO_PREFIX.get(key, key)

    # LR schedule summary (if those globals exist ‚Äì from later cells)
    init_lr         = globals().get("INIT_LR", None)
    decay_factor    = globals().get("LR_DECAY_FACTOR", None)
    lr_patience     = globals().get("LR_PATIENCE_EPOCHS", None)
    min_lr          = globals().get("MIN_LR", None)
    early_stop_pat  = globals().get("EARLY_STOP_PATIENCE", None)

    lr_parts = []
    if init_lr is not None:
        lr_parts.append(f"init={init_lr}")
    if decay_factor is not None and lr_patience is not None:
        lr_parts.append(f"/{decay_factor} after {lr_patience} no-improve")
    if min_lr is not None:
        lr_parts.append(f"min={min_lr}")
    if early_stop_pat is not None:
        lr_parts.append(f"early_stop={early_stop_pat}")

    rowA["lr_schedule"] = "; ".join(lr_parts) if lr_parts else ""
    return rowA

# ------------------------------------------------------------
# Internal: ensure Excel + CSV have Table A + Table B header
# ------------------------------------------------------------
def _ensure_log_tables_initialized(model_name: str):
    """
    Create <Prefix>_log.xlsx and <Prefix>_log.csv with:
      - Table A: 1 row at the top
      - Blank row
      - Table B: header row for epoch-wise metrics
    If files already exist, do nothing (we append only).
    """
    paths     = get_model_paths(model_name)
    csv_path  = paths["log_csv"]
    xlsx_path = paths["log_xlsx"]

    # If both files exist, assume structure is already created.
    if xlsx_path.exists() and csv_path.exists():
        return


    rowA = _build_tableA_row(model_name)
    dfA  = pd.DataFrame([rowA])

    # Excel: Table A + empty row + Table B header
    with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
        dfA.to_excel(writer, index=False, sheet_name="Log", startrow=0)
        start_row_B = len(dfA) + 2
        pd.DataFrame(columns=TABLE_B_COLUMNS).to_excel(
            writer, index=False, sheet_name="Log", startrow=start_row_B
        )

    # CSV: same structure
    with open(csv_path, "w", encoding="utf-8") as f:
        f.write(",".join(dfA.columns) + "\n")
        f.write(",".join(str(rowA.get(col, "")) for col in dfA.columns) + "\n")
        f.write("\n")
        f.write(",".join(TABLE_B_COLUMNS) + "\n")

# ------------------------------------------------------------
# Public: append one epoch row to Table B (both CSV and XLSX)
# ------------------------------------------------------------
def append_log_row(model_name: str, row: dict):
    """
    Append a single row (dict) to Table B of:
      <Prefix>_log.csv / <Prefix>_log.xlsx

    Notes:
      - Table A is written once (RUN_META, model, LR schedule).
      - Table B rows are appended under the header.
      - Soft dice/IoU columns are REMOVED.
      - hard_* columns are expected for VAL (train rows can be empty strings).

    Expected keys in 'row' (Table B):
      phase, eval_domain, epoch, model, lr, time_sec,
      loss_total, loss_ce, loss_dice,
      dice_domain (str), iou_domain (str),
      dice_classes (str), iou_classes (str),
      hard_dice_macro, hard_iou_macro, hard_dice_c*, hard_iou_c*,
      ce_mode, abl_fusion, use_bn, use_gn, base_ch, ds_on
    """
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", FutureWarning)

        _ensure_log_tables_initialized(model_name)

        paths     = get_model_paths(model_name)
        csv_path  = paths["log_csv"]
        xlsx_path = paths["log_xlsx"]

        # Fill hard keys (optional) if missing
        for k in HARD_KEYS:
            row.setdefault(k, "")

        # Fill meta from globals if not explicitly provided
        row.setdefault("ce_mode",    globals().get("CE_MODE", None))
        row.setdefault("abl_fusion", globals().get("ABL_FUSION", None))
        row.setdefault("use_bn",     globals().get("USE_BN", None))
        row.setdefault("use_gn",     globals().get("USE_GN", None))
        row.setdefault("base_ch",    globals().get("BASE_CH", None))

        raw_key = str(model_name).lower()
        key = MODEL_KEY_ALIASES.get(raw_key, raw_key)
        if key == "unet3plus":
            _ds_default = globals().get("UNET3PLUS_DEEP_SUPERVISION", None)
        elif key.startswith("sslite_"):
            _ds_default = globals().get("SSLITE_DEEP_SUPERVISION", None)
        else:
            _ds_default = None
        row.setdefault("ds_on", _ds_default)

        # Build rowB with consistent column order
        rowB = {col: row.get(col, None) for col in TABLE_B_COLUMNS}

        # Round floats to 5 decimals where appropriate
        float_cols = [
            "lr", "time_sec",
            "loss_total", "loss_ce", "loss_dice",

            # hard metrics
            "hard_dice_macro", "hard_iou_macro",
            "hard_dice_c0", "hard_dice_c1", "hard_dice_c2",
            "hard_iou_c0",  "hard_iou_c1",  "hard_iou_c2",
        ]

        for col in float_cols:
            v = rowB.get(col, None)
            if v is not None and v != "":
                try:
                    rowB[col] = float(f"{float(v):.5f}")
                except Exception:
                    pass

        # Append to CSV (Table B only)
        pd.DataFrame([rowB]).to_csv(csv_path, mode="a", header=False, index=False)

        # Append to XLSX
        wb = load_workbook(xlsx_path)
        ws = wb["Log"]
        ws.append([rowB.get(col) for col in TABLE_B_COLUMNS])
        wb.save(xlsx_path)

# ------------------------------------------------------------
# One-time, plain output structure print (easy to verify paths)
# ------------------------------------------------------------
print("\nüìå Results folders")
print("   RESULTS_ROOT:", TRAIN_ROOT)
print("   RUN_ROOT    :", RUN_ROOT)

root_name = TRAIN_ROOT.name
print("\nExpected files (per model) will appear like this:")
print(f"{root_name}/")
print(f"‚îî‚îÄ‚îÄ {RUN_ROOT.name}/")   # <-- FIX: show actual physical folder
for key, prefix in MODEL_KEY_TO_PREFIX.items():
    print(f"    ‚îú‚îÄ‚îÄ {prefix}_best_model.pth")
    print(f"    ‚îú‚îÄ‚îÄ {prefix}_last_model.pth")
    print(f"    ‚îú‚îÄ‚îÄ {prefix}_log.csv")
    print(f"    ‚îî‚îÄ‚îÄ {prefix}_log.xlsx")
    break

print("\n‚úÖ Cell 3 ready: get_model_paths(...) and append_log_row(...) are available.")


In [None]:
# ============================================================
# CELL 4: Model architectures (official-style + SS-UNet3+ Lite)
#   - UNet          (Ronneberger et al. 2015)  [NO BN/GN]
#   - UNet3Plus     (UNet 3+, Huang et al. 2020; full-scale fusion)
#   - SSUNet3PlusLite_V1  (Enhanced Semi-Full, 5 inputs/level)
#   - SSUNet3PlusLite_V2  (Updated Semi-Full, 4 inputs/level)
#   - SSUNet3PlusLite_V3  (Baseline, current setup)
#
# Upsampling control (GLOBAL, for ALL models):
#   - Set UPSAMPLE_POLICY in Cell 4.5A:
#       "nearest" | "bilinear"
#
# NOTE (fusion ablation for SS-Lite):
#   - ABL_FUSION = "native" | "unet_skip"
# ============================================================

!pip -q install thop
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------------------------------------
# Unified resize policy for ALL models (nearest | bilinear)
# ------------------------------------------------------------
if "UPSAMPLE_POLICY" not in globals():
    UPSAMPLE_POLICY = "nearest"  # safe default if Cell 4.5A not run


def _get_policy() -> str:
    return str(globals().get("UPSAMPLE_POLICY", "nearest")).lower()


def _interp(x: torch.Tensor, size, mode: str) -> torch.Tensor:
    mode = mode.lower()
    if mode in ("nearest", "area", "nearest-exact"):
        return F.interpolate(x, size=size, mode=mode)
    return F.interpolate(x, size=size, mode=mode, align_corners=False)

def _resize_like(x: torch.Tensor, ref: torch.Tensor, mode: str = None) -> torch.Tensor:
    """
    Resize x to match ref spatial size using GLOBAL policy:
      - nearest  : nearest interpolation
      - bilinear : bilinear interpolation
    """
    th, tw = ref.shape[-2:]
    h, w = x.shape[-2:]
    if (h, w) == (th, tw):
        return x

    policy = (mode or _get_policy()).lower()
    if policy not in ("nearest", "bilinear", "area", "nearest-exact"):
        policy = "nearest"

    return _interp(x, size=(th, tw), mode=policy)


def _resize_to(x: torch.Tensor, ref: torch.Tensor, mode: str = None) -> torch.Tensor:
    return _resize_like(x, ref, mode=mode)

# ------------------------------------------------------------
# Global config: deep supervision toggle for UNet3Plus
# ------------------------------------------------------------
if "UNET3PLUS_DEEP_SUPERVISION" not in globals():
    UNET3PLUS_DEEP_SUPERVISION = True

# ------------------------------------------------------------
# Normalization factory (ablation-friendly)
# ------------------------------------------------------------
def make_norm(num_channels: int) -> nn.Module:
    if "NORM_LAYER_2D" in globals():
        layer = globals()["NORM_LAYER_2D"]
        if callable(layer):
            return layer(num_channels)

    if "NORM_LAYER" in globals():
        layer = globals()["NORM_LAYER"]
        if callable(layer):
            return layer(num_channels)

    use_gn = bool(globals().get("USE_GN", False))
    use_bn = bool(globals().get("USE_BN", True))

    if (not use_gn) and (not use_bn):
        return nn.Identity()

    if use_gn:
        groups_cfg = int(globals().get("GN_GROUPS", 8))
        groups = max(1, min(groups_cfg, num_channels))
        while groups > 1 and (num_channels % groups != 0):
            groups -= 1
        if num_channels % groups == 0 and groups > 0:
            return nn.GroupNorm(num_groups=groups, num_channels=num_channels)
        return nn.BatchNorm2d(num_channels) if use_bn else nn.Identity()

    return nn.BatchNorm2d(num_channels) if use_bn else nn.Identity()

# ------------------------------------------------------------
# Basic building blocks (UNet3Plus, SS-Lite use these)
# ------------------------------------------------------------
class DoubleConv(nn.Module):
    """(Conv ‚Üí Norm ‚Üí ReLU) √ó 2"""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            make_norm(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            make_norm(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)

class OutConv(nn.Module):
    """1√ó1 conv to logits"""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        return self.conv(x)

# ------------------------------------------------------------
# UNet blocks (NO BN / NO GN) ‚Äî official-style
# ------------------------------------------------------------
class DoubleConv_UNet(nn.Module):
    """(Conv -> ReLU) x2 (no normalization)."""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)

class Down_UNet(nn.Module):
    """MaxPool(2) -> DoubleConv_UNet"""
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv = DoubleConv_UNet(in_ch, out_ch)

    def forward(self, x):
        return self.conv(self.pool(x))

class Up_UNet(nn.Module):
    """
    UNet up block: resize (GLOBAL policy) -> concat -> DoubleConv_UNet.
    Controlled ONLY by UPSAMPLE_POLICY.
    """
    def __init__(self, in_ch: int, skip_ch: int, out_ch: int):
        super().__init__()
        self.conv = DoubleConv_UNet(in_ch + skip_ch, out_ch)

    def forward(self, x, skip):
        x = _resize_like(x, skip)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

# ------------------------------------------------------------
# 1) UNet (Ronneberger et al. 2015) ‚Äî NO BN / NO GN
# ------------------------------------------------------------
class UNet(nn.Module):
    """
    Original U-Net (Ronneberger et al. 2015), depth=5.
    Upsampling controlled ONLY by UPSAMPLE_POLICY.
    """
    def __init__(self, in_channels: int = 1, num_classes: int = 3, base_ch: int = 32, **_ignored):
        super().__init__()
        c1, c2, c3, c4, c5 = base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16

        self.inc   = DoubleConv_UNet(in_channels, c1)
        self.down1 = Down_UNet(c1, c2)
        self.down2 = Down_UNet(c2, c3)
        self.down3 = Down_UNet(c3, c4)
        self.down4 = Down_UNet(c4, c5)

        self.up1 = Up_UNet(c5, c4, c4)
        self.up2 = Up_UNet(c4, c3, c3)
        self.up3 = Up_UNet(c3, c2, c2)
        self.up4 = Up_UNet(c2, c1, c1)

        self.outc = OutConv(c1, num_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x,  x3)
        x = self.up3(x,  x2)
        x = self.up4(x,  x1)
        return self.outc(x)

# ------------------------------------------------------------
# 2) UNet3Plus (full-scale multi-scale fusion, official style)
# ------------------------------------------------------------
class UNet3Plus(nn.Module):
    """
    UNet 3+ with full-scale skip connections and optional deep supervision.
    Resizing honors UPSAMPLE_POLICY via _resize_like/_resize_to.
    """
    def __init__(self, in_channels: int = 1, num_classes: int = 3, base_ch: int = 32, deep_supervision=None):
        super().__init__()
        if deep_supervision is None:
            deep_supervision = bool(globals().get("UNET3PLUS_DEEP_SUPERVISION", True))
        self.deep_supervision = deep_supervision

        f1, f2, f3, f4, f5 = base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16
        self.enc0 = DoubleConv(in_channels, f1); self.pool0 = nn.MaxPool2d(2)
        self.enc1 = DoubleConv(f1, f2);          self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(f2, f3);          self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(f3, f4);          self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(f4, f5)

        cat_ch = base_ch
        dec_ch = base_ch * 5
        self.proj_e0 = nn.Conv2d(f1, cat_ch, 1)
        self.proj_e1 = nn.Conv2d(f2, cat_ch, 1)
        self.proj_e2 = nn.Conv2d(f3, cat_ch, 1)
        self.proj_e3 = nn.Conv2d(f4, cat_ch, 1)
        self.proj_e4 = nn.Conv2d(f5, cat_ch, 1)

        self.proj_d4 = nn.Conv2d(dec_ch, cat_ch, 1)
        self.proj_d3 = nn.Conv2d(dec_ch, cat_ch, 1)
        self.proj_d2 = nn.Conv2d(dec_ch, cat_ch, 1)
        self.proj_d1 = nn.Conv2d(dec_ch, cat_ch, 1)

        self.conv_d4 = DoubleConv(cat_ch*5, dec_ch)
        self.conv_d3 = DoubleConv(cat_ch*6, dec_ch)
        self.conv_d2 = DoubleConv(cat_ch*7, dec_ch)
        self.conv_d1 = DoubleConv(cat_ch*8, dec_ch)
        self.conv_d0 = DoubleConv(cat_ch*9, dec_ch)

        if self.deep_supervision:
            self.out_d4 = nn.Conv2d(dec_ch, num_classes, 1)
            self.out_d3 = nn.Conv2d(dec_ch, num_classes, 1)
            self.out_d2 = nn.Conv2d(dec_ch, num_classes, 1)
            self.out_d1 = nn.Conv2d(dec_ch, num_classes, 1)
            self.out_d0 = nn.Conv2d(dec_ch, num_classes, 1)
        else:
            self.out_d0 = nn.Conv2d(dec_ch, num_classes, 1)

    @staticmethod
    def _proj_resize(x, proj_layer, ref):
        x = proj_layer(x)
        x = _resize_like(x, ref)
        return x

    def forward(self, x):
        e0 = self.enc0(x)
        e1 = self.enc1(self.pool0(e0))
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        ref4 = e4
        d4 = self.conv_d4(torch.cat([
            self._proj_resize(e0, self.proj_e0, ref4),
            self._proj_resize(e1, self.proj_e1, ref4),
            self._proj_resize(e2, self.proj_e2, ref4),
            self._proj_resize(e3, self.proj_e3, ref4),
            self._proj_resize(e4, self.proj_e4, ref4),
        ], dim=1))

        ref3 = e3
        d3 = self.conv_d3(torch.cat([
            self._proj_resize(e0, self.proj_e0, ref3),
            self._proj_resize(e1, self.proj_e1, ref3),
            self._proj_resize(e2, self.proj_e2, ref3),
            self._proj_resize(e3, self.proj_e3, ref3),
            self._proj_resize(e4, self.proj_e4, ref3),
            self._proj_resize(d4, self.proj_d4, ref3),
        ], dim=1))

        ref2 = e2
        d2 = self.conv_d2(torch.cat([
            self._proj_resize(e0, self.proj_e0, ref2),
            self._proj_resize(e1, self.proj_e1, ref2),
            self._proj_resize(e2, self.proj_e2, ref2),
            self._proj_resize(e3, self.proj_e3, ref2),
            self._proj_resize(e4, self.proj_e4, ref2),
            self._proj_resize(d3, self.proj_d3, ref2),
            self._proj_resize(d4, self.proj_d4, ref2),
        ], dim=1))

        ref1 = e1
        d1 = self.conv_d1(torch.cat([
            self._proj_resize(e0, self.proj_e0, ref1),
            self._proj_resize(e1, self.proj_e1, ref1),
            self._proj_resize(e2, self.proj_e2, ref1),
            self._proj_resize(e3, self.proj_e3, ref1),
            self._proj_resize(e4, self.proj_e4, ref1),
            self._proj_resize(d2, self.proj_d2, ref1),
            self._proj_resize(d3, self.proj_d3, ref1),
            self._proj_resize(d4, self.proj_d4, ref1),
        ], dim=1))

        ref0 = e0
        d0 = self.conv_d0(torch.cat([
            self._proj_resize(e0, self.proj_e0, ref0),
            self._proj_resize(e1, self.proj_e1, ref0),
            self._proj_resize(e2, self.proj_e2, ref0),
            self._proj_resize(e3, self.proj_e3, ref0),
            self._proj_resize(e4, self.proj_e4, ref0),
            self._proj_resize(d1, self.proj_d1, ref0),
            self._proj_resize(d2, self.proj_d2, ref0),
            self._proj_resize(d3, self.proj_d3, ref0),
            self._proj_resize(d4, self.proj_d4, ref0),
        ], dim=1))

        if self.deep_supervision:
            y0 = self.out_d0(d0)
            y1 = self.out_d1(d1)
            y2 = self.out_d2(d2)
            y3 = self.out_d3(d3)
            y4 = self.out_d4(d4)

            # resize to input size using GLOBAL policy
            ref_full = x
            y1 = _resize_like(y1, ref_full)
            y2 = _resize_like(y2, ref_full)
            y3 = _resize_like(y3, ref_full)
            y4 = _resize_like(y4, ref_full)
            return [y0, y1, y2, y3, y4]

        return self.out_d0(d0)

# ------------------------------------------------------------
# 3) SS-UNet3+ Lite FAST (decoder variants V1‚ÄìV3)
# ------------------------------------------------------------
class LiteBlock(nn.Module):
    """1x1 -> depthwise 3x3 -> 1x1 (with make_norm)"""
    def __init__(self, in_ch, out_ch, expansion=1.0):
        super().__init__()
        hidden = max(1, int(in_ch * expansion))
        self.pw1 = nn.Conv2d(in_ch, hidden, 1, bias=False); self.bn1 = make_norm(hidden)
        self.dw  = nn.Conv2d(hidden, hidden, 3, padding=1, groups=hidden, bias=False); self.bn2 = make_norm(hidden)
        self.pw2 = nn.Conv2d(hidden, out_ch, 1, bias=False); self.bn3 = make_norm(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.act(self.bn1(self.pw1(x)))
        x = self.act(self.bn2(self.dw(x)))
        x = self.act(self.bn3(self.pw2(x)))
        return x

class ScaleSelectiveFusion(nn.Module):
    """
    Projection + Norm + ReLU per branch, then mean().
    Resizing uses GLOBAL policy via _resize_like.
    """
    def __init__(self, in_ch_list, cat_ch):
        super().__init__()
        self.proj = nn.ModuleList([nn.Conv2d(c, cat_ch, 1, bias=False) for c in in_ch_list])
        self.bn   = nn.ModuleList([make_norm(cat_ch) for _ in in_ch_list])

    def forward(self, features):
        proj_feats = []
        ref = None
        for f, conv, bn in zip(features, self.proj, self.bn):
            p = F.relu(bn(conv(f)))
            if ref is None:
                ref = p
            else:
                p = _resize_like(p, ref)
            proj_feats.append(p)
        return torch.stack(proj_feats, dim=0).mean(dim=0)

class SSUNet3PlusLiteBase(nn.Module):
    def __init__(self, in_channels, num_classes, base_ch=32, deep_supervision=True,
                 ds_w0=0.5, ds_w1=0.15, ds_w2=0.15, ds_w3=0.10, ds_w4=0.10):
        super().__init__()
        self.deep_supervision = deep_supervision

        w = [float(ds_w0), float(ds_w1), float(ds_w2), float(ds_w3), float(ds_w4)]
        s = sum(w)
        if s <= 0:
            w = [0.5, 0.15, 0.15, 0.10, 0.10]; s = sum(w)
        self.ds_w0, self.ds_w1, self.ds_w2, self.ds_w3, self.ds_w4 = [wi/s for wi in w]

        f1, f2, f3, f4, f5 = base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16
        self.f1, self.f2, self.f3, self.f4, self.f5 = f1, f2, f3, f4, f5

        self.enc0 = LiteBlock(in_channels, f1, 1.0); self.pool0 = nn.MaxPool2d(2)
        self.enc1 = LiteBlock(f1, f2, 1.0);         self.pool1 = nn.MaxPool2d(2)
        self.enc2 = LiteBlock(f2, f3, 1.0);         self.pool2 = nn.MaxPool2d(2)
        self.enc3 = LiteBlock(f3, f4, 1.0);         self.pool3 = nn.MaxPool2d(2)
        self.enc4 = LiteBlock(f4, f5, 1.0)

        self.cat_ch = base_ch
        self.dec_ch = base_ch * 4

        self.out0 = nn.Conv2d(self.dec_ch, num_classes, 1)
        if self.deep_supervision:
            self.out1 = nn.Conv2d(self.dec_ch, num_classes, 1)
            self.out2 = nn.Conv2d(self.dec_ch, num_classes, 1)
            self.out3 = nn.Conv2d(self.dec_ch, num_classes, 1)
            self.out4 = nn.Conv2d(self.dec_ch, num_classes, 1)

    def _encode(self, x):
        e0 = self.enc0(x)
        e1 = self.enc1(self.pool0(e0))
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        return e0, e1, e2, e3, e4

    def _fuse_logits(self, x, d0, d1, d2, d3, d4):
        """
        If DS OFF -> return y0 tensor.
        If DS ON:
          - SSLITE_DS_OUTPUT="fused" -> return fused tensor
          - SSLITE_DS_OUTPUT="list"  -> return [y0,y1,y2,y3,y4]
        """
        y0 = self.out0(d0)
        if not self.deep_supervision:
            return y0

        ref = d0
        y1 = self.out1(_resize_to(d1, ref))
        y2 = self.out2(_resize_to(d2, ref))
        y3 = self.out3(_resize_to(d3, ref))
        y4 = self.out4(_resize_to(d4, ref))

        ds_mode = str(globals().get("SSLITE_DS_OUTPUT", "fused")).lower()
        if ds_mode == "list":
            return [y0, y1, y2, y3, y4]

        return (self.ds_w0*y0 + self.ds_w1*y1 + self.ds_w2*y2 + self.ds_w3*y3 + self.ds_w4*y4)

# ---------------- Variant V1: Enhanced Semi-Full (5 inputs each) ------------
class SSUNet3PlusLite_V1(SSUNet3PlusLiteBase):
    def __init__(self, in_channels, num_classes, base_ch=32, deep_supervision=True,
                 ds_w0=0.5, ds_w1=0.15, ds_w2=0.15, ds_w3=0.10, ds_w4=0.10):
        super().__init__(in_channels, num_classes, base_ch, deep_supervision, ds_w0, ds_w1, ds_w2, ds_w3, ds_w4)
        f1,f2,f3,f4,f5 = self.f1,self.f2,self.f3,self.f4,self.f5
        cat_ch, dec_ch = self.cat_ch, self.dec_ch
        self.fuse4 = ScaleSelectiveFusion([f1,f2,f3,f4,f5], cat_ch);             self.dec4 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse3 = ScaleSelectiveFusion([f2,f3,f4,f5,dec_ch], cat_ch);         self.dec3 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse2 = ScaleSelectiveFusion([f1,f2,f3,f4,dec_ch], cat_ch);         self.dec2 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse1 = ScaleSelectiveFusion([f1,f2,f3,dec_ch,dec_ch], cat_ch);     self.dec1 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse0 = ScaleSelectiveFusion([f1,f2,dec_ch,dec_ch,dec_ch], cat_ch); self.dec0 = LiteBlock(cat_ch, dec_ch, 1.0)

    def forward(self, x):
        e0,e1,e2,e3,e4 = self._encode(x)
        use_unet_skip = (str(globals().get("ABL_FUSION","native")).lower() == "unet_skip")
        def _keep_only(lst, keep_idx):
            return lst if not use_unet_skip else [t if i in keep_idx else torch.zeros_like(t) for i,t in enumerate(lst)]

        ref4 = e4
        d4 = self.dec4(self.fuse4(_keep_only([_resize_to(e0,ref4),_resize_to(e1,ref4),_resize_to(e2,ref4),_resize_to(e3,ref4),_resize_to(e4,ref4)], [4])))

        ref3 = e3
        d3 = self.dec3(self.fuse3(_keep_only([_resize_to(e1,ref3),_resize_to(e2,ref3),_resize_to(e3,ref3),_resize_to(e4,ref3),_resize_to(d4,ref3)], [2,4])))

        ref2 = e2
        d2 = self.dec2(self.fuse2(_keep_only([_resize_to(e0,ref2),_resize_to(e1,ref2),_resize_to(e2,ref2),_resize_to(e3,ref2),_resize_to(d3,ref2)], [2,4])))

        ref1 = e1
        d1 = self.dec1(self.fuse1(_keep_only([_resize_to(e0,ref1),_resize_to(e1,ref1),_resize_to(e2,ref1),_resize_to(d2,ref1),_resize_to(d3,ref1)], [1,3])))

        ref0 = e0
        d0 = self.dec0(self.fuse0(_keep_only([_resize_to(e0,ref0),_resize_to(e1,ref0),_resize_to(d1,ref0),_resize_to(d2,ref0),_resize_to(d3,ref0)], [0,2])))

        return self._fuse_logits(x, d0, d1, d2, d3, d4)

# ---------------- Variant V2: Updated Semi-Full (4 inputs each) ------------
class SSUNet3PlusLite_V2(SSUNet3PlusLiteBase):
    def __init__(self, in_channels, num_classes, base_ch=32, deep_supervision=True,
                 ds_w0=0.5, ds_w1=0.15, ds_w2=0.15, ds_w3=0.10, ds_w4=0.10):
        super().__init__(in_channels, num_classes, base_ch, deep_supervision, ds_w0, ds_w1, ds_w2, ds_w3, ds_w4)
        f1,f2,f3,f4,f5 = self.f1,self.f2,self.f3,self.f4,self.f5
        cat_ch, dec_ch = self.cat_ch, self.dec_ch
        self.fuse4 = ScaleSelectiveFusion([f2,f3,f4,f5], cat_ch);         self.dec4 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse3 = ScaleSelectiveFusion([f2,f3,f4,dec_ch], cat_ch);     self.dec3 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse2 = ScaleSelectiveFusion([f1,f2,f3,dec_ch], cat_ch);     self.dec2 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse1 = ScaleSelectiveFusion([f1,f2,dec_ch,dec_ch], cat_ch); self.dec1 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse0 = ScaleSelectiveFusion([f1,f2,dec_ch,dec_ch], cat_ch); self.dec0 = LiteBlock(cat_ch, dec_ch, 1.0)

    def forward(self, x):
        e0,e1,e2,e3,e4 = self._encode(x)
        use_unet_skip = (str(globals().get("ABL_FUSION","native")).lower() == "unet_skip")
        def _keep_only(lst, keep_idx):
            return lst if not use_unet_skip else [t if i in keep_idx else torch.zeros_like(t) for i,t in enumerate(lst)]

        ref4 = e4
        d4 = self.dec4(self.fuse4(_keep_only([_resize_to(e1,ref4),_resize_to(e2,ref4),_resize_to(e3,ref4),_resize_to(e4,ref4)], [3])))

        ref3 = e3
        d3 = self.dec3(self.fuse3(_keep_only([_resize_to(e1,ref3),_resize_to(e2,ref3),_resize_to(e3,ref3),_resize_to(d4,ref3)], [2,3])))

        ref2 = e2
        d2 = self.dec2(self.fuse2(_keep_only([_resize_to(e0,ref2),_resize_to(e1,ref2),_resize_to(e2,ref2),_resize_to(d3,ref2)], [2,3])))

        ref1 = e1
        d1 = self.dec1(self.fuse1(_keep_only([_resize_to(e0,ref1),_resize_to(e1,ref1),_resize_to(d2,ref1),_resize_to(d3,ref1)], [1,2])))

        ref0 = e0
        d0 = self.dec0(self.fuse0(_keep_only([_resize_to(e0,ref0),_resize_to(e1,ref0),_resize_to(d1,ref0),_resize_to(d2,ref0)], [0,2])))

        return self._fuse_logits(x, d0, d1, d2, d3, d4)

# ---------------- Variant V3: Baseline (current setup) ----------------------
class SSUNet3PlusLite_V3(SSUNet3PlusLiteBase):
    def __init__(self, in_channels, num_classes, base_ch=32, deep_supervision=True,
                 ds_w0=0.5, ds_w1=0.15, ds_w2=0.15, ds_w3=0.10, ds_w4=0.10):
        super().__init__(in_channels, num_classes, base_ch, deep_supervision, ds_w0, ds_w1, ds_w2, ds_w3, ds_w4)
        f1,f2,f3,f4,f5 = self.f1,self.f2,self.f3,self.f4,self.f5
        cat_ch, dec_ch = self.cat_ch, self.dec_ch
        self.fuse4 = ScaleSelectiveFusion([f3,f4,f5], cat_ch);             self.dec4 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse3 = ScaleSelectiveFusion([f2,f3,f4,dec_ch], cat_ch);     self.dec3 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse2 = ScaleSelectiveFusion([f1,f2,f3,dec_ch], cat_ch);     self.dec2 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse1 = ScaleSelectiveFusion([f1,f2,dec_ch], cat_ch);        self.dec1 = LiteBlock(cat_ch, dec_ch, 1.0)
        self.fuse0 = ScaleSelectiveFusion([f1,f2,dec_ch,dec_ch], cat_ch); self.dec0 = LiteBlock(cat_ch, dec_ch, 1.0)

    def forward(self, x):
        e0,e1,e2,e3,e4 = self._encode(x)
        use_unet_skip = (str(globals().get("ABL_FUSION","native")).lower() == "unet_skip")
        def _keep_only(lst, keep_idx):
            return lst if not use_unet_skip else [t if i in keep_idx else torch.zeros_like(t) for i,t in enumerate(lst)]

        ref4 = e4
        d4 = self.dec4(self.fuse4(_keep_only([_resize_to(e2,ref4),_resize_to(e3,ref4),_resize_to(e4,ref4)], [2])))

        ref3 = e3
        d3 = self.dec3(self.fuse3(_keep_only([_resize_to(e1,ref3),_resize_to(e2,ref3),_resize_to(e3,ref3),_resize_to(d4,ref3)], [2,3])))

        ref2 = e2
        d2 = self.dec2(self.fuse2(_keep_only([_resize_to(e0,ref2),_resize_to(e1,ref2),_resize_to(e2,ref2),_resize_to(d3,ref2)], [2,3])))

        ref1 = e1
        d1 = self.dec1(self.fuse1(_keep_only([_resize_to(e0,ref1),_resize_to(e1,ref1),_resize_to(d2,ref1)], [1,2])))

        ref0 = e0
        d0 = self.dec0(self.fuse0(_keep_only([_resize_to(e0,ref0),_resize_to(e1,ref0),_resize_to(d1,ref0),_resize_to(d2,ref0)], [0,2])))

        return self._fuse_logits(x, d0, d1, d2, d3, d4)

# ------------------------------------------------------------
# Helper factories for Lite models (V1‚ÄìV3 only)
# ------------------------------------------------------------
def create_sslite_v1(num_classes: int):
    in_ch = globals().get("IN_CHANNELS", 1)
    base_ch_global = globals().get("BASE_CH", 32)
    ds_on = bool(globals().get("SSLITE_DEEP_SUPERVISION", True))
    w = globals().get("DS_WEIGHTS", [0.5,0.15,0.15,0.10,0.10])
    if not isinstance(w, (list, tuple)): w = [0.5,0.15,0.15,0.10,0.10]
    w = (list(w) + [0,0,0,0,0])[:5]
    return SSUNet3PlusLite_V1(in_ch, num_classes, base_ch_global, ds_on, w[0],w[1],w[2],w[3],w[4])

def create_sslite_v2(num_classes: int):
    in_ch = globals().get("IN_CHANNELS", 1)
    base_ch_global = globals().get("BASE_CH", 32)
    ds_on = bool(globals().get("SSLITE_DEEP_SUPERVISION", True))
    w = globals().get("DS_WEIGHTS", [0.5,0.15,0.15,0.10,0.10])
    if not isinstance(w, (list, tuple)): w = [0.5,0.15,0.15,0.10,0.10]
    w = (list(w) + [0,0,0,0,0])[:5]
    return SSUNet3PlusLite_V2(in_ch, num_classes, base_ch_global, ds_on, w[0],w[1],w[2],w[3],w[4])

def create_sslite_v3(num_classes: int):
    in_ch = globals().get("IN_CHANNELS", 1)
    base_ch_global = globals().get("BASE_CH", 32)
    ds_on = bool(globals().get("SSLITE_DEEP_SUPERVISION", True))
    w = globals().get("DS_WEIGHTS", [0.5,0.15,0.15,0.10,0.10])
    if not isinstance(w, (list, tuple)): w = [0.5,0.15,0.15,0.10,0.10]
    w = (list(w) + [0,0,0,0,0])[:5]
    return SSUNet3PlusLite_V3(in_ch, num_classes, base_ch_global, ds_on, w[0],w[1],w[2],w[3],w[4])

# ------------------------------------------------------------
# Helper: parameter counting + FLOPs/MACs summary
# ------------------------------------------------------------
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class _ThopSafeWrapper(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
    def forward(self, x):
        y = self.model(x)
        return y[0] if isinstance(y, (list, tuple)) else y

def _profile_macs_flops(model: nn.Module, x: torch.Tensor):
    from thop import profile
    model.eval()
    safe_model = _ThopSafeWrapper(model)
    with torch.no_grad():
        macs, _ = profile(safe_model, inputs=(x,), verbose=False)
    macs = float(macs)
    flops = 2.0 * macs
    return macs, flops

def print_model_param_flops_summary(in_channels=1, num_classes=3, base_ch=32, input_hw=256, device=None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    models = dict(
        UNet=UNet(in_channels=in_channels, num_classes=num_classes, base_ch=base_ch),
        UNet3Plus=UNet3Plus(in_channels=in_channels, num_classes=num_classes, base_ch=base_ch),
        SSLite_V1_EnhSemiFull=SSUNet3PlusLite_V1(in_channels, num_classes, base_ch, deep_supervision=bool(globals().get("SSLITE_DEEP_SUPERVISION", True))),
        SSLite_V2_4Branch=SSUNet3PlusLite_V2(in_channels, num_classes, base_ch, deep_supervision=bool(globals().get("SSLITE_DEEP_SUPERVISION", True))),
        SSLite_V3_Baseline=SSUNet3PlusLite_V3(in_channels, num_classes, base_ch, deep_supervision=bool(globals().get("SSLITE_DEEP_SUPERVISION", True))),
    )

    x = torch.randn(1, in_channels, input_hw, input_hw, device=device)

    print("\nüìä Params + FLOPs (trainable):")
    print(f"   Config: in_channels={in_channels}, num_classes={num_classes}, base_ch={base_ch}, input=1x{in_channels}x{input_hw}x{input_hw}")
    print(f"   UPSAMPLE_POLICY = '{_get_policy()}' (nearest | bilinear)")
    print(f"   UNET3PLUS_DEEP_SUPERVISION = {globals().get('UNET3PLUS_DEEP_SUPERVISION', True)}")
    print(f"   SSLITE_DS_OUTPUT = '{globals().get('SSLITE_DS_OUTPUT','fused')}'")
    print("")

    for name, model in models.items():
        model = model.to(device)
        n_params = count_parameters(model)
        macs, flops = _profile_macs_flops(model, x)
        print(f"   {name:<26}: {n_params:10,d} params (~{n_params/1e6:6.3f} M) | FLOPs ~ {flops/1e9:7.3f} G | MACs {macs/1e9:7.3f} G")

print("‚úÖ Models defined: UNet, UNet3Plus, SSUNet3PlusLite (V1‚ÄìV3)")
print_model_param_flops_summary(in_channels=1, num_classes=3, base_ch=32, input_hw=256)


## Ablation

In [None]:
# ============================================================
# CELL 4.5A: Ablation / Config switchboard (toggles only)
# ============================================================

# ---------------------------
# Upsampling policy (GLOBAL, for ALL models)
# ---------------------------
#   "nearest"  : F.interpolate(..., mode="nearest")
#   "bilinear" : F.interpolate(..., mode="bilinear", align_corners=False)
#

UPSAMPLE_POLICY = "nearest"  # "nearest" | "bilinear"


# ---------------------------
# UNet3Plus deep supervision output
# ---------------------------
UNET3PLUS_DEEP_SUPERVISION = True  # True -> returns [y0,y1,y2,y3,y4], False -> returns y0 only

# ---------------------------
# SS-UNet3+ Lite deep supervision
# ---------------------------
SSLITE_DEEP_SUPERVISION = True

# What SS-Lite returns when DS is ON:
#   "fused" -> returns fused logits tensor (default behavior in your Lite models)
#   "list"  -> returns [y0,y1,y2,y3,y4] like UNet3Plus
SSLITE_DS_OUTPUT = "fused"  # "fused" | "list"

# Deep supervision weights for SS-Lite (used when SSLITE_DS_OUTPUT="fused")
DS_WEIGHTS = [0.8, 0.05, 0.05, 0.05, 0.05]  # [y0, y1, y2, y3, y4]

# Normalize DS weights safely
_ds_sum = float(sum(DS_WEIGHTS))
_ds_has_neg = any(float(w) < 0.0 for w in DS_WEIGHTS)
if _ds_has_neg or _ds_sum <= 0.0:
    print("‚ö†Ô∏è  Warning: DS_WEIGHTS invalid (negatives or sum<=0). Using safe defaults.")
    DS_WEIGHTS = [0.5, 0.15, 0.15, 0.10, 0.10]
else:
    DS_WEIGHTS = [float(w) / _ds_sum for w in DS_WEIGHTS]

# ---------------------------
# Normalization ablation (affects UNet3Plus + SS-Lite blocks that use make_norm)
# ---------------------------
# - USE_GN=True  -> GroupNorm
# - USE_GN=False + USE_BN=True  -> BatchNorm
# - USE_GN=False + USE_BN=False -> NO normalization (Identity)
USE_GN    = False
USE_BN    = True
GN_GROUPS = 16

# ---------------------------
# Width ablation
# ---------------------------
BASE_CH = 32  # e.g., 32 or 64

# ---------------------------
# Fusion ablation (SS-UNet3+ Lite V1‚ÄìV3 only)
# ---------------------------
#   "native"    -> original version-specific fusion (default)
#   "unet_skip" -> UNet-like same-level skip only (zero out unused inputs)
ABL_FUSION = "native"  # "native" | "unet_skip"

# ---------------------------
# Loss config (if you use hybrid loss elsewhere)
# ---------------------------
CE_WEIGHT   = 0.5
DICE_WEIGHT = 0.5


# ---------------------------
# Quick print
# ---------------------------
print("\n‚úÖ Ablation / config summary")
print(f"   UPSAMPLE_POLICY           = '{UPSAMPLE_POLICY}'  (nearest | bilinear)")
print(f"   BASE_CH                   = {BASE_CH}")
print(f"   USE_GN                    = {USE_GN} (GN_GROUPS={GN_GROUPS})")
print(f"   USE_BN                    = {USE_BN}")
print(f"   UNET3PLUS_DEEP_SUPERVISION= {UNET3PLUS_DEEP_SUPERVISION}")
print(f"   SSLITE_DEEP_SUPERVISION   = {SSLITE_DEEP_SUPERVISION}")
print(f"   SSLITE_DS_OUTPUT          = '{SSLITE_DS_OUTPUT}'  (fused | list)")
print(f"   DS_WEIGHTS (y0..y4)       = {DS_WEIGHTS}")
print(f"   ABL_FUSION                = '{ABL_FUSION}'")
print(f"   CE_WEIGHT / DICE_WEIGHT   = {CE_WEIGHT} / {DICE_WEIGHT}")


In [None]:
# ============================================================
# CELL 4.5B: Training helpers (metrics, epoch runner, printing)
#   - merge_logits: handle deep supervision lists
#   - Deterministic CE for segmentation (avoids CUDA nll_loss2d kernel)
#   - SOFT Dice/IoU sums (for dice loss + soft reporting)
#   - HARD Dice/IoU from tp/fp/fn (for tqdm + early stopping + reporting)
#   - _unpack_batch: support dict / tuple batches
#   - _run_epoch_core: shared TRAIN/VAL one-epoch runner
#   - run_epoch_train / run_epoch_val (wrappers + backward-compatible aliases)
#   - format_domain_metrics, format_class_metrics, print_domain_stats
#
# Returns dict keys expected by Cell 5 training loop:
#   time_sec, loss_total, loss_ce, loss_dice,
#   dice_macro, iou_macro, dice_per_class, iou_per_class,          (SOFT)
#   hard_dice_macro, hard_iou_macro, hard_dice_per_class, hard_iou_per_class,
#   domain_stats (optional): per-domain tp/fp/fn tensors
# ============================================================

import time
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm


# ------------------------------------------------------------
# Deep supervision helper
# ------------------------------------------------------------
def merge_logits(logits_any):
    """
    If model outputs deep supervision as list/tuple of logits, use the MAIN full-res head.

    For your implementations:
      - UNet3Plus (deep_supervision=True) returns [y0, y1, y2, y3, y4]
        where y0 is the full-resolution D0 head (main output).
      - SS-Lite models normally return a single fused tensor when SSLITE_DS_OUTPUT="fused".
    """
    if isinstance(logits_any, (list, tuple)):
        # IMPORTANT: use index 0 (y0, full-res), NOT the last element (y4)
        return logits_any[0]
    return logits_any


def _ensure_logits_size(logits: torch.Tensor, y: torch.Tensor):
    if logits.shape[-2:] == y.shape[-2:]:
        return logits
    mode = str(globals().get("UPSAMPLE_POLICY", "nearest")).lower()
    if mode not in ("nearest", "bilinear", "nearest-exact", "area"):
        mode = "nearest"
    if mode in ("nearest", "area", "nearest-exact"):
        return F.interpolate(logits, size=y.shape[-2:], mode=mode)
    return F.interpolate(logits, size=y.shape[-2:], mode="bilinear", align_corners=False)



# ------------------------------------------------------------
# Deterministic Cross Entropy (segmentation)
#   - Avoids CUDA nll_loss2d_forward_out_* which may be nondeterministic
#     under torch.use_deterministic_algorithms(True)
# ------------------------------------------------------------
def cross_entropy_det_2d(
    logits: torch.Tensor,
    targets: torch.Tensor,
    weight: torch.Tensor = None,
    ignore_index: int = -100,
    label_smoothing: float = 0.0,
    reduction: str = "mean",
):
    """
    Deterministic replacement for F.cross_entropy for segmentation.

    logits:  (B,C,H,W)
    targets: (B,H,W)
    """
    logits = logits.float()
    targets = targets.long()

    logp = F.log_softmax(logits, dim=1)  # (B,C,H,W)

    if ignore_index is None:
        valid = torch.ones_like(targets, dtype=torch.bool)
    else:
        valid = (targets != ignore_index)

    # gather requires valid indices; clamp ignored pixels to 0 safely
    gather_t = targets.clone()
    gather_t[~valid] = 0

    nll = -logp.gather(1, gather_t.unsqueeze(1)).squeeze(1)  # (B,H,W)

    if weight is not None:
        w = weight.to(logits.device, dtype=logits.dtype)
        nll = nll * w[gather_t]

    if label_smoothing and label_smoothing > 0.0:
        smooth = -logp.mean(dim=1)  # (B,H,W)
        nll = (1.0 - label_smoothing) * nll + label_smoothing * smooth

    nll = nll[valid]

    if reduction == "mean":
        return nll.mean() if nll.numel() > 0 else torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
    if reduction == "sum":
        return nll.sum()
    if reduction == "none":
        out = torch.zeros_like(targets, dtype=logits.dtype)
        out[valid] = nll
        return out
    raise ValueError(f"Unknown reduction: {reduction}")


# ------------------------------------------------------------
# SOFT Dice/IoU (for LOSS + optional soft reporting)
# ------------------------------------------------------------
def _soft_sums_from_logits(logits: torch.Tensor, targets: torch.Tensor, num_classes: int):
    """
    Return per-class sums:
      inter_c, dice_denom_c, iou_union_c  (shape C)
    """
    logits = _ensure_logits_size(logits, targets)
    probs = torch.softmax(logits, dim=1)  # (B,C,H,W)
    onehot = F.one_hot(targets.long(), num_classes=num_classes).permute(0, 3, 1, 2).float()
    dims = (0, 2, 3)

    inter = (probs * onehot).sum(dims)
    dice_denom = (probs + onehot).sum(dims)
    iou_union = (probs + onehot - probs * onehot).sum(dims)
    return inter, dice_denom, iou_union


def _soft_metrics_from_sums(inter, dice_denom, iou_union, eps=1e-6):
    dice_pc = (2.0 * inter + eps) / (dice_denom + eps)
    iou_pc  = (inter + eps) / (iou_union + eps)
    dice_m  = dice_pc.mean()
    iou_m   = iou_pc.mean()
    return dice_pc, dice_m, iou_pc, iou_m


# Backward-compatible helper name used in older parts/comments
def _metric_from_sums(inter, dice_denom, iou_union, eps=1e-6):
    """Return (dice_per_class, iou_per_class) from accumulated sums."""
    dice_pc = (2.0 * inter + eps) / (dice_denom + eps)
    iou_pc  = (inter + eps) / (iou_union + eps)
    return dice_pc, iou_pc


# ------------------------------------------------------------
# HARD metrics (tp/fp/fn) for early stopping + tqdm
# ------------------------------------------------------------
@torch.no_grad()
def _hard_counts_from_logits(logits: torch.Tensor, targets: torch.Tensor, num_classes: int):
    logits = _ensure_logits_size(logits, targets)
    pred = torch.argmax(logits, dim=1)  # (B,H,W)

    tp = torch.zeros(num_classes, device=logits.device, dtype=torch.float32)
    fp = torch.zeros(num_classes, device=logits.device, dtype=torch.float32)
    fn = torch.zeros(num_classes, device=logits.device, dtype=torch.float32)

    for c in range(num_classes):
        p = (pred == c)
        t = (targets == c)
        tp[c] += (p & t).sum().float()
        fp[c] += (p & (~t)).sum().float()
        fn[c] += ((~p) & t).sum().float()

    return tp, fp, fn


def _hard_metrics_from_counts(tp, fp, fn, eps=1e-6):
    """
    Must exist because Cell 5 uses it directly for domain strings.
    Returns: dice_pc, dice_macro, iou_pc, iou_macro
    """
    dice_pc = (2.0 * tp + eps) / (2.0 * tp + fp + fn + eps)
    iou_pc  = (tp + eps) / (tp + fp + fn + eps)
    dice_m  = dice_pc.mean()
    iou_m   = iou_pc.mean()
    return dice_pc, dice_m, iou_pc, iou_m


# ------------------------------------------------------------
# Batch unpack helper
# ------------------------------------------------------------
def _unpack_batch(batch):
    """
    Support:
      - dict: {"image":..., "mask":..., "domain":...}
      - tuple/list: (x, y) or (x, y, domain)
    """
    if isinstance(batch, dict):
        x = batch.get("image", None)
        y = batch.get("mask", None)
        dom = batch.get("domain", None)
        return x, y, dom

    if isinstance(batch, (list, tuple)):
        if len(batch) == 2:
            return batch[0], batch[1], None
        if len(batch) >= 3:
            return batch[0], batch[1], batch[2]

    raise ValueError("Unsupported batch format")


# ------------------------------------------------------------
# DS-aware CE + Dice loss (deterministic)
# ------------------------------------------------------------
def _ce_dice_loss_single(
    logits: torch.Tensor,
    targets: torch.Tensor,
    num_classes: int,
    ce_weight: float,
    dice_weight: float,
    dice_eps: float = 1e-6,
    ce_class_weight: torch.Tensor = None,
    ce_ignore_index: int = -100,
    ce_label_smoothing: float = 0.0,
):
    logits = _ensure_logits_size(logits, targets)

    loss_ce = cross_entropy_det_2d(
        logits, targets,
        weight=ce_class_weight,
        ignore_index=ce_ignore_index,
        label_smoothing=ce_label_smoothing,
        reduction="mean",
    )

    inter, denom, union = _soft_sums_from_logits(logits, targets, num_classes=num_classes)
    dice_pc, dice_m, _, _ = _soft_metrics_from_sums(inter, denom, union, eps=dice_eps)
    loss_dice = 1.0 - dice_m

    loss_total = float(ce_weight) * loss_ce + float(dice_weight) * loss_dice
    return loss_total, loss_ce, loss_dice


def loss_any(
    logits_any,
    targets,
    num_classes: int,
    ce_weight: float,
    dice_weight: float,
    dice_eps: float = 1e-6,
    ds_weights=None,
    ce_class_weight: torch.Tensor = None,
    ce_ignore_index: int = -100,
    ce_label_smoothing: float = 0.0,
):
    """
    If logits_any is deep supervision list/tuple:
      loss = sum_i w_i * (CE + Dice) over outputs
    Otherwise:
      loss = CE + Dice
    Returns: (loss_total, loss_ce, loss_dice) as tensors
    """
    if isinstance(logits_any, (list, tuple)):
        outs = list(logits_any)
        k = len(outs)

        if ds_weights is None:
            ds_weights = globals().get("DS_WEIGHTS", None)
        if ds_weights is None:
            ds_weights = [1.0] * k
        assert len(ds_weights) == k, "ds_weights length mismatch"

        # normalize weights (stable scaling)
        s = float(sum(float(w) for w in ds_weights))
        ws = [(float(w) / s) for w in ds_weights] if s > 0 else [1.0 / k] * k

        total = 0.0
        ce_acc = 0.0
        dice_acc = 0.0
        for w, lg in zip(ws, outs):
            lt, lce, ld = _ce_dice_loss_single(
                lg, targets,
                num_classes=num_classes,
                ce_weight=ce_weight,
                dice_weight=dice_weight,
                dice_eps=dice_eps,
                ce_class_weight=ce_class_weight,
                ce_ignore_index=ce_ignore_index,
                ce_label_smoothing=ce_label_smoothing,
            )
            total = total + w * lt
            ce_acc = ce_acc + w * lce
            dice_acc = dice_acc + w * ld

        return total, ce_acc, dice_acc

    return _ce_dice_loss_single(
        logits_any, targets,
        num_classes=num_classes,
        ce_weight=ce_weight,
        dice_weight=dice_weight,
        dice_eps=dice_eps,
        ce_class_weight=ce_class_weight,
        ce_ignore_index=ce_ignore_index,
        ce_label_smoothing=ce_label_smoothing,
    )


# ------------------------------------------------------------
# Shared one-epoch runner (TRAIN or VAL)
# ------------------------------------------------------------
def _run_epoch_core(
    model,
    loader,
    optimizer=None,
    train_mode: bool = True,
    device=None,
    num_classes: int = None,
    ce_weight: float = 1.0,
    dice_weight: float = 1.0,
    dice_eps: float = 1e-6,
    ds_weights=None,
    track_by_domain: bool = False,
    show_tqdm: bool = True,
    tqdm_desc: str = "",
    ce_class_weight: torch.Tensor = None,
    ce_ignore_index: int = -100,
    ce_label_smoothing: float = 0.0,
):
    """
    Returns dict with keys expected by Cell 5 training loop:
      time_sec, loss_total, loss_ce, loss_dice,
      dice_macro, iou_macro, dice_per_class, iou_per_class,               (SOFT)
      hard_dice_macro, hard_iou_macro, hard_dice_per_class, hard_iou_per_class,
      domain_stats (optional): per-domain tp/fp/fn
    """
    if device is None:
        device = globals().get("DEVICE", None)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if num_classes is None:
        num_classes = globals().get("NUM_CLASSES", None)
    if num_classes is None:
        raise ValueError("NUM_CLASSES is not defined. Define it before running training.")

    if train_mode:
        model.train()
    else:
        model.eval()

    use_amp = bool(globals().get("USE_AMP", False)) if train_mode else False
    scaler = globals().get("SCALER", None)

    # accumulators (sample-weighted means)
    n_sum = 0
    loss_total_sum = 0.0
    loss_ce_sum = 0.0
    loss_dice_sum = 0.0

    # SOFT sums (exact soft dice/iou)
    soft_inter_sum = torch.zeros(num_classes, device=device, dtype=torch.float32)
    soft_denom_sum = torch.zeros(num_classes, device=device, dtype=torch.float32)
    soft_union_sum = torch.zeros(num_classes, device=device, dtype=torch.float32)

    # HARD sums (tp/fp/fn)
    tp_sum = torch.zeros(num_classes, device=device, dtype=torch.float32)
    fp_sum = torch.zeros(num_classes, device=device, dtype=torch.float32)
    fn_sum = torch.zeros(num_classes, device=device, dtype=torch.float32)

    # Domain stats (HARD only)
    domain_stats = {}

    t0 = time.perf_counter()
    it = loader
    if show_tqdm:
        it = tqdm(loader, desc=tqdm_desc, leave=False)

    for batch in it:
        x, y, domains = _unpack_batch(batch)
        x = x.to(device, non_blocking=False)
        y = y.to(device, non_blocking=False)

        if train_mode:
            optimizer.zero_grad(set_to_none=True)

        # forward + loss
        if train_mode and use_amp:
            with torch.cuda.amp.autocast(enabled=True):
                logits_any = model(x)
                lt, lce, ld = loss_any(
                    logits_any, y,
                    num_classes=num_classes,
                    ce_weight=ce_weight,
                    dice_weight=dice_weight,
                    dice_eps=dice_eps,
                    ds_weights=ds_weights,
                    ce_class_weight=ce_class_weight,
                    ce_ignore_index=ce_ignore_index,
                    ce_label_smoothing=ce_label_smoothing,
                )
        else:
            logits_any = model(x)
            lt, lce, ld = loss_any(
                logits_any, y,
                num_classes=num_classes,
                ce_weight=ce_weight,
                dice_weight=dice_weight,
                dice_eps=dice_eps,
                ds_weights=ds_weights,
                ce_class_weight=ce_class_weight,
                ce_ignore_index=ce_ignore_index,
                ce_label_smoothing=ce_label_smoothing,
            )

        if train_mode:
            if use_amp and (scaler is not None):
                scaler.scale(lt).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                lt.backward()
                optimizer.step()

        bs = int(x.shape[0])
        n_sum += bs
        loss_total_sum += float(lt.item()) * bs
        loss_ce_sum    += float(lce.item()) * bs
        loss_dice_sum  += float(ld.item()) * bs

        # metrics from final/highest-res logits
        logits = merge_logits(logits_any)
        logits = _ensure_logits_size(logits, y)

        # SOFT sums
        inter_b, denom_b, union_b = _soft_sums_from_logits(logits, y, num_classes=num_classes)
        soft_inter_sum += inter_b
        soft_denom_sum += denom_b
        soft_union_sum += union_b

        # HARD counts
        tp_b, fp_b, fn_b = _hard_counts_from_logits(logits, y, num_classes=num_classes)
        tp_sum += tp_b
        fp_sum += fp_b
        fn_sum += fn_b

        # per-domain HARD
        if track_by_domain and (domains is not None):
            dom_list = list(domains)
            for dom in sorted(set(dom_list)):  # deterministic order
                idxs = [i for i, d in enumerate(dom_list) if d == dom]
                if not idxs:
                    continue
                lg_dom = logits[idxs]
                y_dom  = y[idxs]
                tp_d, fp_d, fn_d = _hard_counts_from_logits(lg_dom, y_dom, num_classes=num_classes)

                if dom not in domain_stats:
                    domain_stats[dom] = {
                        "n": 0,
                        "tp": torch.zeros(num_classes, device=device, dtype=torch.float32),
                        "fp": torch.zeros(num_classes, device=device, dtype=torch.float32),
                        "fn": torch.zeros(num_classes, device=device, dtype=torch.float32),
                    }
                domain_stats[dom]["n"]  += len(idxs)
                domain_stats[dom]["tp"] += tp_d
                domain_stats[dom]["fp"] += fp_d
                domain_stats[dom]["fn"] += fn_d

        # tqdm uses HARD macro (stable + matches epoch-end HARD)
        if show_tqdm:
            _, hd_m, _, hi_m = _hard_metrics_from_counts(tp_sum, fp_sum, fn_sum)
            it.set_postfix({
                "loss": loss_total_sum / max(1, n_sum),
                "hdice": float(hd_m),
                "hiou": float(hi_m),
            })

    t1 = time.perf_counter()

    # finalize SOFT metrics
    dice_pc, dice_m, iou_pc, iou_m = _soft_metrics_from_sums(
        soft_inter_sum, soft_denom_sum, soft_union_sum, eps=dice_eps
    )

    # finalize HARD metrics
    hard_dice_pc, hard_dice_m, hard_iou_pc, hard_iou_m = _hard_metrics_from_counts(tp_sum, fp_sum, fn_sum)

    out = {
        "time_sec": float(t1 - t0),

        "loss_total": float(loss_total_sum / max(1, n_sum)),
        "loss_ce":    float(loss_ce_sum    / max(1, n_sum)),
        "loss_dice":  float(loss_dice_sum  / max(1, n_sum)),

        # SOFT
        "dice_per_class": dice_pc.detach().cpu().numpy(),
        "iou_per_class":  iou_pc.detach().cpu().numpy(),
        "dice_macro": float(dice_m.item()),
        "iou_macro":  float(iou_m.item()),

        # HARD
        "hard_dice_per_class": hard_dice_pc.detach().cpu().numpy(),
        "hard_iou_per_class":  hard_iou_pc.detach().cpu().numpy(),
        "hard_dice_macro": float(hard_dice_m.item()),
        "hard_iou_macro":  float(hard_iou_m.item()),

        "n_samples": int(n_sum),
    }

    if track_by_domain:
        out["domain_stats"] = domain_stats

    return out


# ------------------------------------------------------------
# Wrappers (best practice + backward-compatible aliases)
#   Cell 5 calls:
#     run_epoch_train(model, train_loader, optimizer, track_by_domain=True, desc=..., ce_weight=..., dice_weight=...)
#     run_epoch_val(model, val_loader, track_by_domain=True, desc=..., ce_weight=..., dice_weight=...)
# ------------------------------------------------------------
def run_epoch_train(
    model,
    loader,
    optimizer,
    track_by_domain: bool = True,
    desc: str = "train",
    tqdm_desc: str = None,
    device=None,
    num_classes: int = None,
    ce_weight: float = None,
    dice_weight: float = None,
    ce_w: float = None,          # alias
    dice_w: float = None,        # alias
    dice_eps: float = 1e-6,
    ds_weights=None,
    show_tqdm: bool = True,
    **kwargs,
):
    if tqdm_desc is None:
        tqdm_desc = desc

    # accept both naming styles
    if ce_weight is None:
        ce_weight = ce_w if ce_w is not None else 1.0
    if dice_weight is None:
        dice_weight = dice_w if dice_w is not None else 1.0

    return _run_epoch_core(
        model=model,
        loader=loader,
        optimizer=optimizer,
        train_mode=True,
        device=device,
        num_classes=num_classes,
        ce_weight=ce_weight,
        dice_weight=dice_weight,
        dice_eps=dice_eps,
        ds_weights=ds_weights,
        track_by_domain=track_by_domain,
        show_tqdm=show_tqdm,
        tqdm_desc=tqdm_desc,
        # any unused kwargs are intentionally ignored for backward-compat
    )


@torch.no_grad()
def run_epoch_val(
    model,
    loader,
    track_by_domain: bool = True,
    desc: str = "val",
    tqdm_desc: str = None,
    device=None,
    num_classes: int = None,
    ce_weight: float = None,
    dice_weight: float = None,
    ce_w: float = None,          # alias
    dice_w: float = None,        # alias
    dice_eps: float = 1e-6,
    ds_weights=None,
    show_tqdm: bool = True,
    **kwargs,
):
    if tqdm_desc is None:
        tqdm_desc = desc

    if ce_weight is None:
        ce_weight = ce_w if ce_w is not None else 1.0
    if dice_weight is None:
        dice_weight = dice_w if dice_w is not None else 1.0

    return _run_epoch_core(
        model=model,
        loader=loader,
        optimizer=None,
        train_mode=False,
        device=device,
        num_classes=num_classes,
        ce_weight=ce_weight,
        dice_weight=dice_weight,
        dice_eps=dice_eps,
        ds_weights=ds_weights,
        track_by_domain=track_by_domain,
        show_tqdm=show_tqdm,
        tqdm_desc=tqdm_desc,
    )


# ------------------------------------------------------------
# Formatting helpers
# ------------------------------------------------------------
def format_domain_metrics(domain_stats: dict, num_classes: int, prefer_hard: bool = True):
    """
    Return per-domain mean Dice/IoU dict:
      { dom: {"dice": float, "iou": float, "n": int} }
    Uses HARD tp/fp/fn if available.
    """
    out = {}
    if not domain_stats:
        return out

    for dom in sorted(domain_stats.keys()):
        st = domain_stats[dom]
        n = int(st.get("n", 0))
        if prefer_hard and ("tp" in st) and ("fp" in st) and ("fn" in st):
            _, dice_m, _, iou_m = _hard_metrics_from_counts(st["tp"], st["fp"], st["fn"])
            out[dom] = {"dice": float(dice_m), "iou": float(iou_m), "n": n}
        else:
            out[dom] = {"dice": float("nan"), "iou": float("nan"), "n": n}

    return out


def format_class_metrics(dice_per_class, iou_per_class, class_names=None):
    out = {}
    if class_names is None:
        class_names = globals().get("CLASS_NAMES", None)
    if class_names is None:
        class_names = [str(i) for i in range(len(dice_per_class))]

    for i, name in enumerate(class_names):
        out[name] = {"dice": float(dice_per_class[i]), "iou": float(iou_per_class[i])}
    return out


def print_domain_stats(domain_stats: dict, num_classes: int, prefer_hard: bool = True, title: str = "Domain metrics"):
    if not domain_stats:
        print(f"{title}: (none)")
        return
    dm = format_domain_metrics(domain_stats, num_classes=num_classes, prefer_hard=prefer_hard)
    print(f"{title}:")
    for dom, st in dm.items():
        print(f"  - {dom:>12s} | n={st['n']:4d} | Dice={st['dice']:.4f} | IoU={st['iou']:.4f}")


\## DEBUG FOR SAME Shuffle but different Augmentation each epoch

In [None]:
# ============================================================
# DEBUG CELL (v5): PROOF that all models see identical feeding
#   A) verifies SHUFFLE ORDER (sampler indices)
#   B) verifies POST-AUG CONTENT (img+mask fingerprints)
#
# Strategy:
#   - Build a DEBUG loader with num_workers=0 (removes worker nondet)
#   - Seed shuffle with base_seed (fixed; same shuffle order every epoch)
#   - Set dataset epoch for epoch-wise augmentation changes
#
# EXPECTED:
#   SAME epoch: all models match reference (order + post-aug)
#   DIFF epoch: post-aug fingerprints differ (augmentation schedule changes)
# ============================================================

import hashlib
import torch
from torch.utils.data import DataLoader
from itertools import islice

# -----------------------------
# Config (keep small for speed)
# -----------------------------
EPOCHS_TO_CHECK = [0, 1, 2]
NUM_BATCHES_TO_CHECK = 5
MAX_ELEMS_PER_TENSOR = 200_000   # reduce if still slow; increase if you want stronger proof
PRINT_FIRST_N_PER_BATCH = 3      # how many sample-fps to show per batch

# Model tags (only used for display)
MODEL_TAGS = list(globals().get("MODEL_LIST", ["REF_MODEL"]))

assert "GLOBAL_SEED" in globals(), "GLOBAL_SEED missing."
assert "make_dataloaders" in globals(), "make_dataloaders(...) missing."

# -----------------------------
# Hash helpers
# -----------------------------
def _sha16(b: bytes) -> str:
    return hashlib.sha256(b).hexdigest()[:16]

def _tensor_bytes_subset(t: torch.Tensor, max_elems: int):
    t = t.detach().cpu()
    if not t.is_contiguous():
        t = t.contiguous()
    flat = t.view(-1)
    n = min(int(max_elems), flat.numel())
    flat = flat[:n]
    header = f"{str(t.dtype)}|{tuple(t.shape)}|".encode("utf-8")
    return header + flat.numpy().tobytes()

def _domain_bytes(domain):
    # domain is typically list[str] after collation
    if isinstance(domain, (list, tuple)):
        s = "|".join(map(str, domain))
    else:
        s = str(domain)
    return s.encode("utf-8")

def sample_fingerprint(img_i: torch.Tensor, mask_i: torch.Tensor, dom_i, max_elems: int):
    b = b""
    b += _tensor_bytes_subset(img_i, max_elems)
    b += _tensor_bytes_subset(mask_i, max_elems)
    b += str(dom_i).encode("utf-8")
    return _sha16(b)

# -----------------------------
# Build a DEBUG loader (num_workers=0)
# -----------------------------
def build_debug_train_loader(base_seed: int, epoch: int):
    """
    Uses your make_dataloaders() to get the dataset,
    but rebuilds a debug DataLoader with num_workers=0.
    """
    tr_ds, va_ds, tr_loader, va_loader, tr_eval_by_dom, val_by_dom = make_dataloaders(base_seed=int(base_seed))

    # epoch-aware augmentation
    if hasattr(tr_ds, "set_epoch"):
        tr_ds.set_epoch(int(epoch))

    # deterministic shuffle (fixed across epochs)
    g = torch.Generator()
    g.manual_seed(int(base_seed))

    debug_loader = DataLoader(
        tr_ds,
        batch_size=getattr(tr_loader, "batch_size", 1),
        shuffle=True,
        num_workers=0,           # critical: remove multi-worker nondeterminism for proof
        pin_memory=False,
        generator=g,
    )
    return tr_ds, debug_loader

def get_first_sampler_indices(loader, n_indices: int):
    # RandomSampler yields indices in shuffled order
    sampler = loader.sampler
    return list(islice(iter(sampler), int(n_indices)))

def get_loader_fps(loader, num_batches: int, max_elems: int):
    """
    Returns per-batch list of per-sample fingerprints (stronger than batch-level hash).
    """
    out = []
    it = iter(loader)
    for _ in range(int(num_batches)):
        try:
            img, mask, domain = next(it)
        except StopIteration:
            break

        # img: [B,C,H,W], mask: [B,H,W], domain: list[str]
        B = img.shape[0]
        fps = []
        for i in range(B):
            dom_i = domain[i] if isinstance(domain, (list, tuple)) else domain
            fps.append(sample_fingerprint(img[i], mask[i], dom_i, max_elems=max_elems))
        out.append(fps)
    return out

# -----------------------------
# Main check
# -----------------------------
def check_feed_identity_across_models(model_tags, epochs, num_batches, base_seed, max_elems):
    print("\n================ DEBUG (v5): FEED IDENTITY PROOF ================")
    print(f"GLOBAL_SEED={int(base_seed)} | epochs={list(epochs)} | batches={int(num_batches)}")
    print("We verify TWO things per epoch:")
    print("  (1) sampler indices (shuffle order)")
    print("  (2) per-sample post-augmentation fingerprints (img+mask+domain)")
    print("NOTE: debug loader uses num_workers=0 for strict determinism proof.\n")

    ref = {}  # epoch -> {"idx":..., "fps":...}

    for ep in epochs:
        # Reference
        _, ref_loader = build_debug_train_loader(base_seed, ep)
        ref_idx = get_first_sampler_indices(ref_loader, n_indices=ref_loader.batch_size * num_batches)
        ref_fps = get_loader_fps(ref_loader, num_batches=num_batches, max_elems=max_elems)
        ref[ep] = {"idx": ref_idx, "fps": ref_fps}

        # Compact signature per epoch for quick compare
        sig = _sha16(("|".join(map(str, ref_idx)) + "||" + "|".join(sum(ref_fps, []))).encode("utf-8"))
        print(f"Epoch {ep}: REF sig={sig} | first_idx[:8]={ref_idx[:8]}")

        # Compare ‚Äúother models‚Äù (rebuild loaders exactly the same way)
        for mtag in model_tags:
            _, ld = build_debug_train_loader(base_seed, ep)
            idx = get_first_sampler_indices(ld, n_indices=ld.batch_size * num_batches)
            fps = get_loader_fps(ld, num_batches=num_batches, max_elems=max_elems)

            ok_order = (idx == ref_idx)
            ok_fps   = (fps == ref_fps)

            status = "OK" if (ok_order and ok_fps) else "MISMATCH"
            if status == "MISMATCH":
                # pinpoint where it diverges
                bad_i = next((i for i in range(min(len(idx), len(ref_idx))) if idx[i] != ref_idx[i]), None)
                bad_b = next((b for b in range(min(len(fps), len(ref_fps))) if fps[b] != ref_fps[b]), None)
                print(f"  - {mtag}: {status} | order={ok_order} (first diff idx pos={bad_i}) | aug={ok_fps} (first diff batch={bad_b})")
                if bad_b is not None:
                    print(f"    ref batch{bad_b} samplefps (first {PRINT_FIRST_N_PER_BATCH}): {ref_fps[bad_b][:PRINT_FIRST_N_PER_BATCH]}")
                    print(f"    got batch{bad_b} samplefps (first {PRINT_FIRST_N_PER_BATCH}): {fps[bad_b][:PRINT_FIRST_N_PER_BATCH]}")
            else:
                print(f"  - {mtag}: {status}")

        # Across-epoch check (augmentation must change)
        if ep != epochs[0]:
            prev = epochs[epochs.index(ep) - 1]
            same_aug = (ref[ep]["fps"] == ref[prev]["fps"])
            if same_aug:
                print(f"  ‚ö†Ô∏è  WARNING: REF post-aug fingerprints identical to epoch {prev}.")
                print("     That means your augmentation is not epoch-dependent in practice.")
                print("     In v5 it SHOULD be, via seed = base_seed + epoch*BIG + idx.")
            else:
                print(f"  ‚úÖ REF post-aug differs vs epoch {prev} (expected).")

        print("")

    print("=================================================================\n")
    print("If you see MISMATCH with num_workers=0, it is a REAL bug in your seed/epoch logic.")
    print("If it matches with num_workers=0 but mismatches in training (num_workers>0),")
    print("then the leak is multi-worker / external-library nondeterminism.")
    print("In that case, keep num_workers>0 for speed, but rely on THIS proof for paper claims.\n")

# Run
check_feed_identity_across_models(
    model_tags=MODEL_TAGS,
    epochs=EPOCHS_TO_CHECK,
    num_batches=NUM_BATCHES_TO_CHECK,
    base_seed=int(GLOBAL_SEED),
    max_elems=int(MAX_ELEMS_PER_TENSOR),
)


In [None]:
# ============================================================
# CELL 5: Training loop with TRUE resume + fixed deterministic shuffle
# FIXES (no architecture changes):
#   - Domain gaps fully removed from printing/logging (no compute_domain_gaps usage)
#   - Validation reporting + XL saving uses HARD Dice/IoU only (soft removed for val)
#   - Resume training restores optimizer + scaler (+ RNG states)
#   - Shuffle stays ON, deterministic but FIXED across epochs (seed = GLOBAL_SEED)
#   - Augmentation/crop changes every epoch (fallback enforced if dataset isn't epoch-aware yet)
#
# Requires:
#   - Cell 2: make_dataloaders(...)
#   - Cell 3: get_model_paths, append_log_row
#   - Cell 4: UNet, UNet3Plus, create_sslite_v1/v2/v3
#   - Cell 4.5: run_epoch_train, run_epoch_val, _metric_from_sums
# ============================================================

import time
import random
import numpy as np
import torch
import torch.nn as nn
import hashlib


# ---------- Safety checks ----------
assert "make_dataloaders" in globals(), "make_dataloaders not found. Run Cell 2 first."
assert "get_model_paths" in globals() and "append_log_row" in globals(), \
    "get_model_paths / append_log_row not found. Run Cell 3 first."
assert "run_epoch_train" in globals() and "run_epoch_val" in globals(), \
    "run_epoch_train / run_epoch_val not found. Run Cell 4.5 first."
assert "_hard_metrics_from_counts" in globals(), "_hard_metrics_from_counts not found. Run Cell 4.5B first."

# Lite models (V1‚ÄìV3 only)
assert "create_sslite_v1" in globals() and "create_sslite_v2" in globals() and "create_sslite_v3" in globals(), \
    "Lite wrappers (create_sslite_v1..v3) not found. Run Cell 4 first."

# ---------- Global config for this cell ----------
NUM_CLASSES = 3
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# AMP flag (Cell 4.5 reads USE_AMP + SCALER)
USE_AMP = False

# Learning-rate schedule + early stopping (shared for all models)
INIT_LR             = 1e-3
LR_DECAY_FACTOR     = 10.0
LR_PATIENCE_EPOCHS  = 5
EARLY_STOP_PATIENCE = 20
MIN_LR              = 1e-5

print(f"üñ•  Using device: {DEVICE}")
print(f"üìâ Initial learning rate: {INIT_LR}")
print(f"üìâ LR decay: √∑{LR_DECAY_FACTOR} after {LR_PATIENCE_EPOCHS} epochs without improvement")
print(f"üìâ Minimum learning rate: {MIN_LR}")
print(f"üõë Early stopping: stop after {EARLY_STOP_PATIENCE} epochs without improvement")
print(f"‚öôÔ∏è  AMP enabled? {USE_AMP}")


# ---------- Model creation helper ----------
def create_model(model_key: str, num_classes: int = None) -> nn.Module:
    if num_classes is None:
        num_classes = NUM_CLASSES

    key = str(model_key).lower()

    ds_u3p  = bool(globals().get("UNET3PLUS_DEEP_SUPERVISION", False))
    in_ch   = globals().get("IN_CHANNELS", 1)
    base_ch = globals().get("BASE_CH", 32)

    if key in {"unet", "u-net"}:
        return UNet(in_channels=in_ch, num_classes=num_classes, base_ch=base_ch)

    if key in {"unet3plus", "unet3+", "u-net3plus", "u-net3+"}:
        return UNet3Plus(
            in_channels=in_ch,
            num_classes=num_classes,
            base_ch=base_ch,
            deep_supervision=ds_u3p,
        )

    if key in {"sslite_v1", "ssunet3pluslite_v1", "ssunet3pluslite-v1"}:
        return create_sslite_v1(num_classes=num_classes)

    if key in {"sslite_v2", "ssunet3pluslite_v2", "ssunet3pluslite-v2"}:
        return create_sslite_v2(num_classes=num_classes)

    if key in {"sslite_v3", "ssunet3pluslite_v3", "ssunet3pluslite-v3"}:
        return create_sslite_v3(num_classes=num_classes)

    raise ValueError(f"Unknown model_key: {model_key}")


# ---------- Small formatting helpers (for logging strings) ----------
def _fmt_domain_strings(domain_stats: dict, num_classes: int):
    if not domain_stats:
        return "", ""

    parts_d, parts_i = [], []
    for dom in ["CFRP", "FFRP_Autoclave", "FFRP_Oven"]:
        if dom not in domain_stats:
            continue
        st = domain_stats[dom]
        tp = st["tp"]
        fp = st["fp"]
        fn = st["fn"]
        dice_pc, dice_m, iou_pc, iou_m = _hard_metrics_from_counts(tp, fp, fn)
        d = float(dice_m)
        i = float(iou_m)
        parts_d.append(f"{dom}:{d:.5f}")
        parts_i.append(f"{dom}:{i:.5f}")

    return "|".join(parts_d), "|".join(parts_i)


def _fmt_class_strings(dice_pc, iou_pc, class_names=None):
    if class_names is None:
        class_names = globals().get("CLASS_NAMES", None)
    if class_names is None:
        class_names = [f"class_{i}" for i in range(len(dice_pc))]

    parts_d, parts_i = [], []
    for i, nm in enumerate(class_names):
        parts_d.append(f"{nm}:{float(dice_pc[i]):.5f}")
        parts_i.append(f"{nm}:{float(iou_pc[i]):.5f}")
    return "|".join(parts_d), "|".join(parts_i)


def _print_epoch_table(train_row=None, val_row=None):
    print("\n[PER-EPOCH SUMMARY]")
    print(f"{'phase':<5}  {'epoch':>3}  {'loss':>9}  {'cls':>9}  {'dice_loss':>9}  {'HDice':>9}  {'HIoU':>9}  {'time[s]':>8}")
    print("-" * 84)

    if train_row is not None:
        print(f"{'train':<5}  {int(train_row['epoch']):3d}  {train_row['loss_total']:9.4f}  {train_row['loss_ce']:9.4f}  {train_row['loss_dice']:9.4f}  "
              f"{float(train_row['hard_dice_macro']):9.4f}  {float(train_row['hard_iou_macro']):9.4f}  {train_row['time_sec']:8.1f}")

    if val_row is not None:
        print(f"{'val':<5}  {int(val_row['epoch']):3d}  {val_row['loss_total']:9.4f}  {val_row['loss_ce']:9.4f}  {val_row['loss_dice']:9.4f}  "
              f"{float(val_row['hard_dice_macro']):9.4f}  {float(val_row['hard_iou_macro']):9.4f}  {val_row['time_sec']:8.1f}")


# ---------- RNG state save/load helpers (true resume) ----------
def _get_rng_state_payload():
    payload = {
        "py_random_state": random.getstate(),
        "np_random_state": np.random.get_state(),
        "torch_rng_state": torch.get_rng_state(),
    }
    if torch.cuda.is_available():
        try:
            payload["torch_cuda_rng_state_all"] = torch.cuda.get_rng_state_all()
        except Exception:
            payload["torch_cuda_rng_state_all"] = None
    else:
        payload["torch_cuda_rng_state_all"] = None
    return payload

def _restore_rng_state_payload(payload: dict):
    try:
        if payload.get("py_random_state", None) is not None:
            random.setstate(payload["py_random_state"])

        if payload.get("np_random_state", None) is not None:
            np.random.set_state(payload["np_random_state"])

        # ---- torch CPU RNG (MUST be CPU ByteTensor) ----
        st = payload.get("torch_rng_state", None)
        if st is not None:
            if torch.is_tensor(st):
                st = st.detach().to(device="cpu", dtype=torch.uint8)
            else:
                st = torch.tensor(st, dtype=torch.uint8)
            torch.set_rng_state(st)

        # ---- torch CUDA RNG (list of CPU ByteTensors) ----
        if torch.cuda.is_available():
            st_all = payload.get("torch_cuda_rng_state_all", None)
            if st_all is not None:
                # Sometimes old ckpt may store single tensor; normalize to list
                if torch.is_tensor(st_all):
                    st_all = [st_all]

                fixed = []
                for s in st_all:
                    if s is None:
                        fixed = None
                        break
                    if torch.is_tensor(s):
                        s = s.detach().to(device="cpu", dtype=torch.uint8)
                    else:
                        s = torch.tensor(s, dtype=torch.uint8)
                    fixed.append(s)

                if fixed is not None:
                    torch.cuda.set_rng_state_all(fixed)

    except Exception as e:
        print(f"   ‚ö†Ô∏è  Resume: RNG state restore failed: {e}")



# ---------- Main training function ----------
def run_training_for_model(model_key: str, extra_epochs: int = 10):
    pretty_name = {
        "unet":      "UNet",
        "unet3plus": "UNet3Plus",
        "unet3+":    "UNet3Plus",
        "sslite_v1": "SS-UNet3+ Lite V1 (EnhSemiFull)",
        "sslite_v2": "SS-UNet3+ Lite V2 (4-Branch)",
        "sslite_v3": "SS-UNet3+ Lite V3 (Baseline)",
    }.get(str(model_key).lower(), model_key)

    # Reset RNGs before each model training (keeps cross-model comparability)
    if "reset_all_seeds" in globals():
        reset_all_seeds(int(GLOBAL_SEED))
    else:
        random.seed(int(GLOBAL_SEED))
        np.random.seed(int(GLOBAL_SEED))
        torch.manual_seed(int(GLOBAL_SEED))
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(int(GLOBAL_SEED))

    # Meta toggles (from Cell 4.5 switchboard)
    ce_mode   = str(globals().get("CE_MODE", "ce")).lower().strip()
    abl_fus   = str(globals().get("ABL_FUSION", "native")).lower().strip()
    use_bn    = bool(globals().get("USE_BN", True))
    use_gn    = bool(globals().get("USE_GN", False))
    base_ch   = int(globals().get("BASE_CH", 32))
    mk = str(model_key).lower()
    if mk in {"unet3plus", "unet3+", "u-net3plus", "u-net3+"}:
        ds_on = bool(globals().get("UNET3PLUS_DEEP_SUPERVISION", False))
    elif mk in {"sslite_v1", "ssunet3pluslite_v1", "ssunet3pluslite-v1",
                "sslite_v2", "ssunet3pluslite_v2", "ssunet3pluslite-v2",
                "sslite_v3", "ssunet3pluslite_v3", "ssunet3pluslite-v3"}:
        ds_on = bool(globals().get("SSLITE_DEEP_SUPERVISION", True))
    else:
        ds_on = False


    ce_w_cfg   = float(globals().get("CE_WEIGHT", 0.5))
    dice_w_cfg = float(globals().get("DICE_WEIGHT", 0.5))

    print("\n" + "=" * 72)
    print(f"üöÄ Starting (or resuming) training: {pretty_name}  (key='{model_key}')")
    print(f"   Config ‚Üí CE_MODE='{ce_mode}', ABL_FUSION='{abl_fus}', USE_BN={use_bn}, USE_GN={use_gn}, BASE_CH={base_ch}, DS_ON={ds_on}")
    print("=" * 72)

    # --- rebuild loaders per model (keeps SAME epoch schedule across models) ---
    global train_dataset, val_dataset, train_loader, val_loader, TRAIN_EVAL_LOADERS_BY_DOMAIN, VAL_LOADERS_BY_DOMAIN, train_generator
    train_dataset, val_dataset, train_loader, val_loader, TRAIN_EVAL_LOADERS_BY_DOMAIN, VAL_LOADERS_BY_DOMAIN = make_dataloaders(
        base_seed=int(GLOBAL_SEED)
    )
    train_generator = getattr(getattr(train_loader, "sampler", None), "generator", None)

    model = create_model(model_key, num_classes=NUM_CLASSES).to(DEVICE)
    params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Trainable parameters: {params_count/1e6:.4f} M")

    optimizer = torch.optim.Adam(model.parameters(), lr=INIT_LR)

    # AMP objects used by Cell 4.5
    global SCALER
    SCALER = torch.cuda.amp.GradScaler(enabled=USE_AMP and (DEVICE.type == "cuda"))

    paths = get_model_paths(model_key)
    print("   Model/Log paths:")
    for k, v in paths.items():
        print(f"   {k}: {v}")

    best_val_hard_dice_macro = -1.0
    best_epoch               = 0
    current_lr               = INIT_LR
    start_epoch              = 0
    no_improve_streak        = 0
    no_improve_for_lr        = 0

    last_ckpt = paths["last_ckpt"]
    best_ckpt = paths["best_ckpt"]

    # ---------- Detect whether dataset augmentation changes with epoch ----------
    # If not, we enforce epoch-wise augmentation by changing train_dataset.base_seed per epoch.
    AUG_BIG = 1_000_000
    force_epoch_aug = False
    try:
        # probe idx=0 twice with epoch 0 and 1
        orig_epoch = getattr(train_dataset, "epoch", 0)
        orig_base_seed = getattr(train_dataset, "base_seed", int(GLOBAL_SEED))

        if hasattr(train_dataset, "set_epoch"):
            train_dataset.set_epoch(0)
        x0a = train_dataset[0]  # (img, mask, dom)
        if hasattr(train_dataset, "set_epoch"):
            train_dataset.set_epoch(1)
        x0b = train_dataset[0]

        def _sig(sample):
            img, msk = sample[0], sample[1]
            h = hashlib.sha256()
            h.update(img.detach().cpu().numpy().tobytes())
            h.update(msk.detach().cpu().numpy().tobytes())
            return h.hexdigest()

        same = (_sig(x0a) == _sig(x0b))

        # restore
        if hasattr(train_dataset, "set_epoch"):
            train_dataset.set_epoch(int(orig_epoch))
        try:
            train_dataset.base_seed = int(orig_base_seed)
        except Exception:
            pass

        if same:
            force_epoch_aug = True
            print("   ‚ö†Ô∏è  Dataset aug seems NOT epoch-aware yet ‚Üí enabling safe fallback: base_seed += epoch*BIG")
        else:
            print("   ‚úÖ Dataset aug is epoch-aware ‚Üí no fallback needed")
    except Exception:
        # if probe fails, don't force anything
        force_epoch_aug = False

    # ---------- Resume logic (TRUE resume: model + optimizer + scaler + RNG) ----------
    if last_ckpt.exists():
        ckpt = torch.load(last_ckpt, map_location=DEVICE, weights_only=False)
        state_dict = ckpt.get("model_state", ckpt.get("state_dict", None))
        if state_dict is not None:
            model.load_state_dict(state_dict)

        # optimizer/scaler restore

        if "optimizer_state" in ckpt and ckpt["optimizer_state"] is not None:
            try:
                optimizer.load_state_dict(ckpt["optimizer_state"])
            except Exception as e:
                print(f"   ‚ö†Ô∏è  Resume: optimizer_state load failed: {e}")

        if "scaler_state" in ckpt and ckpt["scaler_state"] is not None and SCALER is not None:
            try:
                SCALER.load_state_dict(ckpt["scaler_state"])
            except Exception as e:
                print(f"   ‚ö†Ô∏è  Resume: scaler_state load failed: {e}")

        # RNG restore (best effort)
        if "rng_state" in ckpt and isinstance(ckpt["rng_state"], dict):

            # ---- BEFORE restore: check what's inside ckpt ----
            try:
                st = ckpt["rng_state"].get("torch_rng_state", None)
                print("   [RNG-BEFORE] ckpt torch_rng_state:",
                      type(st),
                      getattr(st, "dtype", None),
                      getattr(st, "device", None))

                st_all = ckpt["rng_state"].get("torch_cuda_rng_state_all", None)
                if st_all is None:
                    print("   [RNG-BEFORE] ckpt cuda_rng_state_all: None")
                else:
                    # can be list or tensor
                    if torch.is_tensor(st_all):
                        print("   [RNG-BEFORE] ckpt cuda_rng_state_all:",
                              type(st_all), st_all.dtype, st_all.device)
                    else:
                        print("   [RNG-BEFORE] ckpt cuda_rng_state_all: list_len=", len(st_all))
                        if len(st_all) > 0 and st_all[0] is not None:
                            print("   [RNG-BEFORE] ckpt cuda_rng_state_all[0]:",
                                  type(st_all[0]),
                                  getattr(st_all[0], "dtype", None),
                                  getattr(st_all[0], "device", None))
            except Exception as e:
                print("   [RNG-BEFORE] check failed:", e)

            _restore_rng_state_payload(ckpt["rng_state"])

            # ---- AFTER restore: confirm current RNG states ----
            try:
                cpu_st = torch.get_rng_state()
                print("   [RNG-AFTER] torch.get_rng_state():", cpu_st.dtype, cpu_st.device)
                if torch.cuda.is_available():
                    cuda0 = torch.cuda.get_rng_state_all()[0]
                    print("   [RNG-AFTER] torch.cuda.get_rng_state_all()[0]:", cuda0.dtype, cuda0.device)
            except Exception as e:
                print("   [RNG-AFTER] check failed:", e)


        prev_epoch               = int(ckpt.get("epoch", 0))
        best_val_hard_dice_macro = float(ckpt.get("best_val_hard_dice_macro", -1.0))
        best_epoch               = int(ckpt.get("best_epoch", prev_epoch))
        current_lr               = float(ckpt.get("current_lr", INIT_LR))
        no_improve_streak        = int(ckpt.get("no_improve_streak", 0))
        no_improve_for_lr        = int(ckpt.get("no_improve_for_lr", 0))
        start_epoch              = prev_epoch + 1

        # ensure optimizer lr matches current_lr
        for g in optimizer.param_groups:
            g["lr"] = float(current_lr)

        print(f"   üîÅ Resume from LAST: epoch {start_epoch} (bestHardDice={best_val_hard_dice_macro:.4f} @epoch {best_epoch}, LR={current_lr:.6f})")

    elif best_ckpt.exists():
        ckpt = torch.load(best_ckpt, map_location=DEVICE, weights_only=False)
        state_dict = ckpt.get("model_state", ckpt.get("state_dict", None))
        if state_dict is not None:
            model.load_state_dict(state_dict)

        best_val_hard_dice_macro = float(ckpt.get("best_val_hard_dice_macro", -1.0))
        best_epoch               = int(ckpt.get("best_epoch", ckpt.get("epoch", 0)))
        start_epoch              = best_epoch + 1
        print(f"   üîÅ Warm-start from BEST: next epoch {start_epoch} (bestHardDice={best_val_hard_dice_macro:.4f} @epoch {best_epoch})")

    else:
        print("   üî∞ No checkpoint found ‚Üí epoch 0 baseline validation...")

        with torch.no_grad():
            val_stats0 = run_epoch_val(
                model, val_loader,
                track_by_domain=True,
                desc="Val (epoch 0)",
                ce_weight=ce_w_cfg,
                dice_weight=dice_w_cfg,
            )

        # HARD (this is the only val metric we report/save)
        val_hard_dice_macro0 = float(val_stats0["hard_dice_macro"])
        val_hard_iou_macro0  = float(val_stats0["hard_iou_macro"])


        val_hard_dice_pc0 = val_stats0.get("hard_dice_per_class", None)
        val_hard_iou_pc0  = val_stats0.get("hard_iou_per_class",  None)

        # losses
        val_loss0      = float(val_stats0["loss_total"])
        val_ce0        = float(val_stats0["loss_ce"])
        val_dice_loss0 = float(val_stats0["loss_dice"])
        val_time0      = float(val_stats0["time_sec"])

        val_dice_pc0 = val_stats0["dice_per_class"]
        val_iou_pc0  = val_stats0["iou_per_class"]
        val_dom0     = val_stats0.get("domain_stats", None)

        # print (VAL shows HARD only)
        _print_epoch_table(train_row=None, val_row={
            "epoch": 0,
            "loss_total": val_loss0, "loss_ce": val_ce0, "loss_dice": val_dice_loss0,
            "hard_dice_macro": val_hard_dice_macro0, "hard_iou_macro": val_hard_iou_macro0,
            "time_sec": val_time0
        })


        # strings for Table B
        val_dice_domain_str0, val_iou_domain_str0 = _fmt_domain_strings(val_dom0, num_classes=NUM_CLASSES) if val_dom0 else ("", "")
        val_dice_class_str0,  val_iou_class_str0  = _fmt_class_strings(val_dice_pc0, val_iou_pc0, class_names=globals().get("CLASS_NAMES", None))

        # Log epoch 0 (VAL only) ‚Äî SOFT FIELDS BLANK, HARD FILLED
        row_val0 = {
            "phase": "val",
            "eval_domain": "ALL",
            "epoch": 0,
            "model": model_key,
            "lr": float(INIT_LR),
            "time_sec": val_time0,

            "loss_total": val_loss0,
            "loss_ce": val_ce0,
            "loss_dice": val_dice_loss0,

            "dice_domain": val_dice_domain_str0,
            "iou_domain":  val_iou_domain_str0,
            "dice_classes": val_dice_class_str0,
            "iou_classes":  val_iou_class_str0,

            # HARD columns
            "hard_dice_macro": float(val_hard_dice_macro0),
            "hard_iou_macro":  float(val_hard_iou_macro0),

            # Meta
            "ce_mode": ce_mode,
            "abl_fusion": abl_fus,
            "use_gn": use_gn,
            "base_ch": base_ch,
            "ds_on": ds_on,
            "use_bn": use_bn,

        }

        # per-class hard columns
        for ci in range(NUM_CLASSES):
            row_val0[f"hard_dice_c{ci}"] = float(val_hard_dice_pc0[ci]) if val_hard_dice_pc0 is not None else ""
            row_val0[f"hard_iou_c{ci}"]  = float(val_hard_iou_pc0[ci])  if val_hard_iou_pc0  is not None else ""

        append_log_row(model_key, row_val0)

        # initialize best based on HARD
        best_val_hard_dice_macro = float(val_hard_dice_macro0)
        best_epoch               = 0

        ckpt_payload0 = {
            "epoch": 0,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scaler_state": SCALER.state_dict() if SCALER is not None else None,
            "rng_state": _get_rng_state_payload(),

            "best_val_hard_dice_macro": float(best_val_hard_dice_macro),
            "best_epoch": int(best_epoch),
            "current_lr": float(INIT_LR),
            "no_improve_streak": int(no_improve_streak),
            "no_improve_for_lr": int(no_improve_for_lr),

            # meta snapshot
            "ce_mode": ce_mode,
            "abl_fusion": abl_fus,
            "use_bn": use_bn,
            "use_gn": use_gn,
            "base_ch": base_ch,
            "ds_on": ds_on,

        }
        torch.save(ckpt_payload0, best_ckpt)
        torch.save(ckpt_payload0, last_ckpt)
        print(f"   Saved epoch-0 BEST to: {best_ckpt}")
        print(f"   Saved epoch-0 LAST to: {last_ckpt}")

        start_epoch = 1
        current_lr  = float(INIT_LR)

    # ---------- Epoch loop ----------
    max_epoch = start_epoch + int(extra_epochs)
    print(f"\n   Will run epochs {start_epoch}..{max_epoch-1} (extra {extra_epochs} epochs)")

    def _set_epoch_schedule(epoch: int):
        # dataset epoch hook
        try:
            if hasattr(train_dataset, "set_epoch"):
                train_dataset.set_epoch(int(epoch))
        except Exception:
            pass

    # deterministic shuffle (fixed across epochs)
        try:
            if train_generator is not None:
                train_generator.manual_seed(int(GLOBAL_SEED))
        except Exception:
            pass

        # fallback epoch-wise augmentation if dataset isn't epoch-aware yet
        if force_epoch_aug:
            try:
                train_dataset.base_seed = int(GLOBAL_SEED) + int(epoch) * int(AUG_BIG)
            except Exception:
                pass

    for epoch in range(start_epoch, max_epoch):
        print(f"\nüìò Epoch {epoch:03d}  (LR={current_lr:.6f})")

        _set_epoch_schedule(epoch)

        # ---- TRAIN ----
        train_stats = run_epoch_train(
            model, train_loader, optimizer,
            track_by_domain=True,
            desc=f"Train {epoch:03d}",
            ce_weight=ce_w_cfg,
            dice_weight=dice_w_cfg,
        )

        train_row = {
            "phase": "train",
            "eval_domain": "ALL",
            "epoch": int(epoch),
            "model": model_key,
            "lr": float(current_lr),
            "time_sec": float(train_stats["time_sec"]),

            "loss_total": float(train_stats["loss_total"]),
            "loss_ce":    float(train_stats["loss_ce"]),
            "loss_dice":  float(train_stats["loss_dice"]),

            # TRAIN keeps SOFT exact metrics
            "dice_domain": "",
            "iou_domain":  "",
            "dice_classes": "",
            "iou_classes":  "",

            # HARD metrics only (train also logs hard)
            "hard_dice_macro": float(train_stats["hard_dice_macro"]),
            "hard_iou_macro":  float(train_stats["hard_iou_macro"]),

            "ce_mode": ce_mode,
            "abl_fusion": abl_fus,
            "use_bn": use_bn,
            "use_gn": use_gn,
            "base_ch": base_ch,
            "ds_on": ds_on,

        }
        for ci in range(NUM_CLASSES):
            train_row[f"hard_dice_c{ci}"] = float(train_stats["hard_dice_per_class"][ci])
            train_row[f"hard_iou_c{ci}"]  = float(train_stats["hard_iou_per_class"][ci])

        train_dom = train_stats.get("domain_stats", None)
        if train_dom:
            d_dom_str, i_dom_str = _fmt_domain_strings(train_dom, num_classes=NUM_CLASSES)
            train_row["dice_domain"] = d_dom_str
            train_row["iou_domain"]  = i_dom_str

        d_cls_str, i_cls_str = _fmt_class_strings(
            train_stats["dice_per_class"], train_stats["iou_per_class"],
            class_names=globals().get("CLASS_NAMES", None)
        )
        train_row["dice_classes"] = d_cls_str
        train_row["iou_classes"]  = i_cls_str

        # ---- VAL ----
        with torch.no_grad():
            val_stats = run_epoch_val(
                model, val_loader,
                track_by_domain=True,
                desc=f"Val {epoch:03d}",
                ce_weight=ce_w_cfg,
                dice_weight=dice_w_cfg,
            )




        val_hard_dice_macro = float(val_stats["hard_dice_macro"])
        val_hard_iou_macro  = float(val_stats["hard_iou_macro"])

        val_hard_dice_pc    = val_stats.get("hard_dice_per_class", None)
        val_hard_iou_pc     = val_stats.get("hard_iou_per_class",  None)

        val_row = {
            "phase": "val",
            "eval_domain": "ALL",
            "epoch": int(epoch),
            "model": model_key,
            "lr": float(current_lr),
            "time_sec": float(val_stats["time_sec"]),

            "loss_total": float(val_stats["loss_total"]),
            "loss_ce":    float(val_stats["loss_ce"]),
            "loss_dice":  float(val_stats["loss_dice"]),

            "dice_domain": "",
            "iou_domain":  "",
            "dice_classes": "",
            "iou_classes":  "",

            # HARD columns
            "hard_dice_macro": float(val_hard_dice_macro),
            "hard_iou_macro":  float(val_hard_iou_macro),

            "ce_mode": ce_mode,
            "abl_fusion": abl_fus,
            "use_bn": use_bn,
            "use_gn": use_gn,
            "base_ch": base_ch,
            "ds_on": ds_on,
        }


        for ci in range(NUM_CLASSES):
            val_row[f"hard_dice_c{ci}"] = float(val_hard_dice_pc[ci]) if val_hard_dice_pc is not None else ""
            val_row[f"hard_iou_c{ci}"]  = float(val_hard_iou_pc[ci])  if val_hard_iou_pc  is not None else ""

        val_dom = val_stats.get("domain_stats", None)
        if val_dom:
            d_dom_str, i_dom_str = _fmt_domain_strings(val_dom, num_classes=NUM_CLASSES)
            val_row["dice_domain"] = d_dom_str
            val_row["iou_domain"]  = i_dom_str

        d_cls_str, i_cls_str = _fmt_class_strings(
            val_stats["dice_per_class"], val_stats["iou_per_class"],
            class_names=globals().get("CLASS_NAMES", None)
        )
        val_row["dice_classes"] = d_cls_str
        val_row["iou_classes"]  = i_cls_str

        # ---- Print (same rows that go to XL/CSV) ----
        _print_epoch_table(train_row=train_row, val_row=val_row)

        # ---- XL/CSV write ----
        append_log_row(model_key, train_row)
        append_log_row(model_key, val_row)

        # Best checkpoint selection: HARD val dice macro
        improved = val_hard_dice_macro > best_val_hard_dice_macro + 1e-6

        if improved:
            best_val_hard_dice_macro = float(val_hard_dice_macro)
            best_epoch               = int(epoch)
            no_improve_streak        = 0
            no_improve_for_lr        = 0

            torch.save({
                "epoch": int(epoch),
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scaler_state": SCALER.state_dict() if SCALER is not None else None,
                "rng_state": _get_rng_state_payload(),

                "best_val_hard_dice_macro": float(best_val_hard_dice_macro),
                "best_epoch": int(best_epoch),
                "current_lr": float(current_lr),
                "no_improve_streak": int(no_improve_streak),
                "no_improve_for_lr": int(no_improve_for_lr),

                "ce_mode": ce_mode,
                "abl_fusion": abl_fus,
                "use_bn": use_bn,
                "use_gn": use_gn,
                "base_ch": base_ch,
                "ds_on": ds_on,

            }, best_ckpt)
            print(f"   ‚úÖ New BEST (valHardDice={val_hard_dice_macro:.4f}) ‚Üí {best_ckpt}")
        else:
            no_improve_streak += 1
            no_improve_for_lr += 1
            print(f"   ‚ö†Ô∏è  No improvement. Streak={no_improve_streak} (LR-streak={no_improve_for_lr})")

        # Always save LAST checkpoint (with optimizer/scaler)
        torch.save({
            "epoch": int(epoch),
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scaler_state": SCALER.state_dict() if SCALER is not None else None,
            "rng_state": _get_rng_state_payload(),

            "best_val_hard_dice_macro": float(best_val_hard_dice_macro),
            "best_epoch": int(best_epoch),
            "current_lr": float(current_lr),
            "no_improve_streak": int(no_improve_streak),
            "no_improve_for_lr": int(no_improve_for_lr),

            "ce_mode": ce_mode,
            "abl_fusion": abl_fus,
            "use_bn": use_bn,
            "use_gn": use_gn,
            "base_ch": base_ch,
            "ds_on": ds_on,

        }, last_ckpt)
        print(f"   üíæ LAST saved ‚Üí {last_ckpt}")

        # LR decay
        if no_improve_for_lr >= LR_PATIENCE_EPOCHS and current_lr > MIN_LR + 1e-12:
            new_lr = max(current_lr / LR_DECAY_FACTOR, MIN_LR)
            if new_lr < current_lr:
                current_lr = float(new_lr)
                for g in optimizer.param_groups:
                    g["lr"] = current_lr
                no_improve_for_lr = 0
                print(f"   üîª LR decayed ‚Üí {current_lr:.6f}")

        # Early stop
        if no_improve_streak >= EARLY_STOP_PATIENCE:
            print(f"\nüõë Early stopping: no improvement for {no_improve_streak} epochs.")
            break

    print("\n=== Training finished ===")
    print(f"   Best val HARD Dice (macro) = {best_val_hard_dice_macro:.4f} @ epoch {best_epoch}")
    print(f"   BEST ckpt: {best_ckpt}")
    print(f"   LAST ckpt: {last_ckpt}")
    print(f"   Logs: {paths['log_csv']} & {paths['log_xlsx']}")


# ---------- Interactive selection ----------
print("\n================ MODEL MENU ================")
print("  1) UNet")
print("  2) UNet3+ (official-style)")
print("  3) SS-UNet3+ Lite V1  (EnhSemiFull)")
print("  4) SS-UNet3+ Lite V2  (4-Branch)")
print("  5) SS-UNet3+ Lite V3  (Baseline)")
print("  9) ALL (train all 5 models one by one)")
print("===========================================\n")

choice = input("üëâ Choose model (1/2/3/4/5 or 9=ALL): ").strip()

choice_map = {
    "1": "unet",
    "2": "unet3plus",
    "3": "sslite_v1",
    "4": "sslite_v2",
    "5": "sslite_v3",
    "9": "ALL",
}

if choice_map.get(choice, None) == "ALL":
    selected_models = ["unet", "unet3plus", "sslite_v1", "sslite_v2", "sslite_v3"]
else:
    selected_models = [choice_map[choice]]

extra_epochs = int(input("‚è±  Extra epochs to train (add on top of previous)? (e.g. 10): ").strip())

print("\nüìå Will train models:", selected_models)
print(f"   Extra epochs = {extra_epochs}, initial LR = {INIT_LR}")
print(f"   AMP enabled?  {USE_AMP}")

for mk in selected_models:
    run_training_for_model(mk, extra_epochs=extra_epochs)

print("\nüéâ All requested trainings are complete.")


* Downlaod VM folders of Results


In [None]:
# ================================
# ZIP + Download results folder (uses Cell-1 variables)
# ================================
import os, shutil
from pathlib import Path
from google.colab import files

# --- Must come from Cell 1 ---
assert "RESULTS_ROOT" in globals(), "RESULTS_ROOT not found. Run Cell 1 first."
assert "RESULTS_FOLDER_NAME" in globals(), "RESULTS_FOLDER_NAME not found. Run Cell 1 first."
assert "SAVE_TO_DRIVE" in globals(), "SAVE_TO_DRIVE not found. Run Cell 1 first."

RESULTS_ROOT = Path(RESULTS_ROOT)
RESULTS_FOLDER_NAME = str(RESULTS_FOLDER_NAME)

# (Option 2 ‡¶π‡¶≤‡ßá SAVE_TO_DRIVE=False ‡¶π‡¶¨‡ßá, ‡¶ï‡¶ø‡¶®‡ßç‡¶§‡ßÅ ‡¶è‡¶ü‡¶æ ‡¶¶‡ßÅ‡¶á ‡¶ï‡ßç‡¶∑‡ßá‡¶§‡ßç‡¶∞‡ßá‡¶á ‡¶ï‡¶æ‡¶ú ‡¶ï‡¶∞‡¶¨‡ßá)
print("SAVE_TO_DRIVE:", SAVE_TO_DRIVE)
print("RESULTS_FOLDER_NAME:", RESULTS_FOLDER_NAME)
print("RESULTS_ROOT:", RESULTS_ROOT)

assert RESULTS_ROOT.exists() and RESULTS_ROOT.is_dir(), f"Results folder not found: {RESULTS_ROOT}"

# Zip will be created in /content so Colab can download it easily
zip_base = Path("/content") / RESULTS_FOLDER_NAME
zip_path = shutil.make_archive(
    base_name=str(zip_base),       # /content/<foldername>
    format="zip",
    root_dir=str(RESULTS_ROOT.parent),
    base_dir=str(RESULTS_ROOT.name)
)

print("üì¶ Zipped to:", zip_path)

# Download
files.download(zip_path)
