## Choices, Choices, Choices: Registration Settings

The performance of most registration algorithms is dependent on a large number of parameter settings. For optimal performance you will need to customize your settings, turning all the knobs to their "optimal" position:
<img src="knobs.jpg" style="width:700px"/>
<font size="1"> [This image was originally posted to Flickr and downloaded from wikimedia commons https://commons.wikimedia.org/wiki/File:TASCAM_M-520_knobs.jpg]</font>

This notebook illustrates the use of reference data (a.k.a "gold" standard) to empirically tune a registration framework for specific usage. This is dependent on the characteristics of your images (anatomy, modality, physical spacing...) and on the clinical needs.

Also keep in mind that the defintion of optimal settings does not necessarily correspond to those that provide the most accurate results. 

The optimal settings are task specific and should provide:
<ul>
<li>Sufficient accuracy.</li>
<li>Complete the computation in the alloted time.</li>
</ul>

We will be using the training data from the  Retrospective Image Registration Evaluation (<a href="http://www.insight-journal.org/rire/">RIRE</a>) project.

In [None]:
import SimpleITK as sitk
import os

OUTPUT_DIR = 'Output'

    #this notebook requires that ITK-SNAP be installed on your machine, provide the full path to the application below
%env SITK_SHOW_COMMAND /Applications/ITK-SNAP.app/Contents/MacOS/ITK-SNAP 

In [None]:
from PyQt4 import QtGui
from scipy import linalg 
import numpy as np


def get_file_selection(dialog_title=''):
    '''
    Use Qt to display a file selection dialog that returns the full path to the file.
    '''
    app = QtGui.QApplication([])
    return str(QtGui.QFileDialog.getOpenFileName(caption=dialog_title))


def load_RIRE_ground_truth(file_name):
    '''
    Load the point sets defining the ground truth transformations for the RIRE training dataset.

    Args: 
        file_name (str): RIRE ground truth file name. File format is specific to the RIRE training data, with
                         the actual data expectd to be in lines 15-23.
    Returns:
    Two lists of tuples representing the points in the "left" and "right" coordinate systems.
    '''
    fp = open(file_name, 'r')
    lines = fp.readlines()
    l = []
    r = []
            #fiducial information is in lines 15-22, starting with the second entry
    for line in lines[15:23]:
        coordinates = line.split()
        l.append((float(coordinates[1]), float(coordinates[2]), float(coordinates[3])))
        r.append((float(coordinates[4]), float(coordinates[5]), float(coordinates[6])))
    return (l, r)


def absolute_orientation_m(points_in_left, points_in_right):
    '''
    Absolute orientation using a matrix to represent the rotation. Solution is due to
    S. Umeyama, "Least-Squares Estimation of Transformation Parameters 
    Between Two Point Patterns", IEEE Trans. Pattern Anal. Machine Intell., vol. 13(4): 376-380.
    
    This is a refinement of the method proposed by Arun, Huang and Blostein, ensuring that the 
    rotation matrix is indeed a rotation and not a reflection. 
    
    Args:
        points_in_left (list(tuple)): Set of points corresponding to points_in_right in a different coordinate system.
        points_in_right (list(tuple)): Set of points corresponding to points_in_left in a different coordinate system.
        
    Returns:
        R,t (numpy.ndarray, numpy.array): Rigid transformation that maps points_in_left onto points_in_right.
                                          R*points_in_left + t = points_in_right
    '''
    num_points = len(points_in_left)
    dim_points = len(points_in_left[0])
           #cursory check that the number of points is sufficient
    if num_points<dim_points:      
        raise ValueError('Number of points must be greater/equal {0}.'.format(dim_points))

             #construct matrices out of the two point sets for easy manipulation
    left_mat = np.array(points_in_left).T
    right_mat = np.array(points_in_right).T
     
         #center both data sets on the mean
    left_mean = left_mat.mean(1)
    right_mean = right_mat.mean(1)
    left_M = left_mat - np.tile(left_mean, (num_points, 1)).T     
    right_M = right_mat - np.tile(right_mean, (num_points, 1)).T     
    
    M = left_M.dot(right_M.T)               
    U,S,Vt = linalg.svd(M)
    V=Vt.T
         #V * diag(1,1,det(U*V)) * U' - diagonal matrix ensures that we have a rotation and not a reflection
    R = V.dot(np.diag((1,1,linalg.det(U.dot(V))))).dot(U.T) 
    t = right_mean - R.dot(left_mean) 
    return R,t


def generate_random_pointset(image, num_points):
    '''
    Generate a random set (uniform sample) of points in the given image's domain.
    
    Args:
        image (SimpleITK.Image): Domain in which points are created.
        num_points (int): Number of points to generate.
        
    Returns:
        A list of points (tuples).
    '''
              #continous random uniform point indexes inside the image bounds
    point_indexes = np.multiply(np.tile(image.GetSize(), (num_points, 1)), np.random.random((num_points, image.GetDimension())))
    pointset_list = point_indexes.tolist()
             #get a list of physical points corresponding to the indexes
    return [image.TransformContinuousIndexToPhysicalPoint(point_index) for point_index in pointset_list]


def registration_errors(tx, reference_fixed_point_list, reference_moving_point_list):
  '''
  Distances between points transformed by the given transformation and their
  location in another coordinate system. When the points are only used to evaluate
  registration accuracy (not used in the registration) this is the target registration
  error (TRE).
  
  Args:
      tx (SimpleITK.Transform): The transform we want to evaluate.
      reference_fixed_point_list (list(tuple-like)): Points in fixed image cooredinate system.
      reference_moving_point_list (list(tuple-like)): Points in moving image cooredinate system.

  Returns:
   (mean, std, min, max, errors) (float, float, float, float, [float]): TRE statistics and original TREs.
  '''
  errors = [linalg.norm(np.array(tx.TransformPoint(p_fixed)) -  np.array(p_moving))
            for p_fixed,p_moving in zip(reference_fixed_point_list, reference_moving_point_list)]
  return (np.mean(errors), np.std(errors), np.min(errors), np.max(errors), errors) 

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import clear_output


def save_transform_and_image(transform, fixed_image, moving_image, outputfile_prefix):
    '''
    Write the given transformation to file, resample the moving_image onto the fixed_images grid and save the
    result to file.
    
    Args:
        transform (SimpleITK Transform): transform that maps points from the fixed image coordinate system to the moving.
        fixed_image (SimpleITK Image): resample onto the spatial grid defined by this image.
        moving_image (SimpleITK Image): resample this image.
        outputfile_prefix (string): transform is written to outputfile_prefix.tfm and resampled image is written to 
                                    outputfile_prefix.mhd.
    '''
                             
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixed_image)
                #SimpleITK supports several interpolation options.     
    resample.SetInterpolator(sitk.sitkLinear)  
    resample.SetTransform(transform)
    sitk.WriteImage(resample.Execute(moving_image), outputfile_prefix+'.mhd')
    sitk.WriteTransform(transform, outputfile_prefix+'.tfm')
    
    
#callback invoked when the StartEvent happens, sets up our new data
def start_plot():
    global metric_values, multires_iterations, reference_mean_values
    global reference_min_values, reference_max_values
    
    metric_values = []
    multires_iterations = []
    reference_mean_values = []
    reference_min_values = []
    reference_max_values = []

#callback invoked when the EndEvent happens, do cleanup of data and figure
def end_plot():
    global metric_values, multires_iterations, reference_mean_values
    global reference_min_values, reference_max_values
    
    del metric_values
    del multires_iterations
    del reference_mean_values
    del reference_min_values
    del reference_max_values
          #close figure, we don't want to get a duplicate of the plot latter on
    plt.close()

#callback invoked when the IterationEvent happens, update our data and display new figure    
def plot_values(registration_method, fixed_points, moving_points):
    global metric_values, multires_iterations, reference_mean_values
    global reference_min_values, reference_max_values
    
    metric_values.append(registration_method.GetMetricValue())
            #compute and store TRE statistics (mean, min, max)
    current_transform = sitk.Transform(registration_method.GetInitialTransform())
    current_transform.SetParameters(registration_method.GetOptimizerPosition())
    current_transform.AddTransform(registration_method.GetMovingInitialTransform())
    current_transform.AddTransform(registration_method.GetFixedInitialTransform().GetInverse())
    mean_error, _, min_error, max_error, _ = registration_errors(current_transform, fixed_points, moving_points)
    reference_mean_values.append(mean_error)
    reference_min_values.append(min_error)
    reference_max_values.append(max_error)
                                       
         #clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
          #plot the similarity metric values
    plt.subplot(1,2,1)
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
           #plot the TRE mean value and the [min-max] range
    plt.subplot(1,2,2)
    plt.plot(reference_mean_values, color='black', label='mean')
    plt.fill_between(range(len(reference_mean_values)), reference_min_values, reference_max_values, 
                     facecolor='red', alpha=0.5)
    plt.xlabel('Iteration Number', fontsize=12)
    plt.ylabel('TRE [mm]', fontsize=12)
    plt.legend()
    
        #adjust the spacing between subplots so that the axis labels don't overlap
    plt.tight_layout()
    plt.show()
    
#callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
#metric_values list. We assume this event happens before the first IterationEvent on the next resolution.
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))
        

### Select the RIRE datasets and the associated reference data

In [None]:
          #load the data
fixed_image_file_name = get_file_selection('Select fixed image')
moving_image_file_name = get_file_selection('Select moving image')
reference_file_name = get_file_selection('Select ground truth file')

fixed_image = sitk.ReadImage(fixed_image_file_name)
moving_image = sitk.ReadImage(moving_image_file_name)
fixed_fiducial_points, moving_fiducial_points = load_RIRE_ground_truth(reference_file_name)

   #estimate the reference_transform defined by the RIRE fiducials and check that the FRE makes sense (low) 
R, t = absolute_orientation_m(fixed_fiducial_points, moving_fiducial_points)
reference_transform = sitk.Euler3DTransform()
reference_transform.SetMatrix(R.flatten())
reference_transform.SetTranslation(t)
reference_errors_mean, reference_errors_std, _, reference_errors_max,_ = registration_errors(reference_transform, fixed_fiducial_points, moving_fiducial_points)
print('Reference data errors (FRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(reference_errors_mean, reference_errors_std, reference_errors_max))

    #generate a reference dataset from the reference transformation 
    #(corresponding points in the fixed and moving images)
fixed_points = generate_random_pointset(image=fixed_image, num_points=1000)
moving_points = [reference_transform.TransformPoint(p) for p in fixed_points]    
    
    #look at the data prior to registration
sitk.Show(fixed_image, 'Fixed Image')
sitk.Show(moving_image, 'Moving Image')
pre_errors_mean, pre_errors_std, _, pre_errors_max, _ = registration_errors(sitk.Euler3DTransform(), fixed_points, moving_points)
print('Initial errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(pre_errors_mean, pre_errors_std, pre_errors_max))

### Initial Alignment

We use the CenteredTransformInitializer. Should we use the GEOMETRY based version or the MOMENTS based one?

In [None]:
initial_transform = sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelIDValue()), 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

sitk.Show(fixed_image, 'Fixed Image')
sitk.Show(sitk.Resample(moving_image, fixed_image, initial_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelIDValue()), 'Moving Image Initial Alignment')
initial_errors_mean, initial_errors_std, _, initial_errors_max, _ = registration_errors(initial_transform, fixed_points, moving_points)
print('Initial alignment errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(initial_errors_mean, initial_errors_std, initial_errors_max))

## Final registration

Possible choices for simple rigid multi-modality registration framework (<b>300</b> component combinations, in addition to parameter settings for each of the components):
<ul>
<li>Similarity metric, 2 options (Mattes MI, JointHistogram MI):
<ul>
  <li>Number of histogram bins.</li>
  <li>Sampling strategy, 3 options (NONE, REGULAR, RANDOM)</li>
  <li>Sampling percentage.</li>
</ul>
</li>
<li>Interpolator, 10 options (sitkNearestNeighbor, sitkLinear, sitkGaussian, sitkBSpline,...)</li>
<li>Optimizer, 5 options (GradientDescent, GradientDescentLineSearch, RegularStepGradientDescent...): 
<ul>
  <li>Number of iterations.</li>
  <li>learning rate (step size along parameter space traversal direction).</li>
</ul>
</li>
</ul>

In this example we will plot the similarity metric's value and more importantly the TREs for our reference data. A good choice for the former should be reflected by the later. That is, the TREs should go down as the similarity measure value goes down (not necessarily at the same rates).

Finally, we are also interested in timing our registration. Ipython allows us to do this with minimal effort using the <a href="http://ipython.org/ipython-doc/stable/interactive/magics.html?highlight=timeit#magic-timeit">timeit</a> cell magic (Ipython has a set of predefined functions that use a command line syntax, and are referred to as magic functions). 

In [None]:
%%timeit -r1 -n1
#the arguments to the timeit magic specify that this cell should only be run once. running it multiple 
#times to get performance statistics is also possible, but takes time. if you want to analyze the accuracy 
#results from multiple runs you will have to modify the code to save them instead of just printing them out.

registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkNearestNeighbor) #2. Replace with sitkLinear
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100) #1. Increase to 1000
registration_method.SetOptimizerScalesFromPhysicalShift() 
     #don't optimize in-place, we would like to run this cell multiple times
registration_method.SetInitialTransform(initial_transform, inPlace=False)

         #clear all callbacks, if any, and hook-up our display
registration_method.RemoveAllCommands()
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method, fixed_points, moving_points))

final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))

save_transform_and_image(final_transform, fixed_image, moving_image, os.path.join(OUTPUT_DIR, 'RIRE-final'))
sitk.Show(fixed_image)
sitk.Show(sitk.ReadImage(os.path.join(OUTPUT_DIR, 'RIRE-final.mhd')))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
final_errors_mean, final_errors_std, _, final_errors_max,_ = registration_errors(final_transform, fixed_points, moving_points)
print('Final alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))

### Now using the built in multi-resolution framework

Perform registration using the same settings as above, but take advantage of the multi-resolution framework which provides a significant speedup with minimal effort (3 lines of code).

It should be noted that when using this framework the similarity metric value will not necessarily decrease between resolutions, we are only ensured that it decreases per resolution. This is not an issue, as we are actually observing the values of a different function at each resolution. 

The example below shows that registration is improving even though the similarity value increases when changing resolution levels.

In [None]:
%%timeit -r1 -n1
#the arguments to the timeit magic specify that this cell should only be run once. running it multiple 
#times to get performance statistics is also possible, but takes time. if you want to analyze the accuracy 
#results from multiple runs you will have to modify the code to save them instead of just printing them out.

registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
registration_method.SetOptimizerScalesFromPhysicalShift() 
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
     #don't optimize in-place, we would like to run this cell multiple times
registration_method.SetInitialTransform(initial_transform, inPlace=False)

registration_method.RemoveAllCommands()
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method, fixed_points, moving_points))


final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))

save_transform_and_image(final_transform, fixed_image, moving_image, os.path.join(OUTPUT_DIR, 'RIRE-final-2'))
sitk.Show(fixed_image)
sitk.Show(sitk.ReadImage(os.path.join(OUTPUT_DIR, 'RIRE-final-2.mhd')))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
final_errors_mean, final_errors_std, _, final_errors_max,_ = registration_errors(final_transform, fixed_points, moving_points)
print('Final alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(final_errors_mean, final_errors_std, final_errors_max))