In [None]:
!pip install -U scikit-learn

In [None]:
import sys
sys.path.append('../SPIE2019_COURSE/')

In [None]:
import os
import glob
import time
from random import shuffle
import numpy as np
import pandas as pd
import json
import uuid
import gc
from sklearn.model_selection import ParameterGrid
import SimpleITK as sitk
import registration_gui as rgui
import utilities 
from downloaddata import fetch_data as fdata
from ipywidgets import interact, fixed
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
# Path to data
source_path = './data'
# Directory of images inside source_path
images_dir = '/images'
# Directory of masks inside source_path
mask_dir = '/masks'
# Directory of landmarks inside source_path
landmarks_dir = '/landmarks'

# Path to the results
results_path = './results' 
# Directory of the images registration results
registration_results_dir = '/registration' 
# Directory of the hyperparameters seach results
hp_search_results_dir = '/hp_results'

# Identifier of the current param grids
params_grid_id = '02'

# Number of patients
n_patients = 6
patient_ids = range(1, n_patients + 1)

In [None]:
# Grid of the hyperparameters search
params_grid = {
    'grid_physical_spacing': [10.0, 20.0, 30.0],
    'similarity_function': ['mean_squares', 'mattes_mutual_information', 'correlation', 'joint_histogram_mutual_information'],
    'optimizer': ['lbfgs2'],
    'max_optimizer_iterations': [5000],
    'scale_parameter_and_smoothing_sigma_max_power': [1, 2, 3],
    'interpolator': ['linear', 'bspline']
}

# Computing a list with all combinations and shuffle it
grid = list(ParameterGrid(params_grid))
shuffle(grid)
        
print('Total combinations: {}'.format(len(grid)))

time_per_config = 0.2
est_time = len(grid) * time_per_config * 6
print('Estimated time {:.2f}h'.format(est_time))

In [None]:
'''
If the kernel deads or the laptop runs out of memory, load the already computed combinations and
remove it from the param_grid.

'''
# Load the previously computed parameters
files_to_load = glob.glob(os.path.join(results_path + hp_search_results_dir, params_grid_id) + '/*.json')

data = []
for file in files_to_load:
    with open(file, 'r') as fp:
        data.append(eval(json.load(fp)))
        
df_data = pd.DataFrame(data)

params_ls = [row['params'] for idx, row in df_data.iterrows()]

# Get the parameters combination that have not been computed yet
grid = [x for x in grid if x not in params_ls]

del df_data, params_ls
gc.collect()

print('Total combinations: {}'.format(len(grid)))

time_per_config = 0.2
est_time = len(grid) * time_per_config * 6
print('Estimated time {:.2f}h'.format(est_time))

In [None]:
def save_results(path_to_save, file_name, results):
    '''
    Save in JSON format the results for every parameters combination
    
    Args:
    - path_to_save (string): Path to save the JSON file.
    - file_name (string): Name of the JSON file.
    - results (dict): Dictionary with the results.
    
    Returns:
    - None
    '''
    # Check if the directory exists, if not create one
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)

    # Save resutls as JSON file
    with open(os.path.join(path_to_save, file_name), 'w') as fp:
        json.dump(str(results), fp)
        
    gc.collect()
    return None

def save_transform(path_to_save, file_name, transform):
    '''
    Save in TFM format the final transform
    
    Args:
    - path_to_save (string): Path to save the TFM file.
    - file_name (string): Name of the TFM file.
    - transform (SimpleITK.SimpleITK.Transform): SimpleITK transform to save.
    
    Returns:
    - None
    '''
    # Check if the directory exists, if not create one
    if not os.path.exists(path_to_save):
        os.makedirs(path_to_save)

    # Save tranformation as TFM file
    sitk.WriteTransform(transform, os.path.join(path_to_save, file_name + '.tfm'))
        
    gc.collect()
    return None

In [None]:
def iteration_callback_ffd(filter):
    # Define a simple callback which allows us to monitor registration progress.
    print('\rRegistration progress -> {0:.2f}'.format(filter.GetMetricValue()), end='')

def free_form_deformation_registration(images, masks, params):
    '''
    Computes the free form deformation algorithm for the registration of the images.
    
    Args:
    - images (list(SimpleITK.Image)): Images to feed the registration algorithm.
    - masks (list(SimpleITK.Image)): Masks to feed the registration algorithm.
    - params (dict): Dictionary with the parameters to use in the algorithm.
    
    Returns:
    - final_transformation (SimpleITK.Transform)
    - stop_condition (str)
    '''
    fixed_index = 0
    moving_index = 1

    fixed_image = images[fixed_index]
    fixed_image_mask = masks[fixed_index] == 1

    moving_image = images[moving_index]
    moving_image_mask = masks[moving_index] == 1
    
    registration_method = sitk.ImageRegistrationMethod()
    
    # Determine the number of BSpline control points using the physical 
    # spacing we want for the finest resolution control grid. 
    grid_physical_spacing = [params['grid_physical_spacing'], params['grid_physical_spacing'], params['grid_physical_spacing']] # A control point every grid_physical_spacingmm
    image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
    mesh_size = [int(image_size/grid_spacing + 0.5) \
                 for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]
    # The starting mesh size will be 1/4 of the original, it will be refined by 
    # the multi-resolution framework.
    mesh_size = [int(sz/4 + 0.5) for sz in mesh_size]

    initial_transform = sitk.BSplineTransformInitializer(image1 = fixed_image, 
                                                         transformDomainMeshSize = mesh_size, order=3)    
    # Instead of the standard SetInitialTransform we use the BSpline specific method which also
    # accepts the scaleFactors parameter to refine the BSpline mesh. In this case we start with 
    # the given mesh_size at the highest pyramid level then we double it in the next lower level and
    # in the full resolution image we use a mesh that is four times the original size.
    registration_method.SetInitialTransformAsBSpline(initial_transform,
                                                     inPlace=False,
                                                     scaleFactors=[1,2,4])

    # Selecting similarity fuction
    if params['similarity_function'] == 'mean_squares':
        registration_method.SetMetricAsMeanSquares()
    elif params['similarity_function'] == 'mattes_mutual_information':
        registration_method.SetMetricAsMattesMutualInformation()
    elif params['similarity_function'] == 'correlation':
        registration_method.SetMetricAsCorrelation()
    elif params['similarity_function'] == 'ants_neighborhood_correlation':
        registration_method.SetMetricAsANTSNeighborhoodCorrelation(radius=1)
    elif params['similarity_function'] == 'joint_histogram_mutual_information':
        registration_method.SetMetricAsJointHistogramMutualInformation()
    else:
        raise ValueError('Invalid similarity function')
        
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    registration_method.SetMetricFixedMask(fixed_image_mask)

    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[2 ** i for i in range(params['scale_parameter_and_smoothing_sigma_max_power'] - 1, -1, -1)])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2 ** i for i in range(params['scale_parameter_and_smoothing_sigma_max_power'] - 2, -1, -1)] + [0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Selecting interpolator
    if params['interpolator'] == 'linear':
        registration_method.SetInterpolator(sitk.sitkLinear)
    elif params['interpolator'] == 'bspline':
        registration_method.SetInterpolator(sitk.sitkBSpline)
    else:
        raise ValueError('Invalid interpolator')
    
    # Selecting optimizer
    if params['optimizer'] == 'amoeba':
        registration_method.SetOptimizerAsAmoeba(simplexDelta=0.1, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'one_plus_one':
        registration_method.SetOptimizerAsOnePlusOneEvolutionary(numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'powell':
        registration_method.SetOptimizerAsPowell(numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'step_gradient_descent':
        registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=0.1, minStep=0.1, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'gradient_line_search':
        registration_method.SetOptimizerAsConjugateGradientLineSearch(learningRate=0.01, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'gradient_descent':
        registration_method.SetOptimizerAsGradientDescent(learningRate=0.01, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'dradient_descent_line_search':
        registration_method.SetOptimizerAsGradientDescentLineSearch(learningRate=0.01, numberOfIterations=params['max_optimizer_iterations'])
    elif params['optimizer'] == 'lbfgs2':
        registration_method.SetOptimizerAsLBFGS2(numberOfIterations=params['max_optimizer_iterations'])
    else:
        raise ValueError('Invalid optimization function')
        
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: iteration_callback_ffd(registration_method))

    final_transformation = registration_method.Execute(fixed_image, moving_image)
    stop_condition = registration_method.GetOptimizerStopConditionDescription()
    print('\nOptimizer\'s stopping condition, {0}'.format(stop_condition))

    return final_transformation, stop_condition


def registration(original_images_path, original_masks_path, transformed_files_path, patient_id, params):
    '''
    Load the images and compute the registration of the images for each patient. The available methods are
    free form deformation and demons based registration algorithms.
    
    Args:
    - original_images_path (str): Path to the original images.
    - original_masks_path (str): Path to the original masks.
    - transformed_files_path (str): Path to save the transformed files.
    - patient_id (int): Id of the patient.
    - params (dict): Dictionary with the parameters to use in the algorithm.
    
    Returns:
    - stop_condition (str)
    '''
    # Load images and masks
    images = []
    masks = []
    for i in [0, 5]:
        image_file_name = original_images_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        mask_file_name = original_masks_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        
        images.append(sitk.ReadImage(image_file_name, sitk.sitkFloat32)) 
        masks.append(sitk.ReadImage(mask_file_name))
    
    # Compute the registration
    final_transform, stop_condition = free_form_deformation_registration(images, masks, params)
        
    # Save the final transform
    save_transform(transformed_files_path + '/0' + str(patient_id), 'final_transform', final_transform)
    gc.collect()
    return stop_condition

def evaluate_registration(images_files_path, mask_files_path, landmark_files_path, transformed_files_path, patient_id):
    '''
    Evaluates the performance of the registration computing different metrics.
    
    Args:
    - images_files_path (str): Path to the image files.
    - mask_files_path (str): Path to the mask files.
    - landmark_files_path (str): Path to the landmarks files.
    - transformed_files_path (str): Path to the transformed files.
    - patient_id (int): Id of the patient.
    
    Returns:
    - results (dict): Dictionary with all the evaluation metrics.
    '''
    
    # Load images, masks and landmarks
    images = []
    masks = []
    landmarks = []
    for i in [0, 5]:
        image_file_name = images_files_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        mask_file_name = mask_files_path + '/0' + str(patient_id) + '/{}0.mhd'.format(i)
        landmarks_file_name = landmark_files_path + '/0' + str(patient_id) + '/{}0.pts.txt'.format(i)
        images.append(sitk.ReadImage(image_file_name, sitk.sitkFloat32)) 
        masks.append(sitk.ReadImage(mask_file_name))
        landmarks.append(utilities.read_POPI_points(landmarks_file_name))
    
    # Load transformation
    transformation = sitk.ReadTransform(transformed_files_path + '/0' + str(patient_id) + '/final_transform.tfm')
    
    # Create dictionary to store all the relevant information
    results = {}
    
    # Define fixed and moving index
    fixed_index = 0
    moving_index = 1
    
    
    # Compute the evaluation criteria with landmarks
    final_TRE = utilities.target_registration_errors(transformation, landmarks[fixed_index], landmarks[moving_index])

    # Save TRE
    results['TRE'] = final_TRE
    
    # Transfer the segmentation via the estimated transformation. 
    # Nearest Neighbor interpolation so we don't introduce new labels.
    transformed_labels = sitk.Resample(masks[moving_index],
                                       images[fixed_index],
                                       transformation, 
                                       sitk.sitkNearestNeighbor,
                                       0.0, 
                                       masks[moving_index].GetPixelID())
    
    # Specify reference masks
    reference_segmentation = masks[fixed_index]
    # Segmentations after registration ensure that it is the correct label 
    seg = transformed_labels == 1
 
    # Compute the evaluation criteria with masks

    # Note that for the overlap measures filter, because we are dealing with a single label we 
    # use the combined, all labels, evaluation measures without passing a specific label to the methods.
    overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()
    hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()

    # Use the absolute values of the distance map to compute the surface distances (distance map sign, outside or inside 
    # relationship, is irrelevant)
    label = 1
    reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(reference_segmentation, squaredDistance=False))
    reference_surface = sitk.LabelContour(reference_segmentation)
    statistics_image_filter = sitk.StatisticsImageFilter()
    
    # Get the number of pixels in the reference surface by counting all pixels that are 1.
    statistics_image_filter.Execute(reference_surface)
    num_reference_surface_pixels = int(statistics_image_filter.GetSum()) 

    # Overlap measures
    overlap_measures_filter.Execute(reference_segmentation, seg)
    results['JI'] = overlap_measures_filter.GetJaccardCoefficient()
    results['DC'] = overlap_measures_filter.GetDiceCoefficient()
    results['VS'] = overlap_measures_filter.GetVolumeSimilarity()
     
    # Hausdorff distance
    hausdorff_distance_filter.Execute(reference_segmentation, seg)
    results['HD'] = hausdorff_distance_filter.GetHausdorffDistance()
    
    # Symmetric surface distance measures
    segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(seg, squaredDistance=False))
    segmented_surface = sitk.LabelContour(seg)
        
    # Multiply the binary surface segmentations with the distance maps. The resulting distance
    # maps contain non-zero values only on the surface (they can also contain zero on the surface)
    seg2ref_distance_map = reference_distance_map*sitk.Cast(segmented_surface, sitk.sitkFloat32)
    ref2seg_distance_map = segmented_distance_map*sitk.Cast(reference_surface, sitk.sitkFloat32)
        
    # Get the number of pixels in the segmented surface by counting all pixels that are 1.
    statistics_image_filter.Execute(segmented_surface)
    num_segmented_surface_pixels = int(statistics_image_filter.GetSum())
    
    # Get all non-zero distances and then add zero distances if required.
    seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)
    seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0]) 
    seg2ref_distances = seg2ref_distances + \
                        list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))
    ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)
    ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0]) 
    ref2seg_distances = ref2seg_distances + \
                        list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))
        
    all_surface_distances = seg2ref_distances + ref2seg_distances
    
    results['SD'] = all_surface_distances
       
    results['R'] = 0.2*np.mean(results['TRE'])+0.3*np.mean(results['HD'])+0.5*100*np.abs(results['VS'])
    return results

def statistical_info_from_results(results):
    '''
    Compute some statistical information from the results of all patients
    
    Args:
    - results (dict): Dictionary with the all the results.
    
    Returns:
    - results (dict): Dictionary with the all the results and the statistical info.
    '''
    JI_ls = [results['patient_0' + str(patient_id)]['JI'] for patient_id in patient_ids]
    results['JI_mean'] = np.mean(JI_ls) 
    results['JI_median'] = np.median(JI_ls) 
    results['JI_std'] = np.std(JI_ls) 
    
    DC_ls = [results['patient_0' + str(patient_id)]['DC'] for patient_id in patient_ids]
    results['DC_mean'] = np.mean(DC_ls) 
    results['DC_median'] = np.median(DC_ls) 
    results['DC_std'] = np.std(DC_ls) 
    
    SD_ls = [np.mean(results['patient_0' + str(patient_id)]['SD']) for patient_id in patient_ids]
    results['SD_mean'] = np.mean(SD_ls) 
    results['SD_median'] = np.median(SD_ls) 
    results['SD_std'] = np.std(SD_ls) 
    
    TRE_ls = [np.mean(results['patient_0' + str(patient_id)]['TRE']) for patient_id in patient_ids]
    results['TRE_mean'] = np.mean(TRE_ls) 
    results['TRE_median'] = np.median(TRE_ls) 
    results['TRE_std'] = np.std(TRE_ls) 
    
    HD_ls = [results['patient_0' + str(patient_id)]['HD'] for patient_id in patient_ids]
    results['HD_mean'] = np.mean(HD_ls) 
    results['HD_median'] = np.median(HD_ls) 
    results['HD_std'] = np.std(HD_ls) 
    
    VS_ls = [results['patient_0' + str(patient_id)]['VS'] for patient_id in patient_ids]
    results['VS_mean'] = np.mean(VS_ls) 
    results['VS_median'] = np.median(VS_ls) 
    results['VS_std'] = np.std(VS_ls) 
    
    R_ls = [results['patient_0' + str(patient_id)]['R'] for patient_id in patient_ids]
    results['R_mean'] = np.mean(R_ls) 
    results['R_median'] = np.median(R_ls) 
    results['R_std'] = np.std(R_ls) 
        
    return results

In [None]:
# Print the output on the terminal
#sys.stdout = open('/dev/stdout', 'w')

for idx, params in enumerate(grid):    
    t_start = time.time()
    
    # Create an unique id for the parameters combination
    params_id = uuid.uuid4()
    print('parameters {}, id: {}, params: {}'.format(idx, params_id, params))
    
    # Save all the important data in the results dict
    results = {}
    results['id'] = str(params_id)
    results['params'] = params
    
    # Compute the registration for each patient
    for patient_id in patient_ids:
        print('params_id: {}, patient: {}'.format(params_id, patient_id))
        
        stop_condition = registration(source_path + images_dir, 
                                      source_path + mask_dir, 
                                      os.path.join(results_path + registration_results_dir, params_grid_id, str(params_id)),
                                      patient_id, params)
        
        registration_evaluation_results = evaluate_registration(source_path + images_dir,
                                                                source_path + mask_dir, 
                                                                source_path + landmarks_dir, 
                                                                os.path.join(results_path + registration_results_dir, params_grid_id, str(params_id)),
                                                                patient_id)

        # Store the results for each patient
        results['patient_0' + str(patient_id) + '_stop_cond'] = stop_condition
        results['patient_0' + str(patient_id)] = registration_evaluation_results
        gc.collect()
        
    # Compute some statistical information from the results
    results = statistical_info_from_results(results)
    
    t_end = time.time()
    results['computation_time_min'] = (t_end - t_start) / 60
    
    # Save the results of this parameters combination to a file
    path_to_save = os.path.join(results_path + hp_search_results_dir, params_grid_id)
    file_name = str(params_id) + '.json'
    save_results(path_to_save, file_name, results)
    gc.collect()
    print('\n')