<h1 align="center">It's About Time</h1>

When developing a registration algorithm or when selecting parameter value settings for an existing algorithm our choices are dictated by two, often opposing, constraints:
<ul>
<li>Required accuracy.</li>
<li>Alloted time.</li>
</ul>

As the goal of registration is to align multiple data elements into the same coordinate system, it is only natural that the primary focus is on accuracy. In most cases the reported accuracy is obtained without constraining the algorithm's execution time. A complete evaluation should also provide the corresponding running times. This approach is appropriate for longitudinal studies where we have the benefit of loose constraints on time. In this setting a registration taking an hour is perfectly acceptable. At the other end of the spectrum we have intra-operative registration. In this setting, registration is expected to complete within seconds or minutes. The  underlying reasons for tight timining constraints have to do with the detrimental effects of prolonged anesthesia and with the increased costs of operating room time. At the same time, simply completing on time without sufficient accuracy is also unacceptable.  

This notebook illustrates a straightforward approach for offsetting the computational complexity of registration via preprocessing and increased memory usage. 

The computational cost of registration is primarily associated with interpolation, required for evaluating the similarity metric. Ideally we would like to use the fastest possible interpolation method, nearest neighbor. Unfortunatly, nearest neighbor interpolation most often yields sub-optimal results. A straightforward solution is to pre-operativly create a super-sampled version of the moving-image using higher order interpolation*. We then perform registration between the fixed-image and super-sampled moving-image, using nearest neighbor interoplation.

Tallying up time and memory usage we see that:

<table>
  <tr><td></td> <td><b>time</b></td><td><b>memory</b></td></tr>
  <tr><td><b>pre-operative</b></td> <td>increase</td><td>increase</td></tr>
  <tr><td><b>intra-operative</b></td> <td>decrease</td><td>increase</td></tr>
</table><br><br>  


<font size="-1">*A better approach is to use single image super resolution techniques such as the one desrcibed in A. Rueda, N. Malpica, E. Romero,"Single-image super-resolution of brain MR images using overcomplete dictionaries", <i>Med Image Anal.</i>, 17(1):113-132, 2013.</font> 


In [None]:
import SimpleITK as sitk

from __future__ import print_function

#utility method that either downloads data from the MIDAS repository or
#if already downloaded returns the file name for reading from disk (cached data)
from downloaddata import fetch_midas_data as fdata

#always write output to a separate directory, we don't want to polute the source directory 
import os
OUTPUT_DIR = 'Output'

##Utility functions

A number of utility callback functions for loading the RIRE points, estimating the transformation and generating
our own reference data.

In [None]:
import numpy as np
import scipy.linalg as linalg

def load_RIRE_points(file_name):
    '''
    Load the point sets defining the ground truth transformations for the RIRE training dataset.

    Args: 
        file_name (str): RIRE ground truth file name. File format is specific to the RIRE training data, with
                         the actual data expectd to be in lines 15-23.
    Returns:
    Two lists of tuples representing the points in the "left" and "right" coordinate systems.
    '''
    fp = open(file_name, 'r')
    lines = fp.readlines()
    l = []
    r = []
    
    #fiducial information is in lines 15-22, starting with the second entry
    for line in lines[15:23]:
        coordinates = line.split()
        l.append((float(coordinates[1]), float(coordinates[2]), float(coordinates[3])))
        r.append((float(coordinates[4]), float(coordinates[5]), float(coordinates[6])))
    return (l, r)


def absolute_orientation_m(points_in_left, points_in_right):
    '''
    Absolute orientation using a matrix to represent the rotation. Solution is due to
    S. Umeyama, "Least-Squares Estimation of Transformation Parameters 
    Between Two Point Patterns", IEEE Trans. Pattern Anal. Machine Intell., vol. 13(4): 376-380.
    
    This is a refinement of the method proposed by Arun, Huang and Blostein, ensuring that the 
    rotation matrix is indeed a rotation and not a reflection. 
    
    Args:
        points_in_left (list(tuple)): Set of points corresponding to points_in_right in a different coordinate system.
        points_in_right (list(tuple)): Set of points corresponding to points_in_left in a different coordinate system.
        
    Returns:
        R,t (numpy.ndarray, numpy.array): Rigid transformation that maps points_in_left onto points_in_right.
                                          R*points_in_left + t = points_in_right
    '''
    num_points = len(points_in_left)
    dim_points = len(points_in_left[0])
    
    #cursory check that the number of points is sufficient
    if num_points<dim_points:      
        raise ValueError('Number of points must be greater/equal {0}.'.format(dim_points))

    #construct matrices out of the two point sets for easy manipulation
    left_mat = np.array(points_in_left).T
    right_mat = np.array(points_in_right).T
     
    #center both data sets on the mean
    left_mean = left_mat.mean(1)
    right_mean = right_mat.mean(1)
    left_M = left_mat - np.tile(left_mean, (num_points, 1)).T     
    right_M = right_mat - np.tile(right_mean, (num_points, 1)).T     
    
    M = left_M.dot(right_M.T)               
    U,S,Vt = linalg.svd(M)
    V=Vt.T
    
    #V * diag(1,1,det(U*V)) * U' - diagonal matrix ensures that we have a rotation and not a reflection
    R = V.dot(np.diag((1,1,linalg.det(U.dot(V))))).dot(U.T) 
    t = right_mean - R.dot(left_mean) 
    return R,t


def generate_random_pointset(image, num_points):
    '''
    Generate a random set (uniform sample) of points in the given image's domain.
    
    Args:
        image (SimpleITK.Image): Domain in which points are created.
        num_points (int): Number of points to generate.
        
    Returns:
        A list of points (tuples).
    '''
    
    #continous random uniform point indexes inside the image bounds
    point_indexes = np.multiply(np.tile(image.GetSize(), (num_points, 1)), np.random.random((num_points, image.GetDimension())))
    pointset_list = point_indexes.tolist()
    
    #get a list of physical points corresponding to the indexes
    return [image.TransformContinuousIndexToPhysicalPoint(point_index) for point_index in pointset_list]


def registration_errors(tx, reference_fixed_point_list, reference_moving_point_list):
  '''
  Distances between points transformed by the given transformation and their
  location in another coordinate system. When the points are only used to evaluate
  registration accuracy (not used in the registration) this is the target registration
  error (TRE).
  
  Args:
      tx (SimpleITK.Transform): The transform we want to evaluate.
      reference_fixed_point_list (list(tuple-like)): Points in fixed image cooredinate system.
      reference_moving_point_list (list(tuple-like)): Points in moving image cooredinate system.

  Returns:
   (mean, std, min, max, errors) (float, float, float, float, [float]): TRE statistics and original TREs.
  '''
  errors = [linalg.norm(np.array(tx.TransformPoint(p_fixed)) -  np.array(p_moving))
            for p_fixed,p_moving in zip(reference_fixed_point_list, reference_moving_point_list)]
  return (np.mean(errors), np.std(errors), np.min(errors), np.max(errors), errors) 

A number of utility callback functions for image display and for ploting the similarity metric and target 
registration errors during registration.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

from IPython.html.widgets import interact, fixed
from IPython.display import clear_output

#callback invoked by the interact ipython method for scrolling through the image stacks of
#the two images (moving and fixed)
def display_images(fixed_image_z, moving_image_z, resampled_moving_image_z, fixed_npa, moving_npa, resmapled_moving_npa):
    #create a figure with two subplots and the specified size
    plt.subplots(1,3,figsize=(10,8))
    
    #draw the fixed image in the first subplot
    plt.subplot(1,3,1)
    plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('fixed image')
    plt.axis('off')
    
    #draw the moving image in the second subplot
    plt.subplot(1,3,2)
    plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('moving image')
    plt.axis('off')

    #draw the moving resampled image in the third subplot
    plt.subplot(1,3,3)
    plt.imshow(resmapled_moving_npa[resampled_moving_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('resampled moving image')
    plt.axis('off')

        
#callback invoked by the ipython interact method for scrolling and modifying the alpha blending
#of an image stack of two images that occupy the same physical space. 
def display_images_with_alpha(image_z, alpha, fixed, moving):
    img = (1.0 - alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z] 
    plt.imshow(sitk.GetArrayFromImage(img),cmap=plt.cm.Greys_r);
    plt.axis('off')
    
    
#callback invoked when the StartEvent happens, sets up our new data
def start_plot():
    global metric_values, multires_iterations, reference_mean_values
    global reference_min_values, reference_max_values
    
    metric_values = []
    multires_iterations = []
    reference_mean_values = []
    reference_min_values = []
    reference_max_values = []

#callback invoked when the EndEvent happens, do cleanup of data and figure
def end_plot():
    global metric_values, multires_iterations, reference_mean_values
    global reference_min_values, reference_max_values
    
    del metric_values
    del multires_iterations
    del reference_mean_values
    del reference_min_values
    del reference_max_values
    
    #close figure, we don't want to get a duplicate of the plot latter on
    plt.close()

#callback invoked when the IterationEvent happens, update our data and display new figure    
def plot_values(registration_method, fixed_points, moving_points):
    global metric_values, multires_iterations, reference_mean_values
    global reference_min_values, reference_max_values
    
    metric_values.append(registration_method.GetMetricValue())
    
    #compute and store TRE statistics (mean, min, max)
    current_transform = sitk.Transform(registration_method.GetInitialTransform())
    current_transform.SetParameters(registration_method.GetOptimizerPosition())
    current_transform.AddTransform(registration_method.GetMovingInitialTransform())
    current_transform.AddTransform(registration_method.GetFixedInitialTransform().GetInverse())
    mean_error, _, min_error, max_error, _ = registration_errors(current_transform, fixed_points, moving_points)
    reference_mean_values.append(mean_error)
    reference_min_values.append(min_error)
    reference_max_values.append(max_error)
                                       
    #clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    
    #plot the similarity metric values
    plt.subplot(1,2,1)
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    
    #plot the TRE mean value and the [min-max] range
    plt.subplot(1,2,2)
    plt.plot(reference_mean_values, color='black', label='mean')
    plt.fill_between(range(len(reference_mean_values)), reference_min_values, reference_max_values, 
                     facecolor='red', alpha=0.5)
    plt.xlabel('Iteration Number', fontsize=12)
    plt.ylabel('TRE [mm]', fontsize=12)
    plt.legend()
    
    #adjust the spacing between subplots so that the axis labels don't overlap
    plt.tight_layout()
    plt.show()
    
#callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
#metric_values list. We assume this event happens before the first IterationEvent on the next resolution.
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

## Data preparation

###Read Images

We first read the images, casting the pixel type to that required for registration (Float32 or Float64) and look at them.

In [None]:
fixed_image =  sitk.ReadImage(fdata("training_001_ct.mha"), sitk.sitkFloat32)
moving_image = sitk.ReadImage(fdata("training_001_mr_T1.mha"), sitk.sitkFloat32) 

### Resample moving image

We now resample our moving image to a very fine spatial resolution.

In [None]:
#isotropic voxels with 0.5mm spacing
voxel_edge_sizes = [0.5]*moving_image.GetDimension()

#create resampled image
original_size = moving_image.GetSize()
original_spacing = moving_image.GetSpacing()

resampled_image_size = [int(spacing/voxel_edge_size*size) 
                        for spacing, size, voxel_edge_size in zip(original_spacing, original_size, voxel_edge_sizes)]  

resampled_moving_image = sitk.Image(resampled_image_size, moving_image.GetPixelIDValue())
resampled_moving_image.SetSpacing(voxel_edge_sizes)
resampled_moving_image.SetOrigin(moving_image.GetOrigin())
resampled_moving_image.SetDirection(moving_image.GetDirection())

#resample original image using identity transform
resample = sitk.ResampleImageFilter()
resample.SetReferenceImage(resampled_moving_image)                      
resample.SetInterpolator(sitk.sitkBSpline)  
resample.SetTransform(sitk.Transform())
resampled_moving_image = resample.Execute(moving_image)

print('Original image size and spacing: {0} {1}'.format(original_size, original_spacing)) 
print('Resampled image size and spacing: {0} {1}'.format(resampled_moving_image.GetSize(), 
                                                         resampled_moving_image.GetSpacing()))
print('Memory ratio: 1 : {0}'.format((np.array(resampled_image_size)/np.array(original_size).astype(float)).prod())) 

Another option for resampling an image, without any transformation, is to use the ExpandImageFilter or 
in its functional form SimpleITK::Expand. This filter accepts the interpolation method and an integral expansion factor. This is a bit less flexible than the resample filter as we have less control over the resulting image's spacing. On the other hand this requires less effort from the developer, a single line of code as compared to the cell above:

resampled_moving_image = sitk.Expand(moving_image, [2,2,8], sitk.sitkBSpline)

What about computational efficiency?

### Visually inspect our images

In [None]:
interact(display_images, 
         fixed_image_z=(0,fixed_image.GetSize()[2]-1), 
         moving_image_z=(0,moving_image.GetSize()[2]-1),
         resampled_moving_image_z = (0,resampled_moving_image.GetSize()[2]-1),
         fixed_npa = fixed(sitk.GetArrayFromImage(fixed_image)), 
         moving_npa=fixed(sitk.GetArrayFromImage(moving_image)),
         resmapled_moving_npa=fixed(sitk.GetArrayFromImage(resampled_moving_image)));

### Generate reference data

The RIRE reference, ground truth, data consists of a set of corresponding points in the fixed and moving coordinate systems. These points were obtained from fiducials embedded in the patient's skull and are thus sparse (eight points). We use these to compute the rigid transformation between the two coordinate systems, and then generate a dense reference. This generated reference data is closer to the data you would use for registration evaluation (a. la. the freely available <a href="http://www.creatis.insa-lyon.fr/rio/popi-model?action=show&redirect=popi">Validation Data for Deformable Image Registration of the Lungs</a>).  


In [None]:
fixed_fiducial_points, moving_fiducial_points = load_RIRE_points(fdata("ct_T1.standard"))

#estimate the reference_transform defined by the RIRE fiducials and check that the FRE makes sense (low) 
R, t = absolute_orientation_m(fixed_fiducial_points, moving_fiducial_points)
reference_transform = sitk.Euler3DTransform()
reference_transform.SetMatrix(R.flatten())
reference_transform.SetTranslation(t)
reference_errors_mean, reference_errors_std, _, reference_errors_max,_ = registration_errors(reference_transform, fixed_fiducial_points, moving_fiducial_points)
print('Reference data errors (FRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(reference_errors_mean, reference_errors_std, reference_errors_max))

#generate a reference dataset from the reference transformation 
#(corresponding points in the fixed and moving images)
fixed_points = generate_random_pointset(image=fixed_image, num_points=1000)
moving_points = [reference_transform.TransformPoint(p) for p in fixed_points]    

pre_errors_mean, pre_errors_std, _, pre_errors_max, _ = registration_errors(sitk.Euler3DTransform(), fixed_points, moving_points)
print('Initial errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(pre_errors_mean, pre_errors_std, pre_errors_max))

## Initial Alignment

Use the CenteredTransformInitializer to align the centers of the two volumes and set the center of rotation to the center of the fixed image. We then visually inspect the alignment and quantify the error using our reference data.

In [None]:
initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

moving_resampled = sitk.Resample(moving_image, fixed_image, initial_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelIDValue())

interact(display_images_with_alpha, image_z=(0,fixed_image.GetSize()[2]), alpha=(0.0,1.0,0.05), fixed = fixed(fixed_image), moving=fixed(moving_resampled));

In [None]:
pre_errors_mean, pre_errors_std, _, pre_errors_max, _ = registration_errors(initial_transform, fixed_points, moving_points)
print('Initial errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(pre_errors_mean, pre_errors_std, pre_errors_max))

## Registration

To illustrate the effect of using the resampled moving-image and interpolator we use the following registration framework instantiation.

We instrumented our code with callbacks that provide visual feedback on the progress of registration. In this case, we plot two quantities, the value of the similarity metric and the actual TREs (mean and range). The former is relevant for all registration tasks, the latter is only available if you have a reference data set, which we do.

In [None]:
def register_images(fixed_image, moving_image, initial_transform, interpolator):

    registration_method = sitk.ImageRegistrationMethod()
    
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.REGULAR)
    registration_method.SetMetricSamplingPercentage(0.01)
    
    registration_method.SetInterpolator(interpolator) 
    
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=1000) 
    registration_method.SetOptimizerScalesFromPhysicalShift() 
    
    #don't optimize in-place, not nice to change user's input
    registration_method.SetInitialTransform(initial_transform, inPlace=False)
    
    #connect callbacks to registration events, allowing us to plot during registration
    registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
    registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method, fixed_points, moving_points))

    final_transform = registration_method.Execute(fixed_image, moving_image)
    stopping_condition = registration_method.GetOptimizerStopConditionDescription()
    return (final_transform, stopping_condition)

Ipython allows us to time our code with minimal effort using the <a href="http://ipython.org/ipython-doc/stable/interactive/magics.html?highlight=timeit#magic-timeit">timeit</a> cell magic (Ipython has a set of predefined functions that use a command line syntax, and are referred to as magic functions). 

We start by running the registration using the original image data an linear interpolation:

In [None]:
%%timeit -r1 -n1

final_transform, stopping_condition = register_images(fixed_image, moving_image, initial_transform, sitk.sitkLinear)
errors_mean, errors_std, _, errors_max, _ = registration_errors(final_transform, fixed_points, moving_points)
print('Optimizer\'s stopping condition, {0}'.format(stopping_condition))
print('Errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(errors_mean, errors_std, errors_max))

We now run the registration using the resampled image and nearest neighbor interpolation:

In [None]:
%%timeit -r1 -n1

final_transform, stopping_condition = register_images(fixed_image, resampled_moving_image, 
                                                      initial_transform, sitk.sitkNearestNeighbor)
errors_mean, errors_std, _, errors_max, _ = registration_errors(final_transform, fixed_points, moving_points)
print('Optimizer\'s stopping condition, {0}'.format(stopping_condition))
print('Errors (TRE) in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(errors_mean, errors_std, errors_max))