## Working with medical images
---
## Part 2 - registration
Registration forms the basic element of image registration.

1. Resample the moving image into the target image space
2. Calculate a metric describing how similar the images look
3. Use an optimiser to update the transformation
4. Repeat

#### To tell the whole truth...
Step 1 is not *quite* what happens.

Both images are resampled into a **virtual reference image space**. This allows us to apply a transform to both images simulataneously.

Why would we want to do this?

It let's us define virtual domains in any way we want - not necessarily a grid of points. 

In [None]:
# Only run this if you don't have platipy
#!pip install git+https://github.com/pyplati/platipy

In [None]:
# Only run this if you haven't already cloned the repo
#!git clone https://github.com/InghamPhysics/coding-club
#import os
#os.chdir('./coding-club/medical-images') 

In [None]:
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

from platipy.imaging.visualisation.tools import ImageVisualiser

%matplotlib notebook

To start with, let's go back to the spheres we generated in the first notebook.

In [None]:
# Read original image
img_original = sitk.ReadImage("./output/spheres_original.nii.gz", sitk.sitkUInt8)
# Read image with modified spacing
img_modified = sitk.ReadImage("./output/spheres_modified_spacing.nii.gz", sitk.sitkUInt8)

We know that they aren't really aligned if we just resample.

In [None]:
# Resample
img_modified_res = sitk.Resample(img_modified, img_original, interpolator=sitk.sitkNearestNeighbor)

# Visualise
vis = ImageVisualiser(img_original, window=[0,1], figure_size_in=5)
vis.add_comparison_overlay(img_modified_res)
fig = vis.show()

So let's make an ITK-style pipeline for registering these images

In [None]:
# Create the pipeline
reg_pipe = sitk.ImageRegistrationMethod()

# Set the registration
reg_pipe.SetInitialTransform(sitk.TranslationTransform(3))

# A multi-level scheme
reg_pipe.SetShrinkFactorsPerLevel((8,4,2))
reg_pipe.SetSmoothingSigmasPerLevel((2,1,0))
reg_pipe.SetSmoothingSigmasAreSpecifiedInPhysicalUnits(True)


# Compare images using mean squared intensity difference
# We calculate this at each point
reg_pipe.SetMetricAsMeanSquares()
reg_pipe.SetMetricSamplingPercentage(1)
reg_pipe.SetMetricSamplingStrategy(sitk.ImageRegistrationMethod.REGULAR)

# Use gradient descent
reg_pipe.SetOptimizerAsGradientDescentLineSearch(
    learningRate=1,
    numberOfIterations=25
)
reg_pipe.SetOptimizerScalesFromPhysicalShift(True)

# Interpolate using nearest neighbour
reg_pipe.SetInterpolator(sitk.sitkNearestNeighbor)
    
# Voxel values must be of type double (float64)
transform_translation = reg_pipe.Execute(
    fixed=sitk.Cast(img_original, sitk.sitkFloat64),
    moving=sitk.Cast(img_modified, sitk.sitkFloat64)
)

In [None]:
print(transform_translation)

Now we have to resample, and use the transform we just optimised.

In [None]:
img_modified_translated = sitk.Resample(
    img_modified,
    img_original,
    transform = transform_translation,
    interpolator = sitk.sitkNearestNeighbor,
    defaultPixelValue = 0
)

In [None]:
# Visualise
vis = ImageVisualiser(img_original, window=[0,1], figure_size_in=5)
vis.add_comparison_overlay(img_modified_translated)
fig = vis.show()

Clearly, this isn't a great registration. A major problem is that we aren't accounting for the spatial "stretching" caused by changing the image spacing earlier.

Fortunately, there are transformation that include this kind of scaling.

In [None]:
# Create the pipeline
reg_pipe = sitk.ImageRegistrationMethod()

# Set the registration
reg_pipe.SetInitialTransform(sitk.ScaleVersor3DTransform())

# A multi-level scheme
reg_pipe.SetShrinkFactorsPerLevel((8,4,2))
reg_pipe.SetSmoothingSigmasPerLevel((2,1,0))
reg_pipe.SetSmoothingSigmasAreSpecifiedInPhysicalUnits(True)


# Compare images using mean squared intensity difference
# We calculate this at each point
reg_pipe.SetMetricAsMeanSquares()
reg_pipe.SetMetricSamplingPercentage(1)
reg_pipe.SetMetricSamplingStrategy(sitk.ImageRegistrationMethod.REGULAR)

# Use gradient descent
reg_pipe.SetOptimizerAsGradientDescentLineSearch(
    learningRate=1,
    numberOfIterations=25
)
reg_pipe.SetOptimizerScalesFromPhysicalShift(True)

# Interpolate using nearest neighbour
reg_pipe.SetInterpolator(sitk.sitkNearestNeighbor)
                           
transform_scaleversor = reg_pipe.Execute(
    fixed=sitk.Cast(img_original, sitk.sitkFloat64),
    moving=sitk.Cast(img_modified, sitk.sitkFloat64)
)

In [None]:
print(transform_scaleversor)

In [None]:
img_modified_scaleversor = sitk.Resample(
    img_modified,
    img_original,
    transform = transform_scaleversor,
    interpolator = sitk.sitkNearestNeighbor,
    defaultPixelValue = 0
)

In [None]:
# Visualise
vis = ImageVisualiser(img_original, window=[0,1], figure_size_in=5)
vis.add_comparison_overlay(img_modified_scaleversor)
fig = vis.show()

A little bit better, but because the initial parameters are so far from the ideal it is hard for the optimiser to get to the global minimum.

Another issue is that with binary images like ours, there are only two possible values for the mean squared difference.

In [None]:
from platipy.imaging.registration.registration import convert_mask_to_reg_structure

In [None]:
reg_struct_original = convert_mask_to_reg_structure(img_original)

vis = ImageVisualiser(reg_struct_original, window=(0,1), figure_size_in=5)
fig = vis.show()

In [None]:
reg_struct_modified = convert_mask_to_reg_structure(img_modified)

vis = ImageVisualiser(reg_struct_modified , window=(0,1), figure_size_in=5)
fig = vis.show()

In [None]:
# Create initial registration
alignment_tfm = sitk.CenteredTransformInitializer(
    reg_struct_original, reg_struct_modified, sitk.Euler3DTransform(), True
)

# Create the pipeline
reg_pipe = sitk.ImageRegistrationMethod()

# Set the registration
reg_pipe.SetInitialTransform(sitk.ScaleVersor3DTransform())

# Set transform to moving image
reg_pipe.SetMovingInitialTransform(alignment_tfm)

# Only sample in the "spheres"
# reg_pipe.SetMetricFixedMask( img_original )
# reg_pipe.SetMetricMovingMask( img_modified )

# A multi-level scheme
reg_pipe.SetShrinkFactorsPerLevel((8,4,2))
reg_pipe.SetSmoothingSigmasPerLevel((2,1,0))
reg_pipe.SetSmoothingSigmasAreSpecifiedInPhysicalUnits(True)


# Compare images using mean squared intensity difference
# We calculate this at each point
reg_pipe.SetMetricAsMeanSquares()
reg_pipe.SetMetricSamplingPercentage(1)
reg_pipe.SetMetricSamplingStrategy(sitk.ImageRegistrationMethod.REGULAR)

# Use gradient descent
reg_pipe.SetOptimizerAsGradientDescentLineSearch(
    learningRate=1,
    numberOfIterations=25
)
reg_pipe.SetOptimizerScalesFromPhysicalShift(True)

# Interpolate using nearest neighbour
reg_pipe.SetInterpolator(sitk.sitkNearestNeighbor)
                           
transform_scaleversor_2 = reg_pipe.Execute(
    fixed=reg_struct_original,
    moving=reg_struct_modified 
)

# We must combine the transforms
combined_transform = sitk.CompositeTransform([alignment_tfm, transform_scaleversor_2])

In [None]:
print(combined_transform)

In [None]:
img_modified_scaleversor_2 = sitk.Resample(
    img_modified,
    img_original,
    transform = combined_transform,
    interpolator = sitk.sitkNearestNeighbor,
    defaultPixelValue = 0
)

In [None]:
# Visualise
vis = ImageVisualiser(img_original, window=[0,1], figure_size_in=5)
vis.add_comparison_overlay(img_modified_scaleversor_2)

fig = vis.show()

## Platipy - another abstraction layer

It can be a bit of a hassle doing image registration like this.

Platipy has some useful tools to make this process a lot easier.

In [None]:
from platipy.imaging.registration.registration import initial_registration, transform_propagation

In [None]:
_, transform_platipy = initial_registration(
    fixed_image = reg_struct_original,
    moving_image = reg_struct_modified,
    reg_method = "ScaleVersor",
    default_value = 0,
    shrink_factors = [8,4,2],
    optimiser = 'gradient_descent_line_search',
    final_interp = sitk.sitkNearestNeighbor
)

In [None]:
# Propagate transform to the modified spheres image

img_modified_scaleversor_3 = transform_propagation(
    fixed_image = img_original,
    moving_image = img_modified,
    transform = transform_platipy,
    structure = True
)

In [None]:
# Visualise
vis = ImageVisualiser(img_original, window=[0,1], figure_size_in=5)
vis.add_comparison_overlay(img_modified_scaleversor_3)

fig = vis.show()

## "Real" images

Finally, let's check out how we can perform registration on real patient imaging.

In [None]:
# Small utility function
from platipy.imaging.utils.tools import get_com

In [None]:
# We have some contoured RT imaging

img_ct_atlas = sitk.ReadImage("./input/HN_CT_ATLAS.nii.gz")
struct_ctv_atlas = sitk.ReadImage("./input/HN_CTV_ATLAS.nii.gz")

vis = ImageVisualiser(img_ct_atlas, cut=get_com(struct_ctv_atlas), figure_size_in=5)
vis.add_contour({"CTV":struct_ctv_atlas})
fig = vis.show()

In [None]:
# We also have a PET-CT scan, without any contours

img_ct = sitk.ReadImage("./input/HN_CT.nii.gz")
img_pt = sitk.ReadImage("./input/HN_PT.nii.gz")

img_pt_res = sitk.Resample(img_pt, img_ct)

vis = ImageVisualiser(img_ct, figure_size_in=5)
vis.add_scalar_overlay(img_pt_res, name="PET value", colormap=plt.cm.magma, max_value=50000)
fig = vis.show()

In [None]:
# Register the planning CT to the PET-CT

img_ct_atlas_rigid, tfm_rigid = initial_registration(
    fixed_image = img_ct,
    moving_image = img_ct_atlas,
    reg_method = "Similarity",
    default_value = -1000,
    final_interp = sitk.sitkLinear
)

In [None]:
# Visualise

vis = ImageVisualiser(img_ct, figure_size_in=5)
vis.add_comparison_overlay(img_ct_atlas_rigid)
fig = vis.show()

### Deformable registration

To account for non-linear deformations we can use a DIR algorithm.

One great option is the fast, symmetric, log-domain differomorphic demons

In [None]:
from platipy.imaging.registration.registration import fast_symmetric_forces_demons_registration, apply_field

In [None]:
img_ct_atlas_dir, tfm_dir = fast_symmetric_forces_demons_registration(
    fixed_image = img_ct,
    moving_image = img_ct_atlas_rigid,
    ncores = 8
)

In [None]:
# Visualise

vis = ImageVisualiser(img_ct, figure_size_in=5)
vis.add_comparison_overlay(img_ct_atlas_dir)
fig = vis.show()

### Propagating transformations

Finally, we can apply these transformations to the contours we have.

In [None]:
struct_ctv_atlas_rigid = transform_propagation(img_ct, struct_ctv_atlas, tfm_rigid, structure=True)
struct_ctv_atlas_dir = apply_field(struct_ctv_atlas_rigid, tfm_dir, structure=True)

In [None]:
# Overlay on PET-CT

vis = ImageVisualiser(img_ct, cut=get_com(struct_ct_atlas_dir), figure_size_in=5)
vis.add_scalar_overlay(img_pt_res, name="PET value", colormap=plt.cm.magma, max_value=50000)
vis.add_contour({"CTV":struct_ctv_atlas_dir}, color='blue')
fig = vis.show()

In [None]:
# In practice, we could use this contour to extract information from the PET

img_pt_masked = sitk.Mask(img_pt_res, struct_ctv_atlas_dir)

f = sitk.LabelIntensityStatisticsImageFilter()
f.Execute(struct_ctv_atlas_dir, img_pt_res)

vol = f.GetNumberOfPixels(1) * np.product(img_pt_res.GetSpacing())/1000
max_act = f.GetMaximum(1)
mean_act = f.GetMean(1)
tot_act = f.GetSum(1) * np.product(img_pt_res.GetSpacing())/1000

print("CTV information:")
print(f"Volume:        {vol:.2f} mL")
print(f"Max. activity: {max_act:.2f} Bq/mL")
print(f"Mean activity: {mean_act:.2f} Bq/mL")
print(f"Tot. activity: {tot_act/1e6:.2f} MBq")