<table width="100%">
<tr style="background-color: red;"><td><font color="white">SimpleITK conventions:</font></td></tr>
<tr><td>
<ul>
<li>Dimensionality and pixel type of registered images is required to be the same (2D/2D or 3D/3D).</li>
<li>Supported pixel types are sitkFloat32 and sitkFloat64 (use the SimpleITK <a href="http://www.itk.org/SimpleITKDoxygen/html/namespaceitk_1_1simple.html#af8c9d7cc96a299a05890e9c3db911885">Cast()</a> function if your image's pixel type is something else).
</ul>
</td></tr>
</table>

## ITK v4 Registration Components 
<img src="ITKv4RegistrationComponentsDiagram.svg" style="width:700px"/>

### Optimizer types

The SimpleITK registration framework supports several optimizer types via the SetMetricAsX() methods, these include:

<ul>
  <li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1ExhaustiveOptimizerv4.html">Exhaustive</a>
  </li>
  <li>
  Variations on gradient descent:
  <ul>
    <li>
    <a href="http://www.itk.org/Doxygen/html/classitk_1_1GradientDescentOptimizerv4Template.html">GradientDescent</a>
    </li>
    <li>
    <a href="http://www.itk.org/Doxygen/html/classitk_1_1GradientDescentLineSearchOptimizerv4Template.html">GradientDescentLineSearch</a>
    </li>
    <li>
    <a href="http://www.itk.org/Doxygen/html/classitk_1_1RegularStepGradientDescentOptimizerv4.html">RegularStepGradientDescent</a>
    </li>
  </ul>
  </li>
  <li>
    <a href="http://www.itk.org/Doxygen/html/classitk_1_1ConjugateGradientLineSearchOptimizerv4Template.html">ConjugateGradientLineSearch</a> 
  </li>
  <li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1LBFGSBOptimizerv4.html">LBFGSB</a> (Limited memory Broyden–  Fletcher–Goldfarb–Shanno-Byrd) - supports the use of simple constraints ($l\leq x \leq u$)  
  </li>
</ul>

 


### Similarity metric types

The SimpleITK registration framework supports several metric types via the SetMetricAsX() methods, these include:

<ul>
  <li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1MeanSquaresImageToImageMetricv4.html">MeanSquares</a>
  </li>
  <li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1DemonsImageToImageMetricv4.html">Demons</a>
  </li>
  <li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1CorrelationImageToImageMetricv4.html">Correlation</a>
  </li>
  <li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1ANTSNeighborhoodCorrelationImageToImageMetricv4.html">ANTSNeighborhoodCorrelation</a>
  </li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1JointHistogramMutualInformationImageToImageMetricv4.html">JointHistogramMutualInformation</a>
  </li>
  </li>
  <a href="http://www.itk.org/Doxygen/html/classitk_1_1MattesMutualInformationImageToImageMetricv4.html">MattesMutualInformation</a>
  </li>
</ul>

In [None]:
import SimpleITK as sitk
import matplotlib.pyplot as plt
%matplotlib inline

from IPython.html.widgets import interact, fixed
     #display some html output
from IPython.display import display, HTML 

import os

OUTPUT_DIR = 'Output'
INPUT_DIR = 'Data'

    #this notebook requires that ITK-SNAP be installed on your machine, provide the full path to the application below
%env SITK_SHOW_COMMAND /Applications/ITK-SNAP.app/Contents/MacOS/ITK-SNAP 

In [None]:
def save_transform_and_image(transform, fixed_image, moving_image, outputfile_prefix):
    '''
    Write the given transformation to file, resample the moving_image onto the fixed_images grid and save the
    result to file.
    
    Args:
        transform (SimpleITK Transform): transform that maps points from the fixed image coordinate system to the moving.
        fixed_image (SimpleITK Image): resample onto the spatial grid defined by this image.
        moving_image (SimpleITK Image): resample this image.
        outputfile_prefix (string): transform is written to outputfile_prefix.tfm and resampled image is written to 
                                    outputfile_prefix.mhd.
    '''
                             
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(fixed_image)
                #SimpleITK supports several interpolation options.     
    resample.SetInterpolator(sitk.sitkLinear)  
    resample.SetTransform(transform)
    sitk.WriteImage(resample.Execute(moving_image), outputfile_prefix+'.mhd')
    sitk.WriteTransform(transform, outputfile_prefix+'.tfm')

## Loading Data

In [None]:
data_directory = os.path.join(INPUT_DIR,'CIRS057A')

         #global variables 'selected_series_moving/fixed' are updated by the interact function
selected_series_fixed = ''
selected_series_moving = ''
def DICOM_series_dropdown_callback(fixed_image, moving_image, series_dictionary):    
    global selected_series_fixed
    global selected_series_moving
               #print some information about the series from the meta-data dictionary
              #DICOM standard part 6, Data Dictionary: http://medical.nema.org/medical/dicom/current/output/pdf/part06.pdf
    img_fixed = sitk.ReadImage(series_dictionary[fixed_image][0])
    img_moving = sitk.ReadImage(series_dictionary[moving_image][0])
          #there are many interesting tags in a DICOM format, print out some of them
    tags_to_print = {'0010|0010': 'Patient name: ', 
                     '0008|0060' : 'Modality: ',
                     '0008|0021' : 'Series date: ',
                     '0008|0031' : 'Series time:',
                     '0008|0080' : 'Institution name: '}
    html_table = []
    html_table.append('<table><tr><td><b>Tag</b></td><td><b>Fixed Image</b></td><td><b>Moving Image</b></td></tr>')
    for tag in tags_to_print:
        fixed_tag = ''
        moving_tag = ''
        try:            
            fixed_tag = img_fixed.GetMetaData(tag)
        except: #ignore if the tag isn't in the dictionary
            pass
        try:            
            moving_tag = img_moving.GetMetaData(tag)
        except: #ignore if the tag isn't in the dictionary
            pass           
        html_table.append('<tr><td>' + tags_to_print[tag] + 
                          '</td><td>' + fixed_tag + 
                          '</td><td>' + moving_tag + '</td></tr>')
    html_table.append('</table>')
    display(HTML(''.join(html_table)))
    selected_series_fixed = fixed_image
    selected_series_moving = moving_image

             #directory contains multiple DICOM studies/series, store
             #in dictionary with key being the seriesID
reader = sitk.ImageSeriesReader()
series_file_names = {}
series_IDs = reader.GetGDCMSeriesIDs(data_directory)
            #check that we have at least one series
if series_IDs:
    for series in series_IDs:
        series_file_names[series] = reader.GetGDCMSeriesFileNames(data_directory, series)
    
    interact(DICOM_series_dropdown_callback, fixed_image=series_IDs, moving_image =series_IDs, series_dictionary=fixed(series_file_names)); 
else:
    print('Data directory does not contain any DICOM series.')

In [None]:
         #continue only if the previous cell was run and the series were selected
if 'selected_series_fixed' in globals() and 'selected_series_moving' in globals():
    reader.SetFileNames(series_file_names[selected_series_fixed])
    fixed_image = reader.Execute()
    reader.SetFileNames(series_file_names[selected_series_moving])
    moving_image = reader.Execute()
          #look at our images and their alignment (the two instances of ITK-SNAP share a spatially linked cursor)
    sitk.Show(fixed_image)
    sitk.Show(moving_image)

## Initial Alignment

A reasonable guesstimate for the initial translational alignment can be obtained by using
the CenteredTransformInitializer (functional interface to the CenteredTransformInitializerFilter). 

The resulting transformation is centered with respect to the fixed image and the
translation aligns the centers of the two images. There are two options for
defining the centers of the images, either the physical centers
of the two data sets (GEOMETRY), or the centers defined by the intensity 
moments (MOMENTS).

Two things to note about this filter, it requires the fixed and moving image 
have the same type even though it is not algorithmically required, and its
return type is the generic SimpleITK.Transform.

In [None]:
initial_transform = sitk.CenteredTransformInitializer(sitk.Cast(fixed_image,moving_image.GetPixelIDValue()), 
                                                      moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

save_transform_and_image(initial_transform, fixed_image, moving_image, os.path.join(OUTPUT_DIR, 'initialAlignment'))

      #now look at the images after their initial alignment
sitk.Show(fixed_image)
sitk.Show(sitk.ReadImage(os.path.join(OUTPUT_DIR, 'initialAlignment.mhd')))

## Final registration

### Version 1
<ul>
<li> Single scale - no pyramid structure.</li>
<li> Initial transformation is not modified in place. Resulting transformation's type is sitk.Transform. </li>
</ul>

Illustrate the need for scaling the gradient differently for each parameter:
<ul>
<li> SetOptimizerScalesFromIndexShift</li>
<li> SetOptimizerScalesFromPhysicalShift </li>
<li> SetOptimizerScalesFromJacobian </li>
</ul>

In [None]:
pyramid_shrink_factors = [1]

registration_method = sitk.ImageRegistrationMethod()

          #1. Similarity metric setting (MR/CT registration so use mutual information) 
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)

         #2. Interpolater setting
registration_method.SetInterpolator(sitk.sitkLinear)

          #3. Optimizer settings    
                  #learningRate is the step size in the opposite direction of the gradient, equal for all parameters 
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
#registration_method.SetOptimizerScalesFromIndexShift() #uncomment this line if you want registration to succeed

          #4. Registration framework settings (multi-resolution framework)
registration_method.SetShrinkFactorsPerLevel(pyramid_shrink_factors)
registration_method.SetInitialTransform(initial_transform, inPlace=False)

final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))

     #5. Ater registration, report some interesting registration information, save the results and look at them
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
            
save_transform_and_image(final_transform, fixed_image, moving_image, os.path.join(OUTPUT_DIR, 'finalAlignment-v1'))

sitk.Show(fixed_image)
sitk.Show(sitk.ReadImage(os.path.join(OUTPUT_DIR, 'finalAlignment-v1.mhd')))

### Version 2

<ul>
<li> Multi scale - specify both scale, and how much to smooth with respect to original image.</li>
<li> Initial transformation modified in place, so in the end we have the same type of transformation in hand.</li>
</ul>

In [None]:
pyramid_shrink_factors = [4,2,1]
smoothing_sigmas = [2,1,0]

registration_method = sitk.ImageRegistrationMethod()

          #1. Similarity metric setting (MR/CT registration so use mutual information) 
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)

         #2. Interpolater setting
registration_method.SetInterpolator(sitk.sitkLinear)
   
          #3. Optimizer settings (if we use the default estimation approach of Once it will only be relevant for the
          #   first level).
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100) #, estimateLearningRate=registration_method.EachIteration)
registration_method.SetOptimizerScalesFromIndexShift() 

          #4. Registration framework settings (multi-resolution framework)
registration_method.SetInitialTransform(initial_transform)
registration_method.SetShrinkFactorsPerLevel(pyramid_shrink_factors)
registration_method.SetSmoothingSigmasPerLevel(smoothing_sigmas)
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))


     #5. Ater registration, report some interesting registration information, save the results and look at them
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
            
save_transform_and_image(final_transform, fixed_image, moving_image, os.path.join(OUTPUT_DIR, 'finalAlignment-v2'))

sitk.Show(fixed_image)
sitk.Show(sitk.ReadImage(os.path.join(OUTPUT_DIR, 'finalAlignment-v2.mhd')))

## Insight Into Registration

Up to this point we have only shown you how to configure the registration and obtain the final results. This provides limited insight into what is actually happening during the optimization process. We now introduce the use of callbacks to facilitate a more detailed understanding of the process.

Also, the use of callbacks enables us to create an illusary effect. The same registration without feedback during the process appears to be slower than with feedback.

The interesting registration events that you can attach your code to include:
<ul>
<li> sitkStartEvent - registration starts</li>
<li> sitkEndEvent - registration ends</li>
<li> sitkIterationEvent - single iteration</li>
<li> sitkMultiResolutionIterationEvent  - when we change resolution</li>
</ul>


In [None]:
from IPython.display import clear_output

#callback invoked when the StartEvent happens, sets up our new data
def start_plot():
    global metric_values, multires_iterations 
    metric_values = []
    multires_iterations = []

#callback invoked when the EndEvent happens, do cleanup of data and figure
def end_plot():
    global metric_values, multires_iterations  
    del metric_values
    del multires_iterations
          #close figure, we don't want to get a duplicate of the plot latter on
    plt.close()

#callback invoked when the IterationEvent happens, update our data and display new figure    
def plot_value(registration_method, plot_title):
    global metric_values, multires_iterations
    metric_values.append(registration_method.GetMetricValue())
         #clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)    
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.title(plot_title)
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    
#callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
#metric_values list. We assume this event happens before the first IterationEvent on the next resolution.
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))
    

### Version 1 with display

In [None]:
pyramid_shrink_factors = [1]
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100)
registration_method.SetOptimizerScalesFromIndexShift() 
registration_method.SetShrinkFactorsPerLevel(pyramid_shrink_factors)
registration_method.SetInitialTransform(initial_transform, inPlace=False)

registration_method.RemoveAllCommands()
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_value(registration_method, 'Single Scale'))

final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))

print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

### Version 2 with display

In [None]:
pyramid_shrink_factors = [4,2,1]
smoothing_sigmas = [2,1,0]
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, estimateLearningRate=registration_method.EachIteration)
registration_method.SetOptimizerScalesFromIndexShift() 
registration_method.SetInitialTransform(initial_transform)
registration_method.SetShrinkFactorsPerLevel(pyramid_shrink_factors)
registration_method.SetSmoothingSigmasPerLevel(smoothing_sigmas)
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

registration_method.RemoveAllCommands()
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_value(registration_method, 'Multi Scale'))

final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                              sitk.Cast(moving_image, sitk.sitkFloat32))

print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))
