# Segmentation

Segment tumor(s) in an image based on masks.

Architecture: MONAI UNet

## Imports and setup

In [1]:
import torch
from torch.utils.data import Subset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from monai.networks.nets.unet import UNet
from monai.losses.dice import DiceLoss
from monai.metrics.meandice import DiceMetric

import numpy as np
from collections import defaultdict
import random
import sys
import os

In [2]:
# Add project root to sys path to allow for package-like imports
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [3]:
# Set seeds

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [4]:
# Set device - MacOS
device = torch.device( "mps" if torch.backends.mps.is_available() else "cpu")

## Load data

In [5]:
# Load data
from scripts.load_data import MRIDataset

dataset = MRIDataset(root_dir="../data/lesions", labels_path="../data/lesions/PROSTATEx_Classes.csv")
print(len(dataset))

200


In [6]:
# Patient-level index for train/test split
patient_idxs = defaultdict(list)

for idx, sample in enumerate(dataset.samples):
    finding_id = sample["finding_id"]
    patient_id = finding_id.split("_Finding")[0]
    patient_idxs[patient_id].append(idx)

patient_ids = list(patient_idxs.keys())
print(f"Total patients: {len(patient_ids)}")

Total patients: 199


In [7]:
# Excluding zero-mask samples
def has_nonzero_mask(idx):
    return dataset[idx]["mask"].sum() > 0

seg_patients = []

for patient_id, indices in patient_idxs.items():
    if any(has_nonzero_mask(i) for i in indices):
        seg_patients.append(patient_id)

print(f"Patients with lesions: {len(seg_patients)}")

Patients with lesions: 199


In [8]:
train_patients, test_patients = train_test_split(
    seg_patients,
    test_size=0.2,
    random_state=42
)

train_idxs = []
test_idxs = []

for pid in train_patients:
    for idx in patient_idxs[pid]:
        if has_nonzero_mask(idx):
            train_idxs.append(idx)

for pid in test_patients:
    for idx in patient_idxs[pid]:
        if has_nonzero_mask(idx):
            test_idxs.append(idx)

print(f"Train samples: {len(train_idxs)}")
print(f"Test samples: {len(test_idxs)}")

Train samples: 160
Test samples: 40


In [9]:
# Torch subsets and loaders
train_set = Subset(dataset, train_idxs)
test_set = Subset(dataset, test_idxs)

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False)

## Preprocess data

In [10]:
# Helper to iterate over torch Subset

def iter_subset(subset):
    for i in range(len(subset)):
        yield subset[i]

In [11]:
from scripts.preprocess import Compose, ResampleResize, Normalize, ToTensor, CropROI

transform = Compose([
    CropROI(margin=(12, 12, 4)),
    ResampleResize(target_spacing=(0.5, 0.5, 3.0), target_shape=(96, 96, 16)),
    Normalize(),
    ToTensor()
])

train_set = Subset(MRIDataset(
    root_dir="../data/lesions", 
    labels_path="../data/lesions/PROSTATEx_Classes.csv",
    transform=transform
), train_idxs)

test_set = Subset(MRIDataset(
    root_dir="../data/lesions", 
    labels_path="../data/lesions/PROSTATEx_Classes.csv",
    transform=transform
), test_idxs)

In [12]:
# Setup loaders

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False)

## Train model

In [13]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2
).to(device)

model

UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Se

In [14]:
# Dice + cross entropy loss
bce = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.0]).to(device))
dice = DiceLoss(sigmoid=True)

def loss_fn(logits, targets):
    return 0.7 * bce(logits, targets) + 0.3 * dice(logits, targets)

# Dice score as metric
dice_metric = DiceMetric(
    include_background=True,
    reduction='mean',
    ignore_empty=True
)

# Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=1e-4)

In [15]:
# TensorBoard logging
writer = SummaryWriter(log_dir="runs/segmentation")

In [16]:
from scripts.train_model import train_seg_model

trained_model = train_seg_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    criterion=loss_fn,
    optimizer=optim,
    dice_metric=dice_metric,
    device=device,
    epochs=10,
    writer=writer,
    task_name="segmentation",
    path="models/best_segmentation.pt"
)

Epoch 1/10 | Train loss: 0.9026 | Val loss: 0.8745 | Val Dice: 0.0012
Epoch 2/10 | Train loss: 0.8581 | Val loss: 0.8426 | Val Dice: 0.0013
Epoch 3/10 | Train loss: 0.8345 | Val loss: 0.8264 | Val Dice: 0.0013
Epoch 4/10 | Train loss: 0.8222 | Val loss: 0.8175 | Val Dice: 0.0013
Epoch 5/10 | Train loss: 0.8151 | Val loss: 0.8121 | Val Dice: 0.0012
Epoch 6/10 | Train loss: 0.8105 | Val loss: 0.8082 | Val Dice: 0.0012
Epoch 7/10 | Train loss: 0.8070 | Val loss: 0.8051 | Val Dice: 0.0013
Epoch 8/10 | Train loss: 0.8042 | Val loss: 0.8025 | Val Dice: 0.0012
Epoch 9/10 | Train loss: 0.8017 | Val loss: 0.8001 | Val Dice: 0.0012
Epoch 10/10 | Train loss: 0.7995 | Val loss: 0.7981 | Val Dice: 0.0012
Best val Dice: 0.0013


## Evaluation

From `tensorboard`, we observe the following graphs.

**Train loss**

<img src="../assets/segmentation/segmentation_Loss_train.svg" width="600" height="200">

**Val loss**

<img src="../assets/segmentation/segmentation_Loss_val.svg" width="600" height="200">

**Val Dice**

<img src="../assets/segmentation/segmentation_Dice_val.svg" width="600" height="200">