In [27]:
import os
import h5py
import numpy as np
import random
import SimpleITK as sitk
from skimage.transform import resize
from pathlib import Path

In [28]:
# Configuration
# Input paths
image_root = Path("C:\project\mama-mia\images")
mask_root = Path("C:\project\mama-mia\segmentations/expert")
# Output HDF5 file
output_h5 = Path("../data/processed/mama-mia_selected_slices.h5")
output_h5.parent.mkdir(parents=True, exist_ok=True)

# Selection parameters
ratio = (1, 1)  # positive:negative slice ratio
min_tumor_pixels = 20  # at least this many pixels in mask -> positive slice
min_image_std = 0.05  # minimum std-dev for negative
max_slice_distance = 10  # max distance from tumor slices for negatives
fallback_ratio = 0.3  # for patients without any tumor slices
seed = 42
random.seed(seed)

# Volume preprocessing
DESIRED_PHASES = 6  # number of phases/channels to keep (pad or truncate)
new_spacing = (1.0, 1.0, 1.0)
linear_interp = sitk.sitkLinear
nearest_interp = sitk.sitkNearestNeighbor

# Output slice shape
target_shape = (128, 128)  # (height, width)

In [29]:
# Helper Functions
def resample_image(itk_image, spacing, interpolator):
    """
    Resample a SimpleITK image to isotropic spacing.
    """
    orig_size = itk_image.GetSize()
    orig_spacing = itk_image.GetSpacing()
    new_size = [
        int(round(orig_size[i] * (orig_spacing[i] / spacing[i])))
        for i in range(3)
    ]
    resampler = sitk.ResampleImageFilter()
    resampler.SetOutputSpacing(spacing)
    resampler.SetSize(new_size)
    resampler.SetOutputOrigin(itk_image.GetOrigin())
    resampler.SetOutputDirection(itk_image.GetDirection())
    resampler.SetInterpolator(interpolator)
    return resampler.Execute(itk_image)

In [30]:
# Gather Patients
image_files = list(image_root.rglob("*.nii*"))
mask_files = list(mask_root.rglob("*.nii*"))

images_by_patient = {}
for img in image_files:
    pid = img.parent.name.lower()
    images_by_patient.setdefault(pid, []).append(img)

masks_by_patient = {}
for m in mask_files:
    pid = m.name.split('.')[0].lower()
    masks_by_patient[pid] = m

patients = [pid for pid in images_by_patient if pid in masks_by_patient]
print(f"Found {len(patients)} patients with both images and masks.")

Found 1506 patients with both images and masks.


In [31]:
# Determine Slice Dimensions
# Use first patient's first phase to get dimensions
test_img = sitk.ReadImage(str(images_by_patient[patients[0]][0]))
res_test = resample_image(test_img, new_spacing, linear_interp)
h, w = sitk.GetArrayFromImage(res_test).shape[1:]
if target_shape:
    h, w = target_shape

In [32]:
# Create Extensible HDF5 Datasets
h5f = h5py.File(output_h5, "w")
img_ds = h5f.create_dataset(
    "images", shape=(0, DESIRED_PHASES, h, w), maxshape=(None, DESIRED_PHASES, h, w),
    dtype="float32", chunks=(1, DESIRED_PHASES, h, w)
)
mask_ds = h5f.create_dataset(
    "masks", shape=(0, 1, h, w), maxshape=(None, 1, h, w),
    dtype="uint8", chunks=(1, 1, h, w)
)

In [34]:
# Append Utility
plural = lambda n, s: f"{n} {s}" + ("s" if n!=1 else "")
def append_slices(images, masks):
    """
    Append new batches of images and masks to HDF5 datasets.
    """
    n_new = images.shape[0]
    old_n = img_ds.shape[0]
    img_ds.resize(old_n + n_new, axis=0)
    mask_ds.resize(old_n + n_new, axis=0)
    img_ds[old_n:old_n + n_new, ...] = images
    mask_ds[old_n:old_n + n_new, ...] = masks
    print(f"  -> Dataset resized: now {old_n + n_new} total slices.")

In [None]:
def select_negative_slices(volumes, mask_arr, tumor_slices, ratio, min_image_std, Z, fallback_ratio, max_slice_distance):
    """
    Select negative slice indices for training:
    - Close to tumors to capture ample breast tissue (not mostly background)
    - High-contrast slices to avoid uniform or noisy images
    - Fallback: highest-contrast slices when no tumors exist
    """
    neg_candidates = []
    for z in range(Z):
        # Skip slices that contain tumor or already have any mask present
        if z in tumor_slices or mask_arr[z].sum() > 0:
            continue
        # Compute standard deviation (contrast) of the slice
        std_val = volumes[0][z].std()
        # Skip if the slice is too uniform (low contrast)
        if std_val < min_image_std:
            continue
        # Compute the minimum distance (in slices) to any tumor slice; infinite if no tumors
        dist = min(abs(z - t) for t in tumor_slices) if tumor_slices else np.inf
        # Assign a score: only positive if within max_slice_distance, higher when closer
        score = (max_slice_distance - dist) if dist <= max_slice_distance else 0
        neg_candidates.append((z, std_val, score))

    # Sort candidates by descending score first, then by descending contrast
    neg_candidates.sort(key=lambda x: (-x[2], -x[1]))

    if tumor_slices:
        # If there are tumors, choose negatives in proportion to number of tumor slices (ratio[1]/ratio[0])
        desired_neg = ratio[1] * len(tumor_slices) // ratio[0]
        # First take slices with positive score (within defined radius)
        selected = [c for c in neg_candidates if c[2] > 0][:desired_neg]

        # If not enough, fill up with out-of-range slices (score == 0), sorted by highest contrast
        if len(selected) < desired_neg:
            extras = [c for c in neg_candidates if c[2] == 0]
            extras.sort(key=lambda x: -x[1])
            selected += extras[:desired_neg - len(selected)]

        return {z for z, *_ in selected}
    else:
        # Fallback when there are no tumor slices (e.g., healthy subjects)
        # Pick a fixed number of highest-contrast slices
        fb_count = max(1, int(Z * fallback_ratio))
        background = [
            (z, volumes[0][z].std())
            for z in range(Z)
            if mask_arr[z].sum() == 0 # only slices without any mask
        ]
        # Sort descending by contrast
        background.sort(key=lambda x: -x[1])
        return {z for z, _ in background[:fb_count]}

In [35]:
# Process and Write Slices
total_written = 0
print("Processing patients and writing slices...")
for idx_pid, pid in enumerate(patients, start=1):
    print(f"Processing patient {idx_pid}/{len(patients)}: '{pid}'")
    # Load and preprocess volumes
    volumes = []
    for img_path in sorted(images_by_patient[pid]):
        img = sitk.ReadImage(str(img_path))
        res_img = resample_image(img, new_spacing, linear_interp)
        arr = sitk.GetArrayFromImage(res_img).astype(np.float32)
        arr = (arr - arr.mean()) / (arr.std() + 1e-8)
        volumes.append(arr)

    # Load and preprocess mask
    m = sitk.ReadImage(str(masks_by_patient[pid]))
    res_m = resample_image(m, new_spacing, nearest_interp)
    mask_arr = sitk.GetArrayFromImage(res_m).astype(np.uint8)

    Z = volumes[0].shape[0]
    # Identify positive slices
    tumor_slices = [z for z in range(Z) if mask_arr[z].sum() >= min_tumor_pixels]

    # Identify negative slices via helper function
    neg_slices = select_negative_slices(
        volumes, mask_arr, tumor_slices,
        ratio, min_image_std, Z, fallback_ratio,
        max_slice_distance
    )

    print(f"  Found {len(tumor_slices)} positive and {len(neg_slices)} negative slices.")

    # Build batch for this patient
    selected_indices = tumor_slices + sorted(neg_slices)
    n_sel = len(selected_indices)
    batch_imgs = np.zeros((n_sel, DESIRED_PHASES, h, w), dtype=np.float32)
    batch_masks = np.zeros((n_sel, 1, h, w), dtype=np.uint8)

    for i, z in enumerate(selected_indices, start=1):
        # Stack phases and pad
        stack = np.stack([vol[z] for vol in volumes], axis=0)
        if stack.shape[0] < DESIRED_PHASES:
            pad_w = ((0, DESIRED_PHASES - stack.shape[0]), (0, 0), (0, 0))
            stack = np.pad(stack, pad_w, mode='constant')
        else:
            stack = stack[:DESIRED_PHASES]
        # Resize
        if target_shape:
            resized = np.zeros((DESIRED_PHASES, h, w), dtype=np.float32)
            for c in range(DESIRED_PHASES):
                resized[c] = resize(stack[c], (h, w), order=1, preserve_range=True)
            stack = resized

        mask_slice = mask_arr[z]
        if target_shape:
            mask_slice = resize(mask_slice, (h, w), order=0, preserve_range=True).astype(np.uint8)

        batch_imgs[i-1] = stack
        batch_masks[i-1, 0] = mask_slice
        print(f"    Processing slice {i}/{n_sel} (index={z})", end='\r')

    print()  # newline after slice loop

    # Append this patient's batch
    append_slices(batch_imgs, batch_masks)
    total_written += n_sel
    print(f"  Appended {n_sel} slices for patient '{pid}' (total written: {total_written}).\n")

h5f.close()
print("HDF5 container creation complete.")

Processing patients and writing slices...
Processing patient 1/1506: 'duke_001'
  Found 20 positives and 20 negatives slices.
    Processing slice 40/40 (index=84)
  -> Dataset resized: now 40 total slices.
  Appended 40 slices for patient 'duke_001' (total written: 40).

Processing patient 2/1506: 'duke_002'
  Found 14 positives and 14 negatives slices.
    Processing slice 28/28 (index=113)
  -> Dataset resized: now 68 total slices.
  Appended 28 slices for patient 'duke_002' (total written: 68).

Processing patient 3/1506: 'duke_005'
  Found 47 positives and 47 negatives slices.
    Processing slice 94/94 (index=136)
  -> Dataset resized: now 162 total slices.
  Appended 94 slices for patient 'duke_005' (total written: 162).

Processing patient 4/1506: 'duke_009'
  Found 25 positives and 25 negatives slices.
    Processing slice 50/50 (index=115)
  -> Dataset resized: now 212 total slices.
  Appended 50 slices for patient 'duke_009' (total written: 212).

Processing patient 5/15