# Alignment using Keypoint Detection Methods

In [None]:
import cv2
import numpy as np

from matplotlib import pyplot as plt

plt.rcParams["figure.figsize"] = [20, 10]

from ai_ct_scans import data_loading, keypoint_alignment

In [None]:
# Load full scan data
dl = data_loading.MultiPatientLoader()
dl.patients[0].thorax.scan_1.load_scan()
dl.patients[0].thorax.scan_2.load_scan()

In [None]:
# Plot some full scale data
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(dl.patients[0].thorax.scan_1.full_scan[:, 255, :])
axarr[0].title.set_text("First scan")
axarr[1].imshow(dl.patients[0].thorax.scan_2.full_scan[:, 255, :])
axarr[1].title.set_text("Second scan")

In [None]:
full_images = [
    cv2.normalize(
        dl.patients[0].thorax.scan_1.full_scan[:, 255, :],
        np.zeros((0, 0)),
        0,
        255,
        cv2.NORM_MINMAX,
    ).astype("uint8"),
    cv2.normalize(
        dl.patients[0].thorax.scan_2.full_scan[:, 255, :],
        np.zeros((0, 0)),
        0,
        255,
        cv2.NORM_MINMAX,
    ).astype("uint8"),
]

In [None]:
# Start with keypoint detection
key_points_1, descriptors_1 = keypoint_alignment.get_keypoints_and_descriptors(
    full_images[0]
)
key_points_2, descriptors_2 = keypoint_alignment.get_keypoints_and_descriptors(
    full_images[1]
)

kps1 = np.zeros((0, 0))
kps2 = np.zeros((0, 0))
kps1 = cv2.drawKeypoints(full_images[0], key_points_1, kps1)
kps2 = cv2.drawKeypoints(full_images[1], key_points_2, kps2)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(kps1)
axarr[1].imshow(kps2)

In [None]:
# Now apply some matching to the keypoints
good_match_sets = keypoint_alignment.match_descriptors(descriptors_1, descriptors_2)

resized_image_1 = cv2.copyMakeBorder(
    full_images[0],
    0,
    max(full_images[0].shape[0], full_images[1].shape[0]) - full_images[0].shape[0],
    0,
    max(full_images[0].shape[1], full_images[1].shape[1]) - full_images[0].shape[1],
    cv2.BORDER_CONSTANT,
)

resized_image_2 = cv2.copyMakeBorder(
    full_images[1],
    0,
    max(full_images[0].shape[0], full_images[1].shape[0]) - full_images[1].shape[0],
    0,
    max(full_images[0].shape[1], full_images[1].shape[1]) - full_images[1].shape[1],
    cv2.BORDER_CONSTANT,
)

img = cv2.drawMatchesKnn(
    resized_image_1,
    key_points_1,
    resized_image_1,
    key_points_2,
    good_match_sets,
    None,
    flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
)
plt.imshow(img)

In [None]:
aligned_image = keypoint_alignment.align_image(full_images[1], full_images[0])

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(cv2.addWeighted(resized_image_1, 0.5, resized_image_2, 0.5, 0.0))
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(cv2.addWeighted(full_images[0], 0.5, aligned_image, 0.5, 0.0))
axarr[1].title.set_text("After alignment")