In [None]:
# Copyright 2023 DeepMind Technologies Limited
# Copyright 2023 Massachusetts Institute of Technology (M.I.T.)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Estimating 6D object poses of YCB objects from arbitrary RGB-D images using the 3D Neural Embedding Likelihood (3DNEL)

### Introduction

This notebook illustrates how we can do pose estimation on arbitrary RGB-D images containing YCB objects using 3DNEL introduced in the ICCV 2023 paper: [3D Neural Embedding Likelihood: Probabilistic Inverse Graphics for Robust 6D Pose Estimation](https://arxiv.org/abs/2302.03744).

In [None]:
import os

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import taichi as ti
import torch
from threednel.bop.data import BOPTestDataset, RGBDImage
from threednel.bop.detector import Detector

### Setup
3DNEL makes effective use of GPUs through a combination of JAX, taichi and PyTorch for parallel likelihood evaluation.

We first initialize taichi on the CUDA backend, and get the data directory from the environment variable. We then initialize a detector, which we can use to estimate the 6D object poses of YCB objects in arbitrary RGB-D images.

In [None]:
# Initialize Taichi with CUDA backend
ti.init(arch=ti.cuda)
# Get data directory from the environment variable
data_directory = os.environ["BOP_DATA_DIR"]

In [None]:
# Construct a detector to do pose estimation on RGB-D images with YCB objects.
detector = Detector(
    data_directory=data_directory,
    n_passes_pose_hypotheses=1,
    n_passes_icp=1,
    n_passes_finetune=1,
)

### Initialize an RGBDImage object using an input RGB-D image.
Our detector takes as input an RGBDImage object. In this example, we use an input image from the [YCB-V dataset as part of the BOP Challenge](https://bop.felk.cvut.cz/datasets/) as an example. However, in general 3DNEL can be applied to arbitrary RGB-D image to estimate 6D object poses of the YCB objects in the scene.

In [None]:
# Construct the BOPTestDataset object
data = BOPTestDataset(
    data_directory=data_directory,
    load_detector_crops=True,
)
# Load the 1st image in scene 48
scene_id = 48
test_scene = data[scene_id]
img_id = test_scene.img_indices[1]
bop_img = test_scene[img_id]

Next we construct an RGBDImage object as input to our detector.
An RGBDImage object can be constructed from an RGB-D image and known camera intrinsics.
Pose estimation with 3DNEL assumes knowledge of the number of objects and object classes in the scene. These are specified using `bop_obj_indices`, which is an array with elements ranging from 1 to 21. Refer to the [YCB object models](https://bop.felk.cvut.cz/media/data/bop_datasets/ycbv_models.zip) for the object indices of different supported objects.
Empirically we find that filling in missing values in the depth map helps with performance.
3DNEL can optionally take 2D detection results as part of the annotations to help with pose hypotheses generation, although it works even without initial 2D detections.

In [None]:
# Construct an RGBDImage object from the BOP test dataset input image.
test_img = RGBDImage(
    rgb=bop_img.rgb,
    depth=bop_img.depth,
    intrinsics=bop_img.intrinsics,
    bop_obj_indices=np.array(bop_img.bop_obj_indices),
    fill_in_depth=True,  # Fill in missing values in the depth map
    max_depth=1260.0,  # Used to fill in missing values in the depth map
    annotations=bop_img.annotations,
)

### Estimating 6D object poses and visualizing results

In [None]:
# The detector has a simple interface for estimating 6D object poses
scale_factor = 0.25
detection_results = detector.detect(
    img=test_img,
    key=jax.random.PRNGKey(np.random.randint(0, 100000)),
    scale_factor=scale_factor,
)

In [None]:
# We can visualize the query embedding maps for different objects.
fig, ax = plt.subplots(
    1,
    len(test_img.bop_obj_indices),
    figsize=(10 * len(test_img.bop_obj_indices), 10),
)
for ii in range(len(test_img.bop_obj_indices)):
  ax[ii].imshow(
      detector.bop_surfemb.surfemb_model.get_emb_vis(
          torch.from_numpy(jax.device_get(detection_results.query_embeddings)[:, :, ii])
      )
      .cpu()
      .numpy()
  )
  ax[ii].axis('off')

fig.tight_layout()

In [None]:
# We can visualize the estimated 3D scene descriptions in terms of 6D object poses.
gl_renderer = test_img.get_renderer(data_directory)
rendered_data = gl_renderer.render_single(
    detection_results.inferred_poses,
    list(range(len(test_img.bop_obj_indices))),
)
gt_rendered_data = gl_renderer.render_single(
    jnp.array(bop_img.get_gt_poses()),
    list(range(len(test_img.bop_obj_indices))),
)
gt_depth = gt_rendered_data.model_xyz[..., -1]
low = gt_depth.min()
high = gt_depth.max()
fig, ax = plt.subplots(1, 4, figsize=(40, 10))
ax[0].imshow(test_img.rgb)
ax[0].axis('off')
ax[0].set_title('Input scene', fontsize=40)
ax[1].imshow(gt_depth, cmap="turbo", vmin=low, vmax=high)
ax[1].axis('off')
ax[1].set_title('Ground-truth 3D scene description', fontsize=40)
ax[2].imshow(rendered_data.model_xyz[..., -1], cmap="turbo", vmin=low, vmax=high)
ax[2].axis('off')
ax[2].set_title('Estimated 3D scene description', fontsize=40)
ax[3].imshow(rendered_data.obj_ids)
ax[2].axis('off')
ax[3].set_title('Estimated object segmentation', fontsize=40)
ax[3].axis('off')
fig.tight_layout()