In [1]:
from collections import namedtuple

import imageio
import matplotlib.pyplot as plt
import numpy as np
import skimage
import SimpleITK as sitk

from scipy import ndimage
from scipy.spatial.transform import Rotation as R

Status = namedtuple("Status", ["stop_condition", "num_iterations", "metric_value"])
Transform = namedtuple("Transform", ["translation", "rotation_euler_degrees"])

In [2]:
def register_3d(fixed_image, moving_image):
    
    # Convert from numpy to sitk format.
    fixed_image = sitk.GetImageFromArray(fixed_image.astype(np.float32))
    moving_image = sitk.GetImageFromArray(moving_image.astype(np.float32))
    
    registration = sitk.ImageRegistrationMethod()
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image, 
        moving_image, 
        sitk.VersorRigid3DTransform(),
    )
    registration.SetInitialTransform(initial_transform)
    registration.SetMetricAsMeanSquares()
    registration.SetOptimizerScalesFromPhysicalShift()
    registration.SetOptimizerAsRegularStepGradientDescent(
        learningRate=2.0,
        minStep=1e-8,
        numberOfIterations=1000,
        gradientMagnitudeTolerance=1e-8,
    )

    final_transform = registration.Execute(fixed_image, moving_image)

    stop_condition = registration.GetOptimizerStopConditionDescription()
    num_iterations = registration.GetOptimizerIteration()
    metric_value = registration.GetMetricValue()  
    status = Status(stop_condition, num_iterations, metric_value)

    
    # downcast from generic transform to VersorRigid3DTransform
    # Allows access to translation and rotation parameters.
    final_transform = final_transform.Downcast()  

    # Convert parameters to numpy
    translation_xyz = final_transform.GetTranslation()
    versor = final_transform.GetVersor()

    translation = np.array(translation_xyz[::-1])
    rotation_euler_degrees = R.from_quat(versor).as_euler('zyx', degrees=True)
    result = Transform(translation, rotation_euler_degrees)
    
    return status, result

In [3]:
def create_images(cavity_centers, cavity_radii, atom_centers, atom_radii,
                  shape=None):
    if shape is None:
        cavity_extent = max_extent(cavity_centers, cavity_radii)
        atom_extent = max_extent(atom_centers, atom_radii)
        shape = np.ceil(np.maximum(cavity_extent, atom_extent))
        shape = tuple(s.astype(np.int64) + 2 for s in shape)
    
    cavity = spheres(cavity_centers, cavity_radii, shape)
    atoms = spheres(atom_centers, atom_radii, shape)
    
    cavity_edt = ndimage.distance_transform_edt(cavity)
    atoms_edt = ndimage.distance_transform_edt(atoms)
    
    return cavity_edt, atoms_edt

def max_extent(centers, radii):
    max_extents = np.asarray(centers) + np.asarray(radii)[:,None]
    return max_extents.max(axis=0)

def spheres(centers, radii, shape):
    assert len(centers) == len(radii)
    img = np.zeros(shape, np.bool_)
    for center, radius in zip(centers, radii):
        add_sphere(center, radius, img)
    return img        

def add_sphere(center, radius, img):
    slices = tuple(make_slice(c-radius-1, c+radius+2) for c in center)
    img[slices] |= skimage.draw.ellipsoid(radius, radius, radius)
    
def make_slice(lo, hi):
    hi = np.ceil(hi)
    return slice(np.int64(lo), np.int64(hi))

In [4]:
cavity_centers = [[15, 15, 15], [15, 20, 20]]
cavity_radii = [10, 10]

atom_centers = [[15, 15, 15], [15, 15, 23]]
atom_radii = [7, 7]

cavity_edt, atoms_edt = create_images(cavity_centers, cavity_radii, 
                                      atom_centers, atom_radii)

In [5]:
status, transform = register_3d(cavity_edt, atoms_edt)

In [6]:
status

Status(stop_condition='RegularStepGradientDescentOptimizerv4: Gradient magnitude tolerance met after 157 iterations. Gradient magnitude (9.97968e-09) is less than gradient magnitude tolerance (1e-08).', num_iterations=158, metric_value=1.4387584857741298)

In [7]:
transform

Transform(translation=array([-6.23405288e-17, -2.49999995e+00,  1.50228265e+00]), rotation_euler_degrees=array([-45.00015592,   0.        ,   0.        ]))