In [None]:
import cv2
import numpy as np
import cycpd
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import Ridge, RANSACRegressor
from sklearn.pipeline import make_pipeline
from skimage import transform
from scipy import ndimage

from matplotlib import pyplot as plt

plt.rcParams["figure.figsize"] = [20, 10]

from ai_ct_scans import (
    data_loading,
    keypoint_alignment,
    point_matching,
    phase_correlation_image_processing,
    image_processing_utils,
)
from ai_ct_scans.phase_correlation import shift_nd

In [None]:
# Load full scan data
patient_dir = data_loading.data_root_directory() / "1"
patient_loader = data_loading.PatientLoader(patient_dir)

patient_loader.abdo.scan_1.load_scan()
patient_loader.abdo.scan_2.load_scan()

scan_1_mid = 255
scan_2_mid = 255

# Plot some full scale data
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(patient_loader.abdo.scan_1.full_scan[:, scan_1_mid, :])
axarr[0].title.set_text("First scan")
axarr[1].imshow(patient_loader.abdo.scan_2.full_scan[:, scan_2_mid, :])
axarr[1].title.set_text("Second scan")

reference = patient_loader.abdo.scan_1.full_scan[:, :320, :]
image = patient_loader.abdo.scan_2.full_scan[:, :320, :]

thresh = 500
reference[reference < thresh] = 0
image[image < thresh] = 0

In [None]:
# Extract 3D source and target points
reference = phase_correlation_image_processing.lmr(
    reference, filter_type=None, radius=10
)
reference * phase_correlation_image_processing.zero_crossings(reference, thresh="auto")
key_point_coords_1 = np.where(
    reference[2:-2, 2:-2, 2:-2] > np.percentile(reference, 98)
)
reference = None

image = phase_correlation_image_processing.lmr(image, filter_type=None, radius=10)
image * phase_correlation_image_processing.zero_crossings(image, thresh="auto")
key_point_coords_2 = np.where(image[2:-2, 2:-2, 2:-2] > np.percentile(image, 98))
image = None

print("Extracted {len(key_point_coords_1[0])} points from target")
print("Extracted {len(key_point_coords_2[0])} points from source")

x = np.stack(key_point_coords_1, axis=1).astype(np.float64)
y = np.stack(key_point_coords_2, axis=1).astype(np.float64)

x = x[::50]
y = y[::200]

fig = plt.figure()
ax1 = fig.add_subplot(111, projection="3d")
ax1.scatter(x[:, 0], x[:, 1], x[:, 2], color="b", alpha=0.5, s=1.0)

In [None]:
# Align points using Coherent Point Drift
reg = cycpd.deformable_registration(
    **{"X": x, "Y": y, "alpha": 0.1, "beta": 30, "max_iterations": 500}
)

non_rigid_out = reg.register()

fig = plt.figure()
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122, projection="3d")
ax1.scatter(x[:, 0], x[:, 1], x[:, 2], color="b", alpha=0.5, s=1.0)
ax1.scatter(y[:, 0], y[:, 1], y[:, 2], color="r", alpha=0.5, s=1.0)
ax1.title.set_text("Before")
ax2.scatter(x[:, 0], x[:, 1], x[:, 2], color="b", alpha=0.5, s=1.0)
ax2.scatter(
    non_rigid_out[0][:, 0],
    non_rigid_out[0][:, 1],
    non_rigid_out[0][:, 2],
    color="r",
    alpha=0.5,
    s=1.0,
)
ax2.title.set_text("After")

In [None]:
# Match points based on alignment
matched_indices = point_matching.match_indices(x, non_rigid_out[0])

X = x[matched_indices[1]]
y = y[matched_indices[0]]

In [None]:
# Estimate and apply alignment calculated from extracted points to full image data
to_align = patient_loader.abdo.scan_2.full_scan

x_len, y_len, z_len = to_align.shape
dist = np.linalg.norm(X - non_rigid_out[0][matched_indices[0]], axis=1)

target_filtered = X[np.where(dist < 1)].astype(np.float32)
source_filtered = y[np.where(dist < 1)].astype(np.float32)

poly_trans = make_pipeline(PolynomialFeatures(degree=3), Ridge())
poly_trans.fit(target_filtered, source_filtered)


def apply_poly(points):
    return poly_trans.predict(points)


x_grid, y_grid, z_grid = np.meshgrid(
    np.arange(x_len), np.arange(y_len), np.arange(z_len), indexing="ij"
)
grid_points = np.stack([x_grid.ravel(), y_grid.ravel(), z_grid.ravel()], axis=1).astype(
    np.float32
)
x_grid = None
y_grid = None
z_grid = None

coords_in_input = apply_poly(grid_points)
coords_in_input = np.array(
    [
        coords_in_input[:, 0].reshape(x_len, y_len, z_len),
        coords_in_input[:, 1].reshape(x_len, y_len, z_len),
        coords_in_input[:, 2].reshape(x_len, y_len, z_len),
    ]
)
aligned = ndimage.map_coordinates(to_align, coords_in_input)

In [None]:
# Generate a number of examples showing how a 2D slice has been aligned
reference = patient_loader.abdo.scan_1.full_scan

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 255, :]),
            image_processing_utils.normalise(to_align[:, 255, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 255, :]),
            image_processing_utils.normalise(aligned[:, 255, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[255, :, :]),
            image_processing_utils.normalise(to_align[255, :, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[255, :, :]),
            image_processing_utils.normalise(aligned[255, :, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 200, :]),
            image_processing_utils.normalise(to_align[:, 200, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 200, :]),
            image_processing_utils.normalise(aligned[:, 200, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 300, :]),
            image_processing_utils.normalise(to_align[:, 300, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 300, :]),
            image_processing_utils.normalise(aligned[:, 300, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[200, :, :]),
            image_processing_utils.normalise(to_align[200, :, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[200, :, :]),
            image_processing_utils.normalise(aligned[200, :, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[400, :, :]),
            image_processing_utils.normalise(to_align[400, :, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[400, :, :]),
            image_processing_utils.normalise(aligned[400, :, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, :, 200]),
            image_processing_utils.normalise(to_align[:, :, 200]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, :, 200]),
            image_processing_utils.normalise(aligned[:, :, 200]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")