## We have to start somewhere : Registration Initialization 

Initialization is a critical aspect of most registration algorithms, given that most algorithms are formulated as an iterative optimization problem.

In many cases we perform initialization in an automatic manner by making assumptions with regard to the contents of the image and the imaging protocol. For instance, if we expect that images were acquired with the patient in the same orientation we can align the geometric centers of the two volumes or the center of mass of the image contents if the anatomy is not centered in the image (this is what we previously did in [this example](registration1.ipynb)).

When the orientation differences between the two images are large this approach will not yield a reasonable initial estimate for the registration.

When working with clinical images, the DICOM tags define the orientation and position of the anatomy in the volume. The tags of interest are:
<ul>
  <li> (0020|0032) Image Position (Patient) : coordinates of the the first transmitted voxel. </li>
  <li>(0020|0037) Image Orientation (Patient): directions of first row and column in 3D space. </li>
  <li>(0018|5100) Patient Position: Patient placement on the table 
  <ul>
  <li> Head First Prone (HFP)</li>
  <li> Head First Supine (HFS)</li>
  <li> Head First Decibitus Right (HFDR)</li>
  <li> Head First Decibitus Left (HFDL)</li>
  <li> Feet First Prone (FFP)</li>
  <li> Feet First Supine (FFS)</li>
  <li> Feet First Decibitus Right (FFDR)</li>
  <li> Feet First Decibitus Left (FFDL)</li>
  </ul>
  </li>
</ul>

The patient position is manually entered by the CT/MR operator and thus can be erroneous (HFP instead of FFP will result in a $180^o$ orientation error).

In [None]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from IPython.html.widgets import interact, fixed
     #display some html output
from IPython.display import display, HTML 

import os

OUTPUT_DIR = 'Output'
INPUT_DIR = 'Data'

    #this notebook works best using ITK-SNAP as your viewer 
    #please proivde the full path to the application below
%env SITK_SHOW_COMMAND /Applications/ITK-SNAP.app/Contents/MacOS/ITK-SNAP 

In [None]:
def save_transform_and_image(transform, fixed_image, moving_image, outputfile_prefix):
    '''
    Write the given transformation to file, resample the moving_image onto the fixed_images grid and save the
    result to file.
    
    Args:
        transform (SimpleITK Transform): transform that maps points from the fixed image coordinate system to the moving.
        fixed_image (SimpleITK Image): resample onto the spatial grid defined by this image.
        moving_image (SimpleITK Image): resample this image.
        outputfile_prefix (string): transform is written to outputfile_prefix.tfm and resampled image is written to 
                                    outputfile_prefix.mhd.
    '''
                             
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixed_image)
                #SimpleITK supports several interpolation options.     
    resample.SetInterpolator(sitk.sitkLinear)  
    resample.SetTransform(transform)
    sitk.WriteImage(resample.Execute(moving_image), outputfile_prefix+'.mhd')
    sitk.WriteTransform(transform, outputfile_prefix+'.tfm')

In [None]:
from IPython.display import clear_output

#callback invoked when the StartEvent happens, sets up our new data
def start_plot():
    global metric_values, multires_iterations 
    metric_values = []
    multires_iterations = []

#callback invoked when the EndEvent happens, do cleanup of data and figure
def end_plot():
    global metric_values, multires_iterations  
    del metric_values
    del multires_iterations
          #close figure, we don't want to get a duplicate of the plot latter on
    plt.close()

#callback invoked when the IterationEvent happens, update our data and display new figure    
def plot_value(registration_method, plot_title):
    global metric_values, multires_iterations
    metric_values.append(registration_method.GetMetricValue())
         #clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)    
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.title(plot_title)
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    
#callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
#metric_values list. We assume this event happens before the first IterationEvent on the next resolution.
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))    

## Loading Data

In [None]:
data_directory = os.path.join(INPUT_DIR,'CIRS057A')

         #global variables 'selected_series_moving/fixed' are updated by the interact function
selected_series_fixed = ''
selected_series_moving = ''
def DICOM_series_dropdown_callback(fixed_image, moving_image, series_dictionary):    
    global selected_series_fixed
    global selected_series_moving
               #print some information about the series from the meta-data dictionary
              #DICOM standard part 6, Data Dictionary: http://medical.nema.org/medical/dicom/current/output/pdf/part06.pdf
    img_fixed = sitk.ReadImage(series_dictionary[fixed_image][0])
    img_moving = sitk.ReadImage(series_dictionary[moving_image][0])
          #there are many interesting tags in a DICOM format, print out some of them
    tags_to_print = {'0010|0010': 'Patient name: ', 
                     '0008|0060' : 'Modality: ',
                     '0020|0032' : 'Image Position (Patient): ',
                     '0020|0037' : 'Image Orientation (Patient): ',
                     '0018|5100' : 'Patient Position: '}
    html_table = []
    html_table.append('<table><tr><td><b>Tag</b></td><td><b>Fixed Image</b></td><td><b>Moving Image</b></td></tr>')
    for tag in tags_to_print:
        fixed_tag = ''
        moving_tag = ''
        try:            
            fixed_tag = img_fixed.GetMetaData(tag)
        except: #ignore if the tag isn't in the dictionary
            pass
        try:            
            moving_tag = img_moving.GetMetaData(tag)
        except: #ignore if the tag isn't in the dictionary
            pass           
        html_table.append('<tr><td>' + tags_to_print[tag] + 
                          '</td><td>' + fixed_tag + 
                          '</td><td>' + moving_tag + '</td></tr>')
    html_table.append('</table>')
    display(HTML(''.join(html_table)))
    selected_series_fixed = fixed_image
    selected_series_moving = moving_image

             #directory contains multiple DICOM studies/series, store
             #in dictionary with key being the seriesID
reader = sitk.ImageSeriesReader()
series_file_names = {}
series_IDs = reader.GetGDCMSeriesIDs(data_directory)
            #check that we have at least one series
if series_IDs:
    for series in series_IDs:
        series_file_names[series] = reader.GetGDCMSeriesFileNames(data_directory, series)
    
    interact(DICOM_series_dropdown_callback, fixed_image=series_IDs, moving_image =series_IDs, series_dictionary=fixed(series_file_names)); 
else:
    print('Data directory does not contain any DICOM series.')

In [None]:
         #continue only if the previous cell was run and the series were selected
if 'selected_series_fixed' in globals() and 'selected_series_moving' in globals():
    reader.SetFileNames(series_file_names[selected_series_fixed])
    fixed_image = reader.Execute()
    reader.SetFileNames(series_file_names[selected_series_moving])
    original_moving_image = reader.Execute()
          #look at our images and their alignment (the two instances of ITK-SNAP share a spatially linked cursor)
    sitk.Show(fixed_image)
    sitk.Show(original_moving_image)

In [None]:
def orientation_selection_dropdown_callback(orientation, image, orientations_dictionary):
    global moving_image
    
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(image)
    resample.SetInterpolator(sitk.sitkLinear)
    transform = sitk.Euler3DTransform()
    transform.SetCenter(image.TransformContinuousIndexToPhysicalPoint([(index-1)/2.0 for index in image.GetSize()]))
    transform.SetMatrix(orientations_dictionary[orientation])
    resample.SetTransform(transform)
    moving_image = resample.Execute(image)
    
possible_orientation_changes = {'x=0, z=90': (0,-1,0,1,0,0,0,0,1),
                                'x=0, z=-90': (0,1,0,-1,0,0,0,0,1),
                                'x=0, z=180': (-1,0,0,0,-1,0,0,0,1),
                                'x=180, z=0': (1,0,0,0,-1,0,0,0,-1),
                                'x=180, z=90': (0,-1,0,-1,0,0,0,0,-1),
                                'x=180, z=-90': (0,1,0,1,0,0,0,0,-1),
                                'x=180, z=180': (-1,0,0,0,1,0,0,0,-1)}    
all_orientations = possible_orientation_changes.copy()
all_orientations['x=0, z=0'] = (1,0,0,0,1,0,0,0,1)

#moving_image = sitk.Image(original_moving_image)
moving_image = None
interact(orientation_selection_dropdown_callback, orientation=all_orientations.keys(), image=fixed(original_moving_image), orientations_dictionary=fixed(all_orientations)); 

In [None]:
sitk.Show(moving_image, title='Moving Image')
sitk.Show(original_moving_image, title= 'Original Moving Image')

In [None]:
def multires_registration(fixed_image, moving_image, initial_transform, output_prefix):
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, estimateLearningRate=registration_method.Once)
    registration_method.SetOptimizerScalesFromPhysicalShift() 
    registration_method.SetInitialTransform(initial_transform)
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas = [2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    registration_method.RemoveAllCommands()
    registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
    registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
    registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_value(registration_method, 'Multi Scale'))

    final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                  sitk.Cast(moving_image, sitk.sitkFloat32))

    save_transform_and_image(final_transform, fixed_image, moving_image, os.path.join(OUTPUT_DIR, output_prefix))
    sitk.Show(fixed_image)
    sitk.Show(sitk.ReadImage(os.path.join(OUTPUT_DIR, output_prefix + '.mhd')))
    print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
    print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))

## Initialize as usual (assumes orientation is similar)

In [None]:
        #initialize
initial_transform = sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelIDValue()), 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
     #register
multires_registration(fixed_image, moving_image, initial_transform, 'final-standardInitialization')                      

##Iinitialize using all orientations

As we want to account for significant orientation differences due to erroneous patient position (HFS...) we evaluate the similarity measure at eight locations corresponding to the various orientation differences. This can be done in two ways which will be illustrated below:
<ul>
<li>Use the ImageRegistrationMethod.MetricEvaluate() method.</li>
<li>Use the Exhaustive optimizer.
</ul>

The former approach is more computationally intensive as it constructs and configures a metric object each time it is invoked. It is therefore more appropriate for use if the set of parameter values we want to evaluate are not on a rectilinear grid in the parameter space. The latter approach is appropriate if the set of parameter values are on a rectilinear grid, in which case the approach is more computationally efficient.

In both cases we use the CenteredTransformInitializer to obtain the initial translation.

### MetricEvaluate

To use the MetricEvaluate method we create a ImageRegistrationMethod, set its metric and interpolator. We then iterate over all parameter settings, set the initial transform and evaluate the metric. The minimal similarity measure value corresponds to the best paramter settings.

In [None]:
%%timeit -r1 -n1
#the magic above will time a single run of this cell

initial_transform = sitk.Euler3DTransform(sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelIDValue()), 
                                                                            moving_image, 
                                                                            sitk.Euler3DTransform(), 
                                                                            sitk.CenteredTransformInitializerFilter.GEOMETRY))
            #registration framework setup
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)

        #evaluate the similarity metric using the eight possible orientations, translation is the same for all
registration_method.SetInitialTransform(initial_transform, inPlace=False)
best_orientation = (1,0,0,0,1,0,0,0,1)
best_similarity_value = registration_method.MetricEvaluate(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                           sitk.Cast(moving_image, sitk.sitkFloat32))    
       #iterate over all other rotation parameter settings 
for key, orientation in possible_orientation_changes.items():
    initial_transform.SetMatrix(orientation)
    registration_method.SetInitialTransform(initial_transform)
    current_similarity_value = registration_method.MetricEvaluate(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                                  sitk.Cast(moving_image, sitk.sitkFloat32))
    if current_similarity_value < best_similarity_value:
        best_similarity_value = current_similarity_value
        best_orientation = orientation

initial_transform.SetMatrix(best_orientation)

      #register
multires_registration(fixed_image, moving_image, initial_transform, 'final-robustInitializationMetricEvaluate')  

### Exhaustive optimizer

The exhaustive optimizer evaluates the similarity measure using a grid overlaid on the parameter space.
The grid is centered on the parameter values set by the SetInitialTransform, and the location of its vertices are determined by the <b>numberOfSteps</b>, <b>stepLength</b> and <b>optimizer scales</b>. To quote the documentation of this class: "a side of the region is stepLength*(2*numberOfSteps[d]+1)*scaling[d]."

Using this approach we have superfluous evaluations (15 evaluations corresponding to 3 values for rotations around the x axis and five for rotation around the z axis, as compared to the 8 evaluations using the MetricEvaluate method).

In [None]:
%%timeit -r1 -n1
#the magic above will time a single run of this cell

initial_transform = sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelIDValue()), 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
    #the order of parameters for the Euler3DTransform is [angle_x, angle_y, angle_z, t_x, t_y, t_z]
registration_method.SetOptimizerAsExhaustive(numberOfSteps=[1,0,2,0,0,0], stepLength = np.pi)
registration_method.SetOptimizerScales([1,1,0.5,1,1,1])
    #do the registration in-place so that the initial_transform is modified
registration_method.SetInitialTransform(initial_transform)
registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                            sitk.Cast(moving_image, sitk.sitkFloat32))

multires_registration(fixed_image, moving_image, initial_transform, 'final-robustInitialization')