In [None]:
import SimpleITK as sitk
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import util
import registration_callbacks
import os
from IPython.display import display, HTML 

In [None]:
dirname_data = './datasets/2019-03-05/'

### Load the images

In [None]:
view0 = sitk.ReadImage(os.path.normpath(dirname_data + 'head_view0_cropped.mhd'), sitk.sitkFloat32)
view1 = sitk.ReadImage(os.path.normpath(dirname_data + 'head_view1_cropped.mhd'), sitk.sitkFloat32)

# The names view0 and view1 have a meaning in some context and they play the role of fixed and moving images
# in the registration, so we use both naming conventions and they alias the relevant images.
fixed_image = view0
moving_image = view1

In [None]:
print(f"Voxel size, um: {view0.GetSpacing()}")

print(f"View0 size, px: {view0.GetSize()}")
print(f"View1 size, px: {view1.GetSize()}")

print(util.get_minmax(view0))
print(util.get_minmax(view1))

img_merge = util.merge_images_rgb(view0, view1)

fig = plt.figure(figsize=(10,10))
util.show_mips(img_merge, "Before registration: view0 (fixed, green), view1 (moving, magenta), overlap (white).")

## Registration flow

We register in three steps:

0. Basic initialization, centering the two volumes.
1. Rigid registration, global transformation.
2. Affine transformation, constrained so that scale is only along the x and z axes, shearing is not constrained.

### Initialization

In [None]:
rigid_transform = sitk.Euler3DTransform(sitk.CenteredTransformInitializer(fixed_image, moving_image, sitk.Euler3DTransform(), 
                                                                          sitk.CenteredTransformInitializerFilter.GEOMETRY))

### Rigid registration

We use a single level, no need for pyramid (Occam's razor - simplest solution is preferred). We also do the registration in place, the transform variable is just updated.

In [None]:
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsCorrelation()
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)

registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=300)
# Scale the step size differently for each parameter, this is critical!!!
registration_method.SetOptimizerScalesFromPhysicalShift() 

registration_method.SetInitialTransform(rigid_transform, inPlace=True)

registration_method.AddCommand(sitk.sitkStartEvent, registration_callbacks.metric_start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, registration_callbacks.metric_end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, 
                              registration_callbacks.metric_update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, 
                              lambda: registration_callbacks.metric_plot_values(registration_method))

registration_method.Execute(fixed=fixed_image, moving=moving_image)
print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

In [None]:
view1_registered = sitk.Resample(view1, view0, rigid_transform, sitk.sitkLinear, 0.0, view0.GetPixelID())
img_merge = util.merge_images_rgb(view0, view1_registered)

fig = plt.figure(figsize=(10,10))
util.show_mips(img_merge, "Rigid registration: view0 (fixed, green), view1 (moving, magenta), overlap (white).", scalebar=10)

### Affine transformation

Register using affine transformation. Use the same center point as the rigid. Use `SetOptimizerWeights` to limit the optimized parameters so that it only includes scaling in x and z and shearing, no translation.

Affine transformation is represented as a homogenous matrix:

$$
\left[\begin{array}{cccc}
a_1 & a_2 & a_3 & a_{10}\\
a_4 & a_5 & a_6 & a_{11}\\
a_7 & a_8 & a_9 & a_{12}\\
0   & 0   & 0   & 1
\end{array}\right]
$$

We don't want to optimize parameters $a_5$ (scaling in y) and $a_{10}, a_{11}, a_{12}$ (translation).

In [None]:
affine_transform = sitk.AffineTransform(3)
affine_transform.SetCenter(rigid_transform.GetCenter())

registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsCorrelation()
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=300)

# Constrain the transformation to optimize only the desired parameters.
registration_method.SetOptimizerWeights([1,1,1,1,0,1,1,1,1,0,0,0])
# Scale the step size differently for each parameter, this is critical!!!
registration_method.SetOptimizerScalesFromPhysicalShift() 

registration_method.SetInitialTransform(affine_transform, inPlace=True)
registration_method.SetMovingInitialTransform(rigid_transform)

registration_method.AddCommand(sitk.sitkStartEvent, registration_callbacks.metric_start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, registration_callbacks.metric_end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, 
                              registration_callbacks.metric_update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, 
                              lambda: registration_callbacks.metric_plot_values(registration_method))

registration_method.Execute(fixed=fixed_image, moving=moving_image)
# Need to compose the transformations after registration.
final_transform = sitk.CompositeTransform([affine_transform, rigid_transform])

print('Optimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
print('Final metric value: {0}'.format(registration_method.GetMetricValue()))

In [None]:
view1_registered = sitk.Resample(view1, view0, final_transform, sitk.sitkLinear, 0.0, view0.GetPixelID())
img_merge = util.merge_images_rgb(view0, view1_registered)

fig = plt.figure(figsize=(10,10))
util.show_mips(img_merge, "Rigid-Scale: view0 (fixed, green), view1 (moving, magenta), overlap (white).", scalebar=10)

In [None]:
print(final_transform)