In [None]:
import sys
sys.path.append('../SPIE2019_COURSE/')

In [None]:
import os
import time
import numpy as np
import pandas as pd
import gc
import SimpleITK as sitk
import registration_gui as rgui
import utilities 
from downloaddata import fetch_data as fdata
from ipywidgets import interact, fixed
import matplotlib.pyplot as plt

%matplotlib inline

## Configuration

In [None]:
# Path to data
source_path = './data'
# Directory of images inside source_path
images_dir = '/images'
# Directory of masks inside source_path
mask_dir = '/masks'
# Directory of landmarks inside source_path
landmarks_dir = '/landmarks'

# Path to the results
results_path = './final_results_' 

# Number of patients
n_patients = 6
patient_ids = range(1, n_patients + 1)

## Registration and evaluation functions

In [None]:
def save_transform(path_to_save, file_name, transform):
    '''
    Save in TFM format the final transform
    
    Args:
    - path_to_save (string): Path to save the TFM file.
    - file_name (string): Name of the TFM file.
    - transform (SimpleITK.SimpleITK.Transform): SimpleITK transform to save.
    
    Returns:
    - None
    '''
    # Check if the directory exists, if not create one
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)

    # Save tranformation as TFM file
    sitk.WriteTransform(transform, os.path.join(path_to_save, file_name))
        
    gc.collect()
    return None

def save_image(path_to_save, file_name, image):
    '''
    Save in MHD format the given image
    
    Args:
    - path_to_save (string): Path to save the MHD file.
    - file_name (string): Name of the MHD file.
    - image (SimpleITK.SimpleITK.Image): SimpleITK image to save.
    
    Returns:
    - None
    '''
    # Check if the directory exists, if not create one
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)

    # Save the image as MHD file
    sitk.WriteImage(image, os.path.join(path_to_save, file_name))
        
    gc.collect()
    return None

def save_landmarks(path_to_save, file_name, landmark):
    '''
    Save in PTS.TXT format the given points
    
    Args:
    - path_to_save (string): Path to save the MHD file.
    - file_name (string): Name of the MHD file.
    - landmark (list(tuple)): Points to save.
    
    Returns:
    - None
    '''
    # Check if the directory exists, if not create one
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)

    # Save the points as PTS.TXT file
    with open(os.path.join(path_to_save, file_name), "w") as text_file:
        for i in range(len(landmark) + 1):
            # First line in the file is #X Y Z which we ignore.
            if i == 0:
                text_file.write('#X Y Z ' + '\n')
            else:
                point = landmark[i - 1]
                text_file.write(str(point[0]) + ' ' + str(point[1]) + ' ' + str(point[2]) + '\n')
        
    gc.collect()
    return None

In [None]:
def iteration_callback_ffd(filter):
    # Define a simple callback which allows us to monitor registration progress.
    print('\rRegistration progress -> {0:.2f}'.format(filter.GetMetricValue()), end='')

def registration(original_images_path, original_masks_path, transformed_files_path, patient_id, params):
    '''
    Load the images and compute the registration of the images for each patient. The available methods are
    free form deformation and demons based registration algorithms.
    
    Args:
    - original_images_path (str): Path to the original images.
    - original_masks_path (str): Path to the original masks.
    - transformed_files_path (str): Path to save the transformed files.
    - patient_id (int): Id of the patient.
    - params (dict): Dictionary with the parameters to use in the algorithm.
    
    Returns:
    - None
    '''
    # Load images and masks
    images = []
    masks = []
    for i in [0, 5]:
        image_file_name = original_images_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        mask_file_name = original_masks_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        
        images.append(sitk.ReadImage(image_file_name, sitk.sitkFloat32)) 
        masks.append(sitk.ReadImage(mask_file_name))
    
    # Compute the free form deformation registration
    
    # Selecting the fixed and the moving image from the lists
    fixed_index = 0
    moving_index = 1

    fixed_image = images[fixed_index]
    fixed_image_mask = masks[fixed_index] == 1

    moving_image = images[moving_index]
    moving_image_mask = masks[moving_index] == 1
    
    # Init the registration method
    registration_method = sitk.ImageRegistrationMethod()
    
    # Determine the number of BSpline control points using the physical 
    # spacing we want for the finest resolution control grid. 
    grid_physical_spacing = [params['grid_physical_spacing'], params['grid_physical_spacing'], params['grid_physical_spacing']] # A control point every grid_physical_spacingmm
    image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
    mesh_size = [int(image_size/grid_spacing + 0.5) \
                 for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]
    # The starting mesh size will be 1/4 of the original, it will be refined by 
    # the multi-resolution framework.
    mesh_size = [int(sz/4 + 0.5) for sz in mesh_size]

    initial_transform = sitk.BSplineTransformInitializer(image1 = fixed_image, 
                                                         transformDomainMeshSize = mesh_size, order=3)    
    # Instead of the standard SetInitialTransform we use the BSpline specific method which also
    # accepts the scaleFactors parameter to refine the BSpline mesh. In this case we start with 
    # the given mesh_size at the highest pyramid level then we double it in the next lower level and
    # in the full resolution image we use a mesh that is four times the original size.
    registration_method.SetInitialTransformAsBSpline(initial_transform,
                                                     inPlace=False,
                                                     scaleFactors=[1,2,4])

    # Selecting similarity fuction
    if params['similarity_function'] == 'mean_squares':
        registration_method.SetMetricAsMeanSquares()
    elif params['similarity_function'] == 'mattes_mutual_information':
        registration_method.SetMetricAsMattesMutualInformation()
    elif params['similarity_function'] == 'correlation':
        registration_method.SetMetricAsCorrelation()
    elif params['similarity_function'] == 'ants_neighborhood_correlation':
        registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=1)
    elif params['similarity_function'] == 'joint_histogram_mutual_information':
        registration_method.SetMetricAsJointHistogramMutualInformation()
    else:
        raise ValueError('Invalid similarity function')
        
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    registration_method.SetMetricFixedMask(fixed_image_mask)

    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[2 ** i for i in range(params['scale_parameter_and_smoothing_sigma_max_power'] - 1, -1, -1)])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2 ** i for i in range(params['scale_parameter_and_smoothing_sigma_max_power'] - 2, -1, -1)] + [0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Selecting interpolator
    if params['interpolator'] == 'linear':
        registration_method.SetInterpolator(sitk.sitkLinear)
    elif params['interpolator'] == 'bspline':
        registration_method.SetInterpolator(sitk.sitkBSpline)
    else:
        raise ValueError('Invalid interpolator')
    
    # Selecting optimizer
    if params['optimizer'] == 'amoeba':
        registration_method.SetOptimizerAsAmoeba(simplexDelta=0.1, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'one_plus_one':
        registration_method.SetOptimizerAsOnePlusOneEvolutionary(numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'powell':
        registration_method.SetOptimizerAsPowell(numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'step_gradient_descent':
        registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=0.1, minStep=0.1, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'gradient_line_search':
        registration_method.SetOptimizerAsConjugateGradientLineSearch(learningRate=0.01, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'gradient_descent':
        registration_method.SetOptimizerAsGradientDescent(learningRate=0.01, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'dradient_descent_line_search':
        registration_method.SetOptimizerAsGradientDescentLineSearch(learningRate=0.01, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'lbfgs2':
        registration_method.SetOptimizerAsLBFGS2(numberOfIterations=params['max_optimizer_iterations'])
    else:
        raise ValueError('Invalid optimization function')
        
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: iteration_callback_ffd(registration_method))

    # Compute the registration
    final_transformation = registration_method.Execute(fixed_image, moving_image)
    stop_condition = registration_method.GetOptimizerStopConditionDescription()
    print('\nOptimizer\'s stopping condition, {0}'.format(stop_condition))
        
    # Save the final transform
    save_transform(transformed_files_path + '/transform/0' + str(patient_id), 'final_transform.tfm', final_transformation)
    
    # Save the transformed image and mask
    transformed_image = sitk.Resample(moving_image, fixed_image, final_transformation, sitk.sitkNearestNeighbor, 0.0, moving_image_mask.GetPixelID())
    #save_image(transformed_files_path + '/images/0' + str(patient_id), '00.mhd', images[fixed_index])
    save_image(transformed_files_path + '/images/0' + str(patient_id), '50.mhd', transformed_image)
    
    transformed_mask = sitk.Resample(moving_image_mask, fixed_image, final_transformation, sitk.sitkNearestNeighbor, 0.0, moving_image_mask.GetPixelID())
    #save_image(transformed_files_path + '/masks/0' + str(patient_id), '00.mhd', masks[fixed_index])
    save_image(transformed_files_path + '/masks/0' + str(patient_id), '50.mhd', transformed_mask)
    
    gc.collect()
    return None

def evaluate_registration(images_files_path, mask_files_path, landmark_files_path, transformed_files_path, patient_id):
    '''
    Evaluates the performance of the registration computing different metrics.
    
    Args:
    - images_files_path (str): Path to the image files.
    - mask_files_path (str): Path to the mask files.
    - landmark_files_path (str): Path to the landmarks files.
    - transformed_files_path (str): Path to the transformed files.
    - patient_id (int): Id of the patient.
    
    Returns:
    - results (dict): Dictionary with all the metrics of the registration.
    '''
    
    # Load images, masks and landmarks
    images = []
    masks = []
    landmarks = []
    for i in [0, 5]:
        image_file_name = images_files_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        mask_file_name = mask_files_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        landmarks_file_name = landmark_files_path + '/0' + str(patient_id) + '/{}0.pts.txt'.format(i)
        images.append(sitk.ReadImage(image_file_name, sitk.sitkFloat32)) 
        masks.append(sitk.ReadImage(mask_file_name))
        landmarks.append(utilities.read_POPI_points(landmarks_file_name))
    
    # Load transformation
    transformation = sitk.ReadTransform(transformed_files_path + '/transform/0' + str(patient_id) + '/final_transform.tfm')
    
    # Create dictionary to store all the relevant information
    results = {}
    
    # Define fixed and moving index
    fixed_index = 0
    moving_index = 1
    
    initial_TRE = utilities.target_registration_errors(sitk.Transform(), landmarks[fixed_index], landmarks[moving_index])
    # Compute the evaluation criteria with landmarks
    final_TRE = utilities.target_registration_errors(transformation, landmarks[fixed_index], landmarks[moving_index])

    # Save TRE
    results['TRE'] = final_TRE
    
    # Plot the initial and the final TRE histograms
    plt.figure(figsize=(10,7))
    plt.hist(initial_TRE, bins=20, alpha=0.5, label='before registration', color='blue')
    plt.hist(final_TRE, bins=20, alpha=0.5, label='after registration', color='green')
    plt.legend()
    plt.title('TRE histogram for patient {}'.format(patient_id))
    plt.show()
    
    # Plot the distribution of errors as a function of the point location.
    plt.figure()
    initial_errors = utilities.target_registration_errors(sitk.Transform(), landmarks[fixed_index], landmarks[moving_index], display_errors = True)
    utilities.target_registration_errors(transformation, landmarks[fixed_index], landmarks[moving_index], 
                                     min_err=min(initial_errors), max_err=max(initial_errors), display_errors = True)
    plt.show()
    
    # Transfer the segmentation via the estimated transformation. 
    # Nearest Neighbor interpolation so we don't introduce new labels.
    transformed_labels = sitk.Resample(masks[moving_index],
                                       images[fixed_index],
                                       transformation, 
                                       sitk.sitkNearestNeighbor,
                                       0.0, 
                                       masks[moving_index].GetPixelID())
    
    # # Transfer the segmentation from the moving image to the fixed image before and after registration and visually evaluate overlap.
    # interact(rgui.display_coronal_with_overlay, temporal_slice=(0,1), 
    #      coronal_slice = (0, fixed_image.GetSize()[1]-1), 
    #      images = fixed([fixed_image,fixed_image]), masks = fixed([moving_image_mask, transformed_segmentation]), 
    #      label=fixed(1), window_min = fixed(-1024), window_max=fixed(976));
    
    # Specify reference masks
    reference_segmentation = masks[fixed_index]
    # Segmentations after registration ensure that it is the correct label 
    seg = transformed_labels == 1
 
    # Save the landmarks transformations  
    transformed_landmark = [transformation.TransformPoint(p) for p in landmarks[fixed_index]]
    #save_landmarks(transformed_files_path + '/landmarks/0' + str(patient_id), '00.pts.txt', landmarks[fixed_index])
    save_landmarks(transformed_files_path + '/landmarks/0' + str(patient_id), '00.pts.txt', transformed_landmark)

    # Compute the evaluation criteria with masks

    # Note that for the overlap measures filter, because we are dealing with a single label we 
    # use the combined, all labels, evaluation measures without passing a specific label to the methods.
    overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()

    # Use the absolute values of the distance map to compute the surface distances (distance map sign, outside or inside 
    # relationship, is irrelevant)
    label = 1
    reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(reference_segmentation, squaredDistance=False))
    reference_surface = sitk.LabelContour(reference_segmentation)
    statistics_image_filter = sitk.StatisticsImageFilter()
    
    # Get the number of pixels in the reference surface by counting all pixels that are 1.
    statistics_image_filter.Execute(reference_surface)
    num_reference_surface_pixels = int(statistics_image_filter.GetSum()) 

    # Overlap measures
    overlap_measures_filter.Execute(reference_segmentation, seg)
    results['JI'] = overlap_measures_filter.GetJaccardCoefficient()
    results['DC'] = overlap_measures_filter.GetDiceCoefficient()
    results['VS'] = np.abs(overlap_measures_filter.GetVolumeSimilarity())
     
    # Hausdorff distance
    hausdorff_distance_filter.Execute(reference_segmentation, seg)
    results['HD'] = hausdorff_distance_filter.GetHausdorffDistance()
    
    # Symmetric surface distance measures
    segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(seg, squaredDistance=False))
    segmented_surface = sitk.LabelContour(seg)
        
    # Multiply the binary surface segmentations with the distance maps. The resulting distance
    # maps contain non-zero values only on the surface (they can also contain zero on the surface)
    seg2ref_distance_map = reference_distance_map*sitk.Cast(segmented_surface, sitk.sitkFloat32)
    ref2seg_distance_map = segmented_distance_map*sitk.Cast(reference_surface, sitk.sitkFloat32)
        
    # Get the number of pixels in the segmented surface by counting all pixels that are 1.
    statistics_image_filter.Execute(segmented_surface)
    num_segmented_surface_pixels = int(statistics_image_filter.GetSum())
    
    # Get all non-zero distances and then add zero distances if required.
    seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)
    seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0]) 
    seg2ref_distances = seg2ref_distances + \
                        list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))
    ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)
    ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0]) 
    ref2seg_distances = ref2seg_distances + \
                        list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))
        
    all_surface_distances = seg2ref_distances + ref2seg_distances
    
    results['SD'] = all_surface_distances
       
    results['R'] = 0.2*np.mean(results['TRE'])+0.3*np.mean(results['HD'])+0.5*100*results['VS']
    
    print('JI: {:.4f}'.format(results['JI']))
    print('DC: {:.4f}'.format(results['DC']))
    print('VS: {:.4f}'.format(results['VS']))
    print('HD: {:.4f}'.format(results['HD']))
    print('SD: {:.4f}'.format(np.mean(results['SD'])))
    print('R: {:.4f}'.format(results['R']))
    
    return results

## Parameters selection

In [None]:
# Hyperparameters to use in the registration
params = {
    'grid_physical_spacing': 30.0,
    'similarity_function': 'correlation', 
    'optimizer': 'lbfgs2',
    'max_optimizer_iterations': 5000,
    'scale_parameter_and_smoothing_sigma_max_power': 3,
    'interpolator': 'linear'
}

## Register and evaluate each patient

In [None]:
t_start = time.time()

results_ls = []
# Compute the registration for each patient
for patient_id in patient_ids:
    print('Patient {}'.format(patient_id))

    # Compute the registration with the given parameters
    registration(source_path + images_dir, 
                 source_path + mask_dir, 
                 results_path,
                 patient_id, params)

    # Evaluate the registration for the given parameters
    results = evaluate_registration(source_path + images_dir,
                                    source_path + mask_dir, 
                                    source_path + landmarks_dir, 
                                    results_path,
                                    patient_id)

    results_ls.append(results)
    gc.collect()
    print('\n\n')
    print('---------------------------------------------------------------')
    
t_end = time.time()
print('Total computation time: {:.2f}mins'.format((t_end - t_start) / 60))

In [None]:
# Generate a table with the evaluation metrics for all the patients
results_df = pd.DataFrame(results_ls)
results_df['TRE'] = results_df['TRE'].apply(lambda ls: np.mean(ls))
results_df['SD'] = results_df['SD'].apply(lambda ls: np.mean(ls))
results_df

In [None]:
print('TRE: {:.4f}'.format(np.mean(results_df['TRE'])))
print('JI: {:.4f}'.format(np.mean(results_df['JI'])))
print('DC: {:.4f}'.format(np.mean(results_df['DC'])))
print('VS: {:.4f}'.format(np.mean(results_df['VS'])))
print('HD: {:.4f}'.format(np.mean(results_df['HD'])))
print('SD: {:.4f}'.format(np.mean(results_df['SD'])))
print('R: {:.4f}'.format(np.mean(results_df['R'])))