clearer visualization of feature matches (instead of the colorful grid)

In [None]:
import json
import os
from pathlib import Path
from PIL import Image
import gc
import time
from typing import Any
import sys

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils._pytree import tree_map
import torch_levenberg_marquardt as tlm
from torchmetrics import Metric
import roma
from tqdm import tqdm
import visu3d as v3d

sys.path.append(os.path.abspath(os.path.join("external", "bop_toolkit")))
sys.path.append(os.path.abspath(os.path.join("external", "dinov2")))
from bop_toolkit_lib import inout

from utils.misc import array_to_tensor, tensor_to_array, tensors_to_arrays
from utils import (
    corresp_util,
    eval_util,
    feature_util,
    knn_util,
    misc as misc_util,
    pnp_util,
    projector_util,
    repre_util,
    vis_util,
    renderer_builder,
    json_util, 
    logging,
    misc,
    structs,
    geometry,
)
from utils.structs import AlignedBox2f, PinholePlaneCameraModel, CameraModel
from utils.misc import warp_image
from utils.run.opts import CommonOpts, InferOpts

In [None]:
imgpath = Path("/scratch/jeyan/barreldata/divedata/barrelddt1/rgb/cropped0000.png")
maskpath = Path("/scratch/jeyan/barreldata/results/barrelddt1/sam-masks/cropped0000.png")
tmpdir = Path("/scratch/jeyan/barreldata/tmp")

cam_json_path = Path("/scratch/jeyan/barreldata/divedata/barrelddt1/camera.json")

device = "cuda:1"

extractor_name = "dinov2_version=vitl14-reg_stride=14_facet=token_layer=18_logbin=0_norm=1"
extractor = feature_util.make_feature_extractor(extractor_name)
extractor.to(device)

repre_dir = Path("/scratch/jeyan/barreldata/results/barrelddt1/foundpose-output/object_repre")
repre = repre_util.load_object_repre(
    repre_dir=repre_dir,
    tensor_device=device,
)
repre_np = repre_util.convert_object_repre_to_numpy(repre)

# Build a kNN index from object feature vectors.
visual_words_knn_index = None
visual_words_knn_index = knn_util.KNN(
    k=repre.template_desc_opts.tfidf_knn_k,
    metric=repre.template_desc_opts.tfidf_knn_metric
)
visual_words_knn_index.fit(repre.feat_cluster_centroids)

# Build per-template KNN index with features from that template.
template_knn_indices = []
for template_id in range(len(repre.template_cameras_cam_from_model)):
    tpl_feat_mask = repre.feat_to_template_ids == template_id
    tpl_feat_ids = torch.nonzero(tpl_feat_mask).flatten()

    template_feats = repre.feat_vectors[tpl_feat_ids]

    # Build knn index for object features.
    template_knn_index = knn_util.KNN(k=1, metric="l2")
    template_knn_index.fit(template_feats.cpu())
    template_knn_indices.append(template_knn_index)

# Camera parameters.
# transform is from GT, can we just leave as identity?
with open(Path(cam_json_path), "r") as f:
    camjson = json.load(f)
orig_camera_c2w = PinholePlaneCameraModel(
    camjson["width"], camjson["height"],
    (camjson["fx"], camjson["fy"]), (camjson["cx"], camjson["cy"])
)
orig_image_size = (
    orig_camera_c2w.width,
    orig_camera_c2w.height,
)

# Generate grid points at which to sample the feature vectors.
grid_cell_size = 14
crop_size = (420, 420)
grid_size = crop_size
crop_rel_pad = 0.2
grid_points = feature_util.generate_grid_points(
    grid_size=grid_size,
    cell_size=grid_cell_size,
)
grid_points = grid_points.to(device)

# Estimate pose for each object instance.
# Get the input image.
orig_image_np_hwc = np.array(Image.open(imgpath)) / 255.0

# Get the modal mask and amodal bounding box of the instance.
# binary mask
orig_mask_modal = np.array(Image.open(maskpath).convert("L")) / 255.0
sumvert = np.sum(orig_mask_modal, axis=0)
left = np.where(sumvert > 0)[0][0]
right = np.where(sumvert > 0)[0][-1]
sumhor = np.sum(orig_mask_modal, axis=1)
bottom = np.where(sumhor > 0)[0][0]
top = np.where(sumhor > 0)[0][-1]
# bounding box of mask
orig_box_amodal = AlignedBox2f(
    left=left,
    top=top,
    right=right,
    bottom=bottom,
)

# Optional cropping.
# Get box for cropping.
crop_box = misc_util.calc_crop_box(
    box=orig_box_amodal,
    make_square=True,
)

# Construct a virtual camera focused on the crop.
crop_camera_model_c2w = misc_util.construct_crop_camera(
    box=crop_box,
    camera_model_c2w=orig_camera_c2w,
    viewport_size=crop_size,
    viewport_rel_pad=crop_rel_pad,
)

# Map images to the virtual camera.
interpolation = (
    cv2.INTER_AREA
    if crop_box.width >= crop_camera_model_c2w.width
    else cv2.INTER_LINEAR
)
image_np_hwc = warp_image(
    src_camera=orig_camera_c2w,
    dst_camera=crop_camera_model_c2w,
    src_image=orig_image_np_hwc,
    interpolation=interpolation,
)
mask_modal = warp_image(
    src_camera=orig_camera_c2w,
    dst_camera=crop_camera_model_c2w,
    src_image=orig_mask_modal,
    interpolation=cv2.INTER_NEAREST,
)

# Recalculate the object bounding box (it changed if we constructed the virtual camera).
ys, xs = mask_modal.nonzero()
box = np.array(misc_util.calc_2d_box(xs, ys))
box_amodal = AlignedBox2f(
    left=box[0],
    top=box[1],
    right=box[2],
    bottom=box[3],
)

# The virtual camera is becoming the main camera.
camera_c2w = crop_camera_model_c2w


# Extract feature map from the crop.
image_tensor_chw = array_to_tensor(image_np_hwc).to(torch.float32).permute(2, 0, 1).to(device)
image_tensor_bchw = image_tensor_chw.unsqueeze(0)
# BxDxHxW
extractor_output = extractor(image_tensor_bchw)
feature_map_chw = extractor_output["feature_maps"][0]


# Keep only points inside the object mask.
mask_modal_tensor = array_to_tensor(mask_modal).to(device)
query_points = feature_util.filter_points_by_mask(
    grid_points, mask_modal_tensor
)

# Subsample query points if we have too many.
max_num_queries = 1000000
if query_points.shape[0] > max_num_queries:
    perm = torch.randperm(query_points.shape[0])
    query_points = query_points[perm[: max_num_queries]]
    msg = (
        "Randomly sumbsampled queries "
        f"({perm.shape[0]} -> {query_points.shape[0]}))"
    )

# Extract features at the selected points, of shape (num_points, feat_dims).
query_features = feature_util.sample_feature_map_at_points(
    feature_map_chw=feature_map_chw,
    points=query_points,
    image_size=(image_np_hwc.shape[1], image_np_hwc.shape[0]),
).contiguous()

# Potentially project features to a PCA space.
if (
    query_features.shape[1] != repre.feat_vectors.shape[1]
    and len(repre.feat_raw_projectors) != 0
):
    query_features_proj = projector_util.project_features(
        feat_vectors=query_features,
        projectors=repre.feat_raw_projectors,
    ).contiguous()

    _c, _h, _w = feature_map_chw.shape
    feature_map_chw_proj = (
        projector_util.project_features(
            feat_vectors=feature_map_chw.permute(1, 2, 0).view(-1, _c),
            projectors=repre.feat_raw_projectors,
        )
        .view(_h, _w, -1)
        .permute(2, 0, 1)
    )
else:
    query_features_proj = query_features
    feature_map_chw_proj = feature_map_chw


# Establish 2D-3D correspondences.
allcorresp = corresp_util.establish_correspondences(
    query_points=query_points,
    query_features=query_features_proj,
    object_repre=repre,
    template_matching_type="tfidf",
    template_knn_indices=template_knn_indices,
    feat_matching_type="cyclic_buddies",
    top_n_templates=5,
    top_k_buddies=300,
    visual_words_knn_index=visual_words_knn_index,
    debug=True,
)

In [None]:
corresp = allcorresp[0]
corresp = tensors_to_arrays(corresp)
template_id = corresp["template_id"]
object_repre = repre_np
corresp_top_n = 100
# Get left 2D points.
# left: rgb image
selected_ids = corresp["coord_conf"].argsort()[::-1][:corresp_top_n]
kpts_left = corresp["coord_2d"][selected_ids]

# Get right 2D points.
# right: template image
tpl_cameras_m2c = object_repre.template_cameras_cam_from_model[template_id]
all_tpl_vertex_ids = corresp["nn_vertex_ids"]
all_tpl_vertices_in_c = geometry.transform_3d_points_numpy(
    np.linalg.inv(tpl_cameras_m2c.T_world_from_eye), object_repre.vertices[all_tpl_vertex_ids]
)
all_kpts_right = tpl_cameras_m2c.eye_to_window(all_tpl_vertices_in_c)
kpts_right = all_kpts_right[selected_ids]

In [None]:
kpts_right, kpts_left = kpts_right.astype(int), kpts_left.astype(int)

In [None]:
leftimg = (image_np_hwc * 255).astype(np.uint8)
rightimg = repre_np.templates[template_id].transpose(1, 2, 0)

In [None]:
colors = np.random.randint(0, 255, (10, 3)).tolist()
colors = [
    [176, 67, 0],
    [107, 23, 135],
    [24, 21, 179],
    [19, 89, 11],
    [82, 56, 9],
    [22, 110, 99],
    [166, 41, 153],
    [135, 110, 36],
    [108, 128, 46],
    [163, 128, 59],
]

In [None]:
both = np.concatenate((leftimg, rightimg), axis=1)
npoints = 10
for i, (kptleft, kptright) in enumerate(zip(kpts_left[::20], kpts_right[::20])):
    color = colors[i]
    both = cv2.line(both, (kptleft[0], kptleft[1]), (kptright[0] + leftimg.shape[1], kptright[1]), [255, 255, 255], 5)
    both = cv2.line(both, (kptleft[0], kptleft[1]), (kptright[0] + leftimg.shape[1], kptright[1]), color, 3)
    both = cv2.circle(both, (kptleft[0], kptleft[1]), 8, [255, 255, 255], -1)
    both = cv2.circle(both, (kptleft[0], kptleft[1]), 6, color, -1)
    both = cv2.circle(both, (kptright[0] + leftimg.shape[1], kptright[1]), 8, [255, 255, 255], -1)
    both = cv2.circle(both, (kptright[0] + leftimg.shape[1], kptright[1]), 6, color, -1)
tmp = Image.fromarray(both)
tmp.save(tmpdir / "corresp_dino.png")
tmp

In [None]:
# Initiate SIFT detector
sift = cv2.SIFT_create(nfeatures=60, contrastThreshold=0.01, edgeThreshold=None)
 
# find the keypoints and descriptors with SIFT
kp1, des1 = sift.detectAndCompute(leftimg, mask=(mask_modal * 255).astype(np.uint8))
kp2, des2 = sift.detectAndCompute(rightimg, mask=((cv2.cvtColor(rightimg, cv2.COLOR_RGB2GRAY) > 0) * 255).astype(np.uint8))
 
# BFMatcher with default params
bf = cv2.BFMatcher()
matches = bf.knnMatch(des1, des2, k=2)
 
# Apply ratio test
good = []
for m,n in matches:
    # print(m.distance, n.distance)
    if m.distance < 0.95*n.distance:
        good.append([m])
 
# cv.drawMatchesKnn expects list of lists as matches.
# img3 = cv2.drawMatchesKnn(leftimg,kp1,rightimg,kp2,good,None,flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
# Image.fromarray(img3).save(tmpdir / "corresp_sift.png")
both = np.concatenate((leftimg, rightimg), axis=1)
for i, goodmatch in enumerate(good[7:12]):
    color = colors[i]
    kptleft = kp1[goodmatch[0].queryIdx].pt
    kptleft = (int(kptleft[0]), int(kptleft[1]))
    kptright = kp2[goodmatch[0].trainIdx].pt
    kptright = (int(kptright[0]), int(kptright[1]))
    both = cv2.line(both, (kptleft[0], kptleft[1]), (kptright[0] + leftimg.shape[1], kptright[1]), [255, 255, 255], 5)
    both = cv2.line(both, (kptleft[0], kptleft[1]), (kptright[0] + leftimg.shape[1], kptright[1]), color, 3)
    both = cv2.circle(both, (kptleft[0], kptleft[1]), 8, [255, 255, 255], -1)
    both = cv2.circle(both, (kptleft[0], kptleft[1]), 6, color, -1)
    both = cv2.circle(both, (kptright[0] + leftimg.shape[1], kptright[1]), 8, [255, 255, 255], -1)
    both = cv2.circle(both, (kptright[0] + leftimg.shape[1], kptright[1]), 6, color, -1)
tmp = Image.fromarray(both)
tmp.save(tmpdir / "corresp_sift.png")
tmp