# Segmentation

Segment tumor(s) in an image based on masks.

Architecture: U-Net

## Imports and setup

In [1]:
import torch
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import train_test_split
import nibabel as nib

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import random
import sys
import os

In [2]:
# Add project root to sys path to allow for package-like imports
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [3]:
# Set seeds

def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [4]:
# Set device - MacOS
device = torch.device( "mps" if torch.backends.mps.is_available() else "cpu")

## Load data

In [5]:
# Load data
from scripts.load_data import MRIDataset

dataset = MRIDataset(root_dir="../data/lesions", labels_path="../data/lesions/PROSTATEx_Classes.csv")
print(len(dataset))

200


In [6]:
# Patient-level index for train/test split
patient_idxs = defaultdict(list)

for idx, sample in enumerate(dataset.samples):
    finding_id = sample["finding_id"]
    patient_id = finding_id.split("_Finding")[0]
    patient_idxs[patient_id].append(idx)

patient_ids = list(patient_idxs.keys())
print(f"Total patients: {len(patient_ids)}")

Total patients: 199


In [7]:
train_patients, test_patients = train_test_split(
    patient_ids,
    test_size=0.2,
    random_state=42
)

train_idxs = []
test_idxs = []

for pid in train_patients:
    train_idxs.extend(patient_idxs[pid])

for pid in test_patients:
    test_idxs.extend(patient_idxs[pid])

print(f"Train samples: {len(train_idxs)}")
print(f"Test samples: {len(test_idxs)}")

Train samples: 160
Test samples: 40


In [8]:
# Torch subsets and loaders
train_set = Subset(dataset, train_idxs)
test_set = Subset(dataset, test_idxs)

train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, shuffle=False)

## Preprocess data

## Train model

## Evaluation