# Image Registration Tests

In [None]:
from pathlib import Path
import time
import json
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib import patches, patheffects
import ipywidgets as widgets
from IPython import display

In [None]:
import sys
sys.path.insert(0, '../src')

from registration import register_images_ft, register_images_ecc, register_images_fm
from utils import xyxy_to_xywh, xywh_xyxy

In [None]:
# To debug external functions
%load_ext autoreload
%autoreload 2

## Helper Functions

In [None]:
def show_img(img, ax=None, figsize=(12, 12), title=""):
    if not ax:
        fig, ax = plt.subplots(figsize=figsize)
    
    ax.imshow(img, cmap='gray')
    ax.set_axis_off()

    ax.set_title(title)
    ax.grid(False)

    return ax


def plot_img_with_bboxes(img, gt_boxes=None, pred_boxes=None, ax=None, figsize=(12, 12), title=""):

    ax = show_img(img, ax=ax, figsize=figsize, title=title)
    
    if gt_boxes is not None:
        for bbox in gt_boxes:        
            draw_rect(ax, bbox, color='red')

    if pred_boxes is not None:
        for bbox in pred_boxes:        
            draw_rect(ax, bbox, color='blue')            

    return ax    


def draw_rect(ax, b, color='red'):
    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, alpha=0.5, edgecolor=color, lw=2))

## Load Data

In [None]:
data_path = Path("../data/raw/setup_1/")
ann_path = data_path / "instances_coco.json"

In [None]:
with open(ann_path) as f:
    coco_data = json.load(f)   

In [None]:
images = coco_data["images"]
anns = coco_data["annotations"]

In [None]:
templ_idx = 0

templ_img_info = images[templ_idx]

templ_anns = [ann for ann in anns if ann["image_id"] == templ_img_info["id"]]
templ_path = data_path / templ_img_info["file_name"]

templ_img = cv2.imread(str(templ_path))

templ_bboxes = [ann["bbox"] for ann in templ_anns]


plot_img_with_bboxes(templ_img, templ_bboxes)

In [None]:
target_idx = 8

img_info = images[target_idx]

target_anns = [ann for ann in anns if ann["image_id"] == img_info["id"]]
img_path = data_path / img_info["file_name"]

img = cv2.imread(str(img_path))

bboxes = [ann["bbox"] for ann in target_anns]


_ = plot_img_with_bboxes(img, bboxes, templ_bboxes, title="GT boxes (red), boxes from template (blue)")

## Registration

In [None]:
templ_img_g = cv2.cvtColor(templ_img, cv2.COLOR_BGR2GRAY)
img_g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

print("Template image shape: ", templ_img_g.shape)
print("Target image shape: ", img_g.shape)

start = time.time()
mat = register_images_fm(img_g, templ_img_g)
duration = time.time() - start
print(f"Duration: {duration:.3f} s")

templ_bboxes_xyxy = np.array([xywh_xyxy(box) for box in templ_bboxes])

templ_p1s = templ_bboxes_xyxy[:,:2]
templ_p2s = templ_bboxes_xyxy[:,2:]

templ_ps = np.vstack((templ_p1s, templ_p2s))

if mat is None:
    print("Registration failed")
else:    
    pers_mat = np.zeros(shape=(3, 3))
    pers_mat[:2,:] = mat
    pers_mat[2] = [0, 0, 1]

    templ_ps_reg = cv2.perspectiveTransform(np.array([templ_ps]), pers_mat)[0]            

templ_p1s_reg = templ_ps_reg[:len(templ_bboxes)]
templ_p2s_reg = templ_ps_reg[len(templ_bboxes):]

templ_bboxes_xyxy_reg = np.hstack((templ_p1s_reg, templ_p2s_reg))
templ_bboxes_xyxy_reg = templ_bboxes_xyxy_reg.tolist()
templ_bboxes_reg = [xyxy_to_xywh(box) for box in templ_bboxes_xyxy_reg]

_ = plot_img_with_bboxes(img, bboxes, templ_bboxes_reg, title="GT boxes (red), registered boxes from template (blue)")