## Introduction
In this notebook we trained the 3D Unet to segment primary lung tumors in CT scans

## Imports

* pathlib for easy path handling
* HTML for visualizing volume videos
* torchio for dataset creation
* torch for DataLoaders, optimizer and loss
* pytorch-lightning for training
* numpy for masking
* matplotlib for visualization
* Our 3D model

In [None]:
from pathlib import Path

import torchio as tio
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np

from model import UNet

## Dataset Creation
We can loop over all available scans and add them to the subject list

In [None]:
from pathlib import Path
import torchio as tio

# --- Base split root ---
BASE = Path(r"E:\DoNotTouch\projects\LANSCLC\CIS_5810\selected_150_split")

def change_img_to_label_path(p: Path) -> Path:
    """
    Map:
      .../image/Lung_xxx_0000.nii.gz  -->  .../label_gtvp/Lung_xxx.nii.gz
    """
    if p.parent.name != "image":
        raise ValueError(f"Expected parent folder named 'image', got: {p.parent}")
    label_dir = p.parent.with_name("label_gtvp")
    name = p.name
    if name.endswith("_0000.nii.gz"):
        name = name[:-len("_0000.nii.gz")] + ".nii.gz"
    else:
        name = name.replace("_0000", "")
        if not name.endswith(".nii.gz"):
            name += ".nii.gz"
    return label_dir / name

In [None]:
def build_subjects(split_dir: Path, require_label: bool = True):
    """
    Create TorchIO subjects from a split directory (train/val/test).
    If require_label=False, subjects without labels are still created (CT only).
    """
    image_paths = sorted(split_dir.rglob(r"image/Lung_*_0000.nii.gz"))
    subjects, missing = [], 0

    for img_path in image_paths:
        lab_path = change_img_to_label_path(img_path)
        if lab_path.exists():
            subjects.append(
                tio.Subject(
                    CT=tio.ScalarImage(img_path),
                    Label=tio.LabelMap(lab_path),
                )
            )
        else:
            missing += 1
            if require_label:
                print(f"[WARN] Missing label for: {img_path.name} -> {lab_path.name}")
                continue
            # Test-time without labels:
            subjects.append(tio.Subject(CT=tio.ScalarImage(img_path)))

    print(f"[{split_dir.name}] Built {len(subjects)} subjects (missing labels: {missing})")
    return subjects

# --- Build each split ---
train_subjects = build_subjects(BASE / "train", require_label=True)
val_subjects   = build_subjects(BASE / "val",   require_label=True)
test_subjects  = build_subjects(BASE / "test",  require_label=True)  # flip to True if you do have test labels

In [None]:
for subject in subjects:
    assert subject["CT"].orientation == ("L", "P", "S")

We use the same  augmentation steps as used in the Dataset notebook. <br />
Regarding the processing, we use the **CropOrPad** functionality which crops or pads all images and masks to the same shape. <br />

We use ($256 \times 256 \times 200$)

In [None]:
process = tio.Compose([
            tio.CropOrPad((256, 256, 200)),
            tio.RescaleIntensity((-1, 1))
            ])


augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))


val_transform = process
test_transform = process
train_transform = tio.Compose([process, augmentation])

Define the train and validation dataset. We use 70 subjects for training and 40 for validation. <br />
In order to help the segmentation network learn, we use the LabelSampler with p=0.2 for background, p=0.8 for lung tumors with a patch size of ($96 \times 96 \times 96$).

In [None]:
train_dataset = tio.SubjectsDataset(train_subjects, transform=train_transform)
val_dataset = tio.SubjectsDataset(val_subjects, transform=val_transform)
test_dataset = tio.SubjectsDataset(test_subjects, transform=test_transform)

sampler = tio.data.LabelSampler(patch_size=96, label_name="Label", label_probabilities={0:0.2, 1:0.8})
#sampler = tio.data.UniformSampler(patch_size=96)

Create the queue to draw patches from.<br />

In [None]:
train_patches_queue = tio.Queue(
     train_dataset,
     max_length=40,
     samples_per_volume=5,
     sampler=sampler,
     num_workers=23,
    )

val_patches_queue = tio.Queue(
     val_dataset,
     max_length=40,
     samples_per_volume=5,
     sampler=sampler,
     num_workers=23,
    )

Define train and val loader.

In [None]:
use_amp = True  # set False if you don't use autocast

train_bs = 4 if use_amp else 2
val_bs   = min(8, 2*train_bs)

train_loader = torch.utils.data.DataLoader(
    train_patches_queue, batch_size=train_bs, num_workers=0, pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_patches_queue,   batch_size=val_bs,   num_workers=0, pin_memory=True
)

Finally we can create the Segmentation model.

We use the Adam optimizer with a learning rate of 1e-4 and a weighted cross-entropy loss, which assigns a threefold increased loss to tumorous voxels.

In [None]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

class Segmenter(pl.LightningModule):
    def __init__(self, lr=1e-4):
        super().__init__()
        self.model = UNet()
        self.save_hyperparameters(ignore=['model'])
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def _compute_loss_and_pred(self, logits, mask):
        """
        logits: [N,C,D,H,W]
        mask:   from TorchIO -> [N,1,D,H,W] with {0,1}
        """
        if logits.shape[1] == 1:
            # binary case
            target = mask.float()                     # [N,1,D,H,W]
            loss = F.binary_cross_entropy_with_logits(logits, target)
            probs = torch.sigmoid(logits)
            pred = (probs > 0.5).float()             # [N,1,D,H,W]
        else:
            # multi-class case (C>1)
            target = mask.squeeze(1).long()          # [N,D,H,W]
            loss = F.cross_entropy(logits, target)
            pred = torch.argmax(logits, dim=1, keepdim=True).float()  # [N,1,D,H,W]
        return loss, pred

    def training_step(self, batch, batch_idx):
        img  = batch["CT"]["data"].float()          # [N,1,D,H,W]
        mask = batch["Label"]["data"]               # [N,1,D,H,W]
        logits = self(img)
        loss, _ = self._compute_loss_and_pred(logits, mask)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img  = batch["CT"]["data"].float()
        mask = batch["Label"]["data"]
        logits = self(img)
        loss, _ = self._compute_loss_and_pred(logits, mask)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [None]:
# Instanciate the model
model = Segmenter()

In [None]:
# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=100,
    mode='min')

In [None]:
# Create the trainer

logger = TensorBoardLogger(save_dir="./logs")

use_gpu = torch.cuda.is_available()
trainer = pl.Trainer(
    accelerator="gpu" if use_gpu else "cpu",
    devices=1 if use_gpu else None,
    logger=logger,
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    max_epochs=100,
    precision=16 if use_gpu else 32,  # AMP on GPU
)

In [None]:
# Train the model.
# This might take some hours depending on your GPU
trainer.fit(model, train_loader, val_loader)

## Evaluation

In [None]:
from IPython.display import HTML
from celluloid import Camera

First we load the model and place it on the gpu if possible

In [None]:
best_path  = checkpoint_callback.best_model_path
best_score = checkpoint_callback.best_model_score.item() if checkpoint_callback.best_model_score is not None else None
print("Best:", best_path, best_score)

model = Segmenter.load_from_checkpoint(best_path)

In [None]:
model = model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device);

### Patch Aggregation
The model was trained in a patch wise manner as the full volumes are too large to be placed on a typical GPU.
But we still want to get a result for the whole volume.<br />
We performed *Patch Aggregation* by using torchio.

The goal of patch aggregation is to split the image into patches, then compute the segmentation for each patch and finally merge the predictions into the prediction for the full volume.

The pipeline we performed is as follows:
1. Define the **GridSampler(subject, patch_size, patch_overlap)** responsible for dividing the volume into patches. Each patch is defined by its location accesible via *tio.LOCATION*
2. Define the **GridAggregator(grid_sampler)** which merges the predicted patches back together
3. Compute the prediction on the patches and aggregate them via **aggregator.add_batch(pred, location)**
4. Extract the full prediction via **aggregator.get_output_tensor()**

Additionally, we leveraged the DataLoader from pytorch to perform the prediction in a batch wise manner for a nice speed up

In [None]:
# Select a validation subject and extract the images and segmentation for evaluation
IDX = 12
mask = val_dataset[IDX]["Label"]["data"]
imgs = val_dataset[IDX]["CT"]["data"]

# GridSampler
grid_sampler = tio.inference.GridSampler(val_dataset[IDX], 96, (8, 8, 8))

In [None]:
# GridAggregator
aggregator = tio.inference.GridAggregator(grid_sampler)

In [None]:
# DataLoader for speed up
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)

In [None]:
# Prediction
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['CT']["data"].to(device)  # Get batch of patches
        locations = patches_batch[tio.LOCATION]  # Get locations of patches
        pred = model(input_tensor)  # Compute prediction
        aggregator.add_batch(pred, locations)  # Combine predictions to volume

In [None]:
# Extract the volume prediction
output_tensor = aggregator.get_output_tensor()  

Finally we can visualize the prediction as usual

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display

# -------- helpers --------
def to_np(x):
    return x.detach().cpu().numpy() if hasattr(x, "detach") else np.asarray(x)

def make_depth_slicer(vol, D):
    a = to_np(vol)
    if a.ndim == 5:  # [B,C,H,W,D] or [B,C,D,H,W]
        b = a[0]
        if b.shape[-1] == D:   return lambda i: np.squeeze(b[0, :, :, i])     # [C,H,W,D]
        if b.shape[1]  == D:   return lambda i: np.squeeze(b[0, i, :, :])     # [C,D,H,W]
    elif a.ndim == 4:  # [B,H,W,D] or [B,D,H,W]
        b = a[0]
        if b.shape[-1] == D:   return lambda i: np.squeeze(b[:, :, i])
        if b.shape[0]  == D:   return lambda i: np.squeeze(b[i, :, :])
    elif a.ndim == 3:  # [H,W,D] or [D,H,W]
        if a.shape[-1] == D:   return lambda i: np.squeeze(a[:, :, i])
        if a.shape[0]  == D:   return lambda i: np.squeeze(a[i, :, :])
    raise ValueError(f"Can't determine depth axis for shape {a.shape} with D={D}")

def window_ct_hu(img, wl=-600, ww=1500):
    lo, hi = wl - ww/2, wl + ww/2
    x = np.nan_to_num(img.astype(np.float32), nan=lo)
    x = np.clip(x, lo, hi)
    return (x - lo) / (hi - lo + 1e-6)

def normalize_for_display(img):
    x = np.nan_to_num(img.astype(np.float32), nan=0.0)
    vmin, vmax = np.percentile(x, [1, 99])
    looks_like_hu = (x.min() < -500) or (x.max() > 200) or (vmax - vmin > 800)
    return window_ct_hu(x) if looks_like_hu else (
        np.zeros_like(x)+0.5 if vmax - vmin < 1e-6
        else (np.clip(x, vmin, vmax) - vmin) / (vmax - vmin + 1e-6)
    )

# ---- orientation: flip TB then rotate -90Â° (clockwise)
def reorient_2d(a):
    return np.rot90(np.flipud(a), k=3)

# -------- data prep --------
pred_np = to_np(output_tensor.argmax(0))  # [H,W,D]
D = pred_np.shape[-1]
get_img = make_depth_slicer(imgs, D)
get_gt  = make_depth_slicer(mask, D) if ('mask' in globals() and mask is not None) else None

# -------- figure & artists --------
fig, ax = plt.subplots()
ax.set_axis_off()

# initial (slice 0)
ct0 = reorient_2d(normalize_for_display(get_img(0)))
im_ct   = ax.imshow(ct0, cmap="bone", vmin=0, vmax=1, zorder=0)

p0 = reorient_2d(pred_np[:, :, 0])
im_pred = ax.imshow(np.ma.masked_where(p0 <= 0, p0.astype(float)),
                    alpha=0.35, cmap="autumn", interpolation="nearest", zorder=3)

im_gt = None
if get_gt is not None:
    gt0 = reorient_2d(get_gt(0))
    im_gt = ax.imshow(np.ma.masked_where(gt0 <= 0, gt0.astype(float)),
                      alpha=0.35, cmap="winter", interpolation="nearest", zorder=2)

def update(i):
    # base CT
    ct = reorient_2d(normalize_for_display(get_img(i)))
    im_ct.set_data(ct)

    # prediction
    p = reorient_2d(pred_np[:, :, i])
    im_pred.set_data(np.ma.masked_where(p <= 0, p.astype(float)))

    # GT
    if im_gt is not None:
        gt = reorient_2d(get_gt(i))
        im_gt.set_data(np.ma.masked_where(gt <= 0, gt.astype(float)))

    return tuple([x for x in (im_ct, im_pred, im_gt) if x is not None])

def init():
    return update(0)

ani = FuncAnimation(fig,
                    update,
                    frames=range(1, D, 2),   # start at 1 to avoid duplicate first frame
                    init_func=init,
                    interval=60,
                    blit=False,
                    repeat=False)

# Show animation only (avoid extra static image)
html = ani.to_jshtml()
plt.close(fig)
display(HTML(html))

In [None]:
# ---------- Per-split evaluation (now also returns subject_ids) ----------
@torch.no_grad()
def evaluate_split_dice(
    dataset: tio.SubjectsDataset,
    model: torch.nn.Module,
    device: Optional[torch.device] = None,
    patch_size=(96, 96, 96),
    patch_overlap=(32, 32, 32),
    batch_size=2,
    num_workers=0,
    tumor_idx=1,
    pos_thresh=0.5,
    tumor_positive_values=(1,),
    overlap_mode='hann',
    postprocess_lcc: bool = False,
    verbose: bool = True,
) -> Tuple[float, List[float], List[str]]:
    """
    Iterates subjects in `dataset`, predicts full volume, computes per-subject Dice.
    Returns (mean_dice, per_subject_dice_list, subject_ids).
    """
    dices: List[float] = []
    subject_ids: List[str] = []
    model = model.eval()
    if device is None:
        device = getattr(model, 'device', None) or next(model.parameters()).device

    for i in range(len(dataset)):
        subject = dataset[i]
        sid = subject.get('subject_id', f'case_{i}')
        subject_ids.append(sid)

        # predict full logits in native orientation/spacing
        logits_full = predict_full_volume_tio(
            subject, model, device=device,
            patch_size=patch_size, patch_overlap=patch_overlap,
            batch_size=batch_size, num_workers=num_workers,
            overlap_mode=overlap_mode,
        )  # [C, H, W, D]

        # binarize prediction & GT
        pred_bin = logits_to_binary_pred(logits_full, tumor_idx=tumor_idx, pos_thresh=pos_thresh)  # [H,W,D]
        if postprocess_lcc:
            pred_bin = largest_component(pred_bin)
        gt_bin   = subject_label_to_binary(subject, tumor_positive_values=tumor_positive_values)   # [H,W,D]

        # Dice
        d = dice_binary(gt_bin, pred_bin)
        dices.append(d)

        if verbose:
            print(f"[{sid}] Dice: {d:.4f}")

    mean_dice = float(np.mean(dices)) if dices else 0.0
    if verbose:
        print(f"\nMean Dice over {len(dices)} subjects: {mean_dice:.4f}")
    return mean_dice, dices, subject_ids


# ---------- Pretty printer for split results ----------
def print_split_report(subject_ids: List[str], dices: List[float], title: str = "Split"):
    assert len(subject_ids) == len(dices)
    print(f"\n=== {title} Dice Report ===")
    for sid, d in zip(subject_ids, dices):
        print(f"[{sid}] Dice: {d:.4f}")
    dices_np = np.asarray(dices, dtype=float)
    print(f"\n{title} summary:")
    print(f"  Mean : {dices_np.mean():.4f}")
    print(f"  Median: {np.median(dices_np):.4f}")
    print(f"  Std  : {dices_np.std(ddof=0):.4f}")
    print(f"  Min  : {dices_np.min():.4f}")
    print(f"  Max  : {dices_np.max():.4f}")

In [None]:
import re
from pathlib import Path
import numpy as np
import torch
import torchio as tio
from typing import Tuple, List, Optional

# ---------- case-id helpers ----------
def _extract_case_id_from_name(name: str) -> str:
    """
    Accepts a filename like:
      - Lung_003_0000.nii.gz  -> '003'
      - Lung_003.nii.gz       -> '003'
    Falls back to the stem if pattern isn't found.
    """
    # Try CT-style: Lung_xxx_0000.nii.gz
    m = re.match(r"^Lung_([^_]+)_0000\.nii(\.gz)?$", name, flags=re.IGNORECASE)
    if m:
        return m.group(1)

    # Try label-style: Lung_xxx.nii.gz
    m = re.match(r"^Lung_([^_]+)\.nii(\.gz)?$", name, flags=re.IGNORECASE)
    if m:
        return m.group(1)

    # Fallback: return the stem (without .nii/.nii.gz)
    # e.g., name='weirdname.nii.gz' -> 'weirdname'
    stem = re.sub(r"\.nii(\.gz)?$", "", name, flags=re.IGNORECASE)
    return stem

def get_case_id_from_subject(subject: tio.Subject) -> str:
    """
    Prefer CT path; if missing, try Label path. Returns a concise ID (e.g., '003').
    """
    # TorchIO stores path at subject['CT'][tio.PATH] and also as subject['CT'].path
    ct_path = None
    if 'CT' in subject:
        try:
            ct_path = Path(subject['CT'][tio.PATH])
        except Exception:
            ct_path = Path(getattr(subject['CT'], "path", "")) if hasattr(subject['CT'], "path") else None

    if ct_path is None or not str(ct_path):
        # Fallback to label
        lab_path = None
        if 'Label' in subject:
            try:
                lab_path = Path(subject['Label'][tio.PATH])
            except Exception:
                lab_path = Path(getattr(subject['Label'], "path", "")) if hasattr(subject['Label'], "path") else None
        if lab_path is not None and str(lab_path):
            return _extract_case_id_from_name(lab_path.name)

    if ct_path is not None and str(ct_path):
        return _extract_case_id_from_name(ct_path.name)

    # Last resort: subject_id field or a generated name
    return str(subject.get('subject_id', 'unknown'))

# ---------- (reuse your existing predict/eval functions) ----------
@torch.no_grad()
def predict_full_volume_tio(
    subject: tio.Subject,
    model: torch.nn.Module,
    device: Optional[torch.device] = None,
    patch_size=(96, 96, 96),
    patch_overlap=(32, 32, 32),
    batch_size=2,
    num_workers=0,
    overlap_mode='hann',
) -> torch.Tensor:
    model.eval()
    if device is None:
        device = getattr(model, 'device', None) or next(model.parameters()).device

    sampler = tio.inference.GridSampler(subject, patch_size=patch_size, patch_overlap=patch_overlap)
    aggregator = tio.inference.GridAggregator(sampler, overlap_mode=overlap_mode)

    loader = torch.utils.data.DataLoader(
        sampler, batch_size=batch_size, num_workers=num_workers,
        pin_memory=(device.type == 'cuda'), persistent_workers=(num_workers > 0)
    )
    for batch in loader:
        x = batch['CT'][tio.DATA].to(device)        # [B,1,w,h,d]
        logits = model(x)                            # [B,C,w,h,d]
        aggregator.add_batch(logits.detach().cpu(), batch[tio.LOCATION])

    return aggregator.get_output_tensor()            # [C,W,H,D]

def logits_to_binary_pred(logits_chwhd: torch.Tensor, tumor_idx: int = 1, pos_thresh: float = 0.5) -> np.ndarray:
    C = logits_chwhd.shape[0]
    if C == 1:
        prob = torch.sigmoid(logits_chwhd[0])
        pred = (prob > pos_thresh)
    else:
        pred = (torch.argmax(logits_chwhd, dim=0) == tumor_idx)
    return pred.numpy().astype(np.uint8)

def subject_label_to_binary(subject: tio.Subject, tumor_positive_values=(1,)) -> np.ndarray:
    lab = subject['Label'][tio.DATA]         # [1,W,H,D]
    lab = lab[0].detach().cpu().numpy()      # [W,H,D]
    return np.isin(lab, tumor_positive_values).astype(np.uint8)

def dice_binary(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-6, empty_ok_as: float = 1.0) -> float:
    y_true = (y_true > 0).astype(np.uint8)
    y_pred = (y_pred > 0).astype(np.uint8)
    t_sum, p_sum = y_true.sum(), y_pred.sum()
    if t_sum == 0 and p_sum == 0:
        return float(empty_ok_as)
    inter = (y_true & y_pred).sum()
    return float((2.0 * inter + eps) / (t_sum + p_sum + eps))

@torch.no_grad()
def evaluate_split_dice(
    dataset: tio.SubjectsDataset,
    model: torch.nn.Module,
    device: Optional[torch.device] = None,
    patch_size=(96,96,96),
    patch_overlap=(32,32,32),
    batch_size=2,
    num_workers=0,
    tumor_idx=1,
    pos_thresh=0.5,
    tumor_positive_values=(1,),
    overlap_mode='hann',
    print_per_case: bool = True,
) -> Tuple[float, List[float], List[str]]:
    if device is None:
        device = getattr(model, 'device', None) or next(model.parameters()).device
    model.eval()

    dices: List[float] = []
    ids: List[str] = []

    for i in range(len(dataset)):
        subject = dataset[i]
        case_id = get_case_id_from_subject(subject)  # <-- actual case number from filename
        ids.append(case_id)

        logits_full = predict_full_volume_tio(
            subject, model, device=device,
            patch_size=patch_size, patch_overlap=patch_overlap,
            batch_size=batch_size, num_workers=num_workers,
            overlap_mode=overlap_mode
        )  # [C,W,H,D]

        pred_bin = logits_to_binary_pred(logits_full, tumor_idx=tumor_idx, pos_thresh=pos_thresh)
        gt_bin   = subject_label_to_binary(subject, tumor_positive_values=tumor_positive_values)

        d = dice_binary(gt_bin, pred_bin)
        dices.append(d)

        if print_per_case:
            print(f"[case {case_id}] Dice: {d:.4f}")

    mean_dice = float(np.mean(dices)) if dices else 0.0
    if print_per_case:
        print(f"\nMean Dice over {len(dices)} subjects: {mean_dice:.4f}")
    return mean_dice, dices, ids

# =======================
# Usage
# =======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

common = dict(
    patch_size=(96,96,96),
    patch_overlap=(32,32,32),
    batch_size=2,
    num_workers=0,
    overlap_mode='hann',
    tumor_idx=1,
    pos_thresh=0.5,              # fixed sigmoid threshold
    tumor_positive_values=(1,),
    print_per_case=True,
)

print("\nValidation split:")
mean_val, val_dices, val_ids = evaluate_split_dice(val_dataset, model, device=device, **common)

print("\nTest split:")
mean_test, test_dices, test_ids = evaluate_split_dice(test_dataset, model, device=device, **common)