In [None]:
import cv2
import numpy as np
import cycpd
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import Ridge
from sklearn.pipeline import make_pipeline
from skimage import transform

from matplotlib import pyplot as plt

plt.rcParams["figure.figsize"] = [20, 10]

from ai_ct_scans import (
    data_loading,
    keypoint_alignment,
    point_matching,
    phase_correlation_image_processing,
    image_processing_utils,
)

In [None]:
# Load full scan data
patient_dir = data_loading.data_root_directory() / "1"
patient_loader = data_loading.PatientLoader(patient_dir)

patient_loader.abdo.scan_1.load_scan()
patient_loader.abdo.scan_2.load_scan()

scan_1_mid = 255
scan_2_mid = 255

# Plot some full scale data
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(patient_loader.abdo.scan_1.full_scan[:, scan_1_mid, :])
axarr[0].title.set_text("First scan")
axarr[1].imshow(patient_loader.abdo.scan_2.full_scan[:, scan_2_mid, :])
axarr[1].title.set_text("Second scan")

reference = patient_loader.abdo.scan_1.full_scan[:, scan_1_mid, :]
image = patient_loader.abdo.scan_2.full_scan[:, scan_2_mid, :]

thresh = 500
reference[reference < thresh] = 0
image[image < thresh] = 0

In [None]:
reference_filtered = phase_correlation_image_processing.lmr(reference, filter_type=None, radius=10)
reference_filtered * phase_correlation_image_processing.zero_crossings(reference_filtered, thresh='auto')
key_point_coords_1 = np.where(reference_filtered[2:-2, 2:-2] > 250)

image_filtered = phase_correlation_image_processing.lmr(image, filter_type=None, radius=10)
image_filtered * phase_correlation_image_processing.zero_crossings(image_filtered, thresh='auto')
key_point_coords_2 = np.where(image_filtered[2:-2, 2:-2] > 250)

print(f'Extracted {len(key_point_coords_1[0])} target points)
print(f'Extracted {len(key_point_coords_2[0])} source points)

x = np.stack(key_point_coords_1, axis=1).astype(np.float64)
y = np.stack(key_point_coords_2, axis=1).astype(np.float64)

x = x[:,[1, 0]]
y = y[:,[1, 0]]

In [None]:
# Plot the extracted source and target points before alignment
plt.scatter(x[:, 0], x[:, 1], color="b", alpha=0.5, s=1.0)
plt.scatter(y[:, 0], y[:, 1], color="r", alpha=0.5, s=1.0)

In [None]:
# Align source points to target using Coherent Point Drift
reg = cycpd.deformable_registration(
    **{"X": x, "Y": y, "alpha": 0.05, "beta": 30, "max_iterations": 100}
)

non_rigid_out = reg.register()

f, axarr = plt.subplots(1, 2)
axarr[0].scatter(x[:, 0], x[:, 1], color="b", alpha=0.5, s=1.0)
axarr[0].scatter(y[:, 0], y[:, 1], color="r", alpha=0.5, s=1.0)
axarr[0].title.set_text("Before")
axarr[1].scatter(x[:, 0], x[:, 1], color="b", alpha=0.5, s=1.0)
axarr[1].scatter(
    non_rigid_out[0][:, 0], non_rigid_out[0][:, 1], color="r", alpha=0.5, s=1.0
)
axarr[1].title.set_text("After")

In [None]:
# Match the set of aligned points
matched_indices = point_matching.match_indices(x, non_rigid_out[0])

X = x[matched_indices[1]]
y = y[matched_indices[0]]

In [None]:
# Estimate a non-rigid transform
poly_trans = make_pipeline(PolynomialFeatures(degree=3), Ridge())
poly_trans.fit(X, y)

non_rigid_aligned = transform.warp(image, poly_trans.predict)
non_rigid_aligned = image_processing_utils.normalise(non_rigid_aligned)

# Estimate a rigid transform
rigid_homography = transform.ProjectiveTransform()
rigid_homography.estimate(X, y)
rigid_aligned = transform.warp(image, rigid_homography)
rigid_aligned = image_processing_utils.normalise(rigid_aligned)

normalised_reference = image_processing_utils.normalise(reference)

# Plot results of both
f, axarr = plt.subplots(1, 3)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, image_processing_utils.normalise(image)], False
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, rigid_aligned], False
    )
)
axarr[1].title.set_text("After rigid alignment")
axarr[2].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, non_rigid_aligned], False
    )
)
axarr[2].title.set_text("After non-rigid alignment")

In [None]:
# Filter out bad matches

dist = np.linalg.norm(X - non_rigid_out[0][matched_indices[0]], axis=1)

X_filtered = X[np.where(dist < 1)]
y_filtered = y[np.where(dist < 1)]

print(len(X_filtered))
print(len(y_filtered))


# non-rigid transform
poly_trans = make_pipeline(PolynomialFeatures(degree=5), Ridge())
poly_trans.fit(X_filtered, y_filtered)

non_rigid_aligned = transform.warp(image, poly_trans.predict)
non_rigid_aligned = image_processing_utils.normalise(non_rigid_aligned)

# rigid transform
rigid_homography = transform.ProjectiveTransform()
rigid_homography.estimate(X_filtered, y_filtered)
rigid_aligned = transform.warp(image, rigid_homography)
rigid_aligned = image_processing_utils.normalise(rigid_aligned)

f, axarr = plt.subplots(1, 3)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, image_processing_utils.normalise(image)], False
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, rigid_aligned], False
    )
)
axarr[1].title.set_text("After rigid alignment")
axarr[2].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, non_rigid_aligned], False
    )
)
axarr[2].title.set_text("After non-rigid alignment")

In [None]:
from sklearn.linear_model import RANSACRegressor

# non-rigid transform
poly_trans = make_pipeline(PolynomialFeatures(degree=5), RANSACRegressor())
poly_trans.fit(X_filtered, y_filtered)

non_rigid_aligned = transform.warp(image, poly_trans.predict)
non_rigid_aligned = image_processing_utils.normalise(non_rigid_aligned)

# rigid transform
rigid_homography = transform.ProjectiveTransform()
rigid_homography.estimate(X, y)
rigid_aligned = transform.warp(image, rigid_homography)
rigid_aligned = image_processing_utils.normalise(rigid_aligned)

normalised_reference = image_processing_utils.normalise(reference)

f, axarr = plt.subplots(1, 3)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, image_processing_utils.normalise(image)], False
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, rigid_aligned], False
    )
)
axarr[1].title.set_text("After rigid alignment")
axarr[2].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, non_rigid_aligned], False
    )
)
axarr[2].title.set_text("After non-rigid alignment")

In [None]:
# Try fitting a spline as an alignment transform (further work required)
from scipy import interpolate

x_spline = interpolate.SmoothBivariateSpline(
    X_filtered[:, 0],
    X_filtered[:, 1],
    X_filtered[:, 0] - y_filtered[:, 0],
    kx=2,
    ky=2,
    s=1e5,
)
y_spline = interpolate.SmoothBivariateSpline(
    X_filtered[:, 0],
    X_filtered[:, 1],
    X_filtered[:, 1] - y_filtered[:, 1],
    kx=2,
    ky=2,
    s=1e5,
)


def apply_spline(xy):
    return xy[0] - x_spline(xy[0], xy[1]), xy[1] - y_spline(xy[0], xy[1])


from scipy import ndimage

non_rigid_spline_aligned = ndimage.geometric_transform(image, apply_spline)
non_rigid_spline_aligned = image_processing_utils.normalise(non_rigid_spline_aligned)
plt.imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [normalised_reference, non_rigid_spline_aligned], False
    )
)