In [1]:
from collections import namedtuple

import imageio
import matplotlib.pyplot as plt
import numpy as np
import skimage
import SimpleITK as sitk

from scipy import ndimage
from scipy.spatial.transform import Rotation

In [2]:
Status = namedtuple("Status", ["stop_condition", "num_iterations", "metric_value"])

def register_3d_masks(fixed_mask, moving_mask):
    """Rigid 3d registration based on Euclidean Distance Transform of input masks."""

    # Euclidean Distance Transform finds distance of each signal pixel from
    # its nearest background pixel.  Background pixels remain zero.
    fixed_image = ndimage.distance_transform_edt(fixed_mask)
    moving_image = ndimage.distance_transform_edt(moving_mask)

    # Convert from numpy to sitk format.
    fixed_image = sitk.GetImageFromArray(fixed_image)
    moving_image = sitk.GetImageFromArray(moving_image)
    
    registration = sitk.ImageRegistrationMethod()

    # What is appropriate for EDT, which are not intensities?
    # registration.SetMetricAsMeanSquares()  
    registration.SetMetricAsCorrelation()
    
    # Using a Rigid 3D transformation where the rotation is parameterized
    # by a versor, which is a unit quaternion.  Quaternions are more
    # numerically stable for optimization problems compared to rotation
    # matrices or rotation vectors.
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image, 
        moving_image, 
        sitk.VersorRigid3DTransform(),
    )
    registration.SetInitialTransform(initial_transform)
    
    # The following rescales the angular dimensions so that they are
    # on the same scale as the spatial dimensions.
    registration.SetOptimizerScalesFromPhysicalShift()

    registration.SetOptimizerAsRegularStepGradientDescent(
        learningRate=2.0,
        minStep=1e-8,
        numberOfIterations=1000,
        gradientMagnitudeTolerance=1e-8,
    )

    final_transform = registration.Execute(fixed_image, moving_image)

    stop_condition = registration.GetOptimizerStopConditionDescription()
    num_iterations = registration.GetOptimizerIteration()
    metric_value = registration.GetMetricValue()  
    status = Status(stop_condition, num_iterations, metric_value)
    
    # downcast from generic transform to VersorRigid3DTransform, which
    # allows direct access to translation and rotation parameters.
    return status, final_transform.Downcast()

def resample(mask, transform):
    """Return mask after being transformed."""
    mask = sitk.GetImageFromArray(mask.astype(np.float32))
    mask = sitk.Resample(mask, mask, transform)
    return sitk.GetArrayFromImage(mask)

def print_transform(transform):
    """Print transform in human-understandable terms."""
    translation_xyz = transform.GetTranslation()
    print('translation (z,y,x)', translation_xyz[::-1])

    versor = transform.GetVersor()
    rotation = Rotation.from_quat(versor).as_euler('zyx', degrees=True)
    print('rotation (z,y,x degrees)', rotation)

In [3]:
# Routines for creating testdata for sanity checking.

def create_images(cavity_centers, cavity_radii, atom_centers, atom_radii):
    cavity_extent = max_extent(cavity_centers, cavity_radii)
    atom_extent = max_extent(atom_centers, atom_radii)
    shape = np.ceil(np.maximum(cavity_extent, atom_extent))
    
    # Add 2 for how skimage draws ellipsoids.
    shape = tuple(int(s + 2) for s in shape)
    
    cavity = spheres(cavity_centers, cavity_radii, shape)
    atoms = spheres(atom_centers, atom_radii, shape)
    return cavity, atoms

def max_extent(centers, radii):
    """Find how large an image is needed to enclose a set of spheres."""
    max_extents = np.asarray(centers) + np.asarray(radii)[:,None]
    return max_extents.max(axis=0)

def spheres(centers, radii, shape):
    """Draw spheres at the given centers with the given radii."""
    assert len(centers) == len(radii)
    img = np.zeros(shape, np.bool_)
    for center, radius in zip(centers, radii):
        add_sphere(center, radius, img)
    return img        

def add_sphere(center, radius, img):
    """Draw a single sphere on an image."""
    # NOTE: the center is rounded down to an integer.
    slices = tuple(slice(int(c)-int(radius)-1, int(c)+int(radius)+2) for c in center)
    # NOTE: the radius it NOT rounded in drawing the ellipsoid.
    img[slices] |= skimage.draw.ellipsoid(radius, radius, radius)

In [4]:
# Create some test data.

cavity_centers = [[15, 15, 15], [15, 20, 20]]
cavity_radii = [9, 10]

atom_centers = [[15.2, 15, 15], [15, 15, 23]]
atom_radii = [8, 8]

cavity, atoms = create_images(
    cavity_centers, 
    cavity_radii, 
    atom_centers, 
    atom_radii,
)

In [5]:
status, transform = register_3d_masks(cavity, atoms)
atoms_xform = resample(atoms, transform)
overlap_frac = atoms_xform[cavity].sum() / atoms_xform.sum()

In [6]:
print(status)
print()
print_transform(transform)
print()
print('Overlap fraction:', overlap_frac)

Status(stop_condition='RegularStepGradientDescentOptimizerv4: Gradient magnitude tolerance met after 144 iterations. Gradient magnitude (9.62972e-09) is less than gradient magnitude tolerance (1e-08).', num_iterations=145, metric_value=-0.9160434825611494)

translation (z,y,x) (2.5337928039157717e-05, -3.152382780631124, 1.2240877256535732)
rotation (z,y,x degrees) [-4.49998272e+01 -7.09502729e-06 -1.11285402e-05]

Overlap fraction: 1.0000001
