In [5]:
import SimpleITK as sitk
import numpy as np
from ipywidgets import interact, fixed
import matplotlib.pyplot as plt

In [6]:
def display_displacement_scaling_effect(s, original_x_mat, original_y_mat, tx, original_control_point_displacements):
    if tx.GetDimension() != 2:
        raise ValueError("display_displacement_scaling_effect only works in 2D")

    plt.scatter(
        original_x_mat,
        original_y_mat,
        marker="o",
        color="blue",
        label="original points",
    )
    pointsX = []
    pointsY = []
    tx.SetParameters(s * original_control_point_displacements)

    for index, value in np.ndenumerate(original_x_mat):
        px, py = tx.TransformPoint((value, original_y_mat[index]))
        pointsX.append(px)
        pointsY.append(py)

    plt.scatter(pointsX, pointsY, marker="^", color="red", label="transformed points")
    plt.legend(loc=(0.25, 1.01))
    plt.xlim((-2.5, 2.5))
    plt.ylim((-2.5, 2.5))

In [9]:
# Create the transformation (when working with images it is easier to use the BSplineTransformInitializer function
# or its object oriented counterpart BSplineTransformInitializerFilter).
dimension = 2
spline_order = 3
direction_matrix_row_major = [1.0, 0.0, 0.0, 1.0]  # identity, mesh is axis aligned
origin = [-1.0, -1.0]
domain_physical_dimensions = [2, 2]
mesh_size = [4, 3]

bspline = sitk.BSplineTransform(dimension, spline_order)
bspline.SetTransformDomainOrigin(origin)
bspline.SetTransformDomainDirection(direction_matrix_row_major)
bspline.SetTransformDomainPhysicalDimensions(domain_physical_dimensions)
bspline.SetTransformDomainMeshSize(mesh_size)

# Random displacement of the control points, specifying the x and y
# displacements separately allows us to play with these parameters,
# just multiply one of them with zero to see the effect.
x_displacement = np.random.random(len(bspline.GetParameters()) // 2)
y_displacement = np.random.random(len(bspline.GetParameters()) // 2)
original_control_point_displacements = np.concatenate([x_displacement, y_displacement])
bspline.SetParameters(original_control_point_displacements)

# Apply the BSpline transformation to a grid of points
# starting the point set exactly at the origin of the BSpline mesh is problematic as
# these points are considered outside the transformation's domain,
# remove epsilon below and see what happens.
numSamplesX = 10
numSamplesY = 20

coordsX = np.linspace(
    origin[0] + np.finfo(float).eps,
    origin[0] + domain_physical_dimensions[0],
    numSamplesX,
)
coordsY = np.linspace(
    origin[1] + np.finfo(float).eps,
    origin[1] + domain_physical_dimensions[1],
    numSamplesY,
)
XX, YY = np.meshgrid(coordsX, coordsY)

In [10]:
interact(
    display_displacement_scaling_effect,
    s=(-1.5, 1.5),
    original_x_mat=fixed(XX),
    original_y_mat=fixed(YY),
    tx=fixed(bspline),
    original_control_point_displacements=fixed(original_control_point_displacements),
)

interactive(children=(FloatSlider(value=0.0, description='s', max=1.5, min=-1.5), Output()), _dom_classes=('wi…

<function __main__.display_displacement_scaling_effect(s, original_x_mat, original_y_mat, tx, original_control_point_displacements)>

In [11]:
# Create the displacement field.

# When working with images the safer thing to do is use the image based constructor,
# sitk.DisplacementFieldTransform(my_image), all the fixed parameters will be set correctly and the displacement
# field is initialized using the vectors stored in the image. SimpleITK requires that the image's pixel type be
# sitk.sitkVectorFloat64.
displacement = sitk.DisplacementFieldTransform(2)
field_size = [10, 20]
field_origin = [-1.0, -1.0]
field_spacing = [2.0 / 9.0, 2.0 / 19.0]
field_direction = [1, 0, 0, 1]  # direction cosine matrix (row major order)

# Concatenate all the information into a single list
displacement.SetFixedParameters(
    field_size + field_origin + field_spacing + field_direction
)
# Set the interpolator, either sitkLinear which is default or nearest neighbor
displacement.SetInterpolator(sitk.sitkNearestNeighbor)

originalDisplacements = np.random.random(len(displacement.GetParameters()))
displacement.SetParameters(originalDisplacements)

coordsX = np.linspace(
    field_origin[0],
    field_origin[0] + (field_size[0] - 1) * field_spacing[0],
    field_size[0],
)
coordsY = np.linspace(
    field_origin[1],
    field_origin[1] + (field_size[1] - 1) * field_spacing[1],
    field_size[1],
)
XX, YY = np.meshgrid(coordsX, coordsY)

In [12]:
interact(
    display_displacement_scaling_effect,
    s=(-1.5, 1.5),
    original_x_mat=fixed(XX),
    original_y_mat=fixed(YY),
    tx=fixed(displacement),
    original_control_point_displacements=fixed(originalDisplacements),
);

interactive(children=(FloatSlider(value=0.0, description='s', max=1.5, min=-1.5), Output()), _dom_classes=('wi…