## Registering FLYWIRE to MCNS using ITKElastix

### Requirements

For this tutorial we will need:
- `fw_synapse_points.feather` (49Mb): the FLYWIRE synapse cloud (30x subsampled)
- `mcns_synapse_points.feather` (64Mb): t-bars for the entire male CNS (10x subsampled)
- [`ITKElastix`](https://github.com/InsightSoftwareConsortium/ITKElastix): `pip install itk-elastix`
- `pandas`, `numpy`, `matplotlib`, `scipy`

In [None]:
!wget https://flyem.mrc-lmb.cam.ac.uk/flyconnectome/imagereg_workshop/mcns_synapse_points.feather
!wget https://flyem.mrc-lmb.cam.ac.uk/flyconnectome/imagereg_workshop/fw_synapse_points.feather

In [None]:
%%capture
!pip install pandas numpy matplotlib scipy itk-elastix

In [None]:
import itk

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from scipy import ndimage

Step 1: Generate the image volume from the synapse clouds

In [None]:
# Load the male CNS t-bars data (this is 8x8x8nm)
points_mcns = pd.read_feather("mcns_synapse_points.feather")

# Convert to microns
points_mcns = points_mcns[["x", "y", "z"]].values / 125

# Convert to 1x1x1um resolution pixels
points_mcns = (points_mcns // 1).astype(int)

# Create a 3D histogram
mx = points_mcns.max(axis=0) + 1
img_mcns = np.histogramdd(
    points_mcns, bins=[np.arange(mx[0] + 1), np.arange(mx[1] + 1), np.arange(mx[2] + 1)]
)[0]

# Trim image to just the brain and remove some ventral space
img_mcns_brain = img_mcns[:, :450, :400]

# Smooth
img_mcns_brain_smooth = ndimage.gaussian_filter(img_mcns_brain, sigma=1)

# Normalize (note that we normalize to 99th percentile to avoid outliers)
img_mcns_brain_smooth = np.clip(
    img_mcns_brain_smooth / np.percentile(img_mcns_brain_smooth, 99) * 255, 0, 255
).astype(np.uint8)

plt.imshow(img_mcns_brain_smooth.sum(axis=2).T, cmap="gray")

In [None]:
# This data is in nm resolution
points_fw = pd.read_feather(
    "fw_synapse_points.feather"
).values

# # Convert to 1x1x1um resolution voxels
points_fw = (points_fw // 1000).astype(int)

# # Create a 3D histogram
mx = points_fw.max(axis=0) + 1
img_fw = np.histogramdd(
    points_fw, bins=[np.arange(mx[0] + 1), np.arange(mx[1] + 1), np.arange(mx[2] + 1)]
)[0]

# Smooth
img_fw_smooth = ndimage.gaussian_filter(img_fw, sigma=1)

# Normalize (note that we normalize to 99th percentile to avoid outliers)
img_fw_smooth = np.clip(
    img_fw_smooth / np.percentile(img_fw_smooth, 99) * 255, 0, 255
).astype(np.uint8)

# Plot frontal view
plt.imshow(img_fw_smooth.sum(axis=2).T, cmap="gray")

In [None]:
# Free some memory (relevant on Google Colab)
del points_mcns
del points_fw
del img_fw
del img_mcns
del img_mcns_brain

### Register MCNS to FlyWire

We will run the registration with multiple stages, using the spatial transformation result from the current stage to initialize registration at the next stage:
1. A simple rigid (translation-only) registration 
2. Affine transform (translation, scale, shear)
3. Coarse B-spline non-rigid transformation
4. Fine B-spline non-rigid transformation


Here we first perform the above registrations one-by-one to show you what's happening at each step. That said: you can run them all in one go. They will be written to sequential transform files: `TransformParameters.0.txt`, `TransformParameters.1.txt`, etc! If haven't fully explored this but it should be possible to produce a single transform by passing the affine transform into the non-rigid transform as initial affine and run the non-rigid transforms using a grid spacing scheduler.

In [None]:
# Where to store the output (registered images, log files and transform parameter files)
# Note that we will re-use this path and output files will be overwritten!
output_path = Path("./elastix_output/").expanduser()
output_path.mkdir(exist_ok=True)  # Create the directory if it does not exist

Before we get started, some general background on how `elastix` and `ITKElastix` operate:

The `elastix` command line tool takes three arguments:
- the "fixed" image
- the "moving" image that is deformed to match the fixed image
- a parameter file 

The parameter file can specify one or more transforms which is how you can chain rigid and non-rigid registrations in a single run.

`ITKElastix` wraps `elastix` through:
- one or more `ParameterMap` which contains the settings for a single registration
- a single `ParameterObject` which combines one or more `ParameterMap` 
- `elastix_registration_method()` is used to invoke `elastix` with the given parameter


With that out of the way, let's get cracking! First up: the **rigid transform**!

In [None]:
# Number of resolutions
resolutions = 4

# Generate a parameter object
parameter_object = itk.ParameterObject.New()

# Initialize a default parameter map for the rigid transform
parameter_map_rigid = parameter_object.GetDefaultParameterMap("rigid", resolutions)

# You can set the parameters like this was a dictionary
# Importantly though, it seems to (a) always expect strings and (b) sometimes expect tuples even if the parameter is a single value
parameter_map_rigid["MaximumNumberOfIterations"] = (
    "999",
)  # increase the number of iterations
parameter_map_rigid["MaximumStepLength"] = ("3.0",)  # increase the max step length

# Add the parameter map to the parameter object
parameter_object.AddParameterMap(parameter_map_rigid)

# Uncomment to print the parameter object to see the parameters
# print(parameter_object, flush=True)

# Perform the registration (~20s on my laptop, ~3mins on Colab)
this_out = output_path / "transform1-rigid"
this_out.mkdir(exist_ok=True)  # Create the directory if it does not exist
registered_image1, params1 = itk.elastix_registration_method(
    img_fw_smooth.astype(np.float32),  # FlyWire as the fixed image
    img_mcns_brain_smooth.astype(np.float32),  # MCNS as the moving image
    parameter_object=parameter_object,
    log_to_console=False,  # Set to true to see the output in this notebook
    log_to_file=True,
    log_file_name="elastix.log",
    output_directory=str(this_out),
)

In [None]:
def blend_images(img1, img2):
    """Additive blending of two images."""
    # Turn greyscale images to reds and greens respectively
    img1 = plt.get_cmap("Greens")(1 - (img1.sum(axis=2) / img2.sum(axis=2).max()))
    img2 = plt.get_cmap("RdPu")(1 - (img2.sum(axis=2) / img2.sum(axis=2).max()))
    # Merge the two images and clip the values
    return (img1 + img2) / (img1 + img2).max()


def plot_images(*images):
    """Plot images side by side."""
    fig, axes = plt.subplots(
        1, len(images) + 3, figsize=(len(images) + 3 * 2.5, 5), sharex=True, sharey=True
    )

    # Always show the moving image first
    axes[0].imshow(img_mcns_brain_smooth.sum(axis=2), cmap="gray")
    axes[0].set_title("MCNS (moving)")

    # Then show the individual steps
    for i, img in enumerate(images):
        axes[i + 1].imshow(img.sum(axis=2), cmap="gray")
        axes[i + 1].set_title(f"step {i + 1}")

    # Finally show the blend of the final and the fixed image
    axes[-2].imshow(blend_images(images[-1], img_fw_smooth))
    axes[-2].set_title("blend")

    # And the fixed image
    axes[-1].imshow(img_fw_smooth.sum(axis=2), cmap="gray")
    axes[-1].set_title("FlyWire (fixed)")

    plt.tight_layout()

    return axes

axes = plot_images(registered_image1)
axes[1].set_title("MCNS after rigid")

That looks decent!

If your alignment ends up incomplete (i.e. the moving image doesn't overlap well with the fixed image), you may want to increase the step length, max number of iterations and/or the number of resolutions.

For now we will continue with the **affine transform**!

In [None]:
# Number of resolutions
resolutions = 6

# Initialise the elastix image filter
parameter_object = itk.ParameterObject.New()

# Parameter map for the affine transform
parameter_map_affine = parameter_object.GetDefaultParameterMap("affine", resolutions)
parameter_map_affine["MaximumNumberOfIterations"] = ("3000",)  # 999
parameter_map_affine["MaximumStepLength"] = ("3.0",)

parameter_map_affine['HowToCombineTransforms'] = ("Combine", )

# Add the parameter map to the parameter object
parameter_object.AddParameterMap(parameter_map_affine)

# Perform the registration (~30s on my laptop, ~4mins on Colab)
this_out = output_path / "transform2-affine"
this_out.mkdir(exist_ok=True)  # Create the directory if it does not exist
registered_image2, params2 = itk.elastix_registration_method(
    img_fw_smooth.astype(np.float32),
    img_mcns_brain_smooth.astype(np.float32),
    parameter_object=parameter_object,
    initial_transform_parameter_object=params1,
    log_to_console=False,
    log_to_file=True,
    log_file_name="elastix.log",
    output_directory=str(this_out),
)

In [None]:
axes = plot_images(registered_image1, registered_image2)
axes[1].set_title("MCNS after rigid")
axes[2].set_title("MCNS after affine")

Again, that looks good - the MCNS is now correctly scaled. For non-rigid transforms you may want to start with a higher `FinalGridSpacingInPhysicalUnits` to first match against coarser features and then work your way down.

Here, we will demo two steps to show the issue with overfitting:

The first step will use:
1. A high-ish `FinalGridSpacingInPhysicalUnits` value which will match coarser structures
2. A much smaller max step size than in the previous registrations
3. More spatial samples for the optimisation function.

In [None]:
# Number of resolutions
resolutions = 6

# Initialise the elastix image filter
parameter_object = itk.ParameterObject.New()

# # Parameter map for the bspline transform
parameter_map_bspline = parameter_object.GetDefaultParameterMap("bspline", resolutions)
parameter_map_bspline["FinalGridSpacingInPhysicalUnits"] = ("32",)  # 20 is the default
parameter_map_bspline["NumberOfSpatialSamples"] = ("16000",)  # 2048 is the default
parameter_map_bspline["MaximumNumberOfIterations"] = ("999",)
parameter_map_bspline["MaximumStepLength"] = ("0.1",)

parameter_map_bspline["ImageSampler"] = ("RandomCoordinate", )
parameter_map_bspline["NumberOfSpatialSamples"] = ("16000", )
parameter_map_bspline["NewSamplesEveryIteration"] = ("true", )

# Add the parameter map to the parameter object
parameter_object.AddParameterMap(parameter_map_bspline)

# Perform the registration (~2mins on my laptop, ~25mins on Colab)
this_out = output_path / "transform3-bspline"
this_out.mkdir(exist_ok=True)  # Create the directory if it does not exist
registered_image3, params3 = itk.elastix_registration_method(
    img_fw_smooth.astype(np.float32),
    img_mcns_brain_smooth.astype(np.float32),
    parameter_object=parameter_object,
    initial_transform_parameter_object=params2,
    log_to_console=False,
    log_to_file=True,
    log_file_name="elastix.log",
    output_directory=str(this_out),
)

In [None]:
# Inspect the results
axes = plot_images(registered_image1, registered_image2, registered_image3)
axes[1].set_title("MCNS after rigid")
axes[2].set_title("MCNS after affine", size=10)
axes[3].set_title("MCNS after coarse bspline", size=10)

The second step will use a much lower value for `FinalGridSpacingInPhysicalUnits` in an attempt to match smaller features. Note this step will run much longer (~6mins).

In [None]:
# Number of resolutions
resolutions = 6

# Initialise the elastix image filter
parameter_object = itk.ParameterObject.New()

# # Parameter map for the bspline transform
parameter_map_bspline2 = parameter_object.GetDefaultParameterMap("bspline", resolutions)
parameter_map_bspline2["FinalGridSpacingInPhysicalUnits"] = ("8",)  # 20 is the default
parameter_map_bspline2["NumberOfSpatialSamples"] = ("8192",)  # 2048 is the default
parameter_map_bspline2["MaximumNumberOfIterations"] = ("999",)
parameter_map_bspline2["MaximumStepLength"] = ("0.1",)

# Add the parameter map to the parameter object
parameter_object.AddParameterMap(parameter_map_bspline2)

# Perform the registration (~2mins on my laptop, ~16mins on Colab)
registered_image4, params4 = itk.elastix_registration_method(
    img_fw_smooth.astype(np.float32),
    registered_image3,  # pass the result of the previous registration as the moving image
    parameter_object=parameter_object,
    log_to_console=False,
    log_to_file=True,
    log_file_name="elastix.log",
    output_directory=str(output_path),
)

In [None]:
axes = plot_images(
    registered_image1, registered_image2, registered_image3, registered_image4
)
axes[1].set_title("MCNS after rigid", size=8)
axes[2].set_title("MCNS after affine", size=8)
axes[3].set_title("MCNS after coarse bspline", size=8)
axes[4].set_title("MCNS after fine bspline", size=8)

That doesn't look too bad but let's have a closer look at our finely registered image:

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10, 7), sharex=True, sharey=True)

# Slice through the antennal lobe
ix = 40
x = slice(400, 600)
y = slice(100, 300)
axes[0].imshow(registered_image3[x, y, ix].T, cmap="gray")
axes[1].imshow(registered_image4[x, y, ix].T, cmap="gray")
axes[2].imshow(img_fw_smooth[x, y, ix].T, cmap="gray")

axes[0].set_title("MCNS after coarse bspline")
axes[1].set_title("MCNS after fine bspline")
axes[2].set_title("FlyWire target")

for ax in axes:
    ax.set_axis_off()

plt.tight_layout()

Note how we see these stretches in the antennal lobe after the fine bspline? If you try to match small features with a high step size you will overfit and get a sort of "worble" in your registered image!

####  Bonus round: Check quality

Before we stop, let's have a final look at the quality of the transforms:

To assess quality of transform, the spatial jacobian and the determinant of the spatial jacobian of the transformation can be calculated. Especially the determinant of the spatial Jacobian, which identifies the amount of local compression or expansion and can be useful.

Values smaller than 1 indicate local compression, values larger than 1 indicate local expansion, and 1 means volume preservation. The measure is quantitative: a value of 1.1 means a 10% increase in volume. If this value deviates substantially from 1, you may be worried (but maybe not if this is what you expect for your application). In case it is negative you have “foldings” in your transformation, and you definitely should be worried. For more information see the [elastix manual](https://elastix.lumc.nl/download/elastix-5.0.1-manual.pdf).

Evaluate the Jacobian matrix and its determinant for the coarse -> fine bspline registration:

In [None]:
# Convert the moving image to an ITK image
moving_image = itk.image_from_array(registered_image3)

# Calculate Jacobian matrix and it's determinant in a tuple
jacobians = itk.transformix_jacobian(moving_image, params4)

# Casting tuple to two numpy matrices for further calculations.
spatial_jacobian = np.asarray(jacobians[0]).astype(np.float32)
det_spatial_jacobian = np.asarray(jacobians[1]).astype(np.float32)

n_folds = np.sum(det_spatial_jacobian < 0)
frac_folds = n_folds / det_spatial_jacobian.size

print(f"Number of foldings in transformation: {n_folds} ({frac_folds:.2%})")

In [None]:
# Where are those folds?
axim = plt.imshow(det_spatial_jacobian.min(axis=2).T, cmap="coolwarm")
fig.colorbar(axim, location="right", anchor=(0, 0.3), shrink=0.7, label="Jacobian det.")