# Predicting

This notebook implements the prediction of new labels and their evaluation. While the default nnUNet procedure involves an automatic search for the best configuration and model followed by ensemble prediction, here a single model was manually selected after evaluating different configurations on a single fold in order to reduce training and inference time. This model was then trained on the entire training dataset and used to predict the test images.
For more details see `nnUNet/documentation/how_to_use_nnunet.md`.

1. Prediction with the selected model
2. Postprocessing
3. Evaluating the performance of the model


In [None]:
from pathlib import Path
import os

current_dir = Path.cwd()

# Default paths for the datasets
os.environ["nnUNet_raw"] = str(current_dir / "nnUNet_raw")
os.environ["nnUNet_preprocessed"] = str(current_dir / "nnUNet_preprocessed")
os.environ["nnUNet_results"] = str(current_dir / "nnUNet_results")

data_name = "Dataset101_FemurCorrected"

If desired, the code below installs the pretrained 2D and 3D models that were found to perform best in the project for femur CT segmentation.

In [None]:
!nnUNetv2_install_pretrained_model_from_zip 3D
!nnUNetv2_install_pretrained_model_from_zip 2D

## 1. Prediction with the selected model

Predict all labels for the test images located in `nnUNet_raw/{data_name}/imagesTs` and store these in `predictions/not_postprocessed`.


In [None]:
command = f"nnUNetv2_predict -i nnUNet_raw/{data_name}/imagesTs -o predictions/not_postprocessed -d 101 -f all -c 3d_fullres"
!{command}

## 2. Postprocessing

In order to apply postprocessing, first the command `nnUNetv2_find_best_configuration` has to be run to generate the corresponding postprocessing file. For this to work, the selected model has to be trained on all five folds of the training data beforehand.


In [None]:
!nnUNetv2_find_best_configuration 101 -c 3d_fullres --disable_ensembling

The postprcoessing file can now be found at `nnUNet_results/{data_name}/nnUNetTrainer__nnUNetPlans__3d_fullres/crossval_results_folds_0_1_2_3_4/postprocessing.pkl`. The results are stored in `predictions/not_postprocessed`


In [None]:
processing_file_path = Path(f"nnUNet_results/{data_name}/nnUNetTrainer__nnUNetPlans__3d_fullres/crossval_results_folds_0_1_2_3_4/postprocessing.pkl").resolve()

!nnUNetv2_apply_postprocessing -i predictions/not_postprocessed -o predictions/postprocessed -pp_pkl_file {processing_file_path}

## 3. Evaluation

In [None]:
import SimpleITK as sitk
import numpy as np
import pandas as pd
import hvplot.pandas
import panel as pn
import holoviews as hv
import matplotlib.pyplot as plt
import re
import skimage.measure

The Evaluation starts by calculating the quantitave petformance measure: Dice score, specificity, sensitivity and average surface distance. 
First, select the desired predictions.

In [None]:
# Define the directories containing predictions from different models
prediction_dirs = [
    'predictions/postprocessed'
]

In [None]:
def load_nrrd(file_path):
    image = sitk.ReadImage(file_path)
    array = sitk.GetArrayFromImage(image)
    return array

def calculate_dice_coefficient(pred, truth, epsilon=1e-5):
    intersection = np.sum((pred == truth) & (pred != 0))
    volume_sum = np.sum(pred != 0) + np.sum(truth != 0)
    dice = (2. * intersection + epsilon) / (volume_sum + epsilon)
    return dice

def calculate_sensitivity(pred, truth):
    true_positive = np.sum((pred == 1) & (truth == 1))
    false_negative = np.sum((pred == 0) & (truth == 1))
    sensitivity = true_positive / (true_positive + false_negative + 1e-5)
    return sensitivity

def calculate_specificity(pred, truth):
    true_negative = np.sum((pred == 0) & (truth == 0))
    false_positive = np.sum((pred == 1) & (truth == 0))
    specificity = true_negative / (true_negative + false_positive + 1e-5)
    return specificity

# Average Surface Distance
def calculate_asd(pred, truth):
    pred_image = sitk.GetImageFromArray(pred)
    truth_image = sitk.GetImageFromArray(truth)

    pred_surface = sitk.LabelContour(pred_image)
    truth_surface = sitk.LabelContour(truth_image)

    pred_distance_map = sitk.SignedMaurerDistanceMap(pred_surface, squaredDistance=False, useImageSpacing=True)
    truth_distance_map = sitk.SignedMaurerDistanceMap(truth_surface, squaredDistance=False, useImageSpacing=True)

    # Convert to arrays
    pred_distance_map_array = sitk.GetArrayFromImage(pred_distance_map)
    truth_distance_map_array = sitk.GetArrayFromImage(truth_distance_map)

    # Contours
    pred2truth = pred_distance_map_array[truth != 0]
    truth2pred = truth_distance_map_array[pred != 0]

    asd = (np.mean(np.abs(pred2truth)) + np.mean(np.abs(truth2pred))) / 2.0
    return asd

def compare_labels(prediction_dirs, ground_truth_dir):
    all_results = []

    for model_idx, predicted_dir in enumerate(prediction_dirs):
        pred_files = sorted(os.listdir(predicted_dir))
        truth_files = sorted(os.listdir(ground_truth_dir))

        for pred_file, truth_file in zip(pred_files, truth_files):
            pred_path = os.path.join(predicted_dir, pred_file)
            truth_path = os.path.join(ground_truth_dir, truth_file)
            
            pred_array = load_nrrd(pred_path)
            truth_array = load_nrrd(truth_path)
            
            dice_score = calculate_dice_coefficient(pred_array, truth_array)
            sensitivity = calculate_sensitivity(pred_array, truth_array)
            specificity = calculate_specificity(pred_array, truth_array)
            asd = calculate_asd(pred_array, truth_array)
            
            all_results.append({
                'model': f'Model_{model_idx + 1}',
                'file': pred_file,
                'DICE': dice_score,
                'Sensitivity': sensitivity,
                'Specificity': specificity,
                'ASD': asd,
                'pred_array': pred_array, 
                'truth_array': truth_array
            })
    
    return all_results

ground_truth_labels_dir = f"nnUNet_raw/{data_name}/labelsTs"

results = compare_labels(prediction_dirs, ground_truth_labels_dir)
df = pd.DataFrame(results)

average_metrics = df.groupby("model")[["DICE", "Sensitivity", "Specificity", "ASD"]].mean()
print(average_metrics)
average_metrics.to_csv("average_metrics.csv")

Next, an interactive plot is generated, allowing one to switch between different models as well as different performance metrics. The plot displays the distribution of each metric across the images as a histogram.


In [None]:
pn.extension()

# Switch between models and metrics
model_selector = pn.widgets.Select(name='Model', options=df['model'].unique().tolist())
metric_selector = pn.widgets.Select(name='Metric', options=['DICE', 'Sensitivity', 'Specificity', 'ASD'])

@pn.depends(model_selector, metric_selector)
def plot_histogram(model, metric):
    filtered_df = df[df['model'] == model]
    histogram = filtered_df.hvplot.hist(y=metric, bins=15, alpha=0.7, height=400, width=600, title=f'{metric.capitalize()} Distribution for {model}')
    return histogram

histogram_layout = pn.Column(
    pn.Row(model_selector, metric_selector),
    plot_histogram
)

histogram_layout.servable()

Here an interactive boxplot is generated, displaying the distribution of each metric across the images, side by side for the different models.


In [None]:
pn.extension()

# Switch between metrics
metric_selector = pn.widgets.Select(name='Metric', options=['DICE', 'Sensitivity', 'Specificity', 'ASD'])

@pn.depends(metric_selector)
def plot_boxplot(metric):
    boxplot = df.hvplot.box(y=metric, by='model', height=400, width=600, title=f'{metric.capitalize()} Boxplot Across Models')
    return boxplot
boxplot_layout = pn.Column(
    metric_selector,
    plot_boxplot
)
boxplot_layout.servable()

The following code generates a visualization plot for a specified image and model, featuring a slider to select the appropriate slice of the image. The plot highlights the errors in the predicted labels compared to the ground truth labels. Dark violet pixels represent areas where the model incorrectly predicted bone, while light violet pixels indicate regions where the model falsely predicted the background.


In [None]:
# First select here the desired model and image
model = "Model_1"
image_name = "19L_fall.nrrd"

df_filtered = df[(df['model'] == model) & (df['file'] == image_name)]

truth_array = df_filtered.iloc[0]['truth_array']
pred_array = df_filtered.iloc[0]['pred_array']

In [None]:
def save_error_overlay_images(pred_array, truth_array, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    for slice_idx in range(pred_array.shape[0]):
        error_map = np.zeros_like(pred_array[slice_idx], dtype=np.uint8)
        error_map[(pred_array[slice_idx] != truth_array[slice_idx]) & (truth_array[slice_idx] == 1)] = 1
        error_map[(pred_array[slice_idx] != truth_array[slice_idx]) & (truth_array[slice_idx] == 0)] = 2

        plt.figure(figsize=(8, 8))
        plt.imshow(truth_array[slice_idx], cmap='gray', alpha=0.5)
        plt.imshow(error_map, cmap='cool', alpha=0.5)
        plt.title(f'Error Overlay on Slice {slice_idx}')
        plt.axis('off')
        plt.savefig(os.path.join(output_dir, f"slice_{slice_idx}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

output_dir = "evaluation/images/error_overlays"
save_error_overlay_images(pred_array, truth_array, output_dir)

In [None]:
def extract_slice_index(filename):
    match = re.search(r'slice_(\d+).png', filename)
    return int(match.group(1)) if match else -1

image_files = sorted([f"error_overlays/{file}" for file in os.listdir(output_dir) if file.endswith('.png')],
                     key=lambda x: extract_slice_index(x))

def view_image(slice_idx):
    return pn.pane.PNG(image_files[slice_idx], width=400, height=400)

# Select slides
slice_slider = pn.widgets.IntSlider(name='Slice Index', start=0, end=len(image_files) - 1, step=1, value=0)
interactive_view = pn.bind(view_image, slice_idx=slice_slider)
layout = pn.Column(slice_slider, interactive_view)
layout.servable()

Similar to the error visualization, the following contour visualization provides a clearer representation of the models accuracy as most errors are likely concentrated on the outer edges of the bone components. The green contours represent the ground truth labels, while the red contours indicate the predicted labels.


In [None]:
def save_contour_images(pred_array, truth_array, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    for slice_idx in range(pred_array.shape[0]):
        plt.figure(figsize=(8, 8))
        plt.imshow(truth_array[slice_idx], cmap='gray', alpha=0.5)

        pred_contour = skimage.measure.find_contours(pred_array[slice_idx], level=0.5)
        truth_contour = skimage.measure.find_contours(truth_array[slice_idx], level=0.5)

        for contour in pred_contour:
            plt.plot(contour[:, 1], contour[:, 0], 'r', linewidth=2)
        for contour in truth_contour:
            plt.plot(contour[:, 1], contour[:, 0], 'g', linewidth=2)

        plt.title(f'Contours on Slice {slice_idx}')
        plt.axis('off')
        plt.savefig(os.path.join(output_dir, f"slice_{slice_idx}.png"), bbox_inches='tight', pad_inches=0)
        plt.close()

output_dir = "evaluation/images/contour_overlays"
save_contour_images(pred_array, truth_array, output_dir)

In [None]:
def extract_slice_index(filename):
    match = re.search(r'slice_(\d+).png', filename)
    return int(match.group(1)) if match else -1

image_files = sorted([f"contour_overlays/{file}" for file in os.listdir(output_dir) if file.endswith('.png')],
                     key=lambda x: extract_slice_index(x))

def view_image(slice_idx):
    return pn.pane.PNG(image_files[slice_idx], width=400, height=400)

slice_slider = pn.widgets.IntSlider(name='Slice Index', start=0, end=len(image_files) - 1, step=1, value=0)

interactive_view = pn.bind(view_image, slice_idx=slice_slider)
layout = pn.Column(slice_slider, interactive_view)
layout.servable()