In [None]:
import cv2
import onnxruntime as ort
import numpy as np
import torch
import matplotlib.pyplot as plt

import sys
sys.path.append("../")

from likl.models.line_matcher import WunschLineMatcher
from likl.misc.process import (Tp_map_to_line_torch,
                            convert_kp2d_pred,
                            extract_descriptors)
from likl.misc.common import time_sync

from likl.dataset.transforms.homo_transforms import sample_homography
from likl.dataset.transforms.photometric_transforms import random_saturation
from likl.misc.visualize_utils import (plot_images, plot_keypoints, plot_keypoint_matches, plot_lines)
from likl.misc.visualize_utils import plot_color_line_matches, plot_line_matches, plot_color_lines
from likl.misc.geometry_utils import (warp_points, warp_lines, clip_line_to_boundaries)
from likl.misc.metrics import get_line_distance

In [None]:
print(ort.get_available_providers())
print(ort.get_device())

In [None]:
points_cfg = {"grid_size": 8, "cross_ratio": 2,"nms_radius": 0, "detect_thresh": 0.5, "top_k": 1000}
lines_cfg = {"score_thresh": 0.3, "len_thresh": 3, "num_samples": 5, "sample_min_dist": 8}

def onnx_inference(input):
    ort_sess = ort.InferenceSession('/home/hxb/experiments/likl.onnx')
    print(ort_sess.get_providers())
    ort_input = {ort_sess.get_inputs()[0].name: input}

    t1  = time_sync()
    ort_output = ort_sess.run(None, ort_input)
    print("[Info]: ONNX infer time = ", time_sync() - t1)

    line_maps, pts_maps, desc_map = ort_output

    pts_maps = torch.from_numpy(pts_maps)
    line_maps = torch.from_numpy(line_maps)
    desc_map = torch.from_numpy(desc_map)

    img_size = input.shape[2:3]
    # Get points and desc
    batch_pts = []
    batch_pts_desc = []
    pts = convert_kp2d_pred(pts_maps,
                            points_cfg["grid_size"],
                            points_cfg["cross_ratio"],
                            points_cfg["detect_thresh"],
                            points_cfg["nms_radius"])
    batch_size = 1
    for i in range(batch_size):
        pts_pred = pts[i].cpu().numpy()
        inds = np.argsort(pts_pred[:, 2])[::-1]
        pts_pred = pts_pred[inds[:points_cfg["top_k"]]]
        pts_desc = extract_descriptors(
            pts_pred[:, :2], desc_map[i], img_size)
        batch_pts.append(pts_pred)
        batch_pts_desc.append(pts_desc)

    # Get line
    batch_lines, _, batch_lines_scores = Tp_map_to_line_torch(
        line_maps,
        score_thresh=lines_cfg["score_thresh"],
        len_thresh=lines_cfg["len_thresh"],
        image_size=img_size,
        valid_mask=None,
        with_sig=True)

    # Postprocess lines and get line points
    batch_lines_desc = []
    batch_valid_points = []
    for i in range(len(batch_lines)):
        lines = batch_lines[i].cpu().numpy()
        scores = batch_lines_scores[i].cpu().numpy()

        if lines.shape[0] == 0:
            batch_lines_desc.append([])
            batch_valid_points.append([])
            batch_lines[i] = lines
            batch_lines_scores[i] = scores
            continue
        # Get line desc
        line_points, valid_points = WunschLineMatcher.sample_line_points(
            lines,
            lines_cfg["num_samples"],
            lines_cfg["sample_min_dist"])
        lines_desc = extract_descriptors(
            line_points, desc_map[i], img_size)
        lines_desc = lines_desc.reshape(
            lines.shape[0], lines_cfg["num_samples"], -1)
        batch_lines_desc.append(lines_desc.cpu().numpy())
        batch_valid_points.append(valid_points)

        batch_lines[i] = lines
        batch_lines_scores[i] = scores

    output = {
        "batch_pts": batch_pts,
        "batch_pts_desc": batch_pts_desc,
        "batch_lines": batch_lines,
        "batch_lines_score": batch_lines_scores,
        "batch_lines_desc": batch_lines_desc,
        "batch_valid_points": batch_valid_points,
    }

    batch_pts = output["batch_pts"]
    batch_pts_desc = output["batch_pts_desc"]
    for i in range(len(batch_pts)):
        # hw format ==> xy format
        batch_pts[i] = batch_pts[i][:, [1, 0, 2]]
        batch_pts_desc[i] = batch_pts_desc[i].cpu().numpy()
    output["batch_pts"] = batch_pts
    output["batch_pts_desc"] = batch_pts_desc

    return output

In [None]:
homo = sample_homography([480, 480], inverse=True)
img = cv2.imread("../asset/00036796.jpg")
img = cv2.resize(img, [480, 480])
show_ref_image = img
show_target_image = cv2.warpPerspective(show_ref_image, homo, [480, 480], flags=cv2.INTER_LINEAR)

show_ref_image = cv2.cvtColor(show_ref_image, cv2.COLOR_BGR2RGB)
show_target_image = cv2.cvtColor(show_target_image, cv2.COLOR_BGR2RGB)
show_target_image = random_saturation(0.4)(show_target_image)

ref_image = show_ref_image.astype(np.float32) / 127.5 - 1
target_image = show_target_image.astype(np.float32) / 127.5 - 1

outputs1 = onnx_inference(ref_image.transpose(2, 0, 1)[np.newaxis,...])
outputs2 = onnx_inference(target_image.transpose(2, 0, 1)[np.newaxis,...])

points1, desc1 = outputs1["batch_pts"], outputs1["batch_pts_desc"]
points2, desc2 = outputs2["batch_pts"], outputs2["batch_pts_desc"] 

batch_lines1 = outputs1["batch_lines"][0]
batch_lines2 = outputs2["batch_lines"][0]

batch_lines1, _ = clip_line_to_boundaries(batch_lines1, show_ref_image.shape[:2], 5)
batch_lines2, _ = clip_line_to_boundaries(batch_lines2, show_target_image.shape[:2], 5)

In [None]:

plot_images([show_ref_image, show_target_image], size=1, dpi=600, pad=0.)
plot_keypoints([points1[0][:, :2], points2[0][:,  :2]], marker='P', ps=0.5)

plot_images([show_ref_image, show_target_image], size=1, dpi=600, pad=0.)
plot_lines([batch_lines1[..., ::-1], batch_lines2[..., ::-1]], ps=0.2, lw=0.4)

## Match

In [None]:
kps1, kps2 = points1[0][:, :2], points2[0][:, :2]
des1, des2 = desc1[0], desc2[0]
matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
putative_match = matcher.match(des1, des2)
query_idx = np.array([m.queryIdx for m in putative_match])
match_keypoints = kps1[query_idx, :]
train_idx = np.array([m.trainIdx for m in putative_match])
match_warped_keypoints = kps2[train_idx, :]

In [None]:
plt.subplots_adjust()
plot_images([show_ref_image, show_target_image])
# plot_images([data["ref_image"].permute(1, 2, 0), data["target_image"].permute(1, 2, 0)])
plot_keypoint_matches(match_keypoints, match_warped_keypoints)

In [None]:

gt_match_warped_kps = warp_points(match_keypoints, homo, "xy")
dist = np.linalg.norm(gt_match_warped_kps - match_warped_keypoints, axis=1)

In [None]:
idx = dist <= 3
print("putative_match: ", match_keypoints.shape[0])
inlier_match_keypoints = match_keypoints[idx]
error_match_keypoints = match_keypoints[~idx]
inlier_match_warped_keypoints = match_warped_keypoints[idx]
error_match_warped_keypoints = match_warped_keypoints[~idx]
print("inlier_match: ", inlier_match_keypoints.shape[0])
print("MMA: ", inlier_match_keypoints.shape[0] / match_keypoints.shape[0])
color = ["red"] * len(match_keypoints)
for i in range(len(idx)):
    if idx[i]:
        color[i] = "lime"

In [None]:
plot_images([show_ref_image, 
             show_target_image], size=1, dpi=600, pad=0.)

plot_keypoints([points1[0][:, :2], points2[0][:,  :2]], ps=0.5, marker="P", colors="blue")

plot_keypoints([inlier_match_keypoints, inlier_match_warped_keypoints], ps=0.5, marker="P", colors="lime")
plot_keypoints([error_match_keypoints, error_match_warped_keypoints], ps=0.5, marker="P", colors="red")
plot_keypoint_matches(match_keypoints[:], match_warped_keypoints[:], color=color[:], lw=0.15, ps=0.0001, a=0.7)

## Match Line

In [None]:
matcher = WunschLineMatcher(True)

In [None]:
line_desc1 = outputs1["batch_lines_desc"][0]
line_desc2 = outputs2["batch_lines_desc"][0]
print(line_desc1.shape)
valid_points1 = outputs1["batch_valid_points"][0]
valid_points2 = outputs2["batch_valid_points"][0]
line_matches = matcher.match(torch.from_numpy(line_desc1), torch.from_numpy(line_desc2), 
                             torch.from_numpy(valid_points1), torch.from_numpy(valid_points2))

valid_matches = line_matches != -1
match_indices = line_matches[valid_matches]
print(match_indices)
matched_lines1 = batch_lines1[valid_matches]
matched_lines2 = batch_lines2[match_indices]

In [None]:
plot_images([show_ref_image, show_target_image])
colors = plot_color_line_matches([matched_lines1[..., ::-1], matched_lines2[..., ::-1]], lw=2, return_color=True)
# plot_line_matches(matched_lines1[..., ::-1].mean(1), matched_lines2[..., ::-1].mean(1), color=colors, lw=1.5)

In [None]:
warped_matched_lines1 = warp_lines(matched_lines1, homo)
warped_matched_lines1, mask = clip_line_to_boundaries(warped_matched_lines1, show_target_image.shape[:2], 5)
warped_matched_lines1 = warped_matched_lines1[mask]
matched_lines2 = matched_lines2[mask]
matched_lines1 = matched_lines1[mask]
line_dist = get_line_distance(
            warped_matched_lines1, matched_lines2, "orth")
idx = (np.min(line_dist, axis=1)) < 5

In [None]:
import seaborn as sns
import matplotlib

def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1., save_file=None):
    """Plot matches for a pair of existing images, parametrized by their middle point.
    Args:
        kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
        color: color of each match, string or RGB tuple. Random if not given.
        lw: width of the lines.
        indices: indices of the images to draw the matches on.
        a: alpha opacity of the match lines.
    """
    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    ax0, ax1 = ax[indices[0]], ax[indices[1]]
    fig.canvas.draw()

    assert len(kpts0) == len(kpts1)
    if color is None:
        color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
    elif len(color) > 0 and not isinstance(color, (tuple, list)):
        color = [color] * len(kpts0)
    if lw > 0:
        # transform the points into the figure coordinate system
        transFigure = fig.transFigure.inverted()
        fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
        fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
        fig.lines += [matplotlib.lines.Line2D(
            (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
            zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw,
            alpha=a)
            for i in range(len(kpts0))]

    # freeze the axes to prevent the transform to change
    ax0.autoscale(enable=False)
    ax1.autoscale(enable=False)
    if save_file:
        fig.savefig(save_file, bbox_inches='tight', pad_inches=0.0)


plot_images([show_ref_image, 
             show_target_image],
             size=1,
             dpi=600, pad=0.)
inlier_match_lines1 = matched_lines1[idx]
error_match_lines1 = matched_lines1[~idx]
inlier_match_lines2 = matched_lines2[idx]
error_match_lines2 = matched_lines2[~idx]

plot_lines([batch_lines1[..., ::-1], batch_lines2[..., ::-1]], line_colors="blue", point_colors="blue", ps=0.1, lw=0.2)


colors = ["red"] * len(matched_lines1)
for i in range(len(matched_lines1)):
    if idx[i]:
        colors[i] = "lime"
plot_lines([inlier_match_lines1[..., ::-1], inlier_match_lines2[..., ::-1]], line_colors="lime", point_colors="lime", ps=0.1, lw=0.2)
plot_lines([error_match_lines1[..., ::-1], error_match_lines2[..., ::-1]], line_colors="red", point_colors="red", ps=0.1, lw=0.2)
plot_line_matches(matched_lines1[..., ::-1].mean(1), matched_lines2[..., ::-1].mean(1), color=colors, lw=0.2, a=0.8)