In [None]:
# Copyright (c) 2025 Robert Bosch GmbH
# SPDX-License-Identifier: AGPL-3.0

"""
This is a demo script for the XRefine model, which refines keypoints detected by XFeat on two example images.
"""

import os

import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torchvision.io as io
import torchvision.transforms.functional

from dataprocess.data_utils import nearest_neighbor_match

num_keypoints_to_detect = 2048
num_keypoints_to_visualize = 7
visu_half_patch_size = 10
img1_path = "example_data/bench_a.png"
img2_path = "example_data/bench_b.png"
figures_dir = "./output"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(figures_dir, exist_ok=True)

# Load images
img1_tensor = io.read_image(img1_path).float().unsqueeze(dim=0) / 255.0
img2_tensor = io.read_image(img2_path).float().unsqueeze(dim=0) / 255.0

# Create pil images for visualization
img1_pil = torchvision.transforms.functional.to_pil_image(img1_tensor[0])
img2_pil = torchvision.transforms.functional.to_pil_image(img2_tensor[0])

In [None]:
# Detect and match keypoints, e.g. using XFeat and mutual nearest neighbor matching

detector = (
    torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=True, top_k=num_keypoints_to_detect, trust_repo=True)
    .to(device)
    .eval()
)

output1 = detector.detectAndCompute(img1_tensor, top_k=num_keypoints_to_detect)[0]
keypoints1 = output1["keypoints"]
descriptors1 = output1["descriptors"]

output2 = detector.detectAndCompute(img2_tensor, top_k=num_keypoints_to_detect)[0]
keypoints2 = output2["keypoints"]
descriptors2 = output2["descriptors"]

gf_matching_output = nearest_neighbor_match(descriptors1.unsqueeze(0), descriptors2.unsqueeze(0))
matches = gf_matching_output["matches0"][0]
matches_mask = matches != -1
keypoints1 = keypoints1[matches_mask]
keypoints2 = keypoints2[matches[matches_mask]]

In [None]:
# Perform keypoint refinement

img1_tensor = img1_tensor.to(device)
img2_tensor = img2_tensor.to(device)

xrefine = (
    torch.hub.load("boschresearch/xrefine", "XRefine",
                   pretrained=True,
                   detector="general",
                   variant="small",
                   adjust_only_second_keypoint=False,
                   image_values_are_normalized=True,
                   trust_repo=True,
                   )
    .to(device)
    .eval()
)

refined_keypoints1, refined_keypoints2 = xrefine(keypoints1, keypoints2, img1_tensor[0], img2_tensor[0])

In [None]:
# Randomly select a few matching keypoints

np.random.seed(42)
indices = np.random.choice(min(len(keypoints1), len(keypoints2)), num_keypoints_to_visualize, replace=False)

keypoints1_selected = keypoints1[indices]
keypoints2_selected = keypoints2[indices]
refined_keypoints1_selected = refined_keypoints1[indices]
refined_keypoints2_selected = refined_keypoints2[indices]

keypoints1_selected = keypoints1_selected.cpu().numpy()
keypoints2_selected = keypoints2_selected.cpu().numpy()
refined_keypoints1_selected = refined_keypoints1_selected.cpu().numpy()
refined_keypoints2_selected = refined_keypoints2_selected.cpu().numpy()

In [None]:
# Draw the selected patches on their respective images

def draw_patches_on_image(img_pil: PIL.Image, keypoints: np.ndarray, figures_dir: str, img_name: str):
    # Draw the points on the source image
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(img_pil)

    # Draw original source points as red circles
    for point in keypoints:
        ax.add_patch(plt.Circle((point[0], point[1]), radius=3, color="red", fill=True))

    # Draw rectangles around the original source points
    for point in keypoints:
        rect = plt.Rectangle(
            (point[0] - visu_half_patch_size, point[1] - visu_half_patch_size),
            2 * visu_half_patch_size,
            2 * visu_half_patch_size,
            linewidth=3,
            edgecolor="red",
            facecolor="none",
        )
        ax.add_patch(rect)
    ax.axis("off")

    # Save the plot in the "plots" directory
    output_path = os.path.join(figures_dir, img_name + "_correspondences.png")
    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
    plt.close(fig)


draw_patches_on_image(img1_pil, keypoints1_selected, figures_dir, "img1")
draw_patches_on_image(img2_pil, keypoints2_selected, figures_dir, "img2")

In [None]:
# Visualize the original keypoints and refined keypoint on their corresponding image patches

def visualize_patches(
    img_pil: PIL.Image, keypoints: np.ndarray, refined_keypoints: np.ndarray, figures_dir: str, img_name: str
):
    # Crop a patch around the source_points and draw red circles
    for i, (kpt, refined_kpt) in enumerate(zip(keypoints, refined_keypoints)):
        # Define the cropping box
        x_min = int(kpt[0] - visu_half_patch_size)
        y_min = int(kpt[1] - visu_half_patch_size)
        x_max = int(kpt[0] + visu_half_patch_size)
        y_max = int(kpt[1] + visu_half_patch_size)

        # Ensure the cropping box is within image boundaries
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = min(img_pil.width, x_max)
        y_max = min(img_pil.height, y_max)

        # Crop the patch
        cropped_patch = img_pil.crop((x_min, y_min, x_max, y_max))

        fig, ax = plt.subplots()
        ax.imshow(cropped_patch)
        ax.scatter([visu_half_patch_size], [visu_half_patch_size], c="red", s=165, label="Original Point")  # Center of the patch

        ax.scatter(
            [visu_half_patch_size + (refined_kpt[0] - kpt[0])],
            [visu_half_patch_size + (refined_kpt[1] - kpt[1])],
            c="yellow",
            s=165,
            label="Updated Point",
        )
        # Draw a red rectangle with a dotted line of size 11x11 surrounding the original point
        rect = plt.Rectangle(
            (visu_half_patch_size - 5.5, visu_half_patch_size - 5.5),
            11,
            11,
            linewidth=6,
            edgecolor="red",
            facecolor="none",
            linestyle="dotted",
        )
        ax.add_patch(rect)
        legend = ax.legend()
        legend.get_frame().set_facecolor("lightgray")
        for text in legend.get_texts():
            text.set_fontsize(20)  # Increase the font size

        ax.axis("off")

        # Save the figure
        fig_path = os.path.join(figures_dir, img_name + f"_patch_{i}.png")
        plt.savefig(fig_path, bbox_inches="tight", pad_inches=0)
        plt.close()


visualize_patches(img1_pil, keypoints1_selected, refined_keypoints1_selected, figures_dir, "img1")
visualize_patches(img2_pil, keypoints2_selected, refined_keypoints2_selected, figures_dir, "img2")