# Pancreas Segmentation (nnU-Net v2 ResEnc M) + Subtype Classification

This notebook follows the quiz requirements:
- nnU-Net v2 (ResEnc **M**) for **segmentation**.
- Adds a **classification** component using a separate lightweight 3D CNN trained on the same training images (subtype0/1/2).
- Packaging: produces `quiz_*.nii.gz` masks and a `subtype_results.csv` with `Names,Subtype`.

It also:
- Persists preprocessed data & results to Google Drive (so you don't lose work on Colab).
- Uses `--disable_tta` to achieve ≥10% inference speed-up while maintaining accuracy.


## Manual backup/restore saved data

In [None]:
# --- Save nnU-Net + classification outputs to Drive ---
import os, shutil
from google.colab import drive

# 1) Mount Drive (safe to run multiple times)
drive.mount('/content/drive', force_remount=False)

# 2) Where to save
DRIVE_ROOT = '/content/drive/MyDrive'
os.makedirs(DRIVE_ROOT, exist_ok=True)

# 3) What to save (edit as needed)
to_save = {
  # nnU-Net stuff
  # 'preprocessed': '/content/nnUNet_preprocessed/Dataset500_PancreasCancer',
  # 'results': '/content/nnUNet_results',
  'submission_outputs': '/content/submission_outputs',

  # "splits_final.json": '/content/nnUNet_preprocessed/Dataset500_PancreasCancer/splits_final.json',
  #'raw_dataset': '/content/nnUNet_raw',
}

def copy_any(src, dst):
  if os.path.isdir(src):
    if os.path.exists(dst): shutil.rmtree(dst)
    shutil.copytree(src, dst)
    print(f"✅ Copied DIR: {src} -> {dst}")
  elif os.path.isfile(src):
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    shutil.copy2(src, dst)
    print(f"✅ Copied FILE: {src} -> {dst}")
  else:
    print(f"⚠️ Missing (skip): {src}")

for name, src in to_save.items():
  dst = os.path.join(DRIVE_ROOT, os.path.basename(src))
  copy_any(src, dst)

print("\nDone! Saved to:", DRIVE_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Copied DIR: /content/submission_outputs -> /content/drive/MyDrive/submission_outputs

Done! Saved to: /content/drive/MyDrive


In [7]:
# --- Load/restore artifacts from Drive into current Colab session ---
import os, shutil
from google.colab import drive

# 1) Mount Drive
drive.mount('/content/drive', force_remount=False)

# 2) Source on Drive (must match the save cell)
DRIVE_ROOT = '/content/drive/MyDrive'

# 3) Where to restore locally
restore_map = {
  # os.path.join(DRIVE_ROOT, 'nnUNet_preprocessed'): '/content/nnUNet_preprocessed',
  os.path.join(DRIVE_ROOT, 'nnUNet_results not very ideal'): '/content/nnUNet_results',
  # os.path.join(DRIVE_ROOT, 'nnUNet_predictions'): '/content/predictions',
  # os.path.join(DRIVE_ROOT, 'case_subtype_mapping.csv'): '/content/case_subtype_mapping.csv',
  # os.path.join(DRIVE_ROOT, 'subtype_results.csv'): '/content/subtype_results.csv',
  # os.path.join(DRIVE_ROOT, 'nnUNet_raw'): '/content/nnUNet_raw',
  # os.path.join(DRIVE_ROOT, "splits_final.json"): '/content/nnUNet_preprocessed/Dataset500_PancreasCancer/splits_final.json',
}

def restore_any(src, dst):
  if os.path.isdir(src):
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    if os.path.exists(dst): shutil.rmtree(dst)
    shutil.copytree(src, dst)
    print(f"⬇️ Restored DIR: {src} -> {dst}")
  elif os.path.isfile(src):
    os.makedirs(os.path.dirname(dst), exist_ok=True)
    shutil.copy2(src, dst)
    print(f"⬇️ Restored FILE: {src} -> {dst}")
  else:
    print(f"⚠️ Not found on Drive (skip): {src}")

for src, dst in restore_map.items():
  restore_any(src, dst)

print("\nReady! Restored from:", DRIVE_ROOT)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
⬇️ Restored DIR: /content/drive/MyDrive/nnUNet_results not very ideal -> /content/nnUNet_results

Ready! Restored from: /content/drive/MyDrive


## Mount Google Drive and install nnUNetv2


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# Install nnU-Net v2 (ResEnc M) and CLI tools
%cd /content
!git clone https://github.com/MIC-DKFZ/nnUNet.git

%cd nnUNet
!pip install -e .

In [1]:
import os

# Set nnU-Net environment paths
os.environ['nnUNet_raw'] = '/content/nnUNet_raw'
os.environ['nnUNet_preprocessed'] = '/content/nnUNet_preprocessed'
os.environ['nnUNet_results'] = '/content/nnUNet_results'

dataset = 'Dataset500_PancreasCancer'
root = os.path.join(os.environ['nnUNet_raw'], dataset)
for sub in ['imagesTr', 'labelsTr', 'imagesTs']:
    os.makedirs(os.path.join(root, sub), exist_ok=True)

## Data augmentation: convert data to nnU-Net layout


In [None]:
# Convert & rename ML-Quiz-3DMedImg data into nnU-Net format
import nibabel as nib
import numpy as np
import shutil
from glob import glob

source = '/content/drive/MyDrive/ML-Quiz-3DMedImg/data'  # USE YOUR OWN PATH
target = root  # from previous cell
counter = 0

def copy_case(img, lbl):
    global counter
    cid = f"pancreas_{counter:03d}"
    # Copy training image
    shutil.copy(img, os.path.join(target, 'imagesTr', f'{cid}_0000.nii.gz'))
    # Load, fix label, and save (from float to int)
    data = nib.load(lbl)
    arr = np.rint(data.get_fdata()).astype(np.int16)
    nib.save(nib.Nifti1Image(arr, data.affine, data.header),
             os.path.join(target, 'labelsTr', f'{cid}.nii.gz'))
    counter += 1

# Training data only (exclude validation folder)
for subtype in ['subtype0', 'subtype1', 'subtype2']:
    folder = f"{source}/train/{subtype}"
    for img in sorted(glob(f"{folder}/*_0000.nii.gz")):
        lbl = img.replace('_0000.nii.gz', '.nii.gz')
        if os.path.exists(lbl):
            copy_case(img, lbl)

# Test images
for img in sorted(glob(f"{source}/test/*.nii.gz")):
    shutil.copy(img, os.path.join(target, 'imagesTs', os.path.basename(img)))


Saved classification mapping to: /content/nnUNet_raw/Dataset500_PancreasCancer/classification_labels.json and /content/nnUNet_raw/Dataset500_PancreasCancer/classification_labels.csv


## Create the case-subtype mapping

In [None]:
# --- Check mapping coverage and, if needed, rebuild it to include all imagesTr ---
import os, re, hashlib
from pathlib import Path
import nibabel as nib
import numpy as np
import pandas as pd
import json

# Where your original dataset with subtype folders lives (adjust if needed)
ORIG_ROOTS_TRY = [
    "/content/drive/MyDrive/ML-Quiz-3DMedImg/data",  # must contain train/subtype0..2 and validation/subtype0..2
]
ORIG_ROOT = None
for cand in ORIG_ROOTS_TRY:
    p = Path(cand)
    if (p/"train").exists() and (p/"validation").exists():
        ORIG_ROOT = p
        break
assert ORIG_ROOT is not None, "Set ORIG_ROOT to the folder containing train/subtype*/ and validation/subtype*/"

DATASET_ID   = "500"
RAW_ROOT     = Path(f"/content/nnUNet_raw/Dataset{DATASET_ID}_PancreasCancer")
PREP_ROOT    = Path(f"/content/nnUNet_preprocessed/Dataset{DATASET_ID}_PancreasCancer")
IM_TR        = RAW_ROOT / "imagesTr"
MAPPING_CSV  = Path("/content/case_subtype_mapping.csv")

assert IM_TR.exists(), f"{IM_TR} not found"

# quick, robust fingerprint so we can match originals to nnUNet files
def quick_fp(img_path, stride=2000):
    img = nib.load(str(img_path))
    arr = img.get_fdata(dtype=np.float32)
    shp = tuple(arr.shape)
    sample = arr.ravel()[::max(1, min(arr.size, stride))]
    h = hashlib.md5(sample.tobytes()).hexdigest()
    return (str(arr.dtype), shp, h)

def scan_originals(root):
    fp_to_meta = {}
    for split in ["train", "validation"]:
        for subdir in (root/split).glob("subtype*"):
            if not subdir.is_dir():
                continue
            m = re.search(r"subtype(\d+)", subdir.name)
            if not m:
                continue
            subtype = int(m.group(1))
            for img_path in sorted(subdir.rglob("*_0000.nii*")):
                try:
                    fp = quick_fp(img_path)
                    fp_to_meta[fp] = {"orig_path": str(img_path), "subtype": subtype}
                except Exception as e:
                    print("Skip", img_path, "->", e)
    return fp_to_meta

def scan_imagesTr(nnunet_im_tr):
    case_to_fp = {}
    for p in sorted(nnunet_im_tr.glob("*_0000.nii*")):
        case_id = p.name.replace("_0000.nii.gz", "").replace("_0000.nii", "")
        try:
            fp = quick_fp(p)
            case_to_fp[case_id] = (p, fp)
        except Exception as e:
            print("Skip", p, "->", e)
    return case_to_fp

print("Scanning originals...")
orig_fp = scan_originals(ORIG_ROOT)
print("Scanning imagesTr...")
case_to_fp = scan_imagesTr(IM_TR)

rows, missing = [], []
for cid, (p, fp) in case_to_fp.items():
    meta = orig_fp.get(fp)
    if meta is None:
        missing.append(cid)
        continue
    rows.append({"case_id": cid, "subtype": int(meta["subtype"])})

df = pd.DataFrame(rows).sort_values("case_id")
df.to_csv(MAPPING_CSV, index=False)
print(f"\n✅ Wrote {MAPPING_CSV} with {len(df)} rows")
if missing:
    print(f"⚠️ {len(missing)} imagesTr did not match any original by fingerprint. Examples: {missing[:8]}")

# Coverage vs nnUNet split (so you can see train & val presence)
with open(PREP_ROOT/"splits_final.json") as f:
    split = json.load(f)[0]  # fold 0
train_ids = set(split["train"])
val_ids   = set(split["val"])

have = set(df["case_id"].astype(str))
print(f"Coverage: mapping covers {len(have)} / {len(list(IM_TR.glob('*_0000.nii*')))} imagesTr")
print(f"  train covered: {len(train_ids & have)} / {len(train_ids)}")
print(f"  val   covered: {len(val_ids   & have)} / {len(val_ids)}")

print("\nCounts per subtype in mapping:")
print(df["subtype"].value_counts().sort_index())


## Generate dataset.json for nnU-Net (integer labels)


In [None]:
#cell 9: Write dataset.json
import json

tr_imgs = sorted(glob(f'{target}/imagesTr/*.nii.gz'))
train_entries = []
for p in tr_imgs:
    cid = os.path.basename(p).replace('_0000.nii.gz', '')
    train_entries.append({
        'image': f'./imagesTr/{cid}_0000.nii.gz',
        'label': f'./labelsTr/{cid}.nii.gz'
    })

ts_imgs = sorted(glob(f'{target}/imagesTs/*.nii.gz'))
test_entries = [f'./imagesTs/{os.path.basename(x)}' for x in ts_imgs]

ds = {
    'name': 'PancreasCancer',
    'description': 'Multi-task pancreas segmentation + classification',
    'tensorImageSize': '3D',
    'channel_names': {'0': 'CT'},
    'labels': {'background':'0','pancreas':'1','lesion':'2'},
    'numTraining': len(train_entries),
    'file_ending': '.nii.gz',
}

with open(f'{target}/dataset.json', 'w') as f:
    json.dump(ds, f, indent=4)


## Preprocess (plan & preprocess)


In [None]:
!nnUNetv2_extract_fingerprint -d 500 --verify_dataset_integrity

Dataset500_PancreasCancer
Using <class 'nnunetv2.imageio.simpleitk_reader_writer.SimpleITKIO'> as reader/writer

####################
verify_dataset_integrity Done. 
If you didn't see any error messages then your dataset is most likely OK!
####################

Using <class 'nnunetv2.imageio.simpleitk_reader_writer.SimpleITKIO'> as reader/writer
100% 252/252 [00:50<00:00,  4.97it/s]


In [None]:
!nnUNetv2_plan_experiment -d 500 -pl nnUNetPlannerResEncM # you can choose M/L/XL model

Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. 3d_fullres: [ 59. 118. 181.], 3d_lowres: [59, 118, 181]
2D U-Net configuration:
{'data_identifier': 'nnUNetPlans_2d', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 134, 'patch_size': (np.int64(128), np.int64(192)), 'median_image_size_in_voxels': array([118., 181.]), 'spacing': array([0.73046875, 0.73046875]), 'normalization_schemes': ['CTNormalization'], 'use_mask_for_norm': [False], 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'architecture': {'network_class_n

In [None]:
!nnUNetv2_preprocess -d 500 -p nnUNetResEncUNetMPlans # Change nnUNetResEncUNetXPlans with X=M/L/XL

Preprocessing dataset Dataset500_PancreasCancer
Configuration: 2d...
100% 252/252 [13:03<00:00,  3.11s/it]
Configuration: 3d_fullres...
100% 252/252 [07:41<00:00,  1.83s/it]
Configuration: 3d_lowres...
INFO: Configuration 3d_lowres not found in plans file nnUNetResEncUNetMPlans.json of dataset Dataset500_PancreasCancer. Skipping.


## Autosync (when training while I sleep)


In [None]:
# --- Configure autosync targets & mount Drive ---
from google.colab import drive
import os

# 1) Mount (safe to run multiple times)
drive.mount('/content/drive', force_remount=False)

# 2) Where to put backups on Drive
DRIVE_ROOT = '/content/drive/MyDrive'
os.makedirs(DRIVE_ROOT, exist_ok=True)

# 3) What to sync (add/remove as you like)
SYNC_DIRS = [
    '/content/nnUNet_results',                                 # models, logs, checkpoints  <-- important
    # '/content/nnUNet_preprocessed/Dataset500_PancreasCancer/splits_final.json'
]

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# --- Start background autosync thread ---
import threading, time, subprocess, datetime, shutil

SYNC_EVERY_SEC = 10 * 60

def _rsync_dir(src, dst):
    os.makedirs(dst, exist_ok=True)
    # trailing slashes so it syncs *contents* of src into dst
    cmd = ['rsync', '-a', '--delete', src.rstrip('/') + '/', dst.rstrip('/') + '/']
    subprocess.run(cmd, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

_autosync_stop = threading.Event()

def _sync_loop():
    while not _autosync_stop.is_set():
        started = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        for src in SYNC_DIRS:
            if os.path.exists(src):
                dst = os.path.join(DRIVE_ROOT, os.path.basename(src))
                _rsync_dir(src, dst)
        print(f"💾 Autosync completed at {started}")
        # Wait, but allow fast stop
        _autosync_stop.wait(SYNC_EVERY_SEC)

# Launch the background thread (daemon so it doesn't block runtime shutdown)
if '_autosync_thread' in globals() and _autosync_thread.is_alive():
    print("Autosync already running.")
else:
    _autosync_stop.clear()
    _autosync_thread = threading.Thread(target=_sync_loop, daemon=True)
    _autosync_thread.start()
    print("✅ Autosync started. Will run every", SYNC_EVERY_SEC//60, "minutes.")


💾 Autosync completed at 2025-08-15 10:19:06
✅ Autosync started. Will run every 10 minutes.


In [None]:
# --- Stop background autosync ---
try:
    _autosync_stop.set()
    print("🛑 Stopping autosync…")
except NameError:
    print("Autosync was not running.")

🛑 Stopping autosync…


## Install and create the multi-task network files


In [None]:
!pip uninstall -y nnunetv2 nnunet
!pip install --no-cache-dir "git+https://github.com/MIC-DKFZ/nnUNet.git@master"
# if you cloned the repo to /content/nnUNet and are using it in editable mode:
!cd /content/nnUNet && git pull && pip install -e .

Found existing installation: nnunetv2 2.6.2
Uninstalling nnunetv2-2.6.2:
  Successfully uninstalled nnunetv2-2.6.2
[0mCollecting git+https://github.com/MIC-DKFZ/nnUNet.git@master
  Cloning https://github.com/MIC-DKFZ/nnUNet.git (to revision master) to /tmp/pip-req-build-d4e3lnv1
  Running command git clone --filter=blob:none --quiet https://github.com/MIC-DKFZ/nnUNet.git /tmp/pip-req-build-d4e3lnv1
  Resolved https://github.com/MIC-DKFZ/nnUNet.git to commit f1851fbaf2c53dcb51b079b60a01de528a7d0c17
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting argparse (from unittest2->batchgenerators>=0.25.1->nnunetv2==2.6.2)
  Downloading argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Downloading argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Building wheels for collected packages: nnunetv2
  Building wheel for nnunetv2 (pyproject.toml) ... [?25l[?25hdone
  Cre

Already up to date.
Obtaining file:///content/nnUNet
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting argparse (from unittest2->batchgenerators>=0.25.1->nnunetv2==2.6.2)
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Building wheels for collected packages: nnunetv2
  Building editable for nnunetv2 (pyproject.toml) ... [?25l[?25hdone
  Created wheel for nnunetv2: filename=nnunetv2-2.6.2-0.editable-py3-none-any.whl size=16742 sha256=f1ef438baabe53a86620efab12f00b0b88742bc620d7f1b949f91b2dc4fe919a
  Stored in directory: /tmp/pip-ephem-wheel-cache-xzp4v8bb/wheels/2a/a6/3a/a708be8093b6d70dcfceb5de9867bfa2a75ecf91e22d90b1fc
Successfully built nnunetv2
Installing collected packages: argparse, nnunet

In [2]:
%%writefile /content/nnUNet/nnunetv2/training/nnUNetTrainer/MultiTaskTrainer.py
import os, math, json
import numpy as np
import pandas as pd
from typing import Union, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from importlib import import_module

from nnunetv2.utilities.helpers import dummy_context
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.training.loss.dice import get_tp_fp_fn_tn

from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.residual import BasicBlockD


# -------------------------
#   Network with ROI-masked GeM pooling for cls head (NaN-safe)
# -------------------------
class MultiTaskResEnc(nn.Module):
    """Shared encoder + Segmentation decoder + Classification head"""

    @property
    def decoder(self):
        return self.segmentation_net.decoder

    def __init__(self, input_channels, n_stages, features_per_stage, conv_op,
                 kernel_sizes, strides, n_blocks_per_stage, num_segmentation_classes,
                 num_classification_classes=3, n_conv_per_stage_decoder=None,
                 conv_bias=False, norm_op=None, norm_op_kwargs=None,
                 dropout_op=None, dropout_op_kwargs=None, nonlin=None,
                 nonlin_kwargs=None, deep_supervision=False, block=None,
                 cls_stopgrad_through_encoder=True):
        super().__init__()
        block = block or BasicBlockD

        self.segmentation_net = ResidualEncoderUNet(
            input_channels=input_channels,
            n_stages=n_stages,
            features_per_stage=features_per_stage,
            conv_op=conv_op,
            kernel_sizes=kernel_sizes,
            strides=strides,
            n_blocks_per_stage=n_blocks_per_stage,
            num_classes=num_segmentation_classes,
            n_conv_per_stage_decoder=n_conv_per_stage_decoder,
            conv_bias=conv_bias,
            norm_op=norm_op,
            norm_op_kwargs=norm_op_kwargs,
            dropout_op=dropout_op,
            dropout_op_kwargs=dropout_op_kwargs,
            nonlin=nonlin,
            nonlin_kwargs=nonlin_kwargs,
            deep_supervision=deep_supervision,
            block=block
        )

        bottleneck_features = features_per_stage[-1] if isinstance(features_per_stage, (list, tuple)) else features_per_stage
        self.cls_stopgrad_through_encoder = cls_stopgrad_through_encoder

        # GeM pooling exponent for masked pooling
        self.gem_p = nn.Parameter(torch.tensor(3.0))

        # Fallback GAP (3D)
        self.global_pool = nn.AdaptiveAvgPool3d(1) if conv_op == nn.Conv3d else nn.AdaptiveAvgPool2d(1)

        self.classification_head = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(bottleneck_features, 256), nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 128), nn.ReLU(inplace=True),
            nn.Dropout(0.15),
            nn.Linear(128, num_classification_classes)
        )

    @staticmethod
    def _nan_to_num(x, lim=1e6):
        # replace NaN/Inf and clamp to avoid overflow in downstream losses
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return x.clamp(min=-lim, max=lim)

    def _masked_gem_pool(self, bottleneck, seg_logits, eps=1e-6, tiny_roi_frac=5e-4):
        """
        bottleneck: [B, C, d, h, w]
        seg_logits: [B, Cseg, D, H, W] (main head resolution)
        Returns pooled: [B, C, 1, 1, 1]
        """
        seg_logits = self._nan_to_num(seg_logits)
        # soft foreground prob
        if seg_logits.shape[1] > 1:
            fg_prob = torch.softmax(seg_logits, dim=1)[:, 1:, ...].max(1, keepdim=True)[0]
        else:
            fg_prob = torch.sigmoid(seg_logits)

        fg_prob = self._nan_to_num(fg_prob).clamp(0, 1)

        # downsample mask to bottleneck size
        if fg_prob.shape[2:] != bottleneck.shape[2:]:
            fg_prob = F.adaptive_avg_pool3d(fg_prob, bottleneck.shape[2:])
            fg_prob = self._nan_to_num(fg_prob).clamp(0, 1)

        # safe dtype math in fp32, cast back at end
        orig_dtype = bottleneck.dtype
        x = self._nan_to_num(bottleneck.float())
        m = fg_prob.float()

        # tiny ROI -> fall back to GAP
        valid = m.sum(dim=(2, 3, 4), keepdim=True)
        total = float(np.prod(bottleneck.shape[2:]))
        tiny = (valid / (total + eps)) < tiny_roi_frac  # [B,1,1,1,1] boolean

        # GeM over ROI
        p = torch.clamp(self.gem_p, 1.0, 6.0)
        x_pos = x.clamp_min(0)
        num = ((x_pos ** p) * m).sum(dim=(2, 3, 4), keepdim=True)
        den = valid.clamp_min(1.0)
        gem_roi = (num / den).clamp_min(eps) ** (1.0 / p)

        # GAP fallback
        gap = F.adaptive_avg_pool3d(x, 1)

        pooled = torch.where(tiny, gap, gem_roi).to(orig_dtype)
        return self._nan_to_num(pooled)

    def forward(self, x, return_both: bool = False):
        seg_logits = self.segmentation_net(x)
        seg_logits_main = seg_logits[0] if isinstance(seg_logits, (list, tuple)) else seg_logits
        seg_logits_main = self._nan_to_num(seg_logits_main)

        # encoder features
        enc_out = self.segmentation_net.encoder(x)
        bottleneck = enc_out[-1] if isinstance(enc_out, (list, tuple)) else enc_out

        # stop-grad for cls path (optional)
        bn_for_cls = bottleneck.detach() if self.cls_stopgrad_through_encoder else bottleneck

        pooled = self._masked_gem_pool(bn_for_cls, seg_logits_main)  # [B,C,1,1,1]
        cls_logits = self.classification_head(pooled.view(pooled.size(0), -1))
        cls_logits = self._nan_to_num(cls_logits)

        if return_both or self.training:
            return seg_logits, cls_logits

        # Predictor expects segmentation only
        return seg_logits


# -------------------------
#   Trainer
# -------------------------
class MultiTaskTrainer(nnUNetTrainer):
    def __init__(self, plans, configuration, fold, dataset_json,
                 device=torch.device('cuda')):
        super().__init__(plans, configuration, fold, dataset_json, device)

        # ---- classification weight ramps up ----
        self.cls_weight_start = 0.10
        self.cls_weight_end   = 0.55
        self.cls_weight_ramp_epochs = 25  # linear ramp

        # ---- classification loss knobs ----
        self.cls_label_smoothing = 0.05
        self.cls_logit_adj_tau   = 1.0

        # mapping & priors
        self.case_to_subtype = self._load_classification_labels()
        print(f"Loaded {len(self.case_to_subtype)} classification labels")

        counts = np.bincount(list(self.case_to_subtype.values()), minlength=3)
        priors = counts / max(1, counts.sum())
        inv = 1.0 / np.clip(counts, 1, None)
        inv = inv / inv.mean() if inv.sum() > 0 else np.ones_like(inv, dtype=float)

        self.cls_class_weights = torch.tensor(inv, dtype=torch.float32, device=self.device)
        self.cls_log_prior     = torch.log(torch.tensor(np.clip(priors, 1e-8, 1.0), dtype=torch.float32, device=self.device))

        self.classification_loss = nn.CrossEntropyLoss(
            weight=self.cls_class_weights, label_smoothing=self.cls_label_smoothing
        )

        # keep DS off: we feed only the main seg map to the loss
        self.enable_deep_supervision = False

        # bookkeeping
        self.classification_metrics = {'train_acc': [], 'val_acc': []}
        self._epoch_idx = -1
        self._val_cls_acc_epoch = []
        self.best_val_cls_acc = -1.0
        self.best_cls_head_path = os.path.join(self.output_folder, "checkpoint_best_cls_head.pth")

        # stability: warm-up in fp32 for a few steps to avoid AMP NaNs at start
        self._train_step_count = 0
        self._fp32_warmup_steps = 6

    @staticmethod
    def build_network_architecture(architecture_class_name: str,
                                   arch_init_kwargs: dict,
                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                   num_input_channels: int,
                                   num_output_channels: int,
                                   enable_deep_supervision: bool = True) -> nn.Module:
        kwargs = dict(arch_init_kwargs)
        if arch_init_kwargs_req_import:
            for k in arch_init_kwargs_req_import:
                v = kwargs.get(k)
                if isinstance(v, str):
                    mod, attr = v.rsplit('.', 1)
                    kwargs[k] = getattr(import_module(mod), attr)

        return MultiTaskResEnc(
            input_channels=num_input_channels,
            n_stages=kwargs['n_stages'],
            features_per_stage=kwargs['features_per_stage'],
            conv_op=kwargs['conv_op'],
            kernel_sizes=kwargs['kernel_sizes'],
            strides=kwargs['strides'],
            n_blocks_per_stage=kwargs['n_blocks_per_stage'],
            num_segmentation_classes=num_output_channels,
            num_classification_classes=3,
            n_conv_per_stage_decoder=kwargs['n_conv_per_stage_decoder'],
            conv_bias=kwargs['conv_bias'],
            norm_op=kwargs['norm_op'],
            norm_op_kwargs=kwargs['norm_op_kwargs'],
            dropout_op=kwargs.get('dropout_op'),
            dropout_op_kwargs=kwargs.get('dropout_op_kwargs'),
            nonlin=kwargs['nonlin'],
            nonlin_kwargs=kwargs['nonlin_kwargs'],
            deep_supervision=enable_deep_supervision,
            cls_stopgrad_through_encoder=True
        )

    def _load_classification_labels(self):
        try:
            df = pd.read_csv('/content/case_subtype_mapping.csv')
            col = 'case_id' if 'case_id' in df.columns else ('case' if 'case' in df.columns else None)
            assert col is not None, f"Mapping CSV must have case_id or case; got columns {df.columns.tolist()}"
            df[col] = df[col].astype(str)
            df['subtype'] = df['subtype'].astype(int)
            return dict(zip(df[col], df['subtype']))
        except Exception as e:
            print("Warning: Could not load case_subtype_mapping.csv; falling back to dummy mapping.", e)
            return {f"case_{i:04d}": i % 3 for i in range(300)}

    def get_network(self):
        return self.build_network_architecture(
            architecture_class_name=self.network_arch_class_name,
            arch_init_kwargs=self.network_arch_init_kwargs,
            arch_init_kwargs_req_import=self.network_arch_init_kwargs_req_import,
            num_input_channels=self.num_input_channels,
            num_output_channels=self.label_manager.num_segmentation_heads,
            enable_deep_supervision=self.enable_deep_supervision
        )

    def on_train_start(self):
        super().on_train_start()
        # ensure tensors are on the right device
        self.cls_class_weights = self.cls_class_weights.to(self.device)
        self.cls_log_prior = self.cls_log_prior.to(self.device)
        if isinstance(getattr(self.classification_loss, "weight", None), torch.Tensor):
            self.classification_loss.weight = self.cls_class_weights.to(self.device)

    def on_epoch_start(self):
        super().on_epoch_start()
        self._epoch_idx += 1
        self._val_cls_acc_epoch = []

    # ---------- helpers ----------
    @staticmethod
    def _nan_to_num(x, lim=1e6):
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        return x.clamp(min=-lim, max=lim)

    def _align_target_to_output(self, target, seg_logits: torch.Tensor) -> torch.Tensor:
        t = target[0] if isinstance(target, (list, tuple)) else target
        if not torch.is_tensor(t):
            t = torch.as_tensor(t)
        t = t.to(device=seg_logits.device, non_blocking=True)

        # squeeze extra singletons
        while t.ndim > seg_logits.ndim:
            reduced = False
            for dim in (1, 2, 0):
                if dim < t.ndim and t.shape[dim] == 1:
                    t = t.squeeze(dim); reduced = True; break
            if not reduced:
                for dim in range(t.ndim):
                    if t.shape[dim] == 1:
                        t = t.squeeze(dim); reduced = True; break
            if not reduced:
                break

        # add channel if missing
        if t.ndim == seg_logits.ndim - 1:
            t = t.unsqueeze(1)

        # clamp labels into valid range (defensive)
        num_classes = seg_logits.shape[1]
        t = t.long().clamp(min=0, max=max(0, num_classes - 1))

        if t.ndim != seg_logits.ndim:
            raise RuntimeError(f"Target rank {t.ndim} mismatches output rank {seg_logits.ndim}. "
                               f"target {tuple(t.shape)} vs output {tuple(seg_logits.shape)}")
        return t

    def _ensure_bc_first(self, seg_logits: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
        if seg_logits.ndim >= 2:
            B_data = data.shape[0]
            if seg_logits.shape[0] != B_data and seg_logits.shape[1] == B_data:
                perm = list(range(seg_logits.ndim))
                perm[0], perm[1] = 1, 0
                seg_logits = seg_logits.permute(*perm).contiguous()
        return seg_logits

    def _cls_targets_and_mask(self, batch):
        """
        Returns:
          cls_target: LongTensor [B] with -1 where the label is unknown
          mask:       BoolTensor [B] True where label is known
        """
        bs = batch['data'].shape[0]
        if 'keys' in batch:
            keys = list(batch['keys'])
            t = []
            for k in keys:
                t.append(self.case_to_subtype[k] if k in self.case_to_subtype else -1)
            tgt = torch.as_tensor(t, dtype=torch.long, device=self.device)
        else:
            tgt = torch.full((bs,), -1, dtype=torch.long, device=self.device)
        mask = tgt >= 0
        return tgt, mask

    def _cls_weight_schedule(self) -> float:
        e = max(0, self._epoch_idx)
        if self.cls_weight_ramp_epochs <= 0:
            return float(self.cls_weight_end)
        frac = min(1.0, e / float(self.cls_weight_ramp_epochs))
        return float(self.cls_weight_start + frac * (self.cls_weight_end - self.cls_weight_start))

    # ---------- train/val ----------
    def train_step(self, batch):
        data = batch['data'].to(self.device, non_blocking=True)

        self.optimizer.zero_grad(set_to_none=True)

        # FP32 warmup for first few steps (avoid AMP NaNs at start)
        use_amp = (self.device.type == 'cuda') and (self._train_step_count >= self._fp32_warmup_steps)
        ctx = autocast(self.device.type, enabled=use_amp) if use_amp else dummy_context()

        with ctx:
            seg_out, cls_out = self.network(data, return_both=True)
            seg_logits = seg_out[0] if isinstance(seg_out, (list, tuple)) else seg_out
            seg_logits = self._nan_to_num(self._ensure_bc_first(seg_logits, data))
            target = self._align_target_to_output(batch['target'], seg_logits)

            # sanitize cls logits too
            cls_out = self._nan_to_num(cls_out)

            seg_loss = self.loss(seg_logits, target)

            # classification (masked) + logit adjustment
            cls_target, cls_mask = self._cls_targets_and_mask(batch)
            if cls_mask.any():
                cls_logits_adj = cls_out - (self.cls_log_prior.to(cls_out.dtype))[None, :] * self.cls_logit_adj_tau
                cls_logits_adj = self._nan_to_num(cls_logits_adj)
                cls_loss = self.classification_loss(cls_logits_adj[cls_mask], cls_target[cls_mask])
                cls_acc  = (torch.argmax(cls_logits_adj[cls_mask], dim=1) == cls_target[cls_mask]).float().mean().item()
            else:
                cls_loss = torch.zeros((), device=self.device)
                cls_acc  = float('nan')

            total_loss = seg_loss + self._cls_weight_schedule() * cls_loss

        # if NaN/Inf slipped through, try a last-chance FP32 recompute
        if not torch.isfinite(total_loss):
            print("⚠️ Non-finite loss encountered. Retrying batch in FP32 fall-back.")
            with dummy_context():
                seg_out, cls_out = self.network(data, return_both=True)
                seg_logits = seg_out[0] if isinstance(seg_out, (list, tuple)) else seg_out
                seg_logits = self._nan_to_num(self._ensure_bc_first(seg_logits, data).float())
                target = self._align_target_to_output(batch['target'], seg_logits)
                cls_out = self._nan_to_num(cls_out.float())

                seg_loss = self.loss(seg_logits, target)
                cls_target, cls_mask = self._cls_targets_and_mask(batch)
                if cls_mask.any():
                    cls_logits_adj = self._nan_to_num(cls_out - (self.cls_log_prior.to(cls_out.dtype))[None, :] * self.cls_logit_adj_tau)
                    cls_loss = self.classification_loss(cls_logits_adj[cls_mask], cls_target[cls_mask])
                    cls_acc  = (torch.argmax(cls_logits_adj[cls_mask], dim=1) == cls_target[cls_mask]).float().mean().item()
                else:
                    cls_loss = torch.zeros((), device=self.device)
                    cls_acc  = float('nan')
                total_loss = seg_loss + self._cls_weight_schedule() * cls_loss

        if not torch.isfinite(total_loss):
            print("⚠️ Non-finite loss persists. Skipping batch.")
            return {'loss': np.array(float('nan'))}

        if self.grad_scaler is not None and use_amp:
            self.grad_scaler.scale(total_loss).backward()
            self.grad_scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        self._train_step_count += 1

        if not np.isnan(cls_acc):
            self.classification_metrics['train_acc'].append(cls_acc)

        return {
            'loss': total_loss.detach().cpu().numpy(),
            'seg_loss': seg_loss.detach().cpu().numpy(),
            'cls_loss': cls_loss.detach().cpu().numpy(),
            'cls_acc': cls_acc,
        }

    def validation_step(self, batch):
        data = batch['data'].to(self.device, non_blocking=True)
        with torch.no_grad():
            seg_out, cls_out = self.network(data, return_both=True)
            seg_logits = seg_out[0] if isinstance(seg_out, (list, tuple)) else seg_out
            seg_logits = self._nan_to_num(self._ensure_bc_first(seg_logits, data))
            target = self._align_target_to_output(batch['target'], seg_logits)
            target_for_metrics = target

            cls_out = self._nan_to_num(cls_out)

            seg_loss = self.loss(seg_logits, target)

            # classification (masked) + logit adjustment
            cls_target, cls_mask = self._cls_targets_and_mask(batch)
            if cls_mask.any():
                cls_logits_adj = self._nan_to_num(cls_out - (self.cls_log_prior.to(cls_out.dtype))[None, :] * self.cls_logit_adj_tau)
                cls_loss = self.classification_loss(cls_logits_adj[cls_mask], cls_target[cls_mask])
                cls_preds = torch.argmax(cls_logits_adj[cls_mask], dim=1)
                cls_acc = (cls_preds == cls_target[cls_mask]).float().mean().item()
                self.classification_metrics['val_acc'].append(cls_acc)
                self._val_cls_acc_epoch.append(cls_acc)
            else:
                cls_loss = torch.zeros((), device=self.device)
                cls_acc = float('nan')

            total_loss = seg_loss + self._cls_weight_schedule() * cls_loss

            # dice bookkeeping as in nnU-Net
            if self.label_manager.has_regions:
                predicted_segmentation_onehot = (torch.sigmoid(seg_logits) > 0.5).long()
            else:
                output_seg = seg_logits.argmax(1)[:, None]
                predicted_segmentation_onehot = torch.zeros(
                    seg_logits.shape, device=seg_logits.device, dtype=torch.float32
                )
                predicted_segmentation_onehot.scatter_(1, output_seg, 1)
                del output_seg

            if self.label_manager.has_ignore_label:
                if not self.label_manager.has_regions:
                    mask = (target_for_metrics != self.label_manager.ignore_label).float()
                    target_for_metrics = target_for_metrics.clone()
                    target_for_metrics[target_for_metrics == self.label_manager.ignore_label] = 0
                else:
                    if target_for_metrics.dtype == torch.bool:
                        mask = ~target_for_metrics[:, -1:]
                    else:
                        mask = 1 - target_for_metrics[:, -1:]
                    target_for_metrics = target_for_metrics[:, :-1]
            else:
                mask = None

            axes = [0] + list(range(2, seg_logits.ndim))
            tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot,
                                            target_for_metrics, axes=axes, mask=mask)
            tp_hard = tp.detach().cpu().numpy()
            fp_hard = fp.detach().cpu().numpy()
            fn_hard = fn.detach().cpu().numpy()
            if not self.label_manager.has_regions:
                tp_hard, fp_hard, fn_hard = tp_hard[1:], fp_hard[1:], fn_hard[1:]

        return {
            'loss': total_loss.detach().cpu().numpy(),
            'tp_hard': tp_hard,
            'fp_hard': fp_hard,
            'fn_hard': fn_hard,
            'seg_loss': seg_loss.detach().cpu().numpy(),
            'cls_loss': cls_loss.detach().cpu().numpy(),
            'cls_acc': cls_acc,
        }

    def on_epoch_end(self):
        super().on_epoch_end()
        if self.classification_metrics['train_acc']:
            train_acc = float(np.mean(self.classification_metrics['train_acc'][-min(50, len(self.classification_metrics['train_acc'])):]))
            print(f"Classification Train Acc: {train_acc:.4f}")
        if self.classification_metrics['val_acc']:
            val_acc_hist = float(np.mean(self.classification_metrics['val_acc'][-min(50, len(self.classification_metrics['val_acc'])):]))
            print(f"Classification Val Acc: {val_acc_hist:.4f}")

        if self._val_cls_acc_epoch:
            epoch_val_acc = float(np.mean(self._val_cls_acc_epoch))
            print(f"(cls) epoch val acc: {epoch_val_acc:.3f}")
            if epoch_val_acc > self.best_val_cls_acc + 1e-6:
                self.best_val_cls_acc = epoch_val_acc
                try:
                    torch.save(self.network.classification_head.state_dict(), self.best_cls_head_path)
                    print(f"  ↳ saved best classification head → {self.best_cls_head_path}")
                except Exception as e:
                    print("  ↳ failed to save best classification head:", e)

    def initialize_val_metrics(self):
        if not hasattr(self, 'val_metrics'):
            self.val_metrics = []


Writing /content/nnUNet/nnunetv2/training/nnUNetTrainer/MultiTaskTrainer.py


## Training function


In [3]:
# Training function
import sys, os, json, torch
sys.path.append('/content/nnUNet')

from nnunetv2.training.nnUNetTrainer.MultiTaskTrainer import MultiTaskTrainer
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
import json

def train_multitask_model():
    """Main training function"""

    print("=== Starting Multi-task nnUNet Training ===")

    # Parameters
    dataset_id = 500
    configuration = "3d_fullres"
    fold = 0
    plans_identifier = "nnUNetResEncUNetMPlans"

    # Paths
    dataset_name = f"Dataset{dataset_id:03d}_PancreasCancer"
    plans_path = f"/content/nnUNet_preprocessed/{dataset_name}/{plans_identifier}.json"
    dataset_json_path = f"/content/nnUNet_raw/{dataset_name}/dataset.json"

    print(f"Loading plans from: {plans_path}")
    print(f"Loading dataset.json from: {dataset_json_path}")

    # Load plans
    with open(plans_path, 'r') as f:
        plans = json.load(f)

    # Load dataset.json
    with open(dataset_json_path, 'r') as f:
        dataset_json = json.load(f)

    # Initialize trainer
    print("Initializing multi-task trainer...")
    trainer = MultiTaskTrainer(
        plans=plans,
        configuration=configuration,
        fold=fold,
        dataset_json=dataset_json,
        # unpack_dataset=True,
        device=torch.device('cuda')
    )

    # Set reduced number of epochs for Colab
    trainer.num_epochs = 250  # Reduced from default 1000

    print("Starting training...")
    trainer.run_training()

    print("Training completed!")
    return trainer

## Execute training

In [None]:
# Run this after patching MultiTaskTrainer
from importlib import reload
import nnunetv2.training.nnUNetTrainer.MultiTaskTrainer as mt
reload(mt)
from nnunetv2.training.nnUNetTrainer.MultiTaskTrainer import MultiTaskTrainer

In [None]:
try:
    # Train the model
    trainer = train_multitask_model()
    print("✅ Multi-task training completed successfully!")

except Exception as e:
    print(f"❌ Training failed with error: {e}")
    import traceback
    traceback.print_exc()

=== Starting Multi-task nnUNet Training ===
Loading plans from: /content/nnUNet_preprocessed/Dataset500_PancreasCancer/nnUNetResEncUNetMPlans.json
Loading dataset.json from: /content/nnUNet_raw/Dataset500_PancreasCancer/dataset.json
Initializing multi-task trainer...
Using device: cuda:0

#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################

Loaded 252 classification labels
Starting training...
2025-08-15 12:44:31.751993: do_dummy_2d_data_aug: False
2025-08-15 12:44:31.753348: Using splits from existing split file: /content/nnUNet_preprocessed/Dataset500_PancreasCancer/splits_final.json
2025-08-15 12:44:31.753706: The split file contains 5 spli

KeyboardInterrupt: 

## Inference

In [6]:
# === USER CONFIG ===
DATASET_ID = 500
CONFIG = "3d_fullres"
PLANS = "nnUNetResEncUNetMPlans"
TRAINER = "MultiTaskTrainer"
FOLDS = 0

# nnU-Net folder roots (adapt if you’ve set them differently)
import os
os.environ["nnUNet_raw"] = "/content/nnUNet_raw"
os.environ["nnUNet_preprocessed"] = "/content/nnUNet_preprocessed"
os.environ["nnUNet_results"] = "/content/nnUNet_results"

# I/O
IMAGES_TS = f'{os.environ["nnUNet_raw"]}/Dataset{DATASET_ID}_PancreasCancer/imagesTs'
MODEL_DIR = f'{os.environ["nnUNet_results"]}/Dataset{DATASET_ID}_PancreasCancer/{TRAINER}__{PLANS}__{CONFIG}/fold_{FOLDS}'
OUT_DIR = f'/content/submission_outputs'          # where we will save predictions
os.makedirs(OUT_DIR, exist_ok=True)

# Choose checkpoint
CHECKPOINT_NAME = "checkpoint_best.pth"           # fallback to "checkpoint_final.pth" if not present

# Quick sanity checks
print("Model dir:", MODEL_DIR)
print("Has checkpoint?", os.path.exists(os.path.join(MODEL_DIR, CHECKPOINT_NAME)))
print("imagesTs count:", len(glob.glob(os.path.join(IMAGES_TS, "*.nii.gz"))))

Model dir: /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_0
Has checkpoint? True
imagesTs count: 72


In [8]:
# Initialize the nnUNetv2 predictor

import os, shutil, glob
from pathlib import Path
import torch
from inspect import signature
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 1) Auto-detect the correct model ROOT directory (where dataset.json lives)
results_root = os.environ.get("nnUNet_results", "/content/nnUNet_results")
ds_root = f"{results_root}/Dataset{DATASET_ID}_PancreasCancer"

candidates = [
    f"{ds_root}/{CONFIG}/{TRAINER}__{PLANS}",        # e.g., .../3d_fullres/MultiTaskTrainer__Plans
    f"{ds_root}/{TRAINER}__{PLANS}__{CONFIG}",       # e.g., .../MultiTaskTrainer__Plans__3d_fullres  (your case)
    f"{ds_root}/{TRAINER}__{PLANS}",                 # fallback
]

MODEL_DIR_EFF = None
for c in candidates:
    if os.path.isfile(os.path.join(c, "dataset.json")) and os.path.isfile(os.path.join(c, "plans.json")):
        MODEL_DIR_EFF = c
        break

if MODEL_DIR_EFF is None:
    raise FileNotFoundError(
        "Could not find model dir containing dataset.json & plans.json.\n"
        f"Tried:\n- " + "\n- ".join(candidates)
    )

print("[OK] Model root:", MODEL_DIR_EFF)

# --- 2) Ensure each requested fold has dataset.json/plans.json (some nnUNetv2 builds expect them inside fold_X)
folds_to_use = tuple(FOLDS) if isinstance(FOLDS, (list, tuple)) else (FOLDS,)
for f in folds_to_use:
    fold_dir = os.path.join(MODEL_DIR_EFF, f"fold_{f}")
    if not os.path.isdir(fold_dir):
        raise FileNotFoundError(f"Missing fold directory: {fold_dir}")
    for fname in ("dataset.json", "plans.json"):
        src = os.path.join(MODEL_DIR_EFF, fname)
        dst = os.path.join(fold_dir, fname)
        if os.path.isfile(src) and not os.path.isfile(dst):
            shutil.copy2(src, dst)
            print(f"  ↳ mirrored {fname} → {fold_dir}")

# Optional: verify checkpoint exists in the fold
for f in folds_to_use:
    ck1 = os.path.join(MODEL_DIR_EFF, f"fold_{f}", CHECKPOINT_NAME)
    ck2 = os.path.join(MODEL_DIR_EFF, f"fold_{f}", "checkpoint_final.pth")
    if not (os.path.exists(ck1) or os.path.exists(ck2)):
        print(f"⚠️  No checkpoint found for fold {f}: looked for {CHECKPOINT_NAME} or checkpoint_final.pth")

# --- 3) Build predictor with only supported __init__ kwargs (older/newer compat)
_ctor_params = signature(nnUNetPredictor.__init__).parameters
ctor_kwargs = {
    "tile_step_size": 0.5,
    "use_gaussian": True,
    "use_mirroring": False,
    "device": device,
    "verbose": True,
    "verbose_preprocessing": False,
    "allow_tqdm": True,
}
ctor_kwargs = {k: v for k, v in ctor_kwargs.items() if k in _ctor_params}
predictor = nnUNetPredictor(**ctor_kwargs)

# --- 4) Initialize from trained model folder, trying several signatures (no 'strict' in older builds)
init_tried = False
try:
    predictor.initialize_from_trained_model_folder(
        model_training_output_dir=MODEL_DIR_EFF,
        use_folds=folds_to_use,
        checkpoint_name=CHECKPOINT_NAME,
    )
    init_tried = True
except TypeError:
    pass

if not init_tried:
    try:
        predictor.initialize_from_trained_model_folder(
            model_training_output_dir=MODEL_DIR_EFF,
            use_folds=folds_to_use,
        )
        init_tried = True
    except TypeError:
        pass

if not init_tried:
    try:
        predictor.initialize_from_trained_model_folder(MODEL_DIR_EFF, folds_to_use)
        init_tried = True
    except TypeError:
        predictor.initialize_from_trained_model_folder(MODEL_DIR_EFF)
        init_tried = True

# Set checkpoint on attribute if available (older versions ignore in init)
if hasattr(predictor, "checkpoint_name"):
    predictor.checkpoint_name = CHECKPOINT_NAME

print("✅ Predictor ready on", device)
print("   Using model dir:", MODEL_DIR_EFF)
print("   Using folds:", folds_to_use)

[OK] Model root: /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres
  ↳ mirrored dataset.json → /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_0
  ↳ mirrored plans.json → /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_0
✅ Predictor ready on cuda
   Using model dir: /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres
   Using folds: (0,)


In [None]:
# STEP 3 — Version-agnostic inference (GPU sliding window + safe-padded classification)
import os, glob, math, shutil
from pathlib import Path
import numpy as np
import nibabel as nib
import torch
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd

# ---- setup
device = predictor.device if hasattr(predictor, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = predictor.network.to(device)
net.eval()

# I/O
test_files = sorted(glob.glob(os.path.join(IMAGES_TS, "*.nii.gz")))
assert test_files, f"No test files in {IMAGES_TS}"
SEG_DIR = os.path.join(OUT_DIR, "segmentations")
Path(SEG_DIR).mkdir(parents=True, exist_ok=True)

# ---- helpers
def _case_name_for_submission(img_path: str) -> str:
    return Path(img_path).name.replace("_0000", "")

def _clip_zscore(x, p_low=0.5, p_high=99.5, eps=1e-6):
    lo, hi = np.percentile(x, [p_low, p_high])
    x = np.clip(x, lo, hi)
    m, s = x.mean(), x.std()
    return ((x - m) / (s + eps)).astype(np.float32)

def _pad_to_min_shape(arr, min_shape):
    pads = []
    for a, m in zip(arr.shape, min_shape):
        add = max(0, m - a)
        pl = add // 2
        pr = add - pl
        pads.append((pl, pr))
    arr2 = np.pad(arr, pads, mode='constant', constant_values=0)
    slc = tuple(slice(p[0], p[0] + arr.shape[i]) for i, p in enumerate(pads))
    return arr2, slc

def _gen_tiles(shape, patch, stride):
    D, H, W = shape
    Pd, Ph, Pw = patch
    Sd, Sh, Sw = stride
    z_starts = list(range(0, max(1, D - Pd + 1), Sd))
    y_starts = list(range(0, max(1, H - Ph + 1), Sh))
    x_starts = list(range(0, max(1, W - Pw + 1), Sw))
    if z_starts[-1] != max(0, D - Pd): z_starts.append(max(0, D - Pd))
    if y_starts[-1] != max(0, H - Ph): y_starts.append(max(0, H - Ph))
    if x_starts[-1] != max(0, W - Pw): x_starts.append(max(0, W - Pw))
    for z in z_starts:
        for y in y_starts:
            for x in x_starts:
                yield z, y, x

def _gaussian_weight(patch):
    zz = torch.linspace(-1, 1, steps=patch[0], dtype=torch.float32)[:, None, None]
    yy = torch.linspace(-1, 1, steps=patch[1], dtype=torch.float32)[None, :, None]
    xx = torch.linspace(-1, 1, steps=patch[2], dtype=torch.float32)[None, None, :]
    g = torch.exp(-0.5 * (zz**2 + yy**2 + xx**2))
    g /= g.max()
    return g

def _get_divisors_from_plans(model_dir_eff, fallback=(32, 32, 32)):
    try:
        from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
        pm = PlansManager(os.path.join(model_dir_eff, "plans.json"))
        stg = pm.get_stage_from_scale_factor(1.0)
        props = pm.get_properties_of_stage(stg)
        if "num_pool_per_axis" in props:
            npp = props["num_pool_per_axis"]
            return tuple(int(2 ** int(v)) for v in npp)
        if "pool_op_kernel_sizes" in props:
            ks = props["pool_op_kernel_sizes"]
            acc = np.array([1, 1, 1], dtype=int)
            for s in ks:
                acc *= np.array(s, dtype=int)
            return tuple(int(v) for v in acc)
    except Exception:
        pass
    return fallback

def _pad_to_multiples_torch(vol_5d, multiples):  # vol_5d: [B,C,D,H,W] float32 on device
    _, _, D, H, W = vol_5d.shape
    md, mh, mw = multiples
    pad_d = (md - (D % md)) % md
    pad_h = (mh - (H % mh)) % mh
    pad_w = (mw - (W % mw)) % mw
    # F.pad order: (wL, wR, hL, hR, dL, dR)
    pad = (0, pad_w, 0, pad_h, 0, pad_d)
    if any(pad):
        vol_5d = F.pad(vol_5d, pad, mode="constant", value=0.0)
    return vol_5d

# ---- patch/stride from plans or fallback
try:
    from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
    pl = PlansManager(os.path.join(MODEL_DIR_EFF, "plans.json"))
    stg = pl.get_stage_from_scale_factor(1.0)
    PATCH = tuple(pl.get_properties_of_stage(stg)['patch_size'])
except Exception:
    PATCH = (80, 160, 160)
STRIDE = tuple(max(1, int(p * 0.5)) for p in PATCH)

gauss_w = _gaussian_weight(PATCH).to(device)

# ---- main loop
cls_rows = []

for img_path in tqdm(test_files, desc="Predicting"):
    # load & normalize
    img_nii = nib.load(img_path)
    img = img_nii.get_fdata().astype(np.float32)         # [D,H,W]
    img_n = _clip_zscore(img)

    # pad for sliding-window
    img_pad, undo_pad = _pad_to_min_shape(img_n, PATCH)
    D, H, W = img_pad.shape

    # probe for num classes
    with torch.no_grad():
        probe_np = img_pad[:min(PATCH[0], D), :min(PATCH[1], H), :min(PATCH[2], W)]
        probe = torch.from_numpy(probe_np[None, None]).to(device=device, dtype=torch.float32)
        out_probe = net(probe)
        if isinstance(out_probe, (list, tuple)):
            out_probe = out_probe[0]
        C = out_probe.shape[1]

    # accumulators
    logits_acc = torch.zeros((C, D, H, W), device=device, dtype=torch.float32)
    weight_acc = torch.zeros((1, D, H, W), device=device, dtype=torch.float32)

    # sliding-window on GPU with Gaussian blending
    for z, y, x in _gen_tiles((D, H, W), PATCH, STRIDE):
        patch_np = img_pad[z:z+PATCH[0], y:y+PATCH[1], x:x+PATCH[2]]
        patch_t = torch.from_numpy(patch_np[None, None]).to(device=device, dtype=torch.float32)
        with torch.no_grad():
            out = net(patch_t)
            seg_logits = out[0] if isinstance(out, (list, tuple)) else out  # [1,C,Pd,Ph,Pw]
            seg_logits = seg_logits.to(dtype=torch.float32)
        w = gauss_w
        logits_acc[:, z:z+PATCH[0], y:y+PATCH[1], x:x+PATCH[2]] += (seg_logits[0] * w)
        weight_acc[:, z:z+PATCH[0], y:y+PATCH[1], x:x+PATCH[2]] += w

    # normalize, crop back, argmax on GPU
    logits_acc = logits_acc / weight_acc.clamp_min(1e-6)
    zsl, ysl, xsl = undo_pad
    logits_acc = logits_acc[:, zsl, ysl, xsl]
    seg_pred = torch.argmax(logits_acc, dim=0).to(torch.uint8)
    seg_np = seg_pred.detach().cpu().numpy()

    # save segmentation
    out_name = _case_name_for_submission(img_path)
    nib.save(nib.Nifti1Image(seg_np.astype(np.uint8), img_nii.affine, img_nii.header),
             os.path.join(SEG_DIR, out_name))

    # ---- classification: full-volume pass with safe padding (global; no need to unpad)
    divisors = _get_divisors_from_plans(MODEL_DIR_EFF, fallback=(32, 32, 32))
    vol_np = _clip_zscore(img).astype(np.float32)
    vol_t = torch.from_numpy(vol_np[None, None]).to(device=device, dtype=torch.float32)  # [1,1,D,H,W]
    vol_t = _pad_to_multiples_torch(vol_t, divisors)

    with torch.no_grad():
        seg_out_full, cls_logits = net(vol_t, return_both=True)
        subtype = int(torch.argmax(cls_logits, dim=1).item())

    cls_rows.append({"Names": out_name, "Subtype": subtype})

# write classification CSV
CSV_PATH = os.path.join(OUT_DIR, "subtype_results.csv")
pd.DataFrame(cls_rows).to_csv(CSV_PATH, index=False)

print("Saved segmentations to:", SEG_DIR)
print("Saved classification to:", CSV_PATH)


Predicting: 100%|██████████| 72/72 [06:32<00:00,  5.45s/it]

Saved segmentations to: /content/submission_outputs/segmentations
Saved classification to: /content/submission_outputs/subtype_results.csv





## Visualize segmentation prediction

In [None]:
# DIAG + VIS: verify files, then show & save overlays for up to 2 cases
import os, glob
from pathlib import Path
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

SEG_DIR = os.path.join(OUT_DIR, "segmentations")
seg_files = sorted(glob.glob(os.path.join(SEG_DIR, "*.nii.gz")))
ct_files  = sorted(glob.glob(os.path.join(IMAGES_TS, "*.nii.gz")))

print(f"Found {len(seg_files)} segmentations in {SEG_DIR}")
print(f"Found {len(ct_files)} CTs in {IMAGES_TS}")

# Helper: map seg name -> expected CT path
def ct_for_seg(seg_name: str):
    base = seg_name.replace(".nii.gz", "")
    cand1 = os.path.join(IMAGES_TS, f"{base}_0000.nii.gz")
    cand2 = os.path.join(IMAGES_TS, f"{base}.nii.gz")
    if os.path.exists(cand1): return cand1
    if os.path.exists(cand2): return cand2
    return None

def pick_slice_z(seg_3d):
    # prioritize lesion (2), then pancreas (1); otherwise middle slice
    for label in (2, 1):
        idx = np.argwhere(seg_3d == label)
        if idx.size > 0:
            # idx columns: x,y,z for (X,Y,Z) layout
            z = int(np.median(idx[:, 2]))
            return z
    return seg_3d.shape[2] // 2

def show_case(ct_path, seg_path, save_png_path=None):
    ct_nii = nib.load(ct_path); ct = ct_nii.get_fdata().astype(np.float32)   # shape (X,Y,Z)
    seg_nii = nib.load(seg_path); seg = seg_nii.get_fdata().astype(np.uint8) # shape (X,Y,Z)
    assert ct.shape[:3] == seg.shape[:3], f"Shape mismatch: CT {ct.shape} vs SEG {seg.shape}"

    z = pick_slice_z(seg)
    ct_slice  = ct[:, :, z]
    seg_slice = seg[:, :, z]

    # robust window for CT
    lo, hi = np.percentile(ct_slice, [1, 99])
    ct_disp = np.clip(ct_slice, lo, hi)
    ct_disp = (ct_disp - ct_disp.min()) / (ct_disp.max() - ct_disp.min() + 1e-6)

    # overlay: R=pancreas(1), G=lesion(2)
    overlay = np.zeros((ct_disp.shape[0], ct_disp.shape[1], 3), dtype=np.float32)
    overlay[..., 0] = (seg_slice == 1).astype(np.float32) * 0.8
    overlay[..., 1] = (seg_slice == 2).astype(np.float32) * 0.9

    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(ct_disp, cmap='gray'); axes[0].set_title('CT (axial)'); axes[0].axis('off')
    axes[1].imshow(seg_slice, interpolation='nearest'); axes[1].set_title('Seg labels (0/1/2)'); axes[1].axis('off')
    axes[2].imshow(ct_disp, cmap='gray'); axes[2].imshow(overlay, alpha=0.35); axes[2].set_title('Overlay'); axes[2].axis('off')
    fig.suptitle(f"{Path(seg_path).name} @ z={z}", y=1.02)
    plt.tight_layout()
    plt.show()

    if save_png_path:
        fig.savefig(save_png_path, dpi=120, bbox_inches='tight')
        print("Saved preview →", save_png_path)

# Pick which segmentations to show (first two by default)
to_show = [Path(p).name for p in seg_files[:2]]
if not to_show:
    raise RuntimeError(f"No segmentation files found in {SEG_DIR}. Did Step 3 finish successfully?")

PREVIEW_DIR = os.path.join(OUT_DIR, "previews")
os.makedirs(PREVIEW_DIR, exist_ok=True)

for name in to_show:
    ct_path = ct_for_seg(name)
    if ct_path is None:
        print(f"⚠️ Could not find CT for {name} in imagesTs (tried *_0000.nii.gz and same name). Skipping.")
        continue
    seg_path = os.path.join(SEG_DIR, name)
    png_out  = os.path.join(PREVIEW_DIR, f"{Path(name).stem}.png")
    print(f"Visualizing: seg={seg_path}  ct={ct_path}")
    show_case(ct_path, seg_path, save_png_path=png_out)

print("Previews in:", PREVIEW_DIR)


Found 72 segmentations in /content/submission_outputs/segmentations
Found 72 CTs in /content/nnUNet_raw/Dataset500_PancreasCancer/imagesTs
Visualizing: seg=/content/submission_outputs/segmentations/quiz_037.nii.gz  ct=/content/nnUNet_raw/Dataset500_PancreasCancer/imagesTs/quiz_037_0000.nii.gz
Saved preview → /content/submission_outputs/previews/quiz_037.nii.png
Visualizing: seg=/content/submission_outputs/segmentations/quiz_045.nii.gz  ct=/content/nnUNet_raw/Dataset500_PancreasCancer/imagesTs/quiz_045_0000.nii.gz
Saved preview → /content/submission_outputs/previews/quiz_045.nii.png
Previews in: /content/submission_outputs/previews


## Evaluate segmentation with sensible metrics (Dice + boundary)

In [None]:
# pip once (safe to re-run)
!pip -q install nibabel surface-distance pandas numpy scipy

In [None]:
# ==== EVAL for folder layout: data/validation/subtype{0,1,2} ====
# Edit this path to where your tree lives:
VALID_ROOT = "/content/drive/MyDrive/ML-Quiz-3DMedImg/data/validation"   # e.g. /content/data/validation/subtype0/...

import os, glob, math
from pathlib import Path
import numpy as np
import nibabel as nib
import torch
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm

assert os.path.isdir(VALID_ROOT), f"Not found: {VALID_ROOT}"

# ----- collect (case_id, img_path, gt_path, subtype_gt) from validation/subtype{0,1,2} -----
pairs = []
for subname, sublab in (("subtype0", 0), ("subtype1", 1), ("subtype2", 2)):
    subdir = os.path.join(VALID_ROOT, subname)
    if not os.path.isdir(subdir):
        continue
    # images are *_0000.nii.gz; masks are the same stem without _0000
    for ip in sorted(glob.glob(os.path.join(subdir, "*_0000.nii.gz"))):
        stem = Path(ip).name.replace("_0000.nii.gz", "")
        gp = os.path.join(subdir, f"{stem}.nii.gz")
        if os.path.isfile(gp):
            pairs.append((stem, ip, gp, sublab))
        else:
            print(f"⚠️ GT missing for {ip}, expected {gp}")

assert len(pairs) > 0, f"No (image,label) pairs found under {VALID_ROOT}/subtype{{0,1,2}}"
print(f"[Eval] Found {len(pairs)} cases under {VALID_ROOT}")

# ----- setup network -----
device = predictor.device if hasattr(predictor, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = predictor.network.to(device)
net.eval()

# patch/stride from plans or safe fallback
try:
    from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
    pl = PlansManager(os.path.join(MODEL_DIR_EFF, "plans.json"))
    stg = pl.get_stage_from_scale_factor(1.0)
    PATCH = tuple(pl.get_properties_of_stage(stg)['patch_size'])
except Exception:
    PATCH = (80, 160, 160)
STRIDE = tuple(max(1, int(p * 0.5)) for p in PATCH)

# ----- helpers (same as we used for inference) -----
def _clip_zscore(x, p_low=0.5, p_high=99.5, eps=1e-6):
    lo, hi = np.percentile(x, [p_low, p_high])
    x = np.clip(x, lo, hi)
    m, s = x.mean(), x.std()
    return ((x - m) / (s + eps)).astype(np.float32)

def _pad_to_min_shape(arr, min_shape):
    pads=[];
    for a, m in zip(arr.shape, min_shape):
        add = max(0, m - a); pl = add // 2; pr = add - pl
        pads.append((pl, pr))
    arr2 = np.pad(arr, pads, mode='constant', constant_values=0)
    slc = tuple(slice(p[0], p[0] + arr.shape[i]) for i,p in enumerate(pads))
    return arr2, slc

def _gen_tiles(shape, patch, stride):
    D,H,W = shape; Pd,Ph,Pw = patch; Sd,Sh,Sw = stride
    z_starts = list(range(0, max(1, D - Pd + 1), Sd))
    y_starts = list(range(0, max(1, H - Ph + 1), Sh))
    x_starts = list(range(0, max(1, W - Pw + 1), Sw))
    if z_starts[-1] != max(0, D-Pd): z_starts.append(max(0, D-Pd))
    if y_starts[-1] != max(0, H-Ph): y_starts.append(max(0, H-Ph))
    if x_starts[-1] != max(0, W-Pw): x_starts.append(max(0, W-Pw))
    for z in z_starts:
        for y in y_starts:
            for x in x_starts:
                yield z,y,x

def _gaussian_weight(patch):
    zz = torch.linspace(-1,1,steps=patch[0], dtype=torch.float32)[:,None,None]
    yy = torch.linspace(-1,1,steps=patch[1], dtype=torch.float32)[None,:,None]
    xx = torch.linspace(-1,1,steps=patch[2], dtype=torch.float32)[None,None,:]
    g = torch.exp(-0.5*(zz**2 + yy**2 + xx**2)); g /= g.max()
    return g

gauss_w = _gaussian_weight(PATCH).to(device)

def predict_segmentation(img_3d_float32):
    img_n = _clip_zscore(img_3d_float32)
    img_pad, undo = _pad_to_min_shape(img_n, PATCH)
    D,H,W = img_pad.shape

    # probe for C
    with torch.no_grad():
        probe = torch.from_numpy(img_pad[:min(PATCH[0],D), :min(PATCH[1],H), :min(PATCH[2],W)][None,None]).to(device)
        out_probe = net(probe)
        if isinstance(out_probe, (list, tuple)): out_probe = out_probe[0]
        C = out_probe.shape[1]

    logits_acc = torch.zeros((C,D,H,W), device=device, dtype=torch.float32)
    weight_acc = torch.zeros((1,D,H,W), device=device, dtype=torch.float32)

    for z,y,x in _gen_tiles((D,H,W), PATCH, STRIDE):
        patch = torch.from_numpy(img_pad[z:z+PATCH[0], y:y+PATCH[1], x:x+PATCH[2]][None,None]).to(device)
        with torch.no_grad():
            out = net(patch)
            seg_logits = out[0] if isinstance(out, (list, tuple)) else out
            seg_logits = seg_logits.to(dtype=torch.float32)
        w = gauss_w
        logits_acc[:, z:z+PATCH[0], y:y+PATCH[1], x:x+PATCH[2]] += seg_logits[0]*w
        weight_acc[:, z:z+PATCH[0], y:y+PATCH[1], x:x+PATCH[2]] += w

    logits_acc = logits_acc / weight_acc.clamp_min(1e-6)
    zsl, ysl, xsl = undo
    logits_acc = logits_acc[:, zsl, ysl, xsl]
    pred = torch.argmax(logits_acc, dim=0).to(torch.uint8).detach().cpu().numpy()
    return pred

def _get_divisors_from_plans(model_dir_eff, fallback=(32,32,32)):
    try:
        from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
        pm = PlansManager(os.path.join(model_dir_eff, "plans.json"))
        stg = pm.get_stage_from_scale_factor(1.0)
        props = pm.get_properties_of_stage(stg)
        if "num_pool_per_axis" in props:
            npp = props["num_pool_per_axis"]
            return tuple(int(2**int(v)) for v in npp)
        if "pool_op_kernel_sizes" in props:
            ks = props["pool_op_kernel_sizes"]
            acc = np.array([1,1,1], dtype=int)
            for s in ks: acc *= np.array(s, dtype=int)
            return tuple(int(v) for v in acc)
    except Exception:
        pass
    return fallback

def _pad_to_multiples_torch(vol_5d, multiples):
    _,_,D,H,W = vol_5d.shape
    md,mh,mw = multiples
    pd = (md - (D % md)) % md
    ph = (mh - (H % mh)) % mh
    pw = (mw - (W % mw)) % mw
    pad = (0,pw, 0,ph, 0,pd)  # (wL,wR,hL,hR,dL,dR)
    return F.pad(vol_5d, pad, mode="constant", value=0.0) if any(pad) else vol_5d

divisors = _get_divisors_from_plans(MODEL_DIR_EFF, fallback=(32,32,32))

def predict_subtype(img_3d_float32):
    vol = _clip_zscore(img_3d_float32).astype(np.float32)
    vol_t = torch.from_numpy(vol[None,None]).to(device=device, dtype=torch.float32)
    vol_t = _pad_to_multiples_torch(vol_t, divisors)
    with torch.no_grad():
        seg_out, cls_logits = net(vol_t, return_both=True)
        subtype = int(torch.argmax(cls_logits, dim=1).item())
    return subtype

# ----- metrics -----
def dice_bin(pred, gt):
    i = (pred & gt).sum()
    u = pred.sum() + gt.sum()
    return (2.0 * i) / (u + 1e-8)

def compute_case_metrics(pred_lbl, gt_lbl):
    pred_pan = (pred_lbl > 0).astype(np.uint8)
    gt_pan   = (gt_lbl > 0).astype(np.uint8)
    dice_pan = dice_bin(pred_pan, gt_pan)
    pred_les = (pred_lbl == 2).astype(np.uint8)
    gt_les   = (gt_lbl == 2).astype(np.uint8)
    dice_les = dice_bin(pred_les, gt_les)
    return float(dice_pan), float(dice_les)

def macro_f1_from_cm(cm):
    K = cm.shape[0]
    f1s=[]
    for k in range(K):
        tp = cm[k,k]
        fp = cm[:,k].sum() - tp
        fn = cm[k,:].sum() - tp
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        f1s.append(2*prec*rec / (prec+rec+1e-8))
    return float(np.mean(f1s))

# ----- run evaluation -----
rows=[]; y_true=[]; y_pred=[]
for cid, ipath, gpath, subtype_gt in tqdm(pairs, desc="Evaluating"):
    img = nib.load(ipath).get_fdata().astype(np.float32)
    gt  = nib.load(gpath).get_fdata().astype(np.uint8)

    pred = predict_segmentation(img)
    dice_pan, dice_les = compute_case_metrics(pred, gt)

    subtype_pred = predict_subtype(img)

    rows.append({
        "case": cid,
        "dice_pancreas": dice_pan,
        "dice_lesion": dice_les,
        "subtype_pred": subtype_pred,
        "subtype_gt": subtype_gt,
    })
    y_true.append(int(subtype_gt))
    y_pred.append(int(subtype_pred))

# ----- write results -----
EVAL_DIR = os.path.join(OUT_DIR, "eval_val")
os.makedirs(EVAL_DIR, exist_ok=True)

per_case_csv = os.path.join(EVAL_DIR, "validation_per_case.csv")
pd.DataFrame(rows).to_csv(per_case_csv, index=False)

dice_pan_mean = float(np.mean([r["dice_pancreas"] for r in rows]))
dice_les_mean = float(np.mean([r["dice_lesion"] for r in rows]))

summary = {
    "dice_pancreas_mean": dice_pan_mean,
    "dice_lesion_mean": dice_les_mean,
}

from sklearn.metrics import accuracy_score, confusion_matrix
acc = float(accuracy_score(y_true, y_pred))
cm  = confusion_matrix(y_true, y_pred, labels=[0,1,2])
f1m = macro_f1_from_cm(cm.astype(np.float64))
summary.update({
    "cls_accuracy": acc,
    "cls_macro_f1": f1m,
})
pd.DataFrame([summary]).to_csv(os.path.join(EVAL_DIR, "validation_summary.csv"), index=False)

cm_df = pd.DataFrame(cm, index=[f"gt_{k}" for k in [0,1,2]],
                        columns=[f"pred_{k}" for k in [0,1,2]])
cm_df.to_csv(os.path.join(EVAL_DIR, "cls_confusion_matrix.csv"))

print(f"\n[Seg] Mean Dice — Pancreas: {dice_pan_mean:.4f} | Lesion: {dice_les_mean:.4f}")
print(f"[Cls] Acc: {acc:.4f} | Macro-F1: {f1m:.4f}")
print("[Eval] Per-case →", per_case_csv)
print("[Eval] Summary  →", os.path.join(EVAL_DIR, "validation_summary.csv"))
print("[Eval] ConfMat  →", os.path.join(EVAL_DIR, "cls_confusion_matrix.csv"))


[Eval] Found 36 cases under /content/drive/MyDrive/ML-Quiz-3DMedImg/data/validation


Evaluating: 100%|██████████| 36/36 [02:15<00:00,  3.76s/it]


[Seg] Mean Dice — Pancreas: 0.4587 | Lesion: 0.1371
[Cls] Acc: 0.3333 | Macro-F1: 0.1667
[Eval] Per-case → /content/submission_outputs/eval_val/validation_per_case.csv
[Eval] Summary  → /content/submission_outputs/eval_val/validation_summary.csv
[Eval] ConfMat  → /content/submission_outputs/eval_val/cls_confusion_matrix.csv





## Fine-tune

In [14]:
%%writefile /content/nnUNet/nnunetv2/training/nnUNetTrainer/MultiTaskTrainerFT.py
"""
Fine-tuning trainer for the multi-task nnUNetv2 model.

Key changes vs. MultiTaskTrainer:
  • enables gradient flow from the classification head into the encoder
  • increases classification loss weight and ramps it faster
  • (optional) freezes the segmentation decoder to protect seg quality
  • short, low-LR run suitable for quick fine-tuning
  • loads a pretrained checkpoint from env var: NNUNET_PRETRAINED
"""

import os
import torch
from importlib import import_module

from .MultiTaskTrainer import MultiTaskTrainer, MultiTaskResEnc


def _maybe_get_state_dict(ckpt):
    """Accept common nnU-Net/torch checkpoint formats and return a state_dict."""
    if isinstance(ckpt, dict):
        for k in ("network_state_dict", "state_dict", "network_weights", "model"):
            if k in ckpt and isinstance(ckpt[k], dict):
                return ckpt[k]
    return ckpt  # assume it's already a state_dict


class MultiTaskTrainerFT(MultiTaskTrainer):
    """
    Fine-tune variant for classification improvement with minimal seg drift.
    """
    def __init__(self, plans, configuration, fold, dataset_json, device=torch.device('cuda')):
        super().__init__(plans, configuration, fold, dataset_json, device)

        # ↑ Stronger/faster emphasis on classification during FT
        self.cls_weight_start = 0.20
        self.cls_weight_end = 0.80
        self.cls_weight_ramp_epochs = 10

        # Short run (cap epochs if those attrs exist)
        for a in ("num_epochs", "max_num_epochs"):
            if hasattr(self, a):
                setattr(self, a, min(int(getattr(self, a)), 15))

        # Safer base LR
        if hasattr(self, "initial_lr"):
            self.initial_lr = min(float(self.initial_lr), 5e-3)

        # Path to pretrained weights (set via env var before launching training)
        self.pretrained_path = os.environ.get("NNUNET_PRETRAINED", "").strip()

    @staticmethod
    def build_network_architecture(architecture_class_name: str,
                                   arch_init_kwargs: dict,
                                   arch_init_kwargs_req_import,
                                   num_input_channels: int,
                                   num_output_channels: int,
                                   enable_deep_supervision: bool = True):
        """
        Same as parent, but force cls_stopgrad_through_encoder=False so the
        classifier can adapt encoder features during fine-tuning.
        """
        kwargs = dict(arch_init_kwargs)
        if arch_init_kwargs_req_import:
            for k in arch_init_kwargs_req_import:
                v = kwargs.get(k)
                if isinstance(v, str):
                    mod, attr = v.rsplit('.', 1)
                    kwargs[k] = getattr(import_module(mod), attr)

        return MultiTaskResEnc(
            input_channels=num_input_channels,
            n_stages=kwargs['n_stages'],
            features_per_stage=kwargs['features_per_stage'],
            conv_op=kwargs['conv_op'],
            kernel_sizes=kwargs['kernel_sizes'],
            strides=kwargs['strides'],
            n_blocks_per_stage=kwargs['n_blocks_per_stage'],
            num_segmentation_classes=num_output_channels,
            num_classification_classes=3,
            n_conv_per_stage_decoder=kwargs['n_conv_per_stage_decoder'],
            conv_bias=kwargs['conv_bias'],
            norm_op=kwargs['norm_op'],
            norm_op_kwargs=kwargs['norm_op_kwargs'],
            dropout_op=kwargs.get('dropout_op'),
            dropout_op_kwargs=kwargs.get('dropout_op_kwargs'),
            nonlin=kwargs['nonlin'],
            nonlin_kwargs=kwargs['nonlin_kwargs'],
            deep_supervision=enable_deep_supervision,
            cls_stopgrad_through_encoder=False,   # ← key FT change
        )

    def on_train_start(self):
        """
        Freeze seg decoder (optional), load pretrained weights, then call parent
        to set up optimizer etc. Finally, reduce LR in the optimizer param groups.
        """
        # 1) Freeze seg decoder to protect segmentation quality during FT
        if hasattr(self.network, "segmentation_net") and hasattr(self.network.segmentation_net, "decoder"):
            for p in self.network.segmentation_net.decoder.parameters():
                p.requires_grad = False

        # 2) Load pretrained weights if a path is provided
        if self.pretrained_path and os.path.isfile(self.pretrained_path):
            try:
                ckpt = torch.load(self.pretrained_path, map_location=self.device)
                sd = _maybe_get_state_dict(ckpt)
                result = self.network.load_state_dict(sd, strict=False)
                # Be robust to both tuple and LoadStateDictResult
                missing = getattr(result, "missing_keys", [])
                unexpected = getattr(result, "unexpected_keys", [])
                print(f"[FT] Loaded pretrained weights from: {self.pretrained_path}")
                if missing:
                    print("[FT]  missing keys (truncated):", missing[:6], "...")
                if unexpected:
                    print("[FT]  unexpected keys (truncated):", unexpected[:6], "...")
            except Exception as e:
                print("[FT] WARNING: failed to load pretrained weights:", e)

        # 3) Continue with standard nnU-Net setup
        super().on_train_start()

        # 4) Lower LR for stability
        if getattr(self, "optimizer", None) is not None:
            for g in self.optimizer.param_groups:
                g['lr'] = min(g.get('lr', getattr(self, "initial_lr", 5e-3)), 5e-3)

        trainable = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
        print(f"[FT] Trainable params: {trainable/1e6:.2f}M (decoder frozen)")


Writing /content/nnUNet/nnunetv2/training/nnUNetTrainer/MultiTaskTrainerFT.py


In [11]:
# Fill these according to your run
DATASET_ID = 500
FOLD = 0

MODEL_DIR = f"/content/nnUNet_results/Dataset{DATASET_ID}_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_{FOLD}"
PRETRAIN = os.path.join(MODEL_DIR, "checkpoint_best.pth")  # or checkpoint_final.pth

# New output root (keeps FT separate)
OUT_ROOT_FT = f"/content/nnUNet_results/Dataset{DATASET_ID}_PancreasCancer/FT_MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres"

print("Pretrain:", PRETRAIN)
print("FT outputs →", OUT_ROOT_FT)

Pretrain: /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_0/checkpoint_best.pth
FT outputs → /content/nnUNet_results/Dataset500_PancreasCancer/FT_MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres


In [15]:
# set your dataset / fold and the checkpoint you want to start from
DATASET_ID = 500
FOLD = 0
PRETRAIN = f"/content/nnUNet_results/Dataset{DATASET_ID}_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_{FOLD}/checkpoint_best.pth"

import os
assert os.path.isfile(PRETRAIN), f"Checkpoint not found: {PRETRAIN}"
os.environ["NNUNET_PRETRAINED"] = PRETRAIN
print("NNUNET_PRETRAINED =", os.environ["NNUNET_PRETRAINED"])

# launch FT (no --pretrained_weights; the trainer loads env var internally)
!nnUNetv2_train {DATASET_ID} 3d_fullres {FOLD} -tr MultiTaskTrainerFT -p nnUNetResEncUNetMPlans --npz


NNUNET_PRETRAINED = /content/nnUNet_results/Dataset500_PancreasCancer/MultiTaskTrainer__nnUNetResEncUNetMPlans__3d_fullres/fold_0/checkpoint_best.pth
Using device: cuda:0

#######################################################################
Please cite the following paper when using nnU-Net:
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
#######################################################################

Loaded 201 classification labels
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in 

In [16]:
# === CLS SANITY CHECK on validation/subtype{0,1,2} ===
# Edit if needed:
VALID_ROOT = "/content/drive/MyDrive/ML-Quiz-3DMedImg/data/validation"

import os, glob, json
import numpy as np, nibabel as nib, pandas as pd
import torch, torch.nn.functional as F
from pathlib import Path
from sklearn.metrics import accuracy_score, confusion_matrix

# Speed tip from your warning:
torch.set_float32_matmul_precision('high')

device = predictor.device if hasattr(predictor, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = predictor.network.to(device).eval()

# --- pull target spacing & divisors from plans.json ---
def _target_spacing_from_plans(model_dir_eff, fallback=(1.5,1.5,1.5)):
    with open(os.path.join(model_dir_eff, "plans.json")) as f:
        pl = json.load(f)
    for path in [
        ("configurations","3d_fullres","stages","0","target_spacing"),
        ("configurations","3d_fullres","stages","0","spacing"),
        ("stages","0","target_spacing"),
        ("stages","0","spacing"),
    ]:
        try:
            d = pl
            for k in path[:-1]: d = d[k]
            if path[-1] in d: return tuple(d[path[-1]])
        except Exception: pass
    return fallback

def _get_divisors_from_plans(model_dir_eff, fallback=(32,32,32)):
    try:
        from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
        pm = PlansManager(os.path.join(model_dir_eff, "plans.json"))
        stg = pm.get_stage_from_scale_factor(1.0)
        props = pm.get_properties_of_stage(stg)
        if "num_pool_per_axis" in props:
            return tuple(int(2**int(v)) for v in props["num_pool_per_axis"])
        if "pool_op_kernel_sizes" in props:
            acc = np.array([1,1,1], int)
            for s in props["pool_op_kernel_sizes"]:
                acc *= np.array(s, int)
            return tuple(int(v) for v in acc)
    except Exception:
        pass
    return fallback

def _resample_img_to_spacing(img_nii, tgt_spacing, device):
    vol = img_nii.get_fdata().astype(np.float32)
    spac = img_nii.header.get_zooms()[:3]
    scale = np.array(spac, np.float32) / np.array(tgt_spacing, np.float32)
    out_shape = np.maximum(1, np.round(np.array(vol.shape) * scale)).astype(int)
    t = torch.from_numpy(vol[None,None]).to(device, torch.float32)
    t = F.interpolate(t, size=tuple(out_shape.tolist()), mode="trilinear", align_corners=False)
    return t.squeeze().detach().cpu().numpy()

def _pad_to_multiples_torch(vol_5d, multiples):
    _,_,D,H,W = vol_5d.shape
    md,mh,mw = multiples
    pd = (md - (D % md)) % md
    ph = (mh - (H % mh)) % mh
    pw = (mw - (W % mw)) % mw
    pad = (0,pw, 0,ph, 0,pd)
    return F.pad(vol_5d, pad, mode="constant", value=0.0) if any(pad) else vol_5d

# --- log priors (same as training) ---
CSV_MAPPING = "/content/case_subtype_mapping.csv"
def _load_log_prior(device):
    if os.path.isfile(CSV_MAPPING):
        df = pd.read_csv(CSV_MAPPING)
        col = 'case_id' if 'case_id' in df.columns else ('case' if 'case' in df.columns else None)
        if col is not None and 'subtype' in df.columns:
            cnt = np.bincount(df['subtype'].astype(int).values, minlength=3)
        else:
            cnt = np.array([1,1,1], int)
    else:
        cnt = np.array([1,1,1], int)
    priors = cnt / max(1, cnt.sum())
    return torch.log(torch.tensor(np.clip(priors, 1e-8, 1.0), device=device, dtype=torch.float32))

log_prior = _load_log_prior(device)
TAU = 1.0
print("Using TAU =", TAU, "log_prior =", log_prior.detach().cpu().numpy())

# --- collect val cases ---
pairs = []
for subname, lab in (("subtype0",0), ("subtype1",1), ("subtype2",2)):
    subdir = os.path.join(VALID_ROOT, subname)
    if not os.path.isdir(subdir): continue
    for ip in sorted(glob.glob(os.path.join(subdir, "*_0000.nii.gz"))):
        pairs.append((ip, lab))
assert pairs, f"No *_0000.nii.gz under {VALID_ROOT}/subtype*/"

tgt_spacing = _target_spacing_from_plans(MODEL_DIR_EFF)
divs = _get_divisors_from_plans(MODEL_DIR_EFF)

y_true, logits_raw, logits_adj, names = [], [], [], []

with torch.no_grad():
    for ip, lab in pairs:
        img_rs = _resample_img_to_spacing(nib.load(ip), tgt_spacing, device)
        vol = ((img_rs - img_rs.mean()) / (img_rs.std() + 1e-6)).astype(np.float32)
        vt = torch.from_numpy(vol[None,None]).to(device, torch.float32)
        vt = _pad_to_multiples_torch(vt, divs)
        _, cls_logits = net(vt, return_both=True)
        logits_raw.append(cls_logits.squeeze(0).detach().cpu().numpy())
        la = (cls_logits - TAU * log_prior[None,:]).squeeze(0).detach().cpu().numpy()
        logits_adj.append(la)
        y_true.append(lab)
        names.append(Path(ip).name.replace("_0000.nii.gz",""))

y_true = np.array(y_true, int)
lr = np.vstack(logits_raw)
la = np.vstack(logits_adj)
pred_raw = lr.argmax(1)
pred_adj = la.argmax(1)

def macro_f1(cm):
    K = cm.shape[0]; f1 = []
    for k in range(K):
        tp = cm[k,k]; fp = cm[:,k].sum()-tp; fn = cm[k,:].sum()-tp
        p = tp/(tp+fp+1e-8); r = tp/(tp+fn+1e-8)
        f1.append(2*p*r/(p+r+1e-8))
    return float(np.mean(f1))

from sklearn.metrics import accuracy_score
cm_raw = confusion_matrix(y_true, pred_raw, labels=[0,1,2])
cm_adj = confusion_matrix(y_true, pred_adj, labels=[0,1,2])

print("\nPred counts (raw):", np.bincount(pred_raw, minlength=3))
print("Pred counts (adj):", np.bincount(pred_adj, minlength=3))
print("Acc raw/adj: ", accuracy_score(y_true, pred_raw), accuracy_score(y_true, pred_adj))
print("Macro-F1 raw/adj:", macro_f1(cm_raw), macro_f1(cm_adj))
print("Mean logits raw:", lr.mean(0), "\nMean logits adj:", la.mean(0))

# Save per-case debug CSV
dbg = pd.DataFrame({
    "case": names,
    "y_true": y_true,
    "pred_raw": pred_raw,
    "pred_adj": pred_adj,
    "logit0_raw": lr[:,0], "logit1_raw": lr[:,1], "logit2_raw": lr[:,2],
    "logit0_adj": la[:,0], "logit1_adj": la[:,1], "logit2_adj": la[:,2],
})
dbg_path = os.path.join(OUT_DIR, "eval_val", "cls_debug_logits.csv")
os.makedirs(os.path.dirname(dbg_path), exist_ok=True)
dbg.to_csv(dbg_path, index=False)
print("Wrote:", dbg_path)

Using TAU = 1.0 log_prior = [-1.4114846 -0.8724881 -1.0837972]

Pred counts (raw): [ 0  0 36]
Pred counts (adj): [ 0  0 36]
Acc raw/adj:  0.3333333333333333 0.3333333333333333
Macro-F1 raw/adj: 0.16666666534722221 0.16666666534722221
Mean logits raw: [-0.43842143  0.1370004   0.23765898] 
Mean logits adj: [0.97306305 1.0094886  1.3214566 ]
Wrote: /content/submission_outputs/eval_val/cls_debug_logits.csv


In [17]:
# === τ sweep for macro-F1 (keeps logits computed above) ===
from sklearn.metrics import accuracy_score, confusion_matrix
taus = [0.0, 0.25, 0.5, 1.0, 1.5, 2.0]
best = None

lp = log_prior.detach().cpu().numpy()
def mf1_for_tau(t):
    preds = np.argmax(lr - t*lp[None,:], axis=1)
    cm = confusion_matrix(y_true, preds, labels=[0,1,2])
    return accuracy_score(y_true, preds), macro_f1(cm), preds, cm

for t in taus:
    acc, f1, preds, cm = mf1_for_tau(t)
    print(f"tau={t:.2f} acc={acc:.3f} macro-F1={f1:.3f}  counts={np.bincount(preds, minlength=3)}")
    if best is None or f1 > best[1]:
        best = (t, f1, acc, preds, cm)

print("\nBest τ:", best[0], " macro-F1=", best[1], " acc=", best[2])

tau=0.00 acc=0.333 macro-F1=0.167  counts=[ 0  0 36]
tau=0.25 acc=0.333 macro-F1=0.167  counts=[ 0  0 36]
tau=0.50 acc=0.333 macro-F1=0.167  counts=[ 0  0 36]
tau=1.00 acc=0.333 macro-F1=0.167  counts=[ 0  0 36]
tau=1.50 acc=0.333 macro-F1=0.167  counts=[ 0  0 36]
tau=2.00 acc=0.333 macro-F1=0.167  counts=[ 0  0 36]

Best τ: 0.0  macro-F1= 0.16666666534722221  acc= 0.3333333333333333
