# Project: Brain Tumor Segmentation and Classification 

## Details of Dataset Creation

#### [Data Source - The Cancer Imaging Archive - TCIA -- CLICK HERE](https://www.cancerimagingarchive.net/browse-collections)
#### [Data prepration and full training and evaluation code -- CLICK HERE](https://github.com/kundan1974/DHAI-Brain-Segmentation)
- **UCSF-PGDM:** Glioblastoma - 495
- **BRATS-AFRICA:** Glioma - 95
- **MU-Glioma-Post:** Glioma - 203
- **UCSD-PTGBM:** Glioblastoma - 178
- **UPENN-GBM:** Glioblastoma - 630
- **BCBM-RadioGenomics:** Brain Mets - 165
- **Pretreat-MetsToBrain-Masks:** Brain Mets - 200

### Segmentation Dataset
- **Numbers:** *495+95+203+178+630+165+200 = 1966*. But desired segmentation along with desired MRI sequence was present for - **1388** 
- **Segmentation:** Tumor core plus Enhancing area, Single segmentation mask and where two masks were provided like tumor core and enhancing area, then both the mask was combined and a single mask was derived - tumor with enhnacing area
- **MRI sequence** used was - *T1 Contrast*
- File *dataset.json* was created - Details of file path for MRI sequence and Segmentation and other dataset details

### Classification dataset
- **Numbers:** *Training(972)* - Gliomas: 647 Brain Mets: 325 *Validation(209)* - Gliomas:139 Brain Mets: 70 *Test(207)* - Gliomas:138 Brain Mets: 69
- File *train.csv*, *val.csv* and *test.csv* was created which had class labels, image path, segmentation path and case_id

## 🔧 Imports and Environment Setup

This cell imports all the required **standard libraries**, **PyTorch modules**, and **MONAI components** for our brain tumor segmentation and classification pipeline. Here's a breakdown:

### 📦 Standard Python & OS Libraries
- `os`, `math`, `time`, `json`, `random`, `csv`, `hashlib`, `platform`, `subprocess`: Used for file operations, randomization, timing, system information, and scripting utilities.
- `datetime`, `pathlib.Path`: Helpful for managing timestamps and path structures in a platform-independent way.
- `typing`: Provides type hints for better code clarity and development support.

### 🔢 Numerical & Data Libraries
- `numpy`, `pandas`: Essential for handling numerical arrays and structured data tables respectively.

### 🔥 PyTorch Core
- `torch`, `torch.nn`, `torch.nn.functional`: For building, training, and evaluating deep learning models.
- `torch.amp.autocast`, `GradScaler`: For enabling mixed precision training to speed up training and reduce GPU memory usage.

### 🧠 Medical Imaging and Segmentation Tools (MONAI)
- `nibabel`: For loading and working with medical image formats like NIfTI (`.nii`, `.nii.gz`).
- `monai.config.print_config()`: Prints MONAI, PyTorch, and environment versions for reproducibility.
- `monai.data`: Includes `CacheDataset` and `DataLoader` for efficient data handling and batch loading.
- `monai.inferers.SlidingWindowInferer`: Used for patch-wise inference on large 3D volumes.
- `monai.losses.DiceCELoss`: Hybrid Dice + CrossEntropy loss commonly used in segmentation tasks.
- `monai.metrics`: Includes metrics such as Dice and Hausdorff for evaluating segmentation accuracy.
- `monai.networks.nets.DynUNet`: A dynamic, flexible 3D U-Net model used for medical segmentation.
- `monai.transforms`: Provides various preprocessing and augmentation transforms specifically tailored for medical image analysis.
- `monai.utils.set_determinism`: Ensures reproducible results by fixing seeds and deterministic behavior.

### ✅ Purpose
This cell sets the stage for all subsequent steps — including data preparation, model building, training, evaluation, and inference — by importing all dependencies and printing the current MONAI + PyTorch environment configuration.

In [None]:
import os, math, time, json, random, csv, hashlib, platform, subprocess
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Optional

import scipy.ndimage as ndi

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
import nibabel as nib

from monai.config import print_config
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.inferers import SlidingWindowInferer
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import DynUNet
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    RandFlipd,
    RandRotate90d,
    RandAffined,
    AsDiscreted,
    CastToTyped,
)
from monai.utils import set_determinism

print_config()

## ⚙️ Reproducibility, Device Setup, and Configuration Paths

This cell establishes essential configurations to ensure reproducibility, select compute device (GPU/CPU), and define file paths and model parameters.

### 🔁 Reproducibility
- `SEED = 42` sets a fixed seed for random number generation.
- `set_determinism(SEED)` ensures reproducibility by:
  - Seeding all necessary libraries (NumPy, PyTorch, etc.)
  - Enabling deterministic behavior for CUDA operations (when applicable).

### 💻 Device Selection
- `torch.device(...)` selects GPU (`cuda`) if available, otherwise defaults to CPU.
- If using CUDA, it prints the name of the GPU (e.g., “NVIDIA A6000”).
- This helps verify that model training/inference will utilize available hardware acceleration.

### 📂 Dataset and Project Paths
- `PROJ_ROOT`: Root directory of the project on the local system.
- `DUALTASK_ROOT`: Subdirectory that contains derived data for dual-task learning (segmentation + classification).
- `TRAIN_CSV`, `VAL_CSV`, `TEST_CSV`: Point to training, validation, and test CSV files, respectively.
- An assertion ensures all these paths exist before proceeding.

### 🧊 Spatial and Patch Parameters
- `TARGET_SPACING`: The voxel spacing to which all MRI volumes will be resampled. Standardizing spacing is crucial for 3D medical image processing.
- `PATCH_SIZE`: Size of the 3D patch that will be extracted from each image for model training and inference. Used by the sliding window.
- `PATCH_OVERLAP`: Defines the overlap fraction between adjacent patches during inference (0.5 = 50% overlap).

### 💾 Checkpoint Directory
- `ckpt_dir`: Directory to save model weights and logs.
- `mkdir(..., exist_ok=True)` creates the directory if it doesn't already exist.
- Path is printed to confirm where checkpoints will be stored during training and evaluation.

In [None]:
# Reproducibility
SEED = 42
set_determinism(SEED)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    try:
        gpu_name = torch.cuda.get_device_name(0)
    except Exception:
        gpu_name = "Unknown CUDA device"
    print(f"Device: {device} ({gpu_name})")
else:
    print(f"Device: {device}")

# Paths
PROJ_ROOT = Path("/home/ant/projects/brain_tumor_segmentation")
DUALTASK_ROOT = PROJ_ROOT / "derived" / "unified_dualtask"
TRAIN_CSV = DUALTASK_ROOT / "train.csv"
VAL_CSV = DUALTASK_ROOT / "val.csv"
TEST_CSV = DUALTASK_ROOT / "test.csv"

# Basic path checks 
missing = [p for p in [DUALTASK_ROOT, TRAIN_CSV, VAL_CSV, TEST_CSV] if not p.exists()]
assert not missing, f"Missing required paths: {', '.join(str(p) for p in missing)}"

# Target spacing and patch params
TARGET_SPACING = (0.8, 0.8, 1.0)
PATCH_SIZE = (192, 192, 160)
PATCH_OVERLAP = 0.5  # sliding window overlap

# Checkpoint directory
ckpt_dir = PROJ_ROOT / "runs" / "dualtask_monai_v01"
ckpt_dir.mkdir(parents=True, exist_ok=True)
print("Checkpoint dir:", ckpt_dir)

## 🗂️ Run-Specific Directory & Logging Setup

This cell sets up unique directories and file paths for the current training or inference run, allowing us to store and track artifacts in an organized way.

### 🆔 Unique Run ID & Directory
- `RUN_ID`: A timestamp-based string (e.g., `"20250907-103845"`) that uniquely identifies each run.
- `RUN_DIR`: Subfolder under the checkpoint directory (`ckpt_dir`) specific to this run.
- `mkdir(..., exist_ok=True)`: Ensures the folder is created even if it already exists.

### 📁 Paths for Artifacts
Within `RUN_DIR`, we predefine the following key output file paths:
- `METRICS_CSV`: To store evaluation metrics per epoch (e.g., Dice score, loss, etc.)
- `CONFIG_JSON`: To store the configuration parameters used during this run.
- `ENV_JSON`: To log software and hardware environment info (e.g., PyTorch/MONAI version).
- `QC_CSV`: To summarize quality control or performance across epochs.

These logs help with reproducibility, analysis, and comparison of different experiments.

### 📝 CSV Utility Function
The `write_csv_header()` function:
- Opens a CSV file in append mode.
- Writes the header row only if the file doesn't already exist.
- Flushes and syncs the file immediately to ensure safe write.
- Returns the file handle and CSV writer object so that the caller can continue logging rows and close it when done.

This is useful for writing structured logs (like metrics or quality control summaries) across multiple epochs or batches.

In [None]:
RUN_ID = datetime.now().strftime("%Y%m%d-%H%M%S")
RUN_DIR = ckpt_dir / RUN_ID
RUN_DIR.mkdir(parents=True, exist_ok=True)

# Common artifact paths for this run
METRICS_CSV = RUN_DIR / "metrics.csv"
CONFIG_JSON = RUN_DIR / "config.json"
ENV_JSON = RUN_DIR / "env.json"
QC_CSV = RUN_DIR / "qc_epoch_summary.csv"

def write_csv_header(path: Path, header: List[str]):
    """
    Open a CSV for appending and write the header if the file doesn't exist yet.
    Returns (file_handle, csv_writer). Caller is responsible for closing the file_handle.
    """
    is_new = not path.exists()
    f = open(path, "a", newline="")
    w = csv.writer(f)
    if is_new:
        w.writerow(header)
        f.flush()
        os.fsync(f.fileno())
    return f, w

## 📄 CSV → List[Dict] Conversion with Validation

This cell defines a helper function to **read and validate dataset CSV files** (train/val/test), ensuring that each entry is correctly formatted and usable in the MONAI pipeline.

### 🧠 Function: `read_unified_csv(path, check_files=True)`
Reads a CSV (e.g., `train.csv`, `val.csv`, `test.csv`) and converts each row into a dictionary compatible with MONAI’s `Dataset` format. Key components:

#### ✅ Validation
- Checks for required columns: `case_id`, `class_label`, `image_path`, `label_path`.
- Raises an error if any expected column is missing.

#### 🧹 Cleaning & Type Coercion
- Trims whitespace from paths and IDs.
- Ensures `class_label` is a clean integer (invalid/missing values are coerced to `-1`).

#### ❌ Optional File Existence Check
- If `check_files=True`, verifies that both `image_path` and `label_path` point to valid files.
- Drops any rows with missing files and prints a warning showing how many were removed.

#### 🔁 Output Format
- Returns a `List[Dict]`, where each dictionary contains:
  - `"case_id"` – Unique identifier for the subject/case
  - `"image"` – Path to the NIfTI image file
  - `"label"` – Path to the segmentation label file
  - `"class_label"` – Integer label (e.g., for classification: 0 = benign, 1 = malignant)

This format is compatible with `monai.data.CacheDataset`.

### 📦 Dataset Preparation
- `train_items`, `val_items`, and `test_items` are populated by calling `read_unified_csv(...)` for each respective CSV file.
- The final line outputs the number of valid samples in each split as a tuple:  
  `(num_train, num_val, num_test)`

In [None]:
def read_unified_csv(path: Path, check_files: bool = True) -> List[Dict[str, Any]]:
    df = pd.read_csv(path)
    expected_cols = {"case_id", "class_label", "image_path", "label_path"}
    missing = expected_cols - set(df.columns)
    assert not missing, f"Missing columns in {path.name}: {sorted(missing)}"

    # Clean and coerce types
    df = df.copy()
    df["case_id"] = df["case_id"].astype(str).str.strip()
    df["image_path"] = df["image_path"].astype(str).str.strip()
    df["label_path"] = df["label_path"].astype(str).str.strip()
    df["class_label"] = pd.to_numeric(df["class_label"], errors="coerce").fillna(-1).astype(int)

    # Optional: drop rows whose files don’t exist
    if check_files:
        def exists(p: str) -> bool: return Path(p).exists()
        bad_mask = (~df["image_path"].map(exists)) | (~df["label_path"].map(exists))
        if bad_mask.any():
            n_bad = int(bad_mask.sum())
            print(f"[WARN] {n_bad} rows dropped from {path.name} due to missing files")
            df = df.loc[~bad_mask]

    # Build MONAI-style dicts
    items: List[Dict[str, Any]] = [
        {
            "case_id": r["case_id"],
            "image": r["image_path"],
            "label": r["label_path"],
            "class_label": int(r["class_label"]),
        }
        for _, r in df.iterrows()
    ]
    return items

train_items = read_unified_csv(TRAIN_CSV)
val_items = read_unified_csv(VAL_CSV)
test_items = read_unified_csv(TEST_CSV)

len(train_items), len(val_items), len(test_items)

## 🧪 Label Quality Control (QC) and Morphological Protection

This cell defines tools to assess and preserve label integrity across preprocessing or patch-based transformations, especially important for 3D medical segmentation tasks.

---

### 🧩 `LabelQC` Class – Shrinkage Detection

Tracks how much segmentation labels shrink during processing (e.g., spatial cropping or transformations), which could harm training and evaluation.

#### 🔍 Key Features:
- `shrink_warn_threshold`: Warns if **label volume is reduced by more than this fraction** (default 35% loss).
- `update(...)`: Compares voxel count before and after a transformation for a given `case_id`. If too much shrinkage is detected, it:
  - Increments warning count.
  - Saves the case info (`case_id`, voxel counts, shrinkage ratio).
  - Prints a warning if `verbose=True`.

#### 📊 Usage:
This class can be instantiated and used to:
1. Track voxel shrinkage per case.
2. Summarize how often labels were affected.
3. Collect examples of problematic cases.

---

### 🧪 `binary_dilate_then_erode(...)` – Morphological Closing Operation

Applies **binary dilation followed by erosion** to "close" small gaps in segmentation masks, helping preserve fragile or fragmented structures.

### 🔍 Inside the `binary_dilate_then_erode` Function

| Step                     | What's Happening                                              | Why It's Useful                                      |
|--------------------------|---------------------------------------------------------------|------------------------------------------------------|
| `mask > 0`               | Converts the mask into a binary format (1 = tumor, 0 = background) | Helps clearly define foreground (tumor) vs background |
| `generate_binary_structure(...)` | Defines the neighborhood connectivity: 6, 18, or 26 neighbors | Controls how "connected" the dilation/erosion should be |
| `binary_dilation(...)`  | Expands the foreground region slightly outward                | Fills small holes and connects broken regions        |
| `binary_erosion(...)`   | Shrinks the region back to near-original size                 | Keeps filled gaps but removes over-expansion         |
| `astype(np.uint8)`      | Converts the output to binary mask format (0s and 1s)         | Ensures compatibility with later processing steps    |

#### 🧬 Parameters:
- `mask`: 3D NumPy array representing a binary label mask.
- `radius_vox`: Number of iterations for dilation/erosion. `0` means no operation.
- `connectivity`: Determines voxel connectivity (1 = faces, 2 = faces + edges, 3 = faces + edges + corners).

#### 🧼 Purpose:
- Prevents label loss during cropping or patching.
- Helps maintain thin structures (e.g., tumor edges) which may be disconnected in preprocessing.

#### 🔁 Output:
Returns a cleaned and connected binary mask (`np.uint8`) after closing.

---

### ✅ Why This Matters
During patch-based segmentation or augmentations, small structures like tumors can be lost if not preserved. These tools:
- Warn about significant loss of label volume.
- Allow label smoothing while maintaining spatial integrity.

In [None]:
class LabelQC:
    """
    Tracks label shrinkage across patches/cases.
    - shrink_warn_threshold: fraction of volume lost (e.g., 0.35 → warn if after < 65% of before)
    - verbose: if True, prints a line for each warning
    """
    def __init__(self, shrink_warn_threshold: float = 0.35, verbose: bool = True):
        self.shrink_warn_threshold = float(shrink_warn_threshold)
        self.verbose = bool(verbose)
        self.total: int = 0
        self.warn: int = 0
        self.flagged_examples: List[Tuple[str, int, int, float]] = []  # (case_id, before, after, ratio)

    def reset(self) -> None:
        self.total = 0
        self.warn = 0
        self.flagged_examples.clear()

    @property
    def rate(self) -> float:
        return self.warn / max(1, self.total)

    def update(self, before_voxels: int, after_voxels: int, case_id: str) -> None:
        self.total += 1
        if before_voxels <= 0:
            return
        ratio = (after_voxels + 1e-6) / (before_voxels + 1e-6)
        if ratio < (1.0 - self.shrink_warn_threshold):
            self.warn += 1
            self.flagged_examples.append((case_id, int(before_voxels), int(after_voxels), float(ratio)))
            if self.verbose:
                print(f"[QC] label shrinkage: {case_id} before={before_voxels} after={after_voxels} ratio={ratio:.3f}")

    def summary(self) -> None:
        print(f"[QC] shrinkage warnings: {self.warn}/{self.total} (rate={self.rate:.4f})")

def binary_dilate_then_erode(mask: np.ndarray, radius_vox: int = 1, connectivity: int = 1) -> np.ndarray:
    """
    Light morphological close (dilate then erode) to protect thin/fragmented labels.
    - mask: 3D array (D,H,W). Nonzeros are treated as foreground.
    - radius_vox: number of iterations for dilation and erosion. 0 → no-op.
    - connectivity: 1 (faces), 2 (faces+edges), or 3 (faces+edges+corners).
    Returns uint8 mask (0/1).
    """
    if radius_vox <= 0:
        return (mask > 0).astype(np.uint8)
    if mask.ndim != 3:
        raise ValueError(f"Expected 3D mask, got shape {mask.shape}")

    mask_bool = (mask > 0)
    structure = ndi.generate_binary_structure(rank=3, connectivity=int(connectivity))
    dil = ndi.binary_dilation(mask_bool, structure=structure, iterations=int(radius_vox))
    ero = ndi.binary_erosion(dil, structure=structure, iterations=int(radius_vox))
    return ero.astype(np.uint8)

## 🔄 Data Transforms: Preprocessing, Augmentation, and Label Postprocessing

This cell defines the full preprocessing and augmentation pipeline for **training and validation** using MONAI’s transform framework. It includes custom logic to ensure tumor labels are preserved even after spatial operations.

---

### 🛠️ Custom Transform: `LabelPostProcessd`

This transform is applied **after spacing and augmentation**, specifically to the **segmentation label mask**:

#### 🔍 What it does:
1. **Shape Matching**: If the label and image shapes differ, it resizes the label to match the image using **nearest-neighbor interpolation** (avoids smoothing or partial volume effects).
2. **Light Morphology**: Applies `binary_dilate_then_erode()` to prevent thin or small tumors from vanishing due to interpolation, cropping, or resampling.
3. **QC Tracking**: Adds two keys to the output dict:
   - `qc_before_vox`: Number of voxels labeled as tumor *before* processing
   - `qc_after_vox`: Number of voxels labeled as tumor *after* processing

---

### 🧪 Transform Components Breakdown

#### 🔄 `common_load`
Shared transforms for both training and validation:
- `LoadImaged`: Load NIfTI images and labels.
- `EnsureChannelFirstd`: Converts image shape from `[D, H, W]` to `[1, D, H, W]`.
- `EnsureTyped`: Ensures tensors are in PyTorch format.
- `Orientationd`: Aligns all images to RAS orientation.
- `Spacingd`: Resamples images and labels to common voxel spacing:
  - **Image**: uses bilinear interpolation.
  - **Label**: uses nearest-neighbor to avoid label corruption.

#### 🧠 `intensity_train`
Training-only transforms:
- `ScaleIntensityRanged`: Clips and scales intensity values to `[0, 1]` range (from raw range 0–3000 HU).
- `RandFlipd`: Random flip along each axis (20% chance).
- `RandRotate90d`: Random 90° rotation (20% chance).
- `RandAffined`: Random affine transformation (rotation, scaling).
- `SpatialPadd`: Pads to ensure minimum patch size before cropping.
- `RandCropByPosNegLabeld`: Extracts patches with a mix of positive (tumor) and negative (background) regions.

#### 📏 `intensity_val`
Validation-only transforms:
- Only includes `ScaleIntensityRanged` to standardize intensities, without applying any random augmentations.

---

### 🏷️ `CastClassLabeld` – Classification Label Casting

A custom transform that ensures the `"class_label"` is:
- Converted to a PyTorch `float32` tensor.
- This is required for compatibility with loss functions used for classification.

---

### 🧩 Compose Final Transform Pipelines

- `train_transforms`: Combines loading, augmentation, label postprocessing, and classification label casting.
- `val_transforms`: Similar to training, but without random augmentations.

Each sample (image, label, class_label) is processed through these transformations before being fed into the model.

---

### ✅ Summary

This setup ensures that:
- Input images are normalized and aligned in space.
- Labels are protected from degradation during spatial transforms.
- Both segmentation and classification targets are properly prepared.

In [None]:
# Transforms: spacing standardization and intensity scale
# Labels: post-process after Spacingd to preserve small lesions (no extra soft resampling)

from monai.transforms import MapTransform

class LabelPostProcessd(MapTransform):
    """
    - Ensures label shape matches image shape (nearest-neighbor up/down-sample if needed)
    - Optional light morphology (dilate→erode) to protect thin/speck lesions
    - Records QC counts: qc_before_vox, qc_after_vox
    """
    def __init__(self, keys, ref_key: str = "image", morph_radius: int = 1, allow_missing_keys: bool = False):
        super().__init__(keys, allow_missing_keys)
        self.ref_key = ref_key
        self.morph_radius = int(morph_radius)

    def __call__(self, data):
        d = dict(data)
        if "label" not in d:
            return d

        label = d["label"]  # Tensor [1, D, H, W]
        img = d.get(self.ref_key, None)
        before_vox = int((label > 0).sum().item())

        # Match label to image grid if needed (nearest to avoid smoothing)
        if img is not None and label.shape[1:] != img.shape[1:]:
            label = torch.nn.functional.interpolate(
                label.float(), size=img.shape[1:], mode="nearest"
            ).long()

        # Optional light morphology
        if self.morph_radius > 0:
            arr = (label > 0).cpu().numpy().astype(np.uint8)  # [1, D, H, W]
            arr = binary_dilate_then_erode(arr[0], radius_vox=self.morph_radius)[None]  # back to [1, ...]
            label = torch.as_tensor(arr, dtype=torch.long, device=d["label"].device)

        after_vox = int((label > 0).sum().item())

        d["label"] = label
        d["qc_before_vox"] = before_vox
        d["qc_after_vox"] = after_vox
        return d


# Common I/O and spacing (labels via nearest to avoid smoothing)
common_load = [
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    EnsureTyped(keys=["image", "label"], dtype=torch.float32),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=TARGET_SPACING, mode=("bilinear", "nearest")),
]

from monai.transforms import RandCropByPosNegLabeld, SpatialPadd

# Training intensity + spatial augs
intensity_train = [
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
    RandFlipd(keys=["image", "label"], spatial_axis=[0, 1, 2], prob=0.2),
    RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
    RandAffined(
        keys=["image", "label"],
        rotate_range=(math.pi/36, math.pi/36, math.pi/36),
        scale_range=(0.1, 0.1, 0.1),
        mode=("bilinear", "nearest"),
        prob=0.2,
    ),
    SpatialPadd(keys=["image", "label"], spatial_size=PATCH_SIZE),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=PATCH_SIZE,
        pos=1, neg=1, num_samples=1, image_key="image",
        allow_smaller=True,
    ),
]

# Validation intensity only
intensity_val = [
    ScaleIntensityRanged(keys=["image"], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
]

# Cast classification label
class CastClassLabeld(MapTransform):
    def __init__(self, keys, allow_missing_keys=False):
        super().__init__(keys, allow_missing_keys)
    def __call__(self, data):
        d = dict(data)
        if "class_label" in d:
            d["class_label"] = torch.as_tensor(d["class_label"], dtype=torch.float32)
        return d

# Assemble transforms
# - Post label process uses nearest resample if shapes differ (no soft one-hot resample)
# - Morph radius=1; raise to 2 if QC indicates too many tiny lesions vanish
post_label_preserve = [LabelPostProcessd(keys=["label"], ref_key="image", morph_radius=1)]

train_transforms = Compose(common_load + intensity_train + post_label_preserve + [CastClassLabeld(keys=["class_label"])])
val_transforms = Compose(common_load + intensity_val + post_label_preserve + [CastClassLabeld(keys=["class_label"])])

## 📦 Dataset and DataLoader Setup with Quality Control and Reproducibility

This cell prepares the datasets and data loaders for both training and validation. It also initializes the label shrinkage QC trackers and ensures reproducible data loading.

---

### 🧪 Quality Control Trackers
- `qc_train` and `qc_val` are instances of `LabelQC` (defined earlier).
- They track how often tumor labels shrink significantly after preprocessing (e.g., due to cropping or augmentation).
- `shrink_warn_threshold=0.35`: Warn if more than 35% of the tumor voxels are lost.
- `verbose=True`: Logs detailed info for each case that triggers a warning.

---

### 🔁 Reproducible DataLoader Workers
To ensure consistent behavior across runs (especially in a multi-worker setup), we define:

#### 🔧 `seed_worker(worker_id)`
- Sets consistent seeds for NumPy, Python's `random`, and PyTorch using a global `SEED`.
- Ensures the same augmentations, crops, and shuffles are applied for the same input data when re-running experiments.

#### 🎲 Random Generator
- `torch.Generator()` with `.manual_seed(SEED)` is used to further ensure deterministic shuffling and batch sampling.

---

### 🗂️ Dataset Setup with `CacheDataset`
- **`CacheDataset`**: A MONAI dataset that caches transformed data in memory to reduce preprocessing overhead.
- `cache_rate=0.0`: Disables caching to reduce RAM usage (safe for large datasets or limited memory). Set to `>0.0` to enable partial caching later if needed.
- `train_ds` and `val_ds` use the respective transforms (`train_transforms`, `val_transforms`) defined earlier.

---

### 🚚 DataLoaders
Used to feed data into the model during training and validation.

#### 🔧 Parameters:
- `batch_size=1`: Each batch contains one 3D volume (common for volumetric segmentation due to GPU memory limits).
- `shuffle=True` for training (randomized batches), `False` for validation.
- `num_workers`: Number of subprocesses used for data loading. Higher = faster loading (if memory allows).
- `pin_memory=True`: Speeds up host-to-GPU data transfer when using CUDA.
- `persistent_workers=True`: Keeps DataLoader workers alive between epochs for efficiency.
- `worker_init_fn=seed_worker`: Ensures reproducible behavior in worker threads.
- `prefetch_factor=2`: Controls how many batches to preload per worker (improves throughput).

---

### ✅ Summary
This setup ensures:
- Labels are monitored for quality.
- Data loading is fast and reproducible.
- Datasets are transformed correctly and efficiently passed to the model.

In [None]:
# QC accumulators (patch-level stats already added in transforms and training loop)
qc_train = LabelQC(shrink_warn_threshold=0.35, verbose=True)
qc_val = LabelQC(shrink_warn_threshold=0.35, verbose=True)

# Reproducible dataloader workers
def seed_worker(worker_id: int):
    base_seed = SEED
    np.random.seed(base_seed + worker_id)
    random.seed(base_seed + worker_id)
    torch.manual_seed(base_seed + worker_id)

use_pin = (device.type == "cuda")
g = torch.Generator()
g.manual_seed(SEED)

# Datasets (cache_rate=0.0 to avoid RAM pressure; switch to >0.0 if you want caching)
train_ds = CacheDataset(data=train_items, transform=train_transforms, cache_rate=0.0, num_workers=0)
val_ds   = CacheDataset(data=val_items,   transform=val_transforms,   cache_rate=0.0, num_workers=0)

# DataLoaders
train_loader = DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True,
    num_workers=2,
    pin_memory=use_pin,
    persistent_workers=True,
    worker_init_fn=seed_worker,
    generator=g,
    prefetch_factor=2,
)

val_loader = DataLoader(
    val_ds,
    batch_size=1,
    shuffle=False,
    num_workers=1,
    pin_memory=use_pin,
    persistent_workers=True,
    worker_init_fn=seed_worker,
    generator=g,
    prefetch_factor=2,
)

## 🧠 Model Architecture: 3D Segmentation + Classification Pipeline

This cell defines the **core model architecture**, combining a 3D segmentation network with a classification head, and sets up losses, optimizer, and evaluation metrics.

---

### 🧱 1. Segmentation Network: `DynUNet` (nnU-Net-like)

We use `DynUNet` from MONAI — a highly flexible and modular 3D U-Net backbone inspired by nnU-Net.

#### ⚙️ Configuration:
- `spatial_dims=3`: For 3D volumetric data (e.g. brain MRIs).
- `in_channels=1`: Input is a single-channel MRI image.
- `out_channels=2`: Output is a binary segmentation mask (background vs tumor).
- `kernel_size`: Defines convolution kernel sizes at each stage (6 total).
- `strides`: Downsampling steps (progressively reduce spatial resolution).
- `upsample_kernel_size`: Used in the decoder for upsampling.
- `norm_name="instance"`: Applies instance normalization.
- `deep_supervision=False`: Only the final output is used for loss calculation.

The model is transferred to the active `device` (GPU or CPU).

---

### 🎯 2. Classification Head: `LazyClassificationHead`

A lightweight classification module that predicts **Glioma vs Mets** using features extracted from the encoder’s bottleneck.

#### 🔍 Design:
- `AdaptiveAvgPool3d(1)`: Reduces the 3D feature map to a 1×1×1 spatial dimension (global pooling).
- `fc`: A fully connected layer, initialized lazily on first forward pass to match input size.
- `num_classes=1`: Binary classification output (logits for Mets probability).

---

### 🪝 3. Feature Hook for Classification

To allow classification without interfering with segmentation, we use a **forward hook** to tap into the encoder’s bottleneck output:
- The hook captures intermediate features (`encoder_feat["x"]`) for use in the classification head.
- Attached to either `seg_net.bottleneck` or fallback `encoder4`, depending on DynUNet structure.

---

### ⚖️ 4. Loss Functions

- `seg_loss_fn`: `DiceCELoss` combines Dice loss (for overlap) and CrossEntropy (for pixel-wise classification).
  - `to_onehot_y=True`, `softmax=True`: Converts target to one-hot and applies softmax to logits.
- `cls_loss_fn`: `BCEWithLogitsLoss` for binary classification from raw logits (Glioma vs Mets).

---

### 🧪 5. Optimizer and AMP

- `AdamW`: Optimizer for both segmentation and classification parameters.
  - `lr=1e-5`, `weight_decay=1e-5`: Stable learning rate and regularization.
- `GradScaler`: Enables **Automatic Mixed Precision (AMP)** to reduce memory usage and speed up training on GPUs.

---

### 📊 6. Evaluation Metric

- `dice_metric`: MONAI’s Dice metric (averaged across non-background classes).
  - `include_background=False`: Only evaluates tumor class (not background).
  - `reduction="mean"`: Computes mean Dice score across batch or volume.

---

### ✅ Summary

This cell finalizes the dual-task learning setup:
- **Segmentation**: Tumor mask prediction from MRI.
- **Classification**: Glioma vs Mets nature of the tumor.
- Optimized jointly using shared encoder features.

In [None]:
# Segmentation network (nnU-Net-like DynUNet)
seg_net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    kernel_size=[3, 3, 3, 3, 3, 3],     # 6 stages
    strides=[1, 2, 2, 2, 2, 2],         # length matches kernel_size
    upsample_kernel_size=[2, 2, 2, 2, 2],
    norm_name="instance",
    deep_supervision=False,
).to(device)

# Classification head (lazy: initializes fully-connected layer on first forward)
class LazyClassificationHead(nn.Module):
    def __init__(self, num_classes: int = 1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = None
        self.num_classes = num_classes

    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        x = self.pool(feat).flatten(1)
        if self.fc is None:
            self.fc = nn.Linear(x.shape[1], self.num_classes).to(x.device)
        return self.fc(x)

cls_head = LazyClassificationHead(num_classes=1).to(device)

# Hook to capture encoder bottleneck features for classification
encoder_feat = {"x": None}
def hook_fn(module, input, output):
    encoder_feat["x"] = output

# Attach hook to a stable location in DynUNet
if hasattr(seg_net, "bottleneck"):
    seg_net.bottleneck.register_forward_hook(hook_fn)
elif hasattr(seg_net, "encoder4"):
    seg_net.encoder4.register_forward_hook(hook_fn)
else:
    print("[WARN] Could not attach hook; classification head may not receive features")

# Losses and optimizer
seg_loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
cls_loss_fn = nn.BCEWithLogitsLoss()

params = list(seg_net.parameters()) + list(cls_head.parameters())
optimizer = torch.optim.AdamW(params, lr=1e-5, weight_decay=1e-5)

# AMP scaler
scaler = GradScaler(enabled=torch.cuda.is_available())

# Segmentation metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")

## 📐 Tensor Padding and Cropping Utilities for 3D Volumes

This cell defines two essential utility functions to **ensure compatibility of input volumes** with the model architecture — especially for operations like convolutions, which often require spatial dimensions to be multiples of a specific factor (e.g., 16 or 32).

---

### 🔲 `pad_to_factor(...)`

Pads a 5D tensor of shape `(B, C, D, H, W)` so that its **depth (D), height (H), and width (W)** are divisible by a given `factor`.

#### ⚙️ Why is this needed?
Many segmentation models (like U-Net, DynUNet) perform multiple downsampling operations (via strides or pooling). If the input dimensions aren’t divisible by the downsampling factor, shape mismatches may occur during upsampling in the decoder.

#### 🧠 How it works:
- Calculates the **next multiple** of `factor` for each spatial dimension.
- Adds padding only to the **right, bottom, and back** sides — so the origin stays unchanged (important for spatial alignment with labels).
- Supports padding with any value (`value=0.0` by default) and padding mode (e.g., `"constant"`).

#### 🔁 Parameters:
- `x`: Input tensor of shape `(B, C, D, H, W)`.
- `factor`: Scalar (e.g., 32) or tuple `(fD, fH, fW)`.
- `return_pad`: If `True`, returns both the padded tensor and the padding applied.

#### 🧾 Example:
If an input has shape `[1, 1, 123, 256, 249]` and factor is `32`, it will be padded to `[1, 1, 128, 256, 256]`.

---

### ✂️ `crop_to_shape(...)`

Crops a padded 5D tensor `(B, C, D, H, W)` **back down** to a target shape `(D, H, W)`, used typically **after inference** to remove padding and recover the original dimensions.

#### 🔧 How it works:
- Performs a simple slice from the beginning (`:Dz, :Hy, :Wx`) along each spatial axis.
- Ensures that output spatial dimensions match exactly with the ground truth.

#### 🧾 Example:
If output after segmentation is `[1, 1, 128, 256, 256]` and original shape was `[123, 240, 240]`, this function will crop it back accordingly.

---

### ✅ Summary

These two functions work together to:
- 🧱 **Prepare inputs** for model inference or training (via `pad_to_factor`)
- ✂️ **Restore outputs** back to the original shape (via `crop_to_shape`)

In [None]:
import torch.nn.functional as F

def pad_to_factor(x: torch.Tensor, factor=32, return_pad: bool = False, mode: str = "constant", value: float = 0.0):
    """
    Pads a 5D tensor (B, C, D, H, W) so each spatial dim is a multiple of `factor`.
    - factor: int or (fD, fH, fW)
    - Pads only on the "right/bottom/back" to avoid shifting coordinates
    - If return_pad=True, also returns the pad tuple (Wl, Wr, Hl, Hr, Dl, Dr)
    """
    assert x.dim() == 5, f"Expected 5D tensor (B,C,D,H,W), got shape {tuple(x.shape)}"
    if isinstance(factor, int):
        fD = fH = fW = factor
    else:
        assert len(factor) == 3, "factor must be int or 3-tuple"
        fD, fH, fW = factor

    B, C, D, H, W = x.shape
    def next_m(s, f): return ((s + f - 1) // f) * f
    Dn, Hn, Wn = next_m(D, fD), next_m(H, fH), next_m(W, fW)
    pd, ph, pw = Dn - D, Hn - H, Wn - W
    pad = (0, pw, 0, ph, 0, pd)  # (W_left, W_right, H_left, H_right, D_left, D_right)

    if any(p > 0 for p in pad):
        x = F.pad(x, pad, mode=mode, value=value)

    return (x, pad) if return_pad else x

def crop_to_shape(x: torch.Tensor, shape: tuple) -> torch.Tensor:
    """
    Crops tensor x (B,C,D,H,W) to spatial shape `shape` = (D,H,W), slicing from the start along each dim.
    """
    assert x.dim() == 5 and len(shape) == 3, "x must be 5D and shape must be (D,H,W)"
    Dz, Hy, Wx = map(int, shape)
    return x[..., :Dz, :Hy, :Wx]

## 🔍 Sliding Window Inference Setup + Inference Utilities

This cell sets up the **inference engine** for 3D volumetric segmentation using a patch-based strategy and defines helper functions to:
1. Move data to the correct device.
2. Apply inference on padded volumes with automatic cropping to restore original shape.

---

### 🧠 `SlidingWindowInferer`

The `SlidingWindowInferer` allows running inference on large 3D MRI volumes by dividing them into smaller **overlapping patches**, running the model on each, and combining the outputs.

#### 🧰 Parameters:
- `roi_size=PATCH_SIZE`: Size of the 3D patch used during inference.
- `sw_batch_size=1`: Number of patches to process in a mini-batch.
- `overlap=PATCH_OVERLAP`: Amount of overlap between adjacent patches (e.g., 0.5 = 50% overlap). This improves prediction consistency across patch borders.
- `mode="gaussian"`: Overlapping areas are blended using a Gaussian weighting (gives smooth transitions).

This is especially useful when the full 3D volume does not fit in GPU memory.

---

### 🔌 `to_device(batch, device)`

Utility function that:
- Transfers all `torch.Tensor` values in a dictionary (`batch`) to the specified `device` (e.g., GPU).
- Non-tensor values (e.g., strings, integers) are passed through unchanged.

Used during validation/inference when batches are loaded by the DataLoader on CPU but need to be moved to GPU.

---

### 📦 `sliding_infer_padded(...)`

A wrapper for inference that:
1. **Pads** the input volume so it is compatible with the model (dims divisible by `factor` like 32).
2. Runs **sliding window inference** using the `inferer` object.
3. **Crops** the output back to original shape (if `target_shape` is provided).

#### 🔁 Parameters:
- `images`: Input tensor of shape `(B, C, D, H, W)`.
- `network`: The segmentation model (e.g., DynUNet).
- `factor`: The factor to pad to (default: 32).
- `target_shape`: Optional. If set, crops back to the original spatial size.

This utility ensures robust and memory-efficient inference on full volumes, even if their sizes are irregular.

---

### ✅ Summary

These tools ensure:
- Memory-efficient inference on large 3D volumes.
- Compatibility with model requirements via dynamic padding.
- Accurate spatial alignment by restoring original image dimensions after inference.

In [None]:
inferer = SlidingWindowInferer(
    roi_size=PATCH_SIZE,
    sw_batch_size=1,
    overlap=PATCH_OVERLAP,
    mode="gaussian",
)

# Utils
from typing import Dict, Any, Tuple, Optional

def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            out[k] = v.to(device, non_blocking=True)
        else:
            out[k] = v
    return out

def sliding_infer_padded(
    images: torch.Tensor,
    network: nn.Module,
    factor: int = 32,
    target_shape: Optional[Tuple[int, int, int]] = None,
) -> torch.Tensor:
    images_p, _ = pad_to_factor(images, factor=factor, return_pad=True)
    logits = inferer(inputs=images_p, network=network)
    if target_shape is not None and logits.shape[-3:] != tuple(target_shape):
        logits = crop_to_shape(logits, target_shape)
    return logits

## 🧾 Save Configuration Metadata for Reproducibility

This cell captures all important settings and hyperparameters used in the current run and writes them to a JSON file (`config.json`) for **reproducibility, auditing, and sharing**.

---

### 🔍 `file_sha1(...)`: File Fingerprint Utility
A helper function that:
- Computes a **SHA-1 hash** of the file contents at the given path.
- Returns `None` if the file cannot be read.
- Used to **verify integrity and uniqueness** of the CSV files used for train/val/test splits.

---

### 🧠 `config`: Experiment Metadata Dictionary

This dictionary stores **all key configuration parameters** for this training run:

#### 🔑 Run Info
- `run_id`: Unique timestamped ID.
- `seed`: Random seed used for reproducibility.
- `device`: CPU or GPU.
- `gpu_name`: Name of the active CUDA GPU (if available).

#### 🧭 Preprocessing
- `target_spacing`: Desired voxel spacing after resampling.
- `patch_size`: Patch dimensions used during training and inference.
- `inferer`: Settings for the sliding window inferer (ROI size, overlap, blending mode).

#### 🏗️ Model Architecture
- Model type (`DynUNet`) and its core parameters (input/output channels, normalization, deep supervision).

#### ⚖️ Loss Functions
- `seg`: Segmentation loss (`DiceCELoss`)
- `cls`: Classification loss (`BCEWithLogitsLoss`)
- `cls_weight`: Implied use for weighted loss combination (if used later).

#### 🚀 Optimizer
- Type (`AdamW`), learning rate, and weight decay.

#### 🔁 Augmentations
- Flip and rotation probabilities.
- Affine transformation settings (rotation in degrees, scaling in %).
- Cropping strategy.

#### 🧪 Transforms Notes
- Describes how images and labels were processed:
  - Image spacing via bilinear interpolation.
  - Label spacing via nearest neighbor (to avoid smoothing).
  - LabelPostProcessd with morphological protection (`morph_radius=1`).

#### 🗂️ Dataset Splits
- Paths to training/validation/testing CSVs.
- SHA-1 hashes of each CSV to ensure data integrity and version control.

---

### 💾 Save to File
The dictionary is saved as formatted JSON (`config.json`) in the current run directory. This file:
- Serves as a **record of hyperparameters and pipeline settings**.
- Can be used to **reproduce the exact experiment later**.
- Can be shared or logged in experiment tracking systems.

---

### ✅ Summary

Capturing and saving configuration metadata like this ensures that all future reviewers (or our future review) can understand and recreate the model, training conditions, and dataset splits exactly as they were.

In [None]:
def file_sha1(path: Path | str) -> str | None:
    try:
        with open(path, "rb") as f:
            return hashlib.sha1(f.read()).hexdigest()
    except Exception:
        return None

config = {
    "run_id": RUN_ID,
    "seed": SEED,
    "device": str(device),
    "gpu_name": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None),
    "target_spacing": tuple(TARGET_SPACING),
    "patch_size": tuple(PATCH_SIZE),
    "inferer": {"roi_size": tuple(PATCH_SIZE), "overlap": float(PATCH_OVERLAP), "mode": "gaussian"},
    "model": {
        "arch": "DynUNet",
        "in_channels": 1,
        "out_channels": 2,
        "deep_supervision": False,
        "norm_name": "instance",
    },
    "loss": {"seg": "DiceCELoss", "cls": "BCEWithLogitsLoss", "cls_weight": 0.3},
    "optimizer": {
        "type": "AdamW",
        "lr": float(optimizer.param_groups[0]["lr"]),
        "weight_decay": float(optimizer.param_groups[0].get("weight_decay", 0.0)),
    },
    "augs": {
        "flip_p": 0.2, "rot90_p": 0.2,
        "affine": {"rot_deg": 5, "scale_pct": 10},
        "crop": "RandCropByPosNegLabel",
    },
    # Reflect current pipeline (nearest for labels + light morph).
    "transforms_notes": "Spacingd (image=bilinear, label=nearest), LabelPostProcessd(morph_radius=1)",
    "splits": {
        "train_csv": str(TRAIN_CSV), "val_csv": str(VAL_CSV), "test_csv": str(TEST_CSV),
        "train_csv_sha1": file_sha1(TRAIN_CSV), "val_csv_sha1": file_sha1(VAL_CSV), "test_csv_sha1": file_sha1(TEST_CSV),
    },
}

with open(CONFIG_JSON, "w") as f:
    json.dump(config, f, indent=2)
print("Saved config:", CONFIG_JSON)

In [None]:
# Environment snapshot (for full reproducibility)
env = {
    "python": platform.python_version(),
    "os": platform.platform(),
    "torch": torch.__version__,
    "torch_cuda": (torch.version.cuda if torch.cuda.is_available() else None),
    "cudnn": (torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None),
    "monai": __import__("monai").__version__,
    "numpy": np.__version__,
    "pandas": pd.__version__,
    "gpu_count": (torch.cuda.device_count() if torch.cuda.is_available() else 0),
    "gpu_name": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None),
    "seed": SEED,
}

# List all GPU names if multiple are present
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    env["gpu_names"] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]

# Optional: exact package set (can be large)
try:
    env["pip_freeze"] = subprocess.check_output(["pip", "freeze"]).decode().splitlines()
except Exception:
    pass

with open(ENV_JSON, "w") as f:
    json.dump(env, f, indent=2)
print("Saved env:", ENV_JSON)

## 🔁 Training & Validation Loop with Resume Logic, Metrics, and Checkpointing

This block implements the **core training and validation loop** for the joint segmentation + classification model. It includes logic for:
- Resuming from checkpoints
- Training over multiple epochs
- Tracking segmentation and classification metrics
- Saving model states based on validation performance

### 💾 1. Resume from Last or Best Checkpoint
- `find_resume_ckpt()`: Looks for `last.pt` or `best.pt` in the current run directory (`RUN_DIR`) or fallback directory (`ckpt_dir`).
- Loads model weights, optimizer state, scaler (for AMP), and tracking variables like `start_epoch`, `best_val_dice`.

In [None]:
def find_resume_ckpt() -> Path:
    candidates = []
    # Prefer run-scoped last/best if present
    if "RUN_DIR" in globals():
        candidates += [RUN_DIR / "last.pt", RUN_DIR / "best.pt"]
    # Fallback to global dir
    candidates += [ckpt_dir / "last.pt", ckpt_dir / "best.pt"]
    for p in candidates:
        if p.exists():
            return p
    raise FileNotFoundError(f"No checkpoint found in {ckpt_dir} (looked for last.pt/best.pt in run/global dirs).")

RESUME_CKPT = find_resume_ckpt()
print(f"Resuming from: {RESUME_CKPT}")

# Initialize lazy cls_head.fc before loading (captures encoder channels safely)
with torch.no_grad():
    was_training = seg_net.training
    seg_net.eval()
    encoder_feat["x"] = None
    dummy = torch.zeros(1, 1, 64, 64, 64, device=device)
    _ = seg_net(dummy)
    feat = encoder_feat["x"] if encoder_feat["x"] is not None else dummy
    _ = cls_head(feat)  # initializes cls_head.fc lazily
    seg_net.train(was_training)

ckpt = torch.load(RESUME_CKPT, map_location=device)

missing, unexpected = seg_net.load_state_dict(ckpt["seg"], strict=False)
if missing or unexpected:
    print(f"[WARN] seg_net state_dict mismatches. missing={missing}, unexpected={unexpected}")
missing, unexpected = cls_head.load_state_dict(ckpt["cls"], strict=False)
if missing or unexpected:
    print(f"[WARN] cls_head state_dict mismatches. missing={missing}, unexpected={unexpected}")

if "optimizer" in ckpt:
    try:
        optimizer.load_state_dict(ckpt["optimizer"])
    except Exception as e:
        print(f"[WARN] could not load optimizer state: {e}")
if "scaler" in ckpt:
    try:
        scaler.load_state_dict(ckpt["scaler"])
    except Exception as e:
        print(f"[WARN] could not load scaler state: {e}")

# Epoch planning
start_epoch = int(ckpt.get("epoch", 0)) + 1
best_val_dice = float(ckpt.get("best_val_dice", -1.0))
EPOCHS_NEXT = 50
end_epoch = start_epoch + EPOCHS_NEXT - 1

print(f"Resuming at epoch {start_epoch} (best_val_dice={best_val_dice:.4f}); training until epoch {end_epoch}")

### Open metrics CSV (segmentation + training stats)

### 📊 2. Metric Setup
- `metrics_f`, `metrics_w`: File handle and writer for logging per-epoch metrics to CSV.
- `HausdorffDistanceMetric`: For spatial accuracy of segmentation (`hd95`).
- `new_cm()` and `update_cm()`: Confusion matrix tracker (TP, FP, FN, TN) used for classification metrics: Accuracy, Recall, F1.

In [None]:
# Open metrics CSV (segmentation + training stats)
# Close with: metrics_f.close() at the end of training
metrics_f, metrics_w = write_csv_header(
    METRICS_CSV,
    ["epoch","split","loss","dice","acc","recall","f1","hd95","lr","seconds","qc_warns","qc_total","qc_rate"]
)

In [None]:
eps = 1e-8
hd95_tr = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)
hd95_val = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)

def new_cm():
    return {"tp": 0, "fp": 0, "fn": 0, "tn": 0}

def update_cm(cm, ypred, ytrue):
    yp = (ypred > 0).to(torch.bool)
    yt = (ytrue > 0).to(torch.bool)
    cm["tp"] += (yp & yt).sum().item()
    cm["fp"] += (yp & ~yt).sum().item()
    cm["fn"] += (~yp & yt).sum().item()
    cm["tn"] += (~yp & ~yt).sum().item()

def cm_metrics(cm):
    tp, fp, fn, tn = cm["tp"], cm["fp"], cm["fn"], cm["tn"]
    acc = (tp + tn) / max(1, tp + fp + fn + tn)
    rec = tp / max(1, tp + fn)
    pre = tp / max(1, tp + fp)
    f1 = (2 * pre * rec) / max(eps, pre + rec)
    return acc, rec, f1

## Main Training Loop

#### 🔄 Lazy Initialization of `cls_head.fc`
Since the classification head is lazily defined based on encoder output shape:
- A dummy input is passed through `seg_net`, and its bottleneck output is used to initialize `cls_head`.

---

### 🚀 3. Training Loop

For each epoch from `start_epoch` to `end_epoch`:

#### 🔧 Training Phase
- Sets both `seg_net` and `cls_head` to `train()` mode.
- Iterates over the training DataLoader:
  - Moves data to the correct device.
  - Updates label shrinkage quality control (`qc_train`).
  - Runs segmentation forward pass.
  - Captures encoder features for classification.
  - Computes segmentation loss (`DiceCELoss`) and classification loss (`BCEWithLogitsLoss`), combined with a weight (0.3).
  - Computes segmentation metrics: predicted mask, confusion matrix, HD95.
  - Updates model weights using AMP (Autocast + GradScaler).
- Logs metrics per epoch: accuracy, recall, F1, HD95, QC rate, and total loss.

---

### 🧪 4. Validation Phase (Every `val_interval` epochs)

- Sets model to `eval()` mode, disables gradient calculation.
- For each batch in the validation set:
  - Pads input to compatible dimensions (`pad_to_factor()`).
  - Runs **sliding window inference** via `SlidingWindowInferer`.
  - Crops back to original shape (`crop_to_shape()`).
  - Extracts encoder features for classification.
  - Computes validation loss, Dice score, HD95, and classification metrics.
  - Updates QC statistics (`qc_val`).
- Writes validation metrics to CSV.

---

### 💾 5. Checkpointing

If validation Dice improves:
- Saves model weights (`best.pt`) and encoder-only checkpoint (`encoder_fullstate.pt`).
- Also saves current state to `last.pt` after every validation.

Checkpoints include:
- Epoch number
- Model weights (`seg`, `cls`)
- Optimizer and scaler states
- Best validation Dice score
- Paths to config and environment metadata

---

### 🧹 6. Cleanup and Finalization

- Clears variables to free GPU memory after each epoch.
- Logs total training time in minutes.
- Closes the metrics CSV file safely.

---

### ✅ Summary

This loop handles:
- Dual-task training (segmentation + classification)
- Accurate metric tracking and logging
- Reproducible resume functionality
- Intelligent model saving based on best performance
- Efficient handling of 3D volumetric data (with padding + sliding window inference)

In [None]:
# Train / Val loops (refactored: correct metrics, single CSV writer, stable val Dice)

# Assumes:
# - start_epoch, end_epoch, best_val_dice defined (from resume block); if not, set them here.
# - metrics_f, metrics_w already opened by write_csv_header with header:
#   ["epoch","split","loss","dice","acc","recall","f1","hd95","lr","seconds","qc_warns","qc_total","qc_rate"]

if "start_epoch" not in globals() or "end_epoch" not in globals():
    EPOCHS = 500
    start_epoch, end_epoch = 1, EPOCHS
if "best_val_dice" not in globals():
    best_val_dice = -1.0

val_interval = 1

t0_all = time.time()
for epoch in range(start_epoch, end_epoch + 1):
    t_epoch = time.time()

    # ---- TRAIN ----
    cm_train = new_cm()
    seg_net.train(); cls_head.train()
    epoch_loss = 0.0
    num_steps = 0

    for batch in train_loader:
        # QC update per-sample (patch-level)
        for b in decollate_batch(batch):
            qc_train.update(int(b.get("qc_before_vox", 0)), int(b.get("qc_after_vox", 0)), str(b.get("case_id", "?")))
        batch = to_device(batch, device)
        images = batch["image"]
        labels = batch["label"].long()
        class_labels = batch["class_label"].view(-1, 1)

        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            # segmentation forward
            encoder_feat["x"] = None
            seg_logits = seg_net(images)                 # deep_supervision=False
            seg_logits_main = seg_logits

            # train segmentation metrics (per-batch)
            y_pred_tr = torch.argmax(torch.softmax(seg_logits_main, dim=1), dim=1, keepdim=True)
            update_cm(cm_train, y_pred_tr, labels)
            hd95_tr(y_pred_tr, labels)

            # classification forward
            feat = encoder_feat["x"] if encoder_feat["x"] is not None else seg_logits_main
            cls_logits = cls_head(feat)

            # losses
            loss_seg = seg_loss_fn(seg_logits_main, labels)
            loss_cls = cls_loss_fn(cls_logits, class_labels)
            loss = loss_seg + 0.3 * loss_cls

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        num_steps += 1

    epoch_loss /= max(1, num_steps)
    tr_acc, tr_rec, tr_f1 = cm_metrics(cm_train)
    tr_hd95 = hd95_tr.aggregate().item(); hd95_tr.reset()
    train_lr = optimizer.param_groups[0]["lr"]
    qc_rate_tr = (qc_train.warn / max(1, qc_train.total))

    print(f"Epoch {epoch}/{end_epoch} - train loss: {epoch_loss:.4f}")
    print(f"  Train: acc={tr_acc:.4f} rec={tr_rec:.4f} f1={tr_f1:.4f} hd95={tr_hd95:.2f}")
    metrics_w.writerow([
        epoch, "train",
        f"{epoch_loss:.6f}", "", f"{tr_acc:.6f}", f"{tr_rec:.6f}", f"{tr_f1:.6f}", f"{tr_hd95:.6f}",
        f"{train_lr:.6g}", f"{(time.time()-t_epoch):.2f}", qc_train.warn, qc_train.total, f"{qc_rate_tr:.4f}"
    ]); metrics_f.flush()

    # reset QC per-epoch if you prefer epoch-local stats
    qc_train.total = qc_train.warn = 0
    if epoch % val_interval == 0:
        qc_train.summary()

    # ---- VAL ----
    if epoch % val_interval == 0:
        cm_val = new_cm()
        seg_net.eval(); cls_head.eval()
        dice_metric.reset()
        val_loss = 0.0
        steps = 0
        with torch.no_grad():
            for batch in val_loader:
                for b in decollate_batch(batch):
                    qc_val.update(int(b.get("qc_before_vox", 0)), int(b.get("qc_after_vox", 0)), str(b.get("case_id", "?")))
                batch = to_device(batch, device)
                images = batch["image"]
                labels = batch["label"].long()
                class_labels = batch["class_label"].view(-1, 1)

                with autocast(device_type="cuda", enabled=torch.cuda.is_available()):
                    # padded sliding-window seg
                    images_p = pad_to_factor(images, factor=32)
                    seg_logits = inferer(inputs=images_p, network=seg_net)

                    # classification (populate encoder features on same padded grid)
                    encoder_feat["x"] = None
                    _ = seg_net(images_p)
                    if seg_logits.shape[-3:] != labels.shape[-3:]:
                        seg_logits = crop_to_shape(seg_logits, labels.shape[-3:])

                    feat = encoder_feat["x"] if encoder_feat["x"] is not None else seg_logits
                    cls_logits = cls_head(feat)

                    # per-batch seg metrics
                    y_pred = torch.argmax(torch.softmax(seg_logits, dim=1), dim=1, keepdim=True)
                    dice_metric(y_pred=y_pred, y=labels)
                    update_cm(cm_val, y_pred, labels)
                    hd95_val(y_pred, labels)

                    # per-batch val loss
                    loss_seg = seg_loss_fn(seg_logits, labels)
                    loss_cls = cls_loss_fn(cls_logits, class_labels)
                    loss = loss_seg + 0.3 * loss_cls
                val_loss += loss.item()
                steps += 1

        mean_dice = dice_metric.aggregate().item(); dice_metric.reset()
        val_loss /= max(1, steps)
        vl_acc, vl_rec, vl_f1 = cm_metrics(cm_val)
        vl_hd95 = hd95_val.aggregate().item(); hd95_val.reset()
        qc_rate_val = (qc_val.warn / max(1, qc_val.total))

        print(f"  Val loss: {val_loss:.4f} | Val Dice(tumor): {mean_dice:.4f}")
        print(f"  Val  : acc={vl_acc:.4f} rec={vl_rec:.4f} f1={vl_f1:.4f} hd95={vl_hd95:.2f}")
        qc_val.summary()

        metrics_w.writerow([
            epoch, "val",
            f"{val_loss:.6f}", f"{mean_dice:.6f}", f"{vl_acc:.6f}", f"{vl_rec:.6f}", f"{vl_f1:.6f}", f"{vl_hd95:.6f}",
            f"{train_lr:.6g}", f"{(time.time()-t_epoch):.2f}", qc_val.warn, qc_val.total, f"{qc_rate_val:.4f}"
        ]); metrics_f.flush()
        qc_val.total = qc_val.warn = 0

        # ---- Checkpointing ----
        def save_ckpt(path: Path):
            torch.save({
                "epoch": epoch,
                "seg": seg_net.state_dict(),
                "cls": cls_head.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scaler": scaler.state_dict(),
                "best_val_dice": best_val_dice,
                "config_path": str(CONFIG_JSON) if "CONFIG_JSON" in globals() else None,
                "env_path": str(ENV_JSON) if "ENV_JSON" in globals() else None,
            }, path)

        run_last = (RUN_DIR if "RUN_DIR" in globals() else ckpt_dir) / "last.pt"
        save_ckpt(run_last)

        if mean_dice > best_val_dice:
            best_val_dice = mean_dice
            run_best = (RUN_DIR if "RUN_DIR" in globals() else ckpt_dir) / "best.pt"
            save_ckpt(run_best)
            # convenience copy
            save_ckpt(ckpt_dir / "best.pt")
            print(f"  [Saved] best.pt with Dice {best_val_dice:.4f}")

        # encoder-only (save full state; downstream can load encoder subset)
        torch.save({"seg_encoder_compatible": True, "state_dict": seg_net.state_dict()},
                   (RUN_DIR if "RUN_DIR" in globals() else ckpt_dir) / "encoder_fullstate.pt")

        # Cleanup
        try:
            del images, labels, seg_logits, cls_logits, feat, y_pred
        except:
            pass
        try:
            del images_p
        except:
            pass
        torch.cuda.empty_cache()

print(f"Training done in {(time.time()-t0_all)/60:.1f} min")
# Close the metrics file (kept open for speed)
try:
    metrics_f.close()
except:
    pass