In [None]:
import numpy as np
from dipy.data import get_fnames
from dipy.align.imwarp import SymmetricDiffeomorphicRegistration
from dipy.align.metrics import SSDMetric, CCMetric, EMMetric
import dipy.align.imwarp as imwarp
from dipy.viz import regtools
import matplotlib.pyplot as plt
from imgaug import augmenters as iaa
import cv2

In [None]:
keypoints = [(2837, 1528), (2948, 1530), (1296, 1396)]

In [None]:
# fname_moving = get_fnames('reg_o')
# fname_static = get_fnames('reg_c')
aug = iaa.Affine(scale={"x": (0.5, 0.9), "y": (1.2, 1.5)})
static = cv2.imread("/root/data/gtsf_2.0/registration_test/gtsf_ref_mask.jpg", 0)
height, width = static.shape
print(width, height)
ratio_width = width / 800.0
ratio_height = height / 600.0
new_keypoints = np.array([(int(kp[0]/ratio_width), int(kp[1]/ratio_height)) for kp in keypoints])
static = cv2.resize(static, (800, 600))
static[static>0]=1
# moving = aug.augment_image(static)
moving = cv2.imread("/root/data/gtsf_2.0/registration_test/gtsf_2.jpg", 0)
moving = cv2.resize(moving, (800, 600))
moving[moving > 0] = 1

In [None]:
# static = cv2.imread("/root/data/gtsf_2.0/registration_test/gtsf_ref.jpg")
# new_keypoints = np.array([(int(kp[0]), int(kp[1])) for kp in keypoints])
# plt.figure(figsize=(20, 10))
# plt.imshow(static)
# plt.scatter(new_keypoints[:, 0], new_keypoints[:, 1])
# plt.show()

In [None]:
regtools.overlay_images(static, moving, 'Static', 'Overlay', 'Moving', 'input_images.png')

calculate the centroids

In [None]:
# centroids
xs, ys = np.where(static == 1)
centroid_s = np.array((np.mean(xs), np.mean(ys)))
xm, ym = np.where(moving == 1)
centroid_m = np.array((np.mean(xm), np.mean(ym)))

In [None]:
f ,ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(static)
ax[0].scatter(centroid_s[1], centroid_s[0], c="r")
ax[1].imshow(moving)
ax[1].scatter(centroid_m[1], centroid_m[0], c="r")
plt.show()

compute the transform

In [None]:
# m to s
translation = centroid_s - centroid_m
print(translation)
xmt, ymt = (xm+translation[0], ym+translation[1])

In [None]:
moving_translated = np.zeros_like(static)
for (xi, yi) in zip(xmt,ymt):
    moving_translated[int(xi), int(yi)] = 1

In [None]:
xm, ym = np.where(moving_translated == 1)
centroid_m = np.array((np.mean(xm), np.mean(ym)))

In [None]:
f ,ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(static)
ax[0].scatter(centroid_s[1], centroid_s[0], c="r")
ax[1].imshow(moving_translated)
ax[1].scatter(centroid_m[1], centroid_m[0], c="r")
plt.show()

In [None]:
regtools.overlay_images(static, moving_translated, 'Static', 'Overlay', 'Moving', 'input_images.png')

In [None]:
dim = static.ndim
metric = SSDMetric(dim)

In [None]:
level_iters = [200, 100, 50, 25, 10]

sdr = SymmetricDiffeomorphicRegistration(metric, level_iters, inv_iter = 50)

In [None]:
from time import time

In [None]:
start = time()
mapping = sdr.optimize(static, moving_translated)
end = time()
print("Duration {} seconds".format(end-start))

In [None]:
regtools.plot_2d_diffeomorphic_map(mapping, 10, 'diffeomorphic_map.png')

In [None]:
mapping.forward.shape

In [None]:
warped_moving = mapping.transform(moving_translated, 'linear')
regtools.overlay_images(static, warped_moving, 'Static','Overlay','Warped moving',
   'direct_warp_result.png')

In [None]:
warped_static = mapping.transform_inverse(static, 'linear')
regtools.overlay_images(warped_static, moving_translated,'Warped static','Overlay','Moving',
   'inverse_warp_result.png')

In [None]:
kp_map = np.zeros_like(static)
for kp in new_keypoints:
    kp_map[kp[1]-1:kp[1]+1, kp[0]-1:kp[0]+1] = 1
plt.imshow(kp_map)
plt.show()

In [None]:
test = mapping.transform_inverse(kp_map, "linear")

In [None]:
plt.imshow(test)
plt.imshow(moving_translated, alpha=0.5)
plt.show()

In [None]:
warped_keypoints = np.array(np.where(test>0))

back to original space

In [None]:
static = cv2.imread("/root/data/gtsf_2.0/registration_test/gtsf_ref.jpg")
moving = cv2.imread("/root/data/gtsf_2.0/registration_test/gtsf_2.jpg")
keypoints = np.array([(2837, 1528), (2948, 1530), (1296, 1396)])
f, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(static)
ax[0].scatter(keypoints[:, 0], keypoints[:, 1])
ax[0].axis("off")
ax[1].imshow(moving)
ax[1].scatter((warped_keypoints[1, :]-translation[1]) * (ratio_width) , (warped_keypoints[0, :]-translation[0]) * (ratio_height))
ax[1].axis("off")
plt.show()