<a href="https://colab.research.google.com/github/lalvar02/Segmentation_and_Classification_Pancreatic_Scans/blob/main/Copy_of_nnUNet_pancreas_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pancreas Seg+Cls with nnUNetv2

Generates a nnU-Net v2 model for pancreas segmentation and classification.

In [None]:
# Check GPU availability
import torch, os, sys, subprocess
!nvidia-smi
print(torch.cuda.get_device_name(0))

/bin/bash: line 1: nvidia-smi: command not found


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

## Environment Setup & Imports
This cell sets up the Python environment for nnU-Net training/inference and imports all required packages.

**Purpose**:
- Ensure `nnunetv2` is available in the environment.
- Import PyTorch, NumPy, nnU-Net utilities, and other dependencies used throughout the notebook.

In [None]:
!pip -q install --no-input nnunetv2 nibabel SimpleITK pandas==2.2.2 scikit-image --upgrade


## Dataset and Path Configuration
This cell sets up the required folder setup for nnUNet configuration.

- **Required folder layout**:
```
  working/
    nnUNet_raw/
      Dataset_name/  
        images_Tr/
        labelsTr/
        images_Ts/
    nnUNet_processed/
      Dataset_name/
    nnUNet_results/  
      Dataset_name/    
```


In [None]:
import os

# Change this to match your dataset mount point
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
INPUT_ROOT = "/content/drive/MyDrive/ColabData"

# Create dataset folder structure
DS_ID = 777  # arbitrary unused ID
DATASET_NAME = f"Dataset{DS_ID}_PancreasSegCls"

PLANS_NAME   = "nnUNetResEncUNetMPlans" # the same plans as training
CONFIG       = "3d_fullres"             # or "2d" - must match how you trained
TRAINER_NAME = "TrainerSegCls"          # your custom trainer class name used in training
FOLD         = 0

# nnU-Net v2 paths (local workspace inside Kaggle session)
BASE = os.path.join(INPUT_ROOT, "working")
os.environ["nnUNet_raw"] = os.path.join(BASE, "nnUNet_raw")
os.environ["nnUNet_preprocessed"] = os.path.join(BASE, "nnUNet_preprocessed")
os.environ["nnUNet_results"] = os.path.join(BASE, "nnUNet_results")

for k in ("nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"):
    os.makedirs(os.environ[k], exist_ok=True)

RAW_DS_DIR = os.path.join(os.environ["nnUNet_raw"], DATASET_NAME)
PREPRO_DS_DIR = os.path.join(os.environ["nnUNet_preprocessed"], DATASET_NAME)
RESULTS_DS_DIR = os.path.join(os.environ["nnUNet_results"], DATASET_NAME)

# for the training images and labels
IMAGES_TR = imagesTr = os.path.join(RAW_DS_DIR, "imagesTr")
LABELS_TR = labelsTr = os.path.join(RAW_DS_DIR, "labelsTr")
# for the test images
IMAGES_TS = imagesTs = os.path.join(RAW_DS_DIR, "imagesTs")
# for the validation images and labels
IMAGES_VAL = os.path.join(RAW_DS_DIR, "imagesVal")
LABELS_VAL = os.path.join(RAW_DS_DIR, "labelsVal")
for d in (IMAGES_TR, LABELS_TR, IMAGES_VAL, LABELS_VAL, IMAGES_TS):
    os.makedirs(d, exist_ok=True)

print("RAW_DS_DIR:", RAW_DS_DIR)

Mounted at /content/drive
RAW_DS_DIR: /content/drive/MyDrive/ColabData/working/nnUNet_raw/Dataset777_PancreasSegCls



## Convert input data to nnU-Net v2 format

- We copy **training** images into `imagesTr/` and labels into `labelsTr/`.
- We copy **test** images into `imagesTs/`.
- We keep validation images seperate
- We generate a `dataset.json` with the right channels/labels and cases.
- We also build a CSV mapping **case -> subtype** for the classification head.


In [None]:
import os, re, json, shutil, pandas as pd

train_root = os.path.join(INPUT_ROOT, "train")
val_root   = os.path.join(INPUT_ROOT, "validation")
test_root  = os.path.join(INPUT_ROOT, "test")

def cp(src, dst):
    if not os.path.exists(dst):
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        shutil.copy2(src, dst)

# Collect train/val by subtype folders
subtypes = ["subtype0", "subtype1", "subtype2"]
train_cls_rows, val_cls_rows = [], []

def process_split(split_root, out_img_dir, out_lbl_dir, cls_rows):
    for st in subtypes:
        st_dir = os.path.join(split_root, st)
        if not os.path.isdir(st_dir):
            continue
        subtype_idx = int(st.replace("subtype", ""))
        for f in os.listdir(st_dir):
            if not f.endswith("_0000.nii.gz"):
                continue
            img_path = os.path.join(st_dir, f)
            case_id  = f.replace("_0000.nii.gz", "")           # e.g., quiz_0_041
            lbl_path = os.path.join(st_dir, case_id + ".nii.gz")
            # copy image (keep _0000)
            cp(img_path, os.path.join(out_img_dir, case_id + "_0000.nii.gz"))
            # copy label if present (no channel suffix)
            if os.path.exists(lbl_path):
                cp(lbl_path, os.path.join(out_lbl_dir, case_id + ".nii.gz"))
            cls_rows.append({"case": case_id, "subtype": subtype_idx})

# TRAIN → imagesTr/labelsTr
process_split(train_root, IMAGES_TR, LABELS_TR, train_cls_rows)
# VAL   → imagesVal/labelsVal  (NOT mixed into imagesTr/labelsTr)
process_split(val_root,   IMAGES_VAL, LABELS_VAL, val_cls_rows)

# TEST images → imagesTs (no labels)
if os.path.isdir(test_root):
    for f in os.listdir(test_root):
        if f.endswith("_0000.nii.gz"):
            case_id = f[:-len("_0000.nii.gz")]
            cp(os.path.join(test_root, f), os.path.join(IMAGES_TS, case_id + "_0000.nii.gz"))

# Write classification CSVs
train_csv = os.path.join(RAW_DS_DIR, "train_subtypes.csv")
val_csv   = os.path.join(RAW_DS_DIR, "val_subtypes.csv")
both_csv  = os.path.join(RAW_DS_DIR, "trainval_subtypes.csv")  # for convenience if you need one file

pd.DataFrame(train_cls_rows).drop_duplicates().to_csv(train_csv, index=False)
pd.DataFrame(val_cls_rows).drop_duplicates().to_csv(val_csv, index=False)
pd.concat([pd.DataFrame(train_cls_rows), pd.DataFrame(val_cls_rows)], ignore_index=True)\
  .drop_duplicates().to_csv(both_csv, index=False)

print("Wrote:", train_csv, "rows:", len(train_cls_rows))
print("Wrote:", val_csv,   "rows:", len(val_cls_rows))
print("Wrote:", both_csv)

# Build dataset.json (TRAIN ONLY)
training_list = [
    {"image": f"./imagesTr/{f}", "label": f"./labelsTr/{f.replace('_0000.nii.gz', '.nii.gz')}"}
    for f in sorted(os.listdir(IMAGES_TR)) if f.endswith("_0000.nii.gz")
]
dataset_json = {
    "name": "PancreasSegCls",
    "description": "Pancreas ROI segmentation (0 bg, 1 pancreas, 2 lesion) + subtype classification (0/1/2).",
    "tensorImageSize": "3D",
    "reference": "Local",
    "licence": "CC-BY-NC",
    "labels": {"background": 0, "pancreas": 1, "lesion": 2},
    "modality": {"0": "CT"},
    "channel_names": {"0": "CT"},
    "numTraining": len(training_list),
    "file_ending": ".nii.gz",
    "training": training_list,  # ONLY train here
    "test": [f"./imagesTs/{f}" for f in sorted(os.listdir(IMAGES_TS)) if f.endswith("_0000.nii.gz")]
}
with open(os.path.join(RAW_DS_DIR, "dataset.json"), "w", encoding="utf-8") as f:
    json.dump(dataset_json, f, indent=2)

print("dataset.json written.")
print("Train:", len(training_list),
      "imagesVal:", len([x for x in os.listdir(IMAGES_VAL) if x.endswith('_0000.nii.gz')]),
      "labelsVal:", len([x for x in os.listdir(LABELS_VAL) if x.endswith('.nii') or x.endswith('.nii.gz')]),
      "imagesTs:", len([x for x in os.listdir(IMAGES_TS) if x.endswith('_0000.nii.gz')]))


Wrote: /content/drive/MyDrive/ColabData/working/nnUNet_raw/Dataset777_PancreasSegCls/trainval_subtypes.csv with 288 rows
dataset.json written.
imagesTr: 288 labelsTr: 288 imagesTs: 72


## Label Sanitization for nnU-Net (force `{0,1,2}` and `uint8`)

**Purpose**  
This utility ensures that all segmentation label files in the `labelsTr` directory conform to nnU-Net’s expected format:
- Only contain the class IDs `0`, `1`, and `2`.
- Stored as `uint8` integer type rather than floating-point.

In [None]:
# --- Sanitize nnU-Net labels  --- #
import os, shutil
import numpy as np
import nibabel as nib

LABELS_DIRS = [LABELS_TR, LABELS_VAL]  # sanitize both

def snap_to_labels(arr, valid=(0.0, 1.0, 2.0)):
    """
    Map any float-ish labels to the nearest of valid labels (0,1,2).
    Returns uint8 array.
    """
    arr = arr.astype(np.float32, copy=False)
    diffs = np.stack([np.abs(arr - v) for v in valid], axis=0)  # [len(valid), ...]
    snapped = np.argmin(diffs, axis=0).astype(np.uint8)         # 0/1/2
    return snapped

def sanitize_dir(LABELS_DIR):
    assert os.path.isdir(LABELS_DIR), f"Not found: {LABELS_DIR}"
    changed_files = []
    issues = []

    # accept both .nii.gz and .nii
    label_files = [f for f in sorted(os.listdir(LABELS_DIR))
                   if f.lower().endswith(".nii.gz") or f.lower().endswith(".nii")]

    for fname in label_files:
        fpath = os.path.join(LABELS_DIR, fname)
        try:
            img = nib.load(fpath)
            data = img.get_fdata(dtype=np.float32)
            uniq = np.unique(data)

            # If already clean {0,1,2} and dtype is uint8, skip fast
            already_clean = np.all(np.isin(uniq, [0.0, 1.0, 2.0])) and img.get_data_dtype() == np.uint8
            if already_clean:
                continue

            # Snap + validate
            snapped = snap_to_labels(data, valid=(0.0, 1.0, 2.0))
            new_uniq = np.unique(snapped)
            if not np.all(np.isin(new_uniq, [0, 1, 2])):
                issues.append((fname, f"Unexpected labels after snap: {new_uniq.tolist()}"))
                continue

            # Backup once
            bak = fpath + ".bak"
            if not os.path.exists(bak):
                shutil.copy2(fpath, bak)

            # Preserve affine & header; store as uint8
            hdr = img.header.copy()
            hdr.set_data_dtype(np.uint8)
            out_img = nib.Nifti1Image(snapped, img.affine, header=hdr)
            nib.save(out_img, fpath)

            changed_files.append((fname, uniq.tolist(), new_uniq.tolist()))
        except Exception as e:
            issues.append((fname, f"ERROR: {e}"))

    # Reporting
    print(f"\n[{LABELS_DIR}] Sanitization complete. Modified {len(changed_files)} file(s).")
    if changed_files:
        print("Examples of changes:")
        for i, (fn, before, after) in enumerate(changed_files[:10], 1):
            print(f"  {i:02d}) {fn}: {before} -> {after}")

    if issues:
        print("\nWarnings/Issues:")
        for it in issues[:20]:
            print(" ", it)
        if len(issues) > 20:
            print(f"  ... ({len(issues)-20} more)")

    # Verification pass
    bad = []
    for fname in label_files:
        try:
            data = nib.load(os.path.join(LABELS_DIR, fname)).get_fdata(dtype=np.float32)
            uniq = np.unique(data)
            if not np.all(np.isin(uniq, [0.0, 1.0, 2.0])):
                bad.append((fname, uniq.tolist()))
        except Exception as e:
            bad.append((fname, f"ERROR reading: {e}"))
    print(f"Verification: {len(bad)} file(s) still have unexpected labels.")
    if bad[:10]:
        print(" First few:", bad[:10])

    return {"changed": changed_files, "issues": issues, "bad": bad}

# Run on both labelsTr and labelsVal
overall = {"changed": 0, "issues": 0, "bad": 0}
for d in LABELS_DIRS:
    if d and os.path.isdir(d):
        res = sanitize_dir(d)
        overall["changed"] += len(res["changed"])
        overall["issues"]  += len(res["issues"])
        overall["bad"]     += len(res["bad"])
    else:
        print(f"\n[skip] Not found: {d}")

print("\n=== Overall summary ===")
print("Total modified files:", overall["changed"])
print("Total issues:", overall["issues"])
print("Total remaining bad:", overall["bad"])


Sanitization complete. Modified 288 file(s).
Examples of changes:
  01) quiz_0_041.nii.gz: [0.0, 1.0, 2.0] -> [0, 1, 2]
  02) quiz_0_060.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  03) quiz_0_066.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  04) quiz_0_070.nii.gz: [0.0, 1.0, 2.0] -> [0, 1, 2]
  05) quiz_0_077.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  06) quiz_0_117.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  07) quiz_0_126.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  08) quiz_0_139.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  09) quiz_0_145.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]
  10) quiz_0_150.nii.gz: [0.0, 1.0000152587890625, 2.0] -> [0, 1, 2]

Verification: 0 file(s) still have unexpected labels.



## Custom Trainer: add a classification head

- Wrap the ResEnc M network and attach a **global pooling + linear** head for 3-way subtype logits.
- Compute **segmentation loss** (Dice+CE, default) + **classification CE** (with weight `cls_lambda`).  
- We obtain the **subtype target** by reading `trainval_subtypes.csv` using the case name.


In [None]:
%%writefile /content/custom_trainer_cls.py
# save as custom_trainer_cls.py
import os
import torch, torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any, List, Tuple, Union
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn
from nnunetv2.utilities.helpers import dummy_context
from torch import autocast

class GlobalPoolHead(nn.Module):
    def __init__(self, in_ch: int, n_classes: int = 3):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(in_ch, n_classes)

    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        # Accept [N,C,H,W] (2D) or [N,C,D,H,W] (3D)
        if feat.dim() == 5:
            x = F.adaptive_avg_pool3d(feat, 1).flatten(1)  # [N,C,1,1,1] -> [N,C]
        elif feat.dim() == 4:
            x = F.adaptive_avg_pool2d(feat, 1).flatten(1)  # [N,C,1,1]   -> [N,C]
        else:
            raise ValueError(f"Unexpected feature rank {feat.dim()}, expected 4 or 5")
        return self.fc(x)

class TrainerSegCls(nnUNetTrainer):
    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict,
                 device: torch.device = torch.device("cuda")):
        super().__init__(plans, configuration, fold, dataset_json, device)
        self.classifier_head: nn.Module = None
        self._cached_bottleneck: torch.Tensor | None = None
        self.ce_cls = nn.CrossEntropyLoss()
        self.cls_lambda = 0.5          # or read from env; see note below
        self.case_to_subtype = {}      # fill in initialize() (CSV) or pass via env

        # To spped up training during debug
        #self.num_epochs = 2
        #self.enable_deep_supervision = False  # optional: faster
        #self.save_every = 0

    def initialize(self):
        # build network, optimizer, dataloaders, loss, etc.
        super().initialize()

        # discover encoder bottleneck channels
        if hasattr(self.network, "encoder") and hasattr(self.network.encoder, "stages"):
            last_stage = self._plain_network().encoder.stages[-1]
            ch = None
            for attr in ("output_channels", "out_channels", "num_features"):
                if hasattr(last_stage, attr):
                    ch = int(getattr(last_stage, attr))
                    break
            if ch is None and hasattr(last_stage, "convs") and last_stage.convs:
                ch = int(last_stage.convs[-1].out_channels)
            if ch is None:
                # robust fallback for ResEnc M
                ch = 320
        else:
            raise RuntimeError("Unexpected nnU-Net network structure: encoder/stages not found.")

        # classification head
        self.classifier_head = GlobalPoolHead(ch, n_classes=3).to(self.device)

        # hook to cache bottleneck tensor on forward
        def _save_bottleneck(module, inp, out):
            t = out
            # some stages return tuples/lists; pick the last Tensor inside
            while isinstance(t, (list, tuple)):
                t = t[-1]
            if not isinstance(t, torch.Tensor):
                raise RuntimeError(f"Encoder hook returned non-tensor: {type(t)}")
            self._cached_bottleneck = t
        self._hook_handle = last_stage.register_forward_hook(_save_bottleneck)

        # (optional) env-driven config for your CSV + lambda
        import os, pandas as pd
        self.cls_lambda = float(os.environ.get("CLS_LAMBDA", self.cls_lambda))
        csv_path = os.environ.get("CASE_TO_SUBTYPE_CSV", "")
        if os.path.isfile(csv_path):
            df = pd.read_csv(csv_path)
            self.case_to_subtype = {str(r["case"]): int(r["subtype"]) for _, r in df.iterrows()}
        else:
            print("[TrainerSegCls] WARNING: xxxx_subtypes.csv not found. Using zeros for cls labels.")

        self._best_ckpt_mtime = None  # track when checkpoint_best.pth last changed

        # Resume cls head if available
        latest_head = os.path.join(self.output_folder, "cls_head_latest.pth")
        if os.path.isfile(latest_head):
            try:
                self.classifier_head.load_state_dict(
                    torch.load(latest_head, map_location=self.device), strict=False
                )
                print(f"[TrainerSegCls] Loaded classification head from {latest_head}")
            except Exception as e:
                print(f"[TrainerSegCls] WARNING: could not load {latest_head}: {e}")

    # unwrapping helper (recommended if you use torch.compile or DDP)
    def _plain_network(self):
        net = self.network
        if hasattr(net, "module"):  # DDP
            net = net.module
        try:
            from torch._dynamo import OptimizedModule
            if isinstance(net, OptimizedModule):  # torch.compile
                net = net._orig_mod
        except Exception:
            pass
        return net

    def _cls_targets_from_keys(self, keys, device, batch_size=None):
        # Normalize to a plain Python list of strings
        ks = []
        if keys is None:
            ks = []
        else:
            try:
                import numpy as np
                ks = np.array(keys).ravel().tolist()  # handles list/tuple/np.array scalars
            except Exception:
                ks = list(keys) if isinstance(keys, (list, tuple)) else [keys]

        # Map to integer subtypes (default 0 if missing)
        targets_list = [int(self.case_to_subtype.get(str(k), 0)) for k in ks]

        # If we still have a length mismatch (or empty), build a safe fallback of zeros
        if batch_size is not None and len(targets_list) != batch_size:
            targets_list = [0] * batch_size

        return torch.as_tensor(targets_list, device=device, dtype=torch.long)

    def train_step(self, batch: dict) -> dict:
        data, target = batch['data'], batch['target']
        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        self.optimizer.zero_grad(set_to_none=True)
        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            seg_logits = self.network(data)             # triggers our hook
            feat = self._cached_bottleneck if self._cached_bottleneck is not None else \
                (seg_logits[0] if isinstance(seg_logits, (list, tuple)) else seg_logits)
            cls_logits = self.classifier_head(feat)

            seg_loss = self.loss(seg_logits, target)    # keep base seg loss unchanged
            N = cls_logits.shape[0]
            cls_target = self._cls_targets_from_keys(batch.get('keys', None), cls_logits.device, batch_size=N)
            cls_loss = self.ce_cls(cls_logits, cls_target)
            l = seg_loss + self.cls_lambda * cls_loss

        if self.grad_scaler is not None:
            self.grad_scaler.scale(l).backward()
            self.grad_scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            l.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()
        return {'loss': l.detach().cpu().numpy()}

    def validation_step(self, batch: dict) -> dict:
        data, target = batch['data'], batch['target']
        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        # Autocast can be annoying
        # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
        # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
        # So autocast will only be active if we have a cuda device.
        with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
            seg_logits = self.network(data)
            feat = self._cached_bottleneck if self._cached_bottleneck is not None else \
                (seg_logits[0] if isinstance(seg_logits, (list, tuple)) else seg_logits)
            cls_logits = self.classifier_head(feat)

            seg_loss = self.loss(seg_logits, target)
            N = cls_logits.shape[0]
            cls_target = self._cls_targets_from_keys(batch.get('keys', None), cls_logits.device, batch_size=N)
            cls_loss = self.ce_cls(cls_logits, cls_target)
            l = seg_loss + self.cls_lambda * cls_loss

        # used for Dice calculation and validation logging
        output = seg_logits

        # we only need the output with the highest output resolution (if DS enabled)
        if self.enable_deep_supervision:
            output = output[0]
            target = target[0]

        # the following is needed for online evaluation. Fake dice (green line)
        axes = [0] + list(range(2, output.ndim))

        if self.label_manager.has_regions:
            predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
        else:
            # no need for softmax
            output_seg = output.argmax(1)[:, None]
            predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
            predicted_segmentation_onehot.scatter_(1, output_seg, 1)
            del output_seg

        if self.label_manager.has_ignore_label:
            if not self.label_manager.has_regions:
                mask = (target != self.label_manager.ignore_label).float()
                # CAREFUL that you don't rely on target after this line!
                target[target == self.label_manager.ignore_label] = 0
            else:
                if target.dtype == torch.bool:
                    mask = ~target[:, -1:]
                else:
                    mask = 1 - target[:, -1:]
                # CAREFUL that you don't rely on target after this line!
                target = target[:, :-1]
        else:
            mask = None

        tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)

        tp_hard = tp.detach().cpu().numpy()
        fp_hard = fp.detach().cpu().numpy()
        fn_hard = fn.detach().cpu().numpy()
        if not self.label_manager.has_regions:
            # [1:] in order to remove background
            tp_hard = tp_hard[1:]
            fp_hard = fp_hard[1:]
            fn_hard = fn_hard[1:]

        return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}

    def on_epoch_end(self):
        # Let nnU-Net do its normal bookkeeping & checkpointing first
        super().on_epoch_end()

        # 1) Always save the latest cls head
        latest_head = os.path.join(self.output_folder, "cls_head_latest.pth")
        torch.save(self.classifier_head.state_dict(), latest_head)

        # 2) If the best checkpoint has been updated this epoch, mirror a best head
        best_ckpt = os.path.join(self.output_folder, "checkpoint_best.pth")
        if os.path.isfile(best_ckpt):
            mtime = os.path.getmtime(best_ckpt)
            if self._best_ckpt_mtime is None or mtime > self._best_ckpt_mtime:
                best_head = os.path.join(self.output_folder, "cls_head_best.pth")
                torch.save(self.classifier_head.state_dict(), best_head)
                self._best_ckpt_mtime = mtime
                print(f"[TrainerSegCls] Saved best classification head → {best_head}")

    def on_train_end(self):
        # Call parent to write checkpoint_final etc.
        super().on_train_end()
        final_head = os.path.join(self.output_folder, "cls_head_final.pth")
        torch.save(self.classifier_head.state_dict(), final_head)
        print(f"[TrainerSegCls] Saved final classification head → {final_head}")


Overwriting /content/custom_trainer_cls.py


## Register Custom Trainer so `-tr TrainerSegCls` Works

**Purpose**  
Make your custom trainer class discoverable by nnU-Net’s CLI (`nnUNetv2_train` / `nnUNetv2_predict`) so you can call it with `-tr TrainerSegCls`.

In [None]:
# Register custom trainer inside nnUNet so -tr TrainerSegCls works
import os, shutil, pathlib, sys, textwrap, importlib

import nnunetv2
pkg_dir = pathlib.Path(nnunetv2.__file__).parent
trainer_pkg = pkg_dir / "training" / "nnUNetTrainer"
trainer_pkg.mkdir(parents=True, exist_ok=True)

src = "/content/custom_trainer_cls.py"  # <-- this is where you wrote it earlier
dst = trainer_pkg / "trainer_segcls.py"
assert os.path.exists(src), f"Custom trainer not found at {src}."
shutil.copy2(src, dst)

init_py = trainer_pkg / "__init__.py"
if not init_py.exists():
    init_py.write_text("")
with open(init_py, "a", encoding="utf-8") as f:
    f.write("\nfrom .trainer_segcls import TrainerSegCls\n")

importlib.invalidate_caches()
print("Custom trainer registered at:", dst)
print("You can now use: -tr TrainerSegCls")


Custom trainer registered at: /usr/local/lib/python3.11/dist-packages/nnunetv2/training/nnUNetTrainer/trainer_segcls.py
You can now use: -tr TrainerSegCls


## Plan & Preprocess the Dataset (nnU-Net v2)

This cell runs nnU-Net’s **planning + preprocessing** step for your dataset.

In [None]:

# K80-friendly: use 3d_fullres; small patch size will be auto-determined by nnU-Net v2
!nnUNetv2_plan_and_preprocess -d {DS_ID} -pl nnUNetPlannerResEncM --verify_dataset_integrity


In case we need to update the plan without preprocessing the data again

https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md#point_right-we-recommend-nnu-net-resenc-l-as-the-new-default-nnu-net-configuration-point_left

In [None]:
!nnUNetv2_plan_experiment -d {DS_ID} -pl nnUNetPlannerResEncM

Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. 3d_fullres: [ 59.  117.  180.5], 3d_lowres: [59, 117, 180]
2D U-Net configuration:
{'data_identifier': 'nnUNetPlans_2d', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 134, 'patch_size': (np.int64(128), np.int64(192)), 'median_image_size_in_voxels': array([117. , 180.5]), 'spacing': array([0.73242188, 0.73242188]), 'normalization_schemes': ['CTNormalization'], 'use_mask_for_norm': [False], 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'architecture': {'network_cl

## Launch Training (Custom Trainer)

**Purpose**  
Configure runtime parameters for your **custom nnU-Net v2 trainer** and start a training run.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install nnU-Net v2 and dependencies
!pip install nnunetv2 nibabel simpleitk --quiet

# Set nnU-Net environment variables
import os
os.environ["nnUNet_raw"] = "/content/drive/MyDrive/ColabData/working/nnUNet_raw"
os.environ["nnUNet_preprocessed"] = "/content/drive/MyDrive/ColabData/working/nnUNet_preprocessed"
os.environ["nnUNet_results"] = "/content/drive/MyDrive/ColabData/working/nnUNet_results"


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:

# Register our custom trainer class via PYTHONPATH
import sys
sys.path.append("/content")

# Set environment variables for nnU-Net v2 custom trainer
os.environ["CASE_TO_SUBTYPE_CSV"] = os.path.join(RAW_DS_DIR, "train_subtypes.csv")
os.environ["CLS_LAMBDA"] = "0.3"

# To prevent multithreading errors in Windows
# os.environ["nnUNet_n_proc_DA"] = "0"

# For clearer error messages while debugging
# os.environ["nnUNet_compile"] = "0"

# Train 3d_fullres -tr TrainerSegCls
!nnUNetv2_train {DS_ID} 3d_fullres 0 -tr TrainerSegCls -p nnUNetResEncUNetMPlans --npz --c


Using device: cuda:0

#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################

2025-08-15 18:10:23.207594: Using torch.compile...
[TrainerSegCls] TRAIN with CE weights: [1.3548387289047241, 0.7924528121948242, 1.0]
2025-08-15 18:10:25.581164: do_dummy_2d_data_aug: False
2025-08-15 18:10:25.599903: Using splits from existing split file: /content/drive/MyDrive/ColabData/working/nnUNet_preprocessed/Dataset777_PancreasSegCls/splits_final.json
2025-08-15 18:10:25.610543: The split file contains 5 splits.
2025-08-15 18:10:25.613472: Desired fold for training: 0
2025-08-15 18:10:25.622649: This split has 201 training and 51 validation cases.
using pin_me

## Inference on test set
- Predict segmentations using **nnUNetv2_predict**
- Predict classification using **nnUNetPredictor**

In [None]:
# Predict segmentations
MODEL_BASE   = os.path.join(RESULTS_DS_DIR, f"{TRAINER_NAME}__{PLANS_NAME}__{CONFIG}")
MODEL_FOLD   = os.path.join(MODEL_BASE, f"fold_{FOLD}")

PRED_DIR = os.path.join(MODEL_FOLD, "predictions")  # where you want the CSV saved (can reuse PRED_DIR)
os.makedirs(PRED_DIR, exist_ok=True)

# To prevent multithreading errors in Windows
# os.environ["nnUNet_n_proc_DA"] = "0"

# -tr custom_trainer_cls.TrainerSegCls
!nnUNetv2_predict -i {IMAGES_VAL} -o {PRED_DIR} -d {DS_ID} -c 3d_fullres -f 0 -tr TrainerSegCls -p nnUNetResEncUNetMPlans --disable_tta --save_probabilities

# Evaluate the results to get the summary.json with Dice scores
!nnUNetv2_evaluate_simple {LABELS_VAL} {PRED_DIR} -l 1 2

In [None]:
# Predict classification

MODEL_BASE   = os.path.join(RESULTS_DS_DIR, f"{TRAINER_NAME}__{PLANS_NAME}__{CONFIG}")
MODEL_FOLD   = os.path.join(MODEL_BASE, f"fold_{FOLD}")

TEST_DIR     = IMAGES_VAL  # same as IMAGES_TS above (contains quiz_XXX_0000.nii.gz)
# If you saved the head weights during training, point to them here (recommended):
CLS_HEAD_WEIGHTS = f"{MODEL_FOLD}/cls_head_best.pth"

import os, json, glob, csv, warnings
import torch
import torch.nn as nn
import torch.nn.functional as F

from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p

# use same results tree that training created
CKPT         = join(MODEL_FOLD, "checkpoint_best.pth")
if not os.path.isfile(CKPT):
    CKPT = join(MODEL_FOLD, "checkpoint_final.pth")
assert os.path.isfile(CKPT), f"Model checkpoint not found: {CKPT}"

plans_json_path   = join(MODEL_BASE, "plans.json")
dataset_json_path = join(MODEL_BASE, "dataset.json")
assert os.path.isfile(plans_json_path) and os.path.isfile(dataset_json_path), "plans.json/dataset.json missing in model folder"

plans_dict   = json.load(open(plans_json_path, "r"))
dataset_json = json.load(open(dataset_json_path, "r"))

plans_manager      = PlansManager(plans_dict)
configuration_mgr  = plans_manager.get_configuration(CONFIG)
label_manager      = plans_manager.get_label_manager(dataset_json)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use nnUNetPredictor only to do preprocessing & sliding-window — we do NOT export segmentations here
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True,
    perform_everything_on_device=True,
    device=device,
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=False
)
# Make sure MODEL_BASE points at: .../nnUNet_results/DATASET_NAME/TRAINER__PLANS__CONFIG
# and CKPT is "checkpoint_best.pth" or "checkpoint_final.pth" in fold_{FOLD}
predictor.initialize_from_trained_model_folder(
    model_training_output_dir=MODEL_BASE,
    use_folds=(FOLD,),
    checkpoint_name=os.path.basename(CKPT),  # e.g. "checkpoint_best.pth"
)

# Build the segmentation net exactly like training (we’ll only use its encoder features)
num_input_channels  = 1  # CT single-channel; adjust if you trained multi-channel
num_output_channels = label_manager.num_segmentation_heads

predictor.network.to(device)

# Load seg weights (so encoder features match training)
ckpt = torch.load(CKPT, map_location=device, weights_only=False)
new_state = {}
for k, v in ckpt['network_weights'].items():
    key = k[7:] if k.startswith("module.") and not next(iter(predictor.network.state_dict())).startswith("module.") else k
    new_state[key] = v
missing, unexpected = predictor.network.load_state_dict(new_state, strict=False)
if missing:    warnings.warn(f"Missing seg keys: {missing[:10]}{'...' if len(missing)>10 else ''}")
if unexpected: warnings.warn(f"Unexpected seg keys: {unexpected[:10]}{'...' if len(unexpected)>10 else ''}")

# Classification head (dimension-agnostic) — LazyLinear infers in_features on first call
class GlobalPoolHead(nn.Module):
    def __init__(self, n_classes: int = 3):
        super().__init__()
        self.fc = nn.LazyLinear(n_classes)
    def forward(self, feat: torch.Tensor) -> torch.Tensor:
        if feat.dim() == 5:
            x = F.adaptive_avg_pool3d(feat, 1).flatten(1)
        elif feat.dim() == 4:
            x = F.adaptive_avg_pool2d(feat, 1).flatten(1)
        else:
            raise ValueError(f"Unexpected feature rank {feat.dim()} (expected 4 or 5)")
        return self.fc(x)

cls_head = GlobalPoolHead(n_classes=3).to(device)
if CLS_HEAD_WEIGHTS and os.path.isfile(CLS_HEAD_WEIGHTS):
    try:
        cls_sd = torch.load(CLS_HEAD_WEIGHTS, map_location=device)
        cls_head.load_state_dict(cls_sd, strict=False)
    except Exception as e:
        warnings.warn(f"Could not load classification head weights: {e}\nProceeding with random init.")
else:
    warnings.warn("Classification head weights not found; proceeding with random init (predictions will be unreliable).")

# Attach a hook on the last encoder stage to collect logits for each sliding-window patch
_cached_logits = []
def _bottleneck_to_logits(module, inp, out):
    t = out
    while isinstance(t, (list, tuple)):
        t = t[-1]
    if isinstance(t, torch.Tensor):
        with torch.no_grad():
            _cached_logits.append(cls_head(t).detach())  # [B, 3]

enc_last = predictor.network.encoder.stages[-1]
hook_handle = enc_last.register_forward_hook(_bottleneck_to_logits)

maybe_mkdir_p(PRED_DIR)
nii_ext = dataset_json.get("file_ending", ".nii.gz")

test_imgs = sorted(glob.glob(os.path.join(TEST_DIR, "*_0000.nii.gz")))
assert len(test_imgs) > 0, f"No test images found in {TEST_DIR}"

csv_path = os.path.join(PRED_DIR, "subtype_results.csv")
with open(csv_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Names", "Subtype"])

    for img_path in test_imgs:
        case_id = os.path.basename(img_path).replace("_0000.nii.gz", "")  # quiz_XXX
        # reset aggregator
        _cached_logits.clear()

        # Preprocess + sliding window forward (this fills _cached_logits via hook)
        with torch.no_grad():
            img_np, props = SimpleITKIO().read_images([img_path])  # props include spacing, etc.
            # This internally preprocesses then calls the same sliding-window logic.
            _ = predictor.predict_single_npy_array(img_np, props, None, None, False)

        # Aggregate classification logits over all windows → case-level class
        if len(_cached_logits) == 0:
            warnings.warn(f"No classification logits captured for {case_id}; defaulting to 0")
            subtype = 0
        else:
            case_logits = torch.cat(_cached_logits, dim=0).mean(dim=0)  # [3]
            subtype = int(case_logits.argmax().item())

        writer.writerow([f"{case_id}{nii_ext}", subtype])

print(f"Wrote CSV: {csv_path}")


# **Evaluation**

*   Segmentation
    - DSC scores for whole pancreas & lesion

*   Classification
    - Macro-average F1 score

In [None]:
## --------- Calculate DSC scores ---------- ##
import os, re, json, numpy as np
import SimpleITK as sitk  # pip install SimpleITK if missing

LESION_LABEL = 2

def strip_case(p):
    b = os.path.basename(p)
    b = re.sub(r"\.nii(\.gz)?$", "", b, flags=re.I)
    b = re.sub(r"_[0-9]{4}$", "", b)  # drop _0000
    return b

def read_nii(p):
    img = sitk.ReadImage(p)
    return sitk.GetArrayFromImage(img)

def dice_bin(pred, gt, eps=1e-6):
    pred = pred.astype(bool); gt = gt.astype(bool)
    tp = np.logical_and(pred, gt).sum(dtype=np.int64)
    fp = np.logical_and(pred, ~gt).sum(dtype=np.int64)
    fn = np.logical_and(~pred, gt).sum(dtype=np.int64)
    return float((2*tp + eps) / (2*tp + fp + fn + eps))

# Build GT <-> Pred pairs (by case id)
gt_files = [f for f in os.listdir(LABELS_VAL) if f.endswith((".nii",".nii.gz"))]
pairs, missing = [], []
for f in gt_files:
    cid = strip_case(f)
    pred = None
    for pf in os.listdir(PRED_DIR):
        if pf.endswith((".nii",".nii.gz")) and strip_case(pf) == cid:
            pred = os.path.join(PRED_DIR, pf); break
    if pred is None:
        missing.append(cid)
    else:
        pairs.append((os.path.join(LABELS_VAL, f), pred))

whole, les_all, les_pos = [], [], []
for gt_p, pr_p in pairs:
    gt = read_nii(gt_p); pr = read_nii(pr_p)
    if gt.shape != pr.shape:
        raise ValueError(f"Shape mismatch for {os.path.basename(gt_p)}: GT {gt.shape} vs PRED {pr.shape}")
    # Whole-pancreas = union(1,2) -> label > 0
    whole.append(dice_bin(pr > 0, gt > 0))
    # Lesion = label 2
    gL, pL = (gt == LESION_LABEL), (pr == LESION_LABEL)
    dL = dice_bin(pL, gL)
    les_all.append(dL)
    if gL.any(): les_pos.append(dL)

print("n_eval =", len(pairs), " | n_missing =", len(missing))
if missing: print("Missing examples:", missing[:10], "...")
print(f"Whole-pancreas DSC (mean): {np.mean(whole):.4f}")
print(f"Lesion DSC (mean; lesion-positive cases): {np.mean(les_pos) if les_pos else float('nan'):.4f}")
print(f"Lesion DSC (mean; across all cases): {np.mean(les_all):.4f}")


In [1]:
## --------- Calculate macro-average F1 score ---------- ##
import csv, os, re, json
from collections import OrderedDict
from sklearn.metrics import f1_score, confusion_matrix, classification_report

# === CONFIG: set these two paths ===
PRED_CSV = os.path.join(PRED_DIR, "subtype_results.csv")
GT_CSV   = os.path.join(RAW_DS_DIR, "val_subtypes.csv")

# (optional) If you want to restrict to exactly the validation IDs:
VAL_IDS = None  # e.g., ['quiz_0_168', 'quiz_0_041', ...] or load from splits_final.json

def strip_case(name: str) -> str:
    b = os.path.basename(str(name))
    b = re.sub(r"\.nii(\.gz)?$", "", b, flags=re.I)  # drop .nii/.nii.gz
    b = re.sub(r"_[0-9]{4}$", "", b)                 # drop _0000, if present
    return b

def read_csv_dict(path: str):
    with open(path, "r", newline="") as f:
        return list(csv.DictReader(f))

def get_col(row: dict, prefs):
    for k in prefs:
        if k in row: return row[k]
    raise KeyError(f"Expected one of columns {prefs}, got {list(row.keys())}")

# load predictions: Names, Subtype
pred_rows = read_csv_dict(PRED_CSV)
pred_map = OrderedDict()
for r in pred_rows:
    name = strip_case(get_col(r, ("Names","Name","case","Case","id","ID")))
    lab  = int(get_col(r, ("Subtype","subtype","label","Label","class","Class")))
    pred_map[name] = lab

# load ground truth: case, subtype
gt_rows = read_csv_dict(GT_CSV)
gt_map = OrderedDict()
for r in gt_rows:
    name = strip_case(get_col(r, ("case","Case","Names","Name","id","ID")))
    lab  = int(get_col(r, ("subtype","Subtype","label","Label","class","Class")))
    gt_map[name] = lab

# intersection (optionally filter to VAL_IDS)
common = set(pred_map.keys()) & set(gt_map.keys())
if VAL_IDS is not None:
    common &= set(map(str, VAL_IDS))
common = sorted(common)

if not common:
    raise RuntimeError("No overlapping case IDs between PRED_CSV and GT_CSV. Check filenames/columns.")

y_true = [gt_map[k] for k in common]
y_pred = [pred_map[k] for k in common]

macro_f1 = f1_score(y_true, y_pred, average="macro")
print(f"Macro-average F1 (3 classes) on {len(common)} cases: {macro_f1:.4f}")

# (nice to have) per-class and confusion matrix
print("\nPer-class report:\n", classification_report(y_true, y_pred, digits=3))
print("Confusion matrix (labels in order):", sorted(set(y_true) | set(y_pred)))
print(confusion_matrix(y_true, y_pred, labels=sorted(set(y_true) | set(y_pred))))


NameError: name 'PRED_DIR' is not defined