# Detection

Predict if an image contains tumor(s) that are either clinically significant or not, where clinical significance refers to having a Gleason score greater than or equal to 7, volume greater than or equal to 0.5cc, and/or extraprostatic extension.

Architecture: MONAI ResNet10

## Imports and setup

In [1]:
import torch
from torch.utils.data import Subset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from monai.networks.nets import resnet
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, confusion_matrix

import numpy as np
from collections import defaultdict
import random
import sys
import os

In [2]:
# Add project root to sys path to allow for package-like imports
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [3]:
# Set seeds

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [4]:
# Set device - MacOS
device = torch.device( "mps" if torch.backends.mps.is_available() else "cpu")

## Load data

In [5]:
# Load data
from scripts.load_data import MRIDataset

dataset = MRIDataset(root_dir="../data/lesions", labels_path="../data/lesions/PROSTATEx_Classes.csv")
print(len(dataset))

200


In [6]:
# Patient-level index for train/test split
patient_idxs = defaultdict(list)

for idx, sample in enumerate(dataset.samples):
    finding_id = sample["finding_id"]
    patient_id = finding_id.split("_Finding")[0]
    patient_idxs[patient_id].append(idx)

patient_ids = list(patient_idxs.keys())
print(f"Total patients: {len(patient_ids)}")

Total patients: 199


In [7]:
train_patients, test_patients = train_test_split(
    patient_ids,
    test_size=0.2,
    random_state=42
)

train_idxs = []
test_idxs = []

for pid in train_patients:
    train_idxs.extend(patient_idxs[pid])

for pid in test_patients:
    test_idxs.extend(patient_idxs[pid])

print(f"Train samples: {len(train_idxs)}")
print(f"Test samples: {len(test_idxs)}")

Train samples: 160
Test samples: 40


## Preprocess data

In [8]:
# Helper to iterate over torch Subset

def iter_subset(subset):
    for i in range(len(subset)):
        yield subset[i]

In [9]:
from scripts.preprocess import Compose, ResampleResize, Normalize, ToTensor, CropROI

transform = Compose([
    CropROI(margin=(12, 12, 4)),
    ResampleResize(target_spacing=(0.5, 0.5, 3.0), target_shape=(96, 96, 16)),
    Normalize(),
    ToTensor()
])

train_set = Subset(MRIDataset(
    root_dir="../data/lesions", 
    labels_path="../data/lesions/PROSTATEx_Classes.csv",
    transform=transform
), train_idxs)

test_set = Subset(MRIDataset(
    root_dir="../data/lesions", 
    labels_path="../data/lesions/PROSTATEx_Classes.csv",
    transform=transform
), test_idxs)

In [10]:
# Setup loaders

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False)

## Train model

MONAI 3D ResNet
- Input (1, 96, 96, 16)
- Binary classification head (single neuron + sigmoid)

Metrics
- AUC, sensitivity, specificity

Loss: BCEWithLogitsLoss and class weighting

In [11]:
model = resnet.resnet10(
    spatial_dims=3,
    n_input_channels=1,
    num_classes=1,
    conv1_t_stride=1
).to(device)

model

ResNet(
  (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): ReLU(inplace=True)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=T

In [12]:
# Weighted BCE loss
num_pos = sum([s["cls_label"] for s in iter_subset(train_set)])
num_neg = len(train_set) - num_pos 
pos_weight = torch.tensor(num_neg / num_pos).to(device)

criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=1e-4)

  pos_weight = torch.tensor(num_neg / num_pos).to(device)


In [13]:
# Metrics helper
def compute_metrics(y_true, y_pred):
    auc = roc_auc_score(y_true, y_pred)

    # Threshold at 0.5 for binary
    y_bin = (y_pred >= 0.5).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_bin).ravel()

    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    return {
        "AUC": auc,
        "Sensitivity": sensitivity,
        "Specificity": specificity
    }

In [14]:
# TensorBoard logging
writer = SummaryWriter(log_dir="runs/detection")

In [15]:
from scripts.train_model import train_model

trained_model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optim,
    criterion=criterion,
    metrics_fn=compute_metrics,
    device=device,
    epochs=10,
    writer=writer,
    task_name="detection",
    path="models/best_detection.pt"
)

Epoch 1/10 | Train loss: 0.9578 | Val loss: 1.3999 | AUC: 0.5269 | Sensitivity: 1.0000 | Specificity: 0.0000
Epoch 2/10 | Train loss: 0.4893 | Val loss: 1.3669 | AUC: 0.5627 | Sensitivity: 0.2353 | Specificity: 0.8261
Epoch 3/10 | Train loss: 0.2056 | Val loss: 1.1189 | AUC: 0.6087 | Sensitivity: 0.4706 | Specificity: 0.7391
Epoch 4/10 | Train loss: 0.1371 | Val loss: 1.8557 | AUC: 0.5754 | Sensitivity: 0.2353 | Specificity: 0.8261
Epoch 5/10 | Train loss: 0.1422 | Val loss: 1.3310 | AUC: 0.5371 | Sensitivity: 0.4706 | Specificity: 0.5217
Epoch 6/10 | Train loss: 0.1574 | Val loss: 1.9262 | AUC: 0.5575 | Sensitivity: 0.3529 | Specificity: 0.7826
Epoch 7/10 | Train loss: 0.1343 | Val loss: 1.3506 | AUC: 0.6087 | Sensitivity: 0.8235 | Specificity: 0.2609
Epoch 8/10 | Train loss: 0.1381 | Val loss: 2.3325 | AUC: 0.6368 | Sensitivity: 0.1765 | Specificity: 1.0000
Epoch 9/10 | Train loss: 0.1553 | Val loss: 2.7709 | AUC: 0.5703 | Sensitivity: 0.1765 | Specificity: 0.8261
Epoch 10/10 | Train

## Evaluation

From `tensorboard`, we observe the following graphs:

**Train loss**

<img src="../assets/detection/detection_Loss_train.svg" width="600" height="200">

**Validation loss**

<img src="../assets/detection/detection_Loss_val.svg" width="600" height="200">

**Validation AUC**

<img src="../assets/detection/detection_AUC_val.svg" width="600" height="200">

**Validation sensitivity**

<img src="../assets/detection/detection_Sensitivity_val.svg" width="600" height="200">

**Validation specificity**

<img src="../assets/detection/detection_Specificity_val.svg" width="600" height="200">