In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import cv2
import pydicom
from skimage import exposure
from skimage.restoration import denoise_nl_means, estimate_sigma
from scipy.spatial.distance import directed_hausdorff
import logging
import re

In [2]:
# Configuration
BASE_SEGMENTATION_DIR = "datasets\segmentation\PROSTATEx"      # Directory containing segmentation masks
BASE_IMAGES_DIR = "datasets\segmentation2\PROSTATEx"          # Directory containing image slices

SEGMENTATION_FILENAME = "1-1.dcm"                                     # Segmentation mask filename per patient
IMAGE_FILENAME_PATTERN = r".*\.dcm$"                                  # Pattern to match image slices

IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH = 256, 256, 19                       # Adjust IMG_DEPTH based on number of slices per patient
BATCH_SIZE = 2
EPOCHS = 50

In [7]:
def load_dicom_image(file_path):
    """
    Load a DICOM file and return the image array.
    """
    try:
        dicom = pydicom.dcmread(file_path)
        image = dicom.pixel_array.astype(np.float32)
        # Rescale if necessary based on DICOM metadata
        if 'RescaleSlope' in dicom and 'RescaleIntercept' in dicom:
            slope = dicom.RescaleSlope
            intercept = dicom.RescaleIntercept
            image = image * slope + intercept
        return image
    except Exception as e:
        print(f"Error loading DICOM file {file_path}: {e}")
        return None

def denoise_image(image):
    """
    Apply Non-Local Means denoising to the image.
    
    Parameters:
    - image: 2D numpy array
    
    Returns:
    - denoised image
    """
    # Estimate the noise standard deviation from the noisy image
    sigma_est = np.mean(estimate_sigma(image))
    denoised = denoise_nl_means(image, h=1.15 * sigma_est, fast_mode=True,
                                patch_size=5, patch_distance=3)
    return denoised

def normalize_image(image, method='z-score'):
    """
    Normalize the image using the specified method.
    
    Parameters:
    - image: 2D numpy array
    - method: 'z-score' or 'minmax'
    
    Returns:
    - normalized image
    """
    if method == 'z-score':
        mean = np.mean(image)
        std = np.std(image)
        normalized = (image - mean) / (std + 1e-8)
    elif method == 'minmax':
        min_val = np.min(image)
        max_val = np.max(image)
        normalized = (image - min_val) / (max_val - min_val + 1e-8)
    else:
        raise ValueError("Unknown normalization method")
    return normalized

In [4]:
def random_flip_3d(volume, mask):
    """
    Randomly flip the volume and mask along the horizontal and/or vertical axes.
    """
    # Flip along the horizontal axis
    if np.random.rand() < 0.5:
        volume = np.flip(volume, axis=2)  # Assuming axis=2 is the width
        mask = np.flip(mask, axis=2)
    
    # Flip along the vertical axis
    if np.random.rand() < 0.5:
        volume = np.flip(volume, axis=1)  # Assuming axis=1 is the height
        mask = np.flip(mask, axis=1)
    
    return volume, mask

def random_rotate_3d(volume, mask):
    """
    Randomly rotate the volume and mask by 90 degrees around the depth axis.
    """
    k = np.random.choice([0, 1, 2, 3])
    volume = np.rot90(volume, k, axes=(1, 2))
    mask = np.rot90(mask, k, axes=(1, 2))
    return volume, mask

def random_scale_3d(volume, mask, scale_range=(0.9, 1.1)):
    """
    Randomly scale the volume and mask.
    Note: Scaling in 3D is more complex; for simplicity, we'll scale each slice individually.
    """
    scaled_volume = []
    scaled_mask = []
    for slice_img, slice_mask in zip(volume, mask):
        scale = np.random.uniform(*scale_range)
        height, width = slice_img.shape
        new_h, new_w = int(height * scale), int(width * scale)
        slice_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        slice_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        
        # Crop or pad to maintain original size
        if scale < 1.0:
            pad_h = (height - new_h) // 2
            pad_w = (width - new_w) // 2
            slice_img = np.pad(slice_img, ((pad_h, height - new_h - pad_h), 
                                           (pad_w, width - new_w - pad_w)), mode='constant')
            slice_mask = np.pad(slice_mask, ((pad_h, height - new_h - pad_h), 
                                           (pad_w, width - new_w - pad_w)), mode='constant')
        else:
            start_h = (new_h - height) // 2
            start_w = (new_w - width) // 2
            slice_img = slice_img[start_h:start_h + height, start_w:start_w + width]
            slice_mask = slice_mask[start_h:start_h + height, start_w:start_w + width]
        
        scaled_volume.append(slice_img)
        scaled_mask.append(slice_mask)
    
    scaled_volume = np.array(scaled_volume)
    scaled_mask = np.array(scaled_mask)
    
    return scaled_volume, scaled_mask

def augment_data_3d(volume, mask):
    """
    Apply a series of random 3D augmentations to the volume and mask.
    """
    volume, mask = random_flip_3d(volume, mask)
    volume, mask = random_rotate_3d(volume, mask)
    volume, mask = random_scale_3d(volume, mask)
    return volume, mask

In [5]:
def load_dataset(segmentation_dir, images_dir, segmentation_filename, img_height, img_width, img_depth):
    """
    Load images and masks from separate directories, preprocess them, and return as numpy arrays.
    
    Parameters:
    - segmentation_dir: Directory containing segmentation masks
    - images_dir: Directory containing image slices
    - segmentation_filename: Filename of the segmentation mask per patient
    - img_height: Desired image height after resizing
    - img_width: Desired image width after resizing
    - img_depth: Number of slices per patient
    
    Returns:
    - images: Numpy array of preprocessed image volumes
    - masks: Numpy array of corresponding segmentation masks
    """
    images = []
    masks = []
    
    # List of patients based on segmentation directory
    patients = os.listdir(segmentation_dir)
    
    for patient in patients:
        patient_seg_dir = os.path.join(segmentation_dir, patient)
        #print(patient_seg_dir)
        if not os.path.isdir(patient_seg_dir):
            print(f"Skipping {patient}")
            continue  # Skip if not a directory
        
        # Path to segmentation mask
        mask_path = os.path.join(patient_seg_dir, segmentation_filename)
        # Sanitize the path cause the file name has a number in it
        mask_path = mask_path.replace("\\", "/")
        #print(mask_path)
        if not os.path.exists(mask_path):
            #print(f"Segmentation mask not found for patient: {patient}")
            continue
        else:
            print(f"Segmentation mask found for patient: {patient}")

        # Load segmentation mask
        mask_volume = load_dicom_image(mask_path)
        if mask_volume is None:
            #print(f"Failed to load mask for patient: {patient}")
            continue
        else:
            print(f"Mask loaded for patient: {patient}")
        
        # Assuming mask_volume is a 3D array with shape (depth, height, width)
        # If it's stored as multi-frame DICOM, ensure it's loaded correctly
        # Here, we assume it's loaded as a 3D numpy array
        if len(mask_volume.shape) == 2:
            # Single slice mask; expand to 3D
            mask_volume = np.expand_dims(mask_volume, axis=0)
        
        # Load image slices
        patient_img_dir = os.path.join(images_dir, patient)

        # Sanitize the path cause the file name has a number in it
        patient_img_dir = patient_img_dir.replace("\\", "/")
        if not os.path.exists(patient_img_dir):
            #print(f"Image directory not found for patient: {patient}")
            continue
        else:
            print(f"Image directory found for patient: {patient}")
        
        # List all DICOM files in image directory
        img_files = sorted([f for f in os.listdir(patient_img_dir) if re.match(IMAGE_FILENAME_PATTERN, f)])

        # Sanitize the path cause the file name has a number in it
        img_files = [f.replace("\\", "/") for f in img_files]

        if len(img_files) == 0:
            #print(f"No image slices found for patient: {patient}")
            continue
        else:
            print(f"Image slices found for patient: {patient}")
        
        # Load each slice
        img_slices = []
        for img_file in img_files:
            img_path = os.path.join(patient_img_dir, img_file)
            img_path = img_path.replace("\\", "/")
            img_slice = load_dicom_image(img_path)
            if img_slice is None:
                print(f"Failed to load image slice: {img_path}")
                break
            img_slices.append(img_slice)
        
        if len(img_slices) != img_depth:
            print(f"Unexpected number of slices for patient: {patient}. Expected {img_depth}, got {len(img_slices)}.")
            continue
        
        # Stack slices to form a 3D volume
        img_volume = np.stack(img_slices, axis=0)  # Shape: (depth, height, width)
        
        # Preprocessing
        img_volume = denoise_image(img_volume)
        img_volume = normalize_image(img_volume, method='z-score')
        
        # Resize each slice
        img_volume_resized = []
        mask_volume_resized = []
        for slice_img, slice_mask in zip(img_volume, mask_volume):
            slice_img = cv2.resize(slice_img, (img_width, img_height), interpolation=cv2.INTER_LINEAR)
            slice_mask = cv2.resize(slice_mask, (img_width, img_height), interpolation=cv2.INTER_NEAREST)
            img_volume_resized.append(slice_img)
            mask_volume_resized.append(slice_mask)
        
        img_volume_resized = np.array(img_volume_resized)
        mask_volume_resized = np.array(mask_volume_resized)
        
        # Ensure masks are binary
        mask_volume_resized = (mask_volume_resized > 0.5).astype(np.uint8)
        
        images.append(img_volume_resized)
        masks.append(mask_volume_resized)

        print(images)
        print(masks)
    
    images = np.array(images)
    masks = np.array(masks)
    
    return images, masks

In [8]:
# Load the dataset
images, masks = load_dataset(
    segmentation_dir=BASE_SEGMENTATION_DIR,
    images_dir=BASE_IMAGES_DIR,
    segmentation_filename=SEGMENTATION_FILENAME,
    img_height=IMG_HEIGHT,
    img_width=IMG_WIDTH,
    img_depth=IMG_DEPTH
)

print(f"Loaded {images.shape[0]} patients.")
print(f"Image shape: {images.shape}")  # Expected: (num_patients, depth, height, width)
print(f"Mask shape: {masks.shape}")    # Expected: (num_patients, depth, height, width)

Skipping LICENSE
Segmentation mask found for patient: ProstateX-0004
Mask loaded for patient: ProstateX-0004
Image directory found for patient: ProstateX-0004
Image slices found for patient: ProstateX-0004
[array([[[-1.405619  , -1.4061869 , -1.4071844 , ..., -0.69648004,
         -0.7282918 , -1.043816  ],
        [-1.3961741 , -1.3966874 , -1.4084216 , ...,  0.7658621 ,
          0.81251776,  0.149917  ],
        [-1.4035134 , -0.7072568 ,  0.15798044, ..., -0.606623  ,
         -0.482183  , -0.4548376 ],
        ...,
        [-1.4083972 , -1.4083357 , -1.4082916 , ..., -1.4035094 ,
         -1.4060359 , -1.4072194 ],
        [-1.4084176 , -1.408393  , -1.4083695 , ..., -1.4063904 ,
         -1.4071817 , -1.4075536 ],
        [-1.4084206 , -1.4084145 , -1.4083848 , ..., -1.4075558 ,
         -1.4079317 , -1.4078405 ]],

       [[-1.4059147 , -1.4063776 , -1.4073238 , ..., -0.60558796,
         -0.69428533, -0.97405005],
        [-1.4078151 , -1.4019612 , -1.4084216 , ...,  1.7776711 

In [None]:
X_train, X_val, y_train, y_val = train_test_split(
    images, masks, test_size=0.2, random_state=42
)

print(f"Training set: {X_train.shape}, {y_train.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape}")

In [None]:
def build_unet(input_shape=(256, 256, 1)):
    inputs = keras.Input(shape=input_shape)
    
    # Encoder
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)
    
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)
    
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)
    
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D(pool_size=(2, 2))(c4)
    
    # Bottleneck
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    
    # Decoder
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    
    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    
    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    
    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)
    
    model = keras.Model(inputs=[inputs], outputs=[outputs])
    return model

# Build the model
model = build_unet()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

In [None]:
# Calculate steps per epoch
steps_per_epoch = len(X_train) // BATCH_SIZE
validation_steps = len(X_val) // BATCH_SIZE

# Define callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint("prostate_segmentation_best.keras", save_best_only=True, monitor='val_loss'),
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
]

# Train the model
history = model.fit(
    train_gen,
    steps_per_epoch=steps_per_epoch,
    epochs=EPOCHS,
    validation_data=val_gen,
    validation_steps=validation_steps,
    callbacks=callbacks
)

In [None]:
def dice_coefficient(y_true, y_pred):
    """
    Compute Dice Similarity Coefficient.
    
    Parameters:
    - y_true: Ground truth mask
    - y_pred: Predicted mask
    
    Returns:
    - Dice coefficient
    """
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (np.sum(y_true_f) + np.sum(y_pred_f) + 1e-8)

def hausdorff_distance(y_true, y_pred):
    """
    Compute Hausdorff Distance between two binary masks.
    
    Parameters:
    - y_true: Ground truth mask
    - y_pred: Predicted mask
    
    Returns:
    - Hausdorff distance
    """
    y_true_pts = np.argwhere(y_true)
    y_pred_pts = np.argwhere(y_pred)
    
    if len(y_true_pts) == 0 or len(y_pred_pts) == 0:
        return np.inf
    
    forward_hd = directed_hausdorff(y_true_pts, y_pred_pts)[0]
    backward_hd = directed_hausdorff(y_pred_pts, y_true_pts)[0]
    
    return max(forward_hd, backward_hd)

In [None]:
# Load the best model
model.load_weights("prostate_segmentation_best.h5")

# Predict on the validation set
val_predictions = model.predict(X_val[:])

# Binarize predictions
val_predictions_bin = (val_predictions > 0.5).astype(np.uint8)

# Compute metrics
dice_scores = []
hausdorff_scores = []

for i in range(len(y_val)):
    dice = dice_coefficient(y_val[i], val_predictions_bin[i, :, :, 0])
    hd = hausdorff_distance(y_val[i], val_predictions_bin[i, :, :, 0])
    dice_scores.append(dice)
    hausdorff_scores.append(hd)

print(f"Mean Dice Coefficient on Validation Set: {np.mean(dice_scores):.4f}")
print(f"Mean Hausdorff Distance on Validation Set: {np.mean(hausdorff_scores):.4f} mm")

In [None]:
# Save the model
model.save("prostate_segmentation.keras")
print("Model saved as prostate_segmentation.keras")

In [None]:
# Plot training & validation loss values
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper right')

# Plot training & validation accuracy values
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='lower right')

plt.tight_layout()
plt.show()