# Project: Brain Tumor Segmentation and Classification 

## Details of Dataset Creation

#### [Data Source - The Cancer Imaging Archive - TCIA](https://www.cancerimagingarchive.net/browse-collections)

- **UCSF-PGDM:** Glioblastoma - 495
- **BRATS-AFRICA:** Glioma - 95
- **MU-Glioma-Post:** Glioma - 203
- **UCSD-PTGBM:** Glioblastoma - 178
- **UPENN-GBM:** Glioblastoma - 630
- **BCBM-RadioGenomics:** Brain Mets - 165
- **Pretreat-MetsToBrain-Masks:** Brain Mets - 200

### Segmentation Dataset
- **Numbers:** *495+95+203+178+630+165+200 = 1966*. But desired segmentation along with desired MRI sequence was present for - **1388** 
- **Segmentation:** Tumor core plus Enhancing area, Single segmentation mask and where two masks were provided like tumor core and enhancing area, then both the mask was combined and a single mask was derived - tumor with enhnacing area
- **MRI sequence** used was - *T1 Contrast*
- File *dataset.json* was created - Details of file path for MRI sequence and Segmentation and other dataset details

### Classification dataset
- **Numbers:** *Training(972)* - Gliomas: 647 Brain Mets: 325 *Validation(209)* - Gliomas:139 Brain Mets: 70 *Test(207)* - Gliomas:138 Brain Mets: 69
- File *train.csv*, *val.csv* and *test.csv* was created which had class labels, image path, segmentation path and case_id

## Importing Libraries used in this project

In [None]:
import os, math, time, json, random, csv, hashlib, platform, subprocess
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
import nibabel as nib

from monai.config import print_config
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.inferers import SlidingWindowInferer
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import DynUNet
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    RandFlipd,
    RandRotate90d,
    RandAffined,
    AsDiscreted,
    CastToTyped,
)
from monai.utils import set_determinism

print_config()

## Experiment Setup

In [None]:
# Reproducibility
SEED = 42
set_determinism(SEED)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    try:
        gpu_name = torch.cuda.get_device_name(0)
    except Exception:
        gpu_name = "Unknown CUDA device"
    print(f"Device: {device} ({gpu_name})")
else:
    print(f"Device: {device}")

# Paths
PROJ_ROOT = Path("/home/ant/projects/brain_tumor_segmentation")
DUALTASK_ROOT = PROJ_ROOT / "derived" / "unified_dualtask"
TRAIN_CSV = DUALTASK_ROOT / "train.csv"
VAL_CSV = DUALTASK_ROOT / "val.csv"
TEST_CSV = DUALTASK_ROOT / "test.csv"

# Basic path checks with clear messages
missing = [p for p in [DUALTASK_ROOT, TRAIN_CSV, VAL_CSV, TEST_CSV] if not p.exists()]
assert not missing, f"Missing required paths: {', '.join(str(p) for p in missing)}"

# Target spacing and patch params
TARGET_SPACING = (0.8, 0.8, 1.0)
PATCH_SIZE = (192, 192, 160)
PATCH_OVERLAP = 0.5  # sliding window overlap

# Checkpoint directory
ckpt_dir = PROJ_ROOT / "runs" / "dualtask_monai_v01"
ckpt_dir.mkdir(parents=True, exist_ok=True)
print("Checkpoint dir:", ckpt_dir)

## Run/Session paths and CSV helper

In [None]:
RUN_ID = datetime.now().strftime("%Y%m%d-%H%M%S")
RUN_DIR = ckpt_dir / RUN_ID
RUN_DIR.mkdir(parents=True, exist_ok=True)

# Common artifact paths for this run
METRICS_CSV = RUN_DIR / "metrics.csv"
CONFIG_JSON = RUN_DIR / "config.json"
ENV_JSON = RUN_DIR / "env.json"
QC_CSV = RUN_DIR / "qc_epoch_summary.csv"

def write_csv_header(path: Path, header: List[str]):
    """
    Open a CSV for appending and write the header if the file doesn't exist yet.
    Returns (file_handle, csv_writer). Caller is responsible for closing the file_handle.
    """
    is_new = not path.exists()
    f = open(path, "a", newline="")
    w = csv.writer(f)
    if is_new:
        w.writerow(header)
        f.flush()
        os.fsync(f.fileno())
    return f, w

In [None]:
# CSV → list[dict] helpers with validation

from typing import Any

def read_unified_csv(path: Path, check_files: bool = True) -> List[Dict[str, Any]]:
    df = pd.read_csv(path)
    expected_cols = {"case_id", "class_label", "image_path", "label_path"}
    missing = expected_cols - set(df.columns)
    assert not missing, f"Missing columns in {path.name}: {sorted(missing)}"

    # Clean and coerce types
    df = df.copy()
    df["case_id"] = df["case_id"].astype(str).str.strip()
    df["image_path"] = df["image_path"].astype(str).str.strip()
    df["label_path"] = df["label_path"].astype(str).str.strip()
    df["class_label"] = pd.to_numeric(df["class_label"], errors="coerce").fillna(-1).astype(int)

    # Optional: drop rows whose files don’t exist
    if check_files:
        def exists(p: str) -> bool: return Path(p).exists()
        bad_mask = (~df["image_path"].map(exists)) | (~df["label_path"].map(exists))
        if bad_mask.any():
            n_bad = int(bad_mask.sum())
            print(f"[WARN] {n_bad} rows dropped from {path.name} due to missing files")
            df = df.loc[~bad_mask]

    # Build MONAI-style dicts
    items: List[Dict[str, Any]] = [
        {
            "case_id": r["case_id"],
            "image": r["image_path"],
            "label": r["label_path"],
            "class_label": int(r["class_label"]),
        }
        for _, r in df.iterrows()
    ]
    return items

train_items = read_unified_csv(TRAIN_CSV)
val_items = read_unified_csv(VAL_CSV)
test_items = read_unified_csv(TEST_CSV)

len(train_items), len(val_items), len(test_items)

## Morphology utilities and QC counters

In [None]:
import scipy.ndimage as ndi
from typing import Optional, List, Tuple

class LabelQC:
    """
    Tracks label shrinkage across patches/cases.
    - shrink_warn_threshold: fraction of volume lost (e.g., 0.35 → warn if after < 65% of before)
    - verbose: if True, prints a line for each warning
    """
    def __init__(self, shrink_warn_threshold: float = 0.35, verbose: bool = True):
        self.shrink_warn_threshold = float(shrink_warn_threshold)
        self.verbose = bool(verbose)
        self.total: int = 0
        self.warn: int = 0
        self.flagged_examples: List[Tuple[str, int, int, float]] = []  # (case_id, before, after, ratio)

    def reset(self) -> None:
        self.total = 0
        self.warn = 0
        self.flagged_examples.clear()

    @property
    def rate(self) -> float:
        return self.warn / max(1, self.total)

    def update(self, before_voxels: int, after_voxels: int, case_id: str) -> None:
        self.total += 1
        if before_voxels <= 0:
            return
        ratio = (after_voxels + 1e-6) / (before_voxels + 1e-6)
        if ratio < (1.0 - self.shrink_warn_threshold):
            self.warn += 1
            self.flagged_examples.append((case_id, int(before_voxels), int(after_voxels), float(ratio)))
            if self.verbose:
                print(f"[QC] label shrinkage: {case_id} before={before_voxels} after={after_voxels} ratio={ratio:.3f}")

    def summary(self) -> None:
        print(f"[QC] shrinkage warnings: {self.warn}/{self.total} (rate={self.rate:.4f})")

def binary_dilate_then_erode(mask: np.ndarray, radius_vox: int = 1, connectivity: int = 1) -> np.ndarray:
    """
    Light morphological close (dilate then erode) to protect thin/fragmented labels.
    - mask: 3D array (D,H,W). Nonzeros are treated as foreground.
    - radius_vox: number of iterations for dilation and erosion. 0 → no-op.
    - connectivity: 1 (faces), 2 (faces+edges), or 3 (faces+edges+corners).
    Returns uint8 mask (0/1).
    """
    if radius_vox <= 0:
        return (mask > 0).astype(np.uint8)
    if mask.ndim != 3:
        raise ValueError(f"Expected 3D mask, got shape {mask.shape}")

    mask_bool = (mask > 0)
    structure = ndi.generate_binary_structure(rank=3, connectivity=int(connectivity))
    dil = ndi.binary_dilation(mask_bool, structure=structure, iterations=int(radius_vox))
    ero = ndi.binary_erosion(dil, structure=structure, iterations=int(radius_vox))
    return ero.astype(np.uint8)

## Transformation and Label Handeling

In [None]:
# Transforms: spacing standardization and intensity scale
# Labels: post-process after Spacingd to preserve small lesions (no extra soft resampling)

from monai.transforms import MapTransform

class LabelPostProcessd(MapTransform):
    """
    - Ensures label shape matches image shape (nearest-neighbor up/down-sample if needed)
    - Optional light morphology (dilate→erode) to protect thin/speck lesions
    - Records QC counts: qc_before_vox, qc_after_vox
    """
    def __init__(self, keys, ref_key: str = "image", morph_radius: int = 1, allow_missing_keys: bool = False):
        super().__init__(keys, allow_missing_keys)
        self.ref_key = ref_key
        self.morph_radius = int(morph_radius)

    def __call__(self, data):
        d = dict(data)
        if "label" not in d:
            return d

        label = d["label"]  # Tensor [1, D, H, W]
        img = d.get(self.ref_key, None)
        before_vox = int((label > 0).sum().item())

        # Match label to image grid if needed (nearest to avoid smoothing)
        if img is not None and label.shape[1:] != img.shape[1:]:
            label = torch.nn.functional.interpolate(
                label.float(), size=img.shape[1:], mode="nearest"
            ).long()

        # Optional light morphology
        if self.morph_radius > 0:
            arr = (label > 0).cpu().numpy().astype(np.uint8)  # [1, D, H, W]
            arr = binary_dilate_then_erode(arr[0], radius_vox=self.morph_radius)[None]  # back to [1, ...]
            label = torch.as_tensor(arr, dtype=torch.long, device=d["label"].device)

        after_vox = int((label > 0).sum().item())

        d["label"] = label
        d["qc_before_vox"] = before_vox
        d["qc_after_vox"] = after_vox
        return d


# Common I/O and spacing (labels via nearest to avoid smoothing)
common_load = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    EnsureTyped(keys=["image", "label"], dtype=torch.float32),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=TARGET_SPACING, mode=("bilinear", "nearest")),
]

from monai.transforms import RandCropByPosNegLabeld, SpatialPadd

# Training intensity + spatial augs
intensity_train = [
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
    RandFlipd(keys=["image", "label"], spatial_axis=[0, 1, 2], prob=0.2),
    RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
    RandAffined(
        keys=["image", "label"],
        rotate_range=(math.pi/36, math.pi/36, math.pi/36),
        scale_range=(0.1, 0.1, 0.1),
        mode=("bilinear", "nearest"),
        prob=0.2,
    ),
    SpatialPadd(keys=["image", "label"], spatial_size=PATCH_SIZE),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=PATCH_SIZE,
        pos=1, neg=1, num_samples=1, image_key="image",
        allow_smaller=True,
    ),
]

# Validation intensity only
intensity_val = [
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
]

# Cast classification label
class CastClassLabeld(MapTransform):
    def __init__(self, keys, allow_missing_keys=False):
        super().__init__(keys, allow_missing_keys)
    def __call__(self, data):
        d = dict(data)
        if "class_label" in d:
            d["class_label"] = torch.as_tensor(d["class_label"], dtype=torch.float32)
        return d

# Assemble transforms
# - Post label process uses nearest resample if shapes differ (no soft one-hot resample)
# - Morph radius=1; raise to 2 if QC indicates too many tiny lesions vanish
post_label_preserve = [LabelPostProcessd(keys=["label"], ref_key="image", morph_radius=1)]

train_transforms = Compose(common_load + intensity_train + post_label_preserve + [CastClassLabeld(keys=["class_label"])])
val_transforms = Compose(common_load + intensity_val + post_label_preserve + [CastClassLabeld(keys=["class_label"])])

## Datasets and Loaders (with reproducible workers)

In [None]:
# QC accumulators (patch-level stats already added in transforms and training loop)
qc_train = LabelQC(shrink_warn_threshold=0.35, verbose=True)
qc_val = LabelQC(shrink_warn_threshold=0.35, verbose=True)

# Reproducible dataloader workers
def seed_worker(worker_id: int):
    base_seed = SEED
    np.random.seed(base_seed + worker_id)
    random.seed(base_seed + worker_id)
    torch.manual_seed(base_seed + worker_id)

use_pin = (device.type == "cuda")
g = torch.Generator()
g.manual_seed(SEED)

# Datasets (cache_rate=0.0 to avoid RAM pressure; switch to >0.0 if you want caching)
train_ds = CacheDataset(data=train_items, transform=train_transforms, cache_rate=0.0, num_workers=0)
val_ds   = CacheDataset(data=val_items,   transform=val_transforms,   cache_rate=0.0, num_workers=0)

# DataLoaders
train_loader = DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True,
    num_workers=2,
    pin_memory=use_pin,
    persistent_workers=True,
    worker_init_fn=seed_worker,
    generator=g,
    prefetch_factor=2,
)

val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    pin_memory=use_pin,
    persistent_workers=True,
    worker_init_fn=seed_worker,
    generator=g,
    prefetch_factor=2,
)

## Model: DynUNet backbone + classification head from encoder bottleneck

In [None]:
# Segmentation network (nnU-Net-like DynUNet)
seg_net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    kernel_size=[3, 3, 3, 3, 3, 3],     # 6 stages
    strides=[1, 2, 2, 2, 2, 2],         # length matches kernel_size
    upsample_kernel_size=[2, 2, 2, 2, 2],
    norm_name="instance",
    deep_supervision=False,
).to(device)

# Classification head (lazy: initializes fully-connected layer on first forward)
class LazyClassificationHead(nn.Module):
    def __init__(self, num_classes: int = 1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = None
        self.num_classes = num_classes

    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        x = self.pool(feat).flatten(1)
        if self.fc is None:
            self.fc = nn.Linear(x.shape[1], self.num_classes).to(x.device)
        return self.fc(x)

cls_head = LazyClassificationHead(num_classes=1).to(device)

# Hook to capture encoder bottleneck features for classification
encoder_feat = {"x": None}
def hook_fn(module, input, output):
    encoder_feat["x"] = output

# Attach hook to a stable location in DynUNet
if hasattr(seg_net, "bottleneck"):
    seg_net.bottleneck.register_forward_hook(hook_fn)
elif hasattr(seg_net, "encoder4"):
    seg_net.encoder4.register_forward_hook(hook_fn)
else:
    print("[WARN] Could not attach hook; classification head may not receive features")

# Losses and optimizer
seg_loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
cls_loss_fn = nn.BCEWithLogitsLoss()

params = list(seg_net.parameters()) + list(cls_head.parameters())
optimizer = torch.optim.AdamW(params, lr=1e-5, weight_decay=1e-5)

# AMP scaler
scaler = GradScaler(enabled=torch.cuda.is_available())

# Segmentation metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")

## Helper: pad tensor to next multiple-of factor for each spatial dim (+ safe crop helper)

In [None]:
import torch.nn.functional as F

def pad_to_factor(x: torch.Tensor, factor=32, return_pad: bool = False, mode: str = "constant", value: float = 0.0):
    """
    Pads a 5D tensor (B, C, D, H, W) so each spatial dim is a multiple of `factor`.
    - factor: int or (fD, fH, fW)
    - Pads only on the "right/bottom/back" to avoid shifting coordinates
    - If return_pad=True, also returns the pad tuple (Wl, Wr, Hl, Hr, Dl, Dr)
    """
    assert x.dim() == 5, f"Expected 5D tensor (B,C,D,H,W), got shape {tuple(x.shape)}"
    if isinstance(factor, int):
        fD = fH = fW = factor
    else:
        assert len(factor) == 3, "factor must be int or 3-tuple"
        fD, fH, fW = factor

    B, C, D, H, W = x.shape
    def next_m(s, f): return ((s + f - 1) // f) * f
    Dn, Hn, Wn = next_m(D, fD), next_m(H, fH), next_m(W, fW)
    pd, ph, pw = Dn - D, Hn - H, Wn - W
    pad = (0, pw, 0, ph, 0, pd)  # (W_left, W_right, H_left, H_right, D_left, D_right)

    if any(p > 0 for p in pad):
        x = F.pad(x, pad, mode=mode, value=value)

    return (x, pad) if return_pad else x

def crop_to_shape(x: torch.Tensor, shape: tuple) -> torch.Tensor:
    """
    Crops tensor x (B,C,D,H,W) to spatial shape `shape` = (D,H,W), slicing from the start along each dim.
    """
    assert x.dim() == 5 and len(shape) == 3, "x must be 5D and shape must be (D,H,W)"
    Dz, Hy, Wx = map(int, shape)
    return x[..., :Dz, :Hy, :Wx]

## Inferer for validation/test

In [None]:
inferer = SlidingWindowInferer(
    roi_size=PATCH_SIZE,
    sw_batch_size=1,
    overlap=PATCH_OVERLAP,
    mode="gaussian",
)

# Utils
from typing import Dict, Any, Tuple, Optional

def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            out[k] = v.to(device, non_blocking=True)
        else:
            out[k] = v
    return out

def sliding_infer_padded(
    images: torch.Tensor,
    network: nn.Module,
    factor: int = 32,
    target_shape: Optional[Tuple[int, int, int]] = None,
) -> torch.Tensor:
    images_p, _ = pad_to_factor(images, factor=factor, return_pad=True)
    logits = inferer(inputs=images_p, network=network)
    if target_shape is not None and logits.shape[-3:] != tuple(target_shape):
        logits = crop_to_shape(logits, target_shape)
    return logits

## Setting up training configuration and saving it.

In [None]:
def file_sha1(path: Path | str) -> str | None:
    try:
        with open(path, "rb") as f:
            return hashlib.sha1(f.read()).hexdigest()
    except Exception:
        return None

config = {
    "run_id": RUN_ID,
    "seed": SEED,
    "device": str(device),
    "gpu_name": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None),
    "target_spacing": tuple(TARGET_SPACING),
    "patch_size": tuple(PATCH_SIZE),
    "inferer": {"roi_size": tuple(PATCH_SIZE), "overlap": float(PATCH_OVERLAP), "mode": "gaussian"},
    "model": {
        "arch": "DynUNet",
        "in_channels": 1,
        "out_channels": 2,
        "deep_supervision": False,
        "norm_name": "instance",
    },
    "loss": {"seg": "DiceCELoss", "cls": "BCEWithLogitsLoss", "cls_weight": 0.3},
    "optimizer": {
        "type": "AdamW",
        "lr": float(optimizer.param_groups[0]["lr"]),
        "weight_decay": float(optimizer.param_groups[0].get("weight_decay", 0.0)),
    },
    "augs": {
        "flip_p": 0.2, "rot90_p": 0.2,
        "affine": {"rot_deg": 5, "scale_pct": 10},
        "crop": "RandCropByPosNegLabel",
    },
    # Reflect current pipeline (nearest for labels + light morph).
    "transforms_notes": "Spacingd (image=bilinear, label=nearest), LabelPostProcessd(morph_radius=1)",
    "splits": {
        "train_csv": str(TRAIN_CSV), "val_csv": str(VAL_CSV), "test_csv": str(TEST_CSV),
        "train_csv_sha1": file_sha1(TRAIN_CSV), "val_csv_sha1": file_sha1(VAL_CSV), "test_csv_sha1": file_sha1(TEST_CSV),
    },
}

with open(CONFIG_JSON, "w") as f:
    json.dump(config, f, indent=2)
print("Saved config:", CONFIG_JSON)

In [None]:
# Environment snapshot (for full reproducibility)
env = {
    "python": platform.python_version(),
    "os": platform.platform(),
    "torch": torch.__version__,
    "torch_cuda": (torch.version.cuda if torch.cuda.is_available() else None),
    "cudnn": (torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None),
    "monai": __import__("monai").__version__,
    "numpy": np.__version__,
    "pandas": pd.__version__,
    "gpu_count": (torch.cuda.device_count() if torch.cuda.is_available() else 0),
    "gpu_name": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None),
    "seed": SEED,
}

# List all GPU names if multiple are present
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    env["gpu_names"] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]

# Optional: exact package set (can be large)
try:
    env["pip_freeze"] = subprocess.check_output(["pip", "freeze"]).decode().splitlines()
except Exception:
    pass

with open(ENV_JSON, "w") as f:
    json.dump(env, f, indent=2)
print("Saved env:", ENV_JSON)

## Resume Training: prefer most recent "last.pt", otherwise fall back to "best.pt"

In [None]:
def find_resume_ckpt() -> Path:
    candidates = []
    # Prefer run-scoped last/best if present
    if "RUN_DIR" in globals():
        candidates += [RUN_DIR / "last.pt", RUN_DIR / "best.pt"]
    # Fallback to global dir
    candidates += [ckpt_dir / "last.pt", ckpt_dir / "best.pt"]
    for p in candidates:
        if p.exists():
            return p
    raise FileNotFoundError(f"No checkpoint found in {ckpt_dir} (looked for last.pt/best.pt in run/global dirs).")

RESUME_CKPT = find_resume_ckpt()
print(f"Resuming from: {RESUME_CKPT}")

# Initialize lazy cls_head.fc before loading (captures encoder channels safely)
with torch.no_grad():
    was_training = seg_net.training
    seg_net.eval()
    encoder_feat["x"] = None
    dummy = torch.zeros(1, 1, 64, 64, 64, device=device)
    _ = seg_net(dummy)
    feat = encoder_feat["x"] if encoder_feat["x"] is not None else dummy
    _ = cls_head(feat)  # initializes cls_head.fc lazily
    seg_net.train(was_training)

ckpt = torch.load(RESUME_CKPT, map_location=device)

missing, unexpected = seg_net.load_state_dict(ckpt["seg"], strict=False)
if missing or unexpected:
    print(f"[WARN] seg_net state_dict mismatches. missing={missing}, unexpected={unexpected}")
missing, unexpected = cls_head.load_state_dict(ckpt["cls"], strict=False)
if missing or unexpected:
    print(f"[WARN] cls_head state_dict mismatches. missing={missing}, unexpected={unexpected}")

if "optimizer" in ckpt:
    try:
        optimizer.load_state_dict(ckpt["optimizer"])
    except Exception as e:
        print(f"[WARN] could not load optimizer state: {e}")
if "scaler" in ckpt:
    try:
        scaler.load_state_dict(ckpt["scaler"])
    except Exception as e:
        print(f"[WARN] could not load scaler state: {e}")

# Epoch planning
start_epoch = int(ckpt.get("epoch", 0)) + 1
best_val_dice = float(ckpt.get("best_val_dice", -1.0))
EPOCHS_NEXT = 50
end_epoch = start_epoch + EPOCHS_NEXT - 1

print(f"Resuming at epoch {start_epoch} (best_val_dice={best_val_dice:.4f}); training until epoch {end_epoch}")

### Open metrics CSV (segmentation + training stats)

In [None]:
# Open metrics CSV (segmentation + training stats)
# Close with: metrics_f.close() at the end of training
metrics_f, metrics_w = write_csv_header(
    METRICS_CSV,
    ["epoch","split","loss","dice","acc","recall","f1","hd95","lr","seconds","qc_warns","qc_total","qc_rate"]
)

## Segmentation metric helpers (To be used with training/validation loops)

In [None]:
eps = 1e-8
hd95_tr = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)
hd95_val = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)

def new_cm():
    return {"tp": 0, "fp": 0, "fn": 0, "tn": 0}

def update_cm(cm, ypred, ytrue):
    yp = (ypred > 0).to(torch.bool)
    yt = (ytrue > 0).to(torch.bool)
    cm["tp"] += (yp & yt).sum().item()
    cm["fp"] += (yp & ~yt).sum().item()
    cm["fn"] += (~yp & yt).sum().item()
    cm["tn"] += (~yp & ~yt).sum().item()

def cm_metrics(cm):
    tp, fp, fn, tn = cm["tp"], cm["fp"], cm["fn"], cm["tn"]
    acc = (tp + tn) / max(1, tp + fp + fn + tn)
    rec = tp / max(1, tp + fn)
    pre = tp / max(1, tp + fp)
    f1 = (2 * pre * rec) / max(eps, pre + rec)
    return acc, rec, f1

## Main Training Loop

In [None]:
# Train / Val loops (refactored: correct metrics, single CSV writer, stable val Dice)

# Assumes:
# - start_epoch, end_epoch, best_val_dice defined (from resume block); if not, set them here.
# - metrics_f, metrics_w already opened by write_csv_header with header:
#   ["epoch","split","loss","dice","acc","recall","f1","hd95","lr","seconds","qc_warns","qc_total","qc_rate"]

if "start_epoch" not in globals() or "end_epoch" not in globals():
    EPOCHS = 500
    start_epoch, end_epoch = 1, EPOCHS
if "best_val_dice" not in globals():
    best_val_dice = -1.0

val_interval = 1

t0_all = time.time()
for epoch in range(start_epoch, end_epoch + 1):
    t_epoch = time.time()

    # ---- TRAIN ----
    cm_train = new_cm()
    seg_net.train(); cls_head.train()
    epoch_loss = 0.0
    num_steps = 0

    for batch in train_loader:
        # QC update per-sample (patch-level)
        for b in decollate_batch(batch):
            qc_train.update(int(b.get("qc_before_vox", 0)), int(b.get("qc_after_vox", 0)), str(b.get("case_id", "?")))
        batch = to_device(batch, device)
        images = batch["image"]
        labels = batch["label"].long()
        class_labels = batch["class_label"].view(-1, 1)

        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            # segmentation forward
            encoder_feat["x"] = None
            seg_logits = seg_net(images)                 # deep_supervision=False
            seg_logits_main = seg_logits

            # train segmentation metrics (per-batch)
            y_pred_tr = torch.argmax(torch.softmax(seg_logits_main, dim=1), dim=1, keepdim=True)
            update_cm(cm_train, y_pred_tr, labels)
            hd95_tr(y_pred_tr, labels)

            # classification forward
            feat = encoder_feat["x"] if encoder_feat["x"] is not None else seg_logits_main
            cls_logits = cls_head(feat)

            # losses
            loss_seg = seg_loss_fn(seg_logits_main, labels)
            loss_cls = cls_loss_fn(cls_logits, class_labels)
            loss = loss_seg + 0.3 * loss_cls

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        num_steps += 1

    epoch_loss /= max(1, num_steps)
    tr_acc, tr_rec, tr_f1 = cm_metrics(cm_train)
    tr_hd95 = hd95_tr.aggregate().item(); hd95_tr.reset()
    train_lr = optimizer.param_groups[0]["lr"]
    qc_rate_tr = (qc_train.warn / max(1, qc_train.total))

    print(f"Epoch {epoch}/{end_epoch} - train loss: {epoch_loss:.4f}")
    print(f"  Train: acc={tr_acc:.4f} rec={tr_rec:.4f} f1={tr_f1:.4f} hd95={tr_hd95:.2f}")
    metrics_w.writerow([
        epoch, "train",
        f"{epoch_loss:.6f}", "", f"{tr_acc:.6f}", f"{tr_rec:.6f}", f"{tr_f1:.6f}", f"{tr_hd95:.6f}",
        f"{train_lr:.6g}", f"{(time.time()-t_epoch):.2f}", qc_train.warn, qc_train.total, f"{qc_rate_tr:.4f}"
    ]); metrics_f.flush()

    # reset QC per-epoch if you prefer epoch-local stats
    qc_train.total = qc_train.warn = 0
    if epoch % val_interval == 0:
        qc_train.summary()

    # ---- VAL ----
    if epoch % val_interval == 0:
        cm_val = new_cm()
        seg_net.eval(); cls_head.eval()
        dice_metric.reset()
        val_loss = 0.0
        steps = 0
        with torch.no_grad():
            for batch in val_loader:
                for b in decollate_batch(batch):
                    qc_val.update(int(b.get("qc_before_vox", 0)), int(b.get("qc_after_vox", 0)), str(b.get("case_id", "?")))
                batch = to_device(batch, device)
                images = batch["image"]
                labels = batch["label"].long()
                class_labels = batch["class_label"].view(-1, 1)

                with autocast(device_type="cuda", enabled=torch.cuda.is_available()):
                    # padded sliding-window seg
                    images_p = pad_to_factor(images, factor=32)
                    seg_logits = inferer(inputs=images_p, network=seg_net)

                    # classification (populate encoder features on same padded grid)
                    encoder_feat["x"] = None
                    _ = seg_net(images_p)
                    if seg_logits.shape[-3:] != labels.shape[-3:]:
                        seg_logits = crop_to_shape(seg_logits, labels.shape[-3:])

                    feat = encoder_feat["x"] if encoder_feat["x"] is not None else seg_logits
                    cls_logits = cls_head(feat)

                    # per-batch seg metrics
                    y_pred = torch.argmax(torch.softmax(seg_logits, dim=1), dim=1, keepdim=True)
                    dice_metric(y_pred=y_pred, y=labels)
                    update_cm(cm_val, y_pred, labels)
                    hd95_val(y_pred, labels)

                    # per-batch val loss
                    loss_seg = seg_loss_fn(seg_logits, labels)
                    loss_cls = cls_loss_fn(cls_logits, class_labels)
                    loss = loss_seg + 0.3 * loss_cls
                val_loss += loss.item()
                steps += 1

        mean_dice = dice_metric.aggregate().item(); dice_metric.reset()
        val_loss /= max(1, steps)
        vl_acc, vl_rec, vl_f1 = cm_metrics(cm_val)
        vl_hd95 = hd95_val.aggregate().item(); hd95_val.reset()
        qc_rate_val = (qc_val.warn / max(1, qc_val.total))

        print(f"  Val loss: {val_loss:.4f} | Val Dice(tumor): {mean_dice:.4f}")
        print(f"  Val  : acc={vl_acc:.4f} rec={vl_rec:.4f} f1={vl_f1:.4f} hd95={vl_hd95:.2f}")
        qc_val.summary()

        metrics_w.writerow([
            epoch, "val",
            f"{val_loss:.6f}", f"{mean_dice:.6f}", f"{vl_acc:.6f}", f"{vl_rec:.6f}", f"{vl_f1:.6f}", f"{vl_hd95:.6f}",
            f"{train_lr:.6g}", f"{(time.time()-t_epoch):.2f}", qc_val.warn, qc_val.total, f"{qc_rate_val:.4f}"
        ]); metrics_f.flush()
        qc_val.total = qc_val.warn = 0

        # ---- Checkpointing ----
        def save_ckpt(path: Path):
            torch.save({
                "epoch": epoch,
                "seg": seg_net.state_dict(),
                "cls": cls_head.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scaler": scaler.state_dict(),
                "best_val_dice": best_val_dice,
                "config_path": str(CONFIG_JSON) if "CONFIG_JSON" in globals() else None,
                "env_path": str(ENV_JSON) if "ENV_JSON" in globals() else None,
            }, path)

        run_last = (RUN_DIR if "RUN_DIR" in globals() else ckpt_dir) / "last.pt"
        save_ckpt(run_last)

        if mean_dice > best_val_dice:
            best_val_dice = mean_dice
            run_best = (RUN_DIR if "RUN_DIR" in globals() else ckpt_dir) / "best.pt"
            save_ckpt(run_best)
            # convenience copy
            save_ckpt(ckpt_dir / "best.pt")
            print(f"  [Saved] best.pt with Dice {best_val_dice:.4f}")

        # encoder-only (save full state; downstream can load encoder subset)
        torch.save({"seg_encoder_compatible": True, "state_dict": seg_net.state_dict()},
                   (RUN_DIR if "RUN_DIR" in globals() else ckpt_dir) / "encoder_fullstate.pt")

        # Cleanup
        try:
            del images, labels, seg_logits, cls_logits, feat, y_pred
        except:
            pass
        try:
            del images_p
        except:
            pass
        torch.cuda.empty_cache()

print(f"Training done in {(time.time()-t0_all)/60:.1f} min")
# Close the metrics file (kept open for speed)
try:
    metrics_f.close()
except:
    pass