In [5]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

from env.envs.hookScene import HookScene
from env.envs.mugScene import MugScene
from env.envs.spoonScene import SpoonScene
from scripts.hooking import get_reference as get_hooking_reference
from scripts.hanging import get_reference as get_hanging_reference
from scripts.scooping import get_reference as get_scooping_reference
from magic.match import match

In [6]:
task_name = 'hooking'
tool_name = 'scissors'
mug_id = tool_name
spoon_id = tool_name
sd_dino = False
dift = True

In [None]:
if task_name == 'hooking':
    reference_img, reference_depth_img, reference_camera, reference_contact_center, reference_grasp_center, reference_scene, reference_init_pose, reference_pose = get_hooking_reference()
elif task_name == 'hanging':
    reference_img, reference_depth_img, reference_camera, reference_contact_center, reference_scene = get_hanging_reference()
    reference_grasp_center = None
elif task_name == 'scooping':
    plane_origin = np.array([0, -0.35, 0])
    plane_normal = np.array([0, -1, 0])
    (reference_img, reference_contact_center, reference_collide_center, reference_grasp_center,
     reference_pixel_to_3d_fn, reference_init_pose, reference_pcd_o3d) = get_scooping_reference(
        manual_center=False, plane_origin=plane_origin, plane_normal=plane_normal
    )
else:
    raise ValueError('Invalid task name')

In [None]:
if task_name == 'hooking':
    target_scene = HookScene(
        tool_name,
        'box',
    )

    target_scene.hide_env_visual()
    target_img, target_depth_img, target_camera = target_scene.get_picture(
        direction='+z',
        additional_translation=np.array([0.6, -0.6, 0.2-0.375*1.414]),
        debug_viewer=False,
        get_depth=True
    )
    target_scene.unhide_env_visual()
elif task_name == 'hanging':
    target_scene = MugScene(
        mug_id,
        add_robot=True,
        fps=480
    )

    target_scene.hide_env_visual()
    target_img, target_depth_img, target_camera = target_scene.get_picture(direction='+x', debug_viewer=False, get_depth=True)
    target_scene.unhide_env_visual()
else:
    plane_origin = np.array([0, -0.35, 0])
    plane_normal = np.array([0, -1, 0])
    target_scene = SpoonScene(
        spoon_id,
        fps=480,
        add_robot=True,
        radius=0.035
    )

    target_img, target_pixel_to_3d_fn = target_scene.get_slice(plane_origin, plane_normal)

results, resized_imgs, downsampled_images = match(reference_img, target_img, reference_contact_center, reference_grasp_center, only_compute_dino_feature=True, sd_dino=sd_dino, dift=dift)
results = results.permute(0, 2, 3, 1)

In [None]:
import matplotlib 
matplotlib.use('TKAgg')
%matplotlib inline

fig, axes = plt.subplots(3, 5, figsize=(2 * resized_imgs[0].size[0] / 100, resized_imgs[0].size[1] / 100), dpi=100)
for i in range(3):
    for j in range(5):
        axes[i][j].axis('off')

axes[0][0].imshow(resized_imgs[0])
source_point = reference_contact_center
resized_source_point = np.array(source_point) / reference_img.size[1] * resized_imgs[0].size[0]
axes[0][0].scatter(resized_source_point[0], resized_source_point[1], c='r', s=50)


import cv2
from tqdm import tqdm 

for rotation_index in tqdm(range(0, 12)):
    downsampled_source_point = (np.array(source_point) / reference_img.size[1] * downsampled_images[0].size[1]).astype(int)
    source_feature = results[0][downsampled_source_point[1], downsampled_source_point[0]]

    target_feature = results[rotation_index+1]
    heatmap = torch.sum(target_feature * source_feature, dim=-1).cpu().numpy()

    # resize heatmap to target image size
    heatmap = cv2.resize(heatmap, (resized_imgs[0].size[0], resized_imgs[0].size[1]), interpolation=cv2.INTER_LINEAR)
    # gaussian filter
    heatmap = cv2.GaussianBlur(heatmap, (25, 25), 0)
    heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))

    # boost the heatmap where the feature is high for visualization
    heatmap = np.power(heatmap, 10)

    # get the coordinates of the maximum value in the heatmap
    max_index = np.unravel_index(heatmap.argmax(), heatmap.shape)
    target_point = max_index

    # set the size of the canvas to be the same as the img
    i = rotation_index // 4
    j = rotation_index % 4 + 1
    axes[i][j].imshow(resized_imgs[rotation_index+1])
    axes[i][j].imshow(heatmap, alpha=0.8)
    axes[i][j].scatter(target_point[1], target_point[0], c='r', s=50, marker='*')

plt.show()