## Introduction
Here we train our tumor segmentation network. <br />

## Imports:

* Pathlib for easy path handling
* torch for tensor handling
* pytorch lightning for efficient and easy training implementation
* ModelCheckpoint and TensorboardLogger for checkpoint saving and logging
* imgaug for Data Augmentation
* numpy for file loading and array ops
* matplotlib for visualizing some images
* tqdm for progress par when validating the model
* celluloid for easy video generation
* Our dataset and model

In [1]:
from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug.augmenters as iaa
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from celluloid import Camera

In [2]:
from pathlib import Path
import importlib.util, os

# put the folder with dataset.py/model.py on sys.path
import sys
sys.path.insert(0, r"E:/DoNotTouch/projects/LANSCLC/CIS_5810")

import dataset   # loads dataset.py
import model     # loads model.py

LungDataset = dataset.LungDataset
UNet        = model.UNet

print(LungDataset.__module__)  # should print "dataset"

dataset


## Dataset Creation
Here we create the train and validation dataset. <br />
Additionally we define our data augmentation pipeline.
Subsequently the two dataloaders are created

In [3]:
seq = iaa.Sequential([
    iaa.Affine(translate_percent=(0.15),
               scale=(0.85, 1.15), # zoom in or out
               rotate=(-45, 45)#
               ),  # rotate up to 45 degrees
    iaa.ElasticTransformation()  # Elastic Transformations
                ])

In [4]:
# Create the dataset objects
BASE = Path("E:/DoNotTouch/projects/LANSCLC/CIS_5810/selected_150_split")
train_path = BASE / "Preprocessed_for_2D_Unet/train"
val_path = BASE / "Preprocessed_for_2D_Unet/val"
test_path = BASE / "Preprocessed_for_2D_Unet/test"

train_dataset = LungDataset(train_path, seq)
val_dataset = LungDataset(val_path, None)
test_dataset = LungDataset(test_path, None)

print(f"There are {len(train_dataset)} train images, {len(val_dataset)} val images and {len(test_dataset)} test images")

There are 7876 train images, 4236 val images and 5222 test images


## Oversampling to tackle strong class imbalance
Lung tumors are often very small, thus we need to make sure that our model does not learn a trivial solution which simply outputs 0 for all voxels.<br />
In this notebook we use oversampling to sample slices which contain a tumor more often.

To do so we can use the WeightedRandomSampler provided by pytorch which needs a weight for each sample in the dataset.
Typically we have one weight for each class, which means that we need to calculate two weights, one for slices without tumors and one for slices with a tumor and create list that assigns each sample from the dataset the corresponding weight

To do so, we at first need to create a list containing only the class labels:

In [5]:
from tqdm import tqdm
import numpy as np
from contextlib import contextmanager

@contextmanager
def disable_aug(ds):
    prev = getattr(ds, "augment_params", None)
    try:
        ds.augment_params = None
        yield ds
    finally:
        ds.augment_params = prev

def has_tumor(mask, thr=0.5):
    a = np.asarray(mask)
    a = np.squeeze(a)
    if a.ndim == 3:               # collapse channel dim if present
        a = (a > thr).any(axis=0)
    return int((a > thr).any())

with disable_aug(train_dataset):
    target_list = []
    for i in tqdm(range(len(train_dataset))):
        _, label = train_dataset[i]     # now no imgaug/warpAffine is triggered
        target_list.append(has_tumor(label))

pos = sum(target_list)
print(f"positives: {pos}/{len(target_list)} ({100*pos/len(target_list):.1f}%)")

100%|█████████████████████████████████████████████████████████████████████████████| 7876/7876 [00:12<00:00, 632.45it/s]

positives: 1949/7876 (24.7%)





Then we can calculate the weight for each class: To do so, we can simply compute the fraction between the classes and then create the weight list

In [6]:
# !pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless thinc
# !pip install -U --force-reinstall --no-cache-dir numpy==1.26.4 scipy==1.11.4
# !pip install -U imgaug==0.4.0
# !pip install -U "opencv-python==4.8.1.78" "opencv-contrib-python==4.8.1.78" "opencv-python-headless==4.8.1.78"
# # (Only if you truly need thinc/spaCy here)
# # pip install "thinc<8.3"


In [7]:
uniques = np.unique(target_list, return_counts=True)
uniques

(array([0, 1]), array([5927, 1949], dtype=int64))

In [8]:
fraction = uniques[1][0] / uniques[1][1]
fraction

3.0410466906105693

Subsequently we assign the weight 1 to each slice without a tumor and ~9 to each slice with a tumor

In [9]:
weight_list = []
for target in target_list:
    if target == 0:
        weight_list.append(1)
    else:
        weight_list.append(fraction)

In [10]:
weight_list[:50]

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

Finally we create the sampler which we can pass to the DataLoader. We only use a sampler for the train loader. We dont't want to change the validation data to get a real validation.

In [11]:
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight_list, len(weight_list))

In [12]:
batch_size = 8
num_workers = 23

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler, pin_memory=True, persistent_workers=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True, persistent_workers=True)

We can verify that our sampler works by taking a batch from the train loader and count how many labels are larger than zero

In [13]:
verify_sampler = next(iter(train_loader))

In [14]:
(verify_sampler[1][:,0]).sum([1, 2]) > 0  # ~ half the batch size

tensor([ True,  True, False, False, False, False, False,  True])

## Loss

We use the Binary Cross Entropy

## Full Segmentation Model

We now combine everything into the full pytorch lightning model

In [15]:
import torch
import torch.nn as nn
import lightning.pytorch as pl
import numpy as np
import matplotlib.pyplot as plt

class TumorSegmentation(pl.LightningModule):
    def __init__(self, lr: float = 1e-4):
        super().__init__()
        self.model = UNet().float()                 # ensure model params are float32
        self.lr = lr
        self.loss_fn = nn.BCEWithLogitsLoss()       # expects float targets, logits input

    def forward(self, x):
        return self.model(x)                        # logits

    def training_step(self, batch, batch_idx):
        ct, mask = batch
        ct   = ct.float()                           # <— enforce float32
        mask = mask.float()                         # <— enforce float32

        logits = self(ct)                           # (B,1,H,W) logits (float32)
        loss = self.loss_fn(logits, mask)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        if (batch_idx % 50 == 0) and self.logger is not None:
            self.log_images(ct, logits, mask, "Train")
        return loss

    def validation_step(self, batch, batch_idx):
        ct, mask = batch
        ct   = ct.float()                           # <— enforce float32
        mask = mask.float()                         # <— enforce float32

        logits = self(ct)
        val_loss = self.loss_fn(logits, mask)

        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)

        if (batch_idx % 50 == 0) and self.logger is not None:
            self.log_images(ct, logits, mask, "Val")
        return val_loss

    def log_images(self, ct, logits, mask, split_name: str):
        probs = torch.sigmoid(logits)
        pred_bin = (probs > 0.5)

        img = ct[0, 0].detach().cpu().numpy()
        gt  = mask[0, 0].detach().cpu().numpy()
        pr  = pred_bin[0, 0].detach().cpu().numpy()

        fig, axis = plt.subplots(1, 2, figsize=(8, 4))
        axis[0].imshow(img, cmap="bone")
        axis[0].imshow(np.ma.masked_where(gt == 0, gt), alpha=0.6)
        axis[0].set_title("Ground Truth"); axis[0].axis("off")

        axis[1].imshow(img, cmap="bone")
        axis[1].imshow(np.ma.masked_where(pr == 0, pr), alpha=0.6, cmap="autumn")
        axis[1].set_title("Prediction"); axis[1].axis("off")

        self.logger.experiment.add_figure(f"{split_name} Prediction vs Label", fig, self.global_step)
        plt.close(fig)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [16]:
# Instanciate the model
model = TumorSegmentation()

In [17]:
# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Val Dice',
    save_top_k=100,
    mode='min')

In [18]:
# Create the trainer
# Change the gpus parameter to the number of available gpus in your computer. Use 0 for CPU training

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath="./ckpts", monitor="val_loss", mode="min", save_top_k=1
)

trainer = pl.Trainer(
    accelerator="gpu", devices=1,            # or accelerator="auto", devices="auto"
    logger=TensorBoardLogger(save_dir="./logs"),
    log_every_n_steps=1,
    callbacks=[checkpoint_callback],
    max_epochs=100
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [19]:
trainer.fit(model, train_loader, val_loader)

C:\ProgramData\Anaconda3\envs\py38\lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:652: Checkpoint directory E:\DoNotTouch\projects\LANSCLC\CIS_5810\2D U-Net\ckpts exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | UNet              | 7.8 M  | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.127    Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<…

Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=100` reached.


## Evaluation:
Evaluate the results

In [20]:
class DiceScore(torch.nn.Module):
    """
    class to compute the Dice Loss
    """
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):

        #flatten label and prediction tensors
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        counter = (pred * mask).sum()  # Counter
        denum = pred.sum() + mask.sum()  # denominator
        dice = (2*counter)/denum

        return dice

In [21]:
from pathlib import Path
import torch

def resolve_ckpt_path(ckpt_cb) -> Path:
    # 1) Prefer the "best" checkpoint
    p = getattr(ckpt_cb, "best_model_path", "") or ""
    if p and Path(p).is_file():
        return Path(p)

    # 2) Fall back to "last"
    p = getattr(ckpt_cb, "last_model_path", "") or ""
    if p and Path(p).is_file():
        return Path(p)

    # 3) Search the callback dirpath for any .ckpt (pick most recent)
    root = Path(getattr(ckpt_cb, "dirpath", "."))  # where ModelCheckpoint saves
    cands = list(root.rglob("*.ckpt")) if root.exists() else []
    if not cands:
        # also search Lightning default logs/checkpoints tree
        cands = list(Path("lightning_logs").rglob("*.ckpt"))
    if not cands:
        raise FileNotFoundError(
            "No checkpoint found. Ensure ModelCheckpoint has dirpath/monitor or use save_last=True."
        )
    return max(cands, key=lambda x: x.stat().st_mtime)

# --- use it ---
ckpt_path = resolve_ckpt_path(checkpoint_callback)
print(f"Loading checkpoint: {ckpt_path}")

model = TumorSegmentation.load_from_checkpoint(
    ckpt_path,
    map_location="cuda" if torch.cuda.is_available() else "cpu",
    # strict=True  # set to False if your model signature changed since saving
)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Loading checkpoint: E:\DoNotTouch\projects\LANSCLC\CIS_5810\2D U-Net\ckpts\epoch=10-step=10835.ckpt


TumorSegmentation(
  (model): UNet(
    (layer1): DoubleConv(
      (step): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (layer2): DoubleConv(
      (step): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (layer3): DoubleConv(
      (step): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU()
      )
    )
    (layer4): DoubleConv(
      (step): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): 

In [22]:
preds = []
labels = []

for slice, label in tqdm(val_dataset):
    slice = torch.tensor(slice).float().to(device).unsqueeze(0)
    with torch.no_grad():
        pred = torch.sigmoid(model(slice))
    preds.append(pred.cpu().numpy())
    labels.append(label)

preds = np.array(preds)
labels = np.array(labels)

100%|█████████████████████████████████████████████████████████████████████████████| 4236/4236 [00:37<00:00, 111.87it/s]


Compute overall Dice Score on the Validation Set

In [23]:
dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels).unsqueeze(0).float())
print(f"The Val Dice Score is: {dice_score}")

The Val Dice Score is: 0.4555647075176239


In [24]:
preds = []
labels = []

for slice, label in tqdm(test_dataset):
    slice = torch.tensor(slice).float().to(device).unsqueeze(0)
    with torch.no_grad():
        pred = torch.sigmoid(model(slice))
    preds.append(pred.cpu().numpy())
    labels.append(label)

preds = np.array(preds)
labels = np.array(labels)

100%|█████████████████████████████████████████████████████████████████████████████| 5222/5222 [00:46<00:00, 111.33it/s]


Compute overall Dice Score on the Test Set

In [25]:
dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels).unsqueeze(0).float())
print(f"The Test Dice Score is: {dice_score}")

The Test Dice Score is: 0.48201602697372437


In [36]:
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch

# ---------- helpers ----------
def _dice_coef(pred_bin: np.ndarray, gt_bin: np.ndarray, eps=1e-7) -> float:
    inter = np.logical_and(pred_bin, gt_bin).sum()
    denom = pred_bin.sum() + gt_bin.sum()
    if denom == 0:
        return 1.0
    return (2.0 * inter + eps) / (denom + eps)

def _get_paths_attr(ds):
    for name in ("all_files", "files", "paths", "slices"):
        if hasattr(ds, name):
            return getattr(ds, name)
    raise AttributeError("Dataset must expose a list of slice paths via "
                         "`all_files`, `files`, `paths`, or `slices`.")

def _pid_from_slice_path(p: Path) -> str:
    # .../<patient>/data/<slice>.npy
    return p.parent.parent.name

def _slice_idx_from_path(p: Path) -> int:
    return int(p.stem)

def _ensure_img_tensor(x) -> torch.Tensor:
    """
    Returns torch.float32 tensor of shape (1,H,W) on CPU.
    Accepts numpy or torch; shapes (H,W) or (1,H,W) or (C,H,W) with C==1.
    """
    if torch.is_tensor(x):
        arr = x.detach().cpu().numpy()
    else:
        arr = np.asarray(x)

    # squeeze extra leading dims but keep spatial
    while arr.ndim > 2 and arr.shape[0] == 1:
        arr = arr[1-1] if arr.ndim == 3 else arr.squeeze(axis=0)
    if arr.ndim == 2:
        arr = arr[None, ...]  # (1,H,W)
    elif arr.ndim == 3 and arr.shape[0] != 1:
        raise ValueError(f"Expected 1 channel, got shape {arr.shape}")

    return torch.from_numpy(arr.astype(np.float32, copy=False))  # (1,H,W)

def _ensure_mask_np(y) -> np.ndarray:
    """
    Returns float32 numpy array of shape (H,W).
    Accepts numpy or torch; shapes (H,W) or (1,H,W).
    """
    if torch.is_tensor(y):
        arr = y.detach().cpu().numpy()
    else:
        arr = np.asarray(y)
    # squeeze a single channel if present
    if arr.ndim == 3 and arr.shape[0] == 1:
        arr = arr[0]
    if arr.ndim != 2:
        raise ValueError(f"Mask must be 2D or (1,H,W); got {arr.shape}")
    return arr.astype(np.float32, copy=False)

# ---------- main ----------
def eval_by_patient(dataset, model, threshold=0.5, device=None, verbose=True):
    """
    Groups slices per patient using dataset slice paths and computes volumetric Dice.
    Works whether dataset returns numpy arrays or torch tensors.
    """
    if device is None:
        device = next(model.parameters()).device

    paths = _get_paths_attr(dataset)
    by_patient = defaultdict(list)  # pid -> list of (slice_idx, prob2d, gt2d)

    model.eval()
    with torch.no_grad():
        for idx in range(len(dataset)):
            x, y = dataset[idx]                     # may be numpy or torch
            img_t = _ensure_img_tensor(x)           # (1,H,W) float32
            gt2d  = _ensure_mask_np(y)              # (H,W) float32

            p = Path(str(paths[idx]))
            pid = _pid_from_slice_path(p)
            sidx = _slice_idx_from_path(p)

            # model expects (N,C,H,W)
            logits = model(img_t.unsqueeze(0).to(device))  # (1,C,H,W)
            if logits.shape[1] == 1:
                prob2d = torch.sigmoid(logits)[0, 0].cpu().numpy()
            else:
                prob2d = torch.softmax(logits, dim=1)[0, 1].cpu().numpy()

            by_patient[pid].append((sidx, prob2d.astype(np.float32, copy=False), gt2d))

    patient_dice = {}
    patient_pred_vols, patient_gt_vols = {}, {}

    for pid, triples in by_patient.items():
        triples.sort(key=lambda t: t[0])
        pred_vol = np.stack([t[1] for t in triples], axis=0)  # (D,H,W)
        gt_vol   = np.stack([t[2] for t in triples], axis=0)

        pred_bin = pred_vol >= float(threshold)
        gt_bin   = gt_vol  >  0.0

        patient_pred_vols[pid] = pred_vol
        patient_gt_vols[pid]   = gt_vol
        patient_dice[pid]      = _dice_coef(pred_bin, gt_bin)

    if verbose and patient_dice:
        vals = np.array(list(patient_dice.values()), dtype=np.float32)
        print(f"Patients: {len(vals)} | mean Dice: {vals.mean():.4f} | median: {np.median(vals):.4f}")
    return patient_dice, patient_pred_vols, patient_gt_vols

In [37]:
val_dice, val_pred_vols, val_gt_vols = eval_by_patient(val_dataset, model, threshold=0.5)
val_dice

Patients: 40 | mean Dice: 0.4413 | median: 0.4363


{'100': 0.4250162654568557,
 '101': 0.6293026307951163,
 '102': 0.6035349351018958,
 '103': 0.671194455391226,
 '104': 1.9580967299401194e-11,
 '105': 0.30859266140590524,
 '106': 0.8355995055635517,
 '107': 0.3701524282180607,
 '108': 1.5071590052523413e-11,
 '109': 0.407997038136017,
 '70': 0.30694551558632244,
 '71': 0.7961047479696349,
 '72': 0.34346685324700554,
 '73': 0.007783477808174566,
 '74': 0.7784842505314735,
 '75': 0.24455346310269413,
 '76': 0.3907680063749732,
 '77': 0.33970826580654967,
 '78': 0.7165822897250351,
 '79': 1.9102196752261658e-11,
 '80': 0.020039363043298983,
 '81': 0.5536780159768501,
 '82': 0.7494719053666176,
 '83': 0.7590961593997995,
 '84': 0.019417475748411125,
 '85': 0.4622822978972588,
 '86': 0.5961730117610221,
 '87': 0.7795941375426427,
 '88': 0.4475446428633086,
 '89': 0.7976995319532635,
 '90': 0.8082623965716049,
 '91': 0.21830585561088253,
 '92': 0.8636750319018018,
 '93': 0.5082494552272063,
 '94': 0.05834937440412446,
 '95': 0.8275118917507

In [28]:
test_dice, test_pred_vols, test_gt_vols = eval_by_patient(test_dataset, model, threshold=0.5)
test_dice

Patients: 40 | mean Dice: 0.5056 | median: 0.5515


{'110': 0.32252733900693786,
 '111': 0.7562136435761173,
 '112': 0.6907679033676378,
 '113': 0.9058605798895382,
 '114': 0.6698140555769457,
 '115': 0.5786007526531877,
 '116': 1.4526438117162603e-11,
 '117': 0.679285460786948,
 '118': 0.4869669538676874,
 '119': 0.35981838819657835,
 '120': 0.44800382500861585,
 '121': 0.2628290484557735,
 '122': 0.45013558300929657,
 '123': 0.20807174889076127,
 '124': 0.8407346491232436,
 '125': 0.14030494379223543,
 '126': 8.39278220723128e-12,
 '127': 0.781235750115246,
 '128': 0.914712861889144,
 '129': 0.7013664766895463,
 '130': 0.4148246056463333,
 '131': 0.5620263942121668,
 '132': 0.7857745966088124,
 '133': 0.6355664044875186,
 '134': 0.541040866228248,
 '135': 0.8101017030864985,
 '136': 5.65227221339785e-12,
 '137': 0.7570604977021198,
 '138': 0.2513274336366003,
 '139': 0.024313561108527417,
 '140': 0.6483870967758139,
 '141': 0.33017715333437553,
 '142': 0.46906245471863117,
 '143': 0.6524839840486826,
 '144': 0.6522639834285195,
 '145'

## Visualization

In [29]:
THRESHOLD = 0.5

In [30]:
import nibabel as nib
import cv2

In [31]:
BASE = Path("E:/DoNotTouch/projects/LANSCLC/CIS_5810/selected_150_split")

subject = Path(BASE / "test" / "image" / "Lung_045_0000.nii.gz")
ct = nib.load(subject).get_fdata() / 3071  # standardize
ct = ct[:,:,30:]  # crop

In [32]:
segmentation = []
label = []
scan = []

for i in range(ct.shape[-1]):
    slice = ct[:,:,i]
    slice = cv2.resize(slice, (256, 256))
    slice = torch.tensor(slice)
    scan.append(slice)
    slice = slice.unsqueeze(0).unsqueeze(0).float().to(device)

    with torch.no_grad():
        pred = model(slice)[0][0].cpu()
    pred = pred > THRESHOLD
    segmentation.append(pred)
    label.append(segmentation)

In [33]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
import torch

def reorient_2d(a):
    # flip vertically, then rotate -90° (clockwise)
    return np.rot90(np.flipud(a), k=3)

def to_slice_list(x):
    """
    Return a list of 2D numpy arrays (same shape).
    Handles: torch.Tensor, list/tuple of slices, or a 3D volume (D,H,W) / (H,W,D).
    """
    # torch -> numpy
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()

    # squeeze singleton axes (e.g., [1, D, H, W] -> [D, H, W])
    arr = np.array(x, dtype=object)  # allow ragged temporarily
    try:
        arr = np.squeeze(np.asarray(x))
    except Exception:
        # if x is a ragged list, skip to list-branch below
        arr = None

    slices = None
    if isinstance(x, (list, tuple)):
        # list/tuple of slices
        slices = [np.asarray(s) for s in x]
    elif isinstance(arr, np.ndarray):
        if arr.ndim == 2:
            slices = [arr]
        elif arr.ndim == 3:
            # decide which axis is depth (use the one with the smallest size heuristically)
            # then move it to front so we have (D, H, W)
            d_axis = int(np.argmin(arr.shape))
            vol = np.moveaxis(arr, d_axis, 0)  # (D,H,W)
            slices = [vol[i] for i in range(vol.shape[0])]
        else:
            raise ValueError(f"Expected 2D/3D input or list of 2D slices; got array with ndim={arr.ndim}")
    else:
        raise ValueError("Unsupported input type for scan/segmentation.")

    # verify uniform shapes
    H, W = np.asarray(slices[0]).shape
    for i, s in enumerate(slices):
        if np.asarray(s).ndim != 2 or np.asarray(s).shape != (H, W):
            raise ValueError(f"Slice {i} has shape {np.asarray(s).shape}, "
                             f"but expected {(H, W)}. Resize/crop first so all slices match.")

    # ensure base float arrays
    return [np.asarray(s, dtype=np.float32, order="C") for s in slices]

# -------- build consistent slice lists --------
scan_slices = to_slice_list(scan)
seg_slices  = to_slice_list(segmentation)

# optional: ensure seg is binary/integer for masking
seg_slices  = [(s > 0).astype(np.uint8) for s in seg_slices]

N = min(len(scan_slices), len(seg_slices))

# reorient
scan_r = [reorient_2d(s) for s in scan_slices[:N]]
seg_r  = [reorient_2d(s) for s in seg_slices[:N]]

# -------- animate --------
fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(scan_r[0], cmap="bone", interpolation="nearest")
ov = ax.imshow(np.ma.masked_where(seg_r[0]==0, seg_r[0]),
               cmap="autumn", alpha=0.5, interpolation="nearest")
ax.axis("off")

def update(i):
    im.set_data(scan_r[i])
    ov.set_data(np.ma.masked_where(seg_r[i]==0, seg_r[i]))
    return im, ov

anim = animation.FuncAnimation(fig, update, frames=range(0, N, 2), interval=80, blit=False)
display(HTML(anim.to_jshtml()))
plt.close(fig)

  arr = np.array(x, dtype=object)  # allow ragged temporarily
  arr = np.squeeze(np.asarray(x))
