This notebook shows the basic inference process of the se3dif algorithm.

In [31]:
from pathlib import Path

import numpy as np
import trimesh

from se3dif.models.loader import load_model
from se3dif.samplers import Grasp_AnnealedLD
from se3dif.utils import to_numpy, to_torch
from se3dif.visualization import create_gripper_marker

In [32]:
pointcloud = np.load(
    list(
        (
            Path.cwd().parent
            / "test_data"
            / "pointcloud_segmented_points"
            / "011_banana"
        ).glob("*.npy")
    )[0]
)

pointcloud.shape

(2190, 3)

In [33]:
N_POINTS = 1000
TARGET_MAX_SIZE = 0.3*8 # 0.3 is roughly the average max extent of the objects in the dataset

pointcloud = pointcloud[np.random.choice(len(pointcloud), N_POINTS)]
pointcloud_offset = np.mean(pointcloud, axis=0)
pointcloud -= pointcloud_offset

pointcloud_extents = np.max(pointcloud, axis=0) - np.min(pointcloud, axis=0)
scale_factor = TARGET_MAX_SIZE / np.max(pointcloud_extents)
pointcloud *= scale_factor

In [34]:
DEVICE = "cuda:0"
MODEL = "partial_grasp_dif"
BATCH = 10

model = load_model({"device": DEVICE, "pretrained_model": MODEL})
model.set_latent(to_torch(pointcloud[None, ...], DEVICE), batch=BATCH)
generator = Grasp_AnnealedLD(
    model, batch=BATCH, T=70, T_fit=50, k_steps=2, device=DEVICE
)

In [35]:
H_grasps = to_numpy(generator.sample())

In [36]:
def visualize_grasp_axis(H_grasp, gripper_height=None):
    gripper_height = gripper_height or 0.02

    grasp_point = trimesh.primitives.Sphere(radius=0.004, center=H_grasp[:3, 3])
    
    # assumption: grasp axis is along z-axis of gripper coordinate system
    grasp_vector = trimesh.creation.cylinder(
        radius=0.002, segment=(H_grasp[:3, 3], H_grasp[:3, 3] + gripper_height * H_grasp[:3, 2])
    )
    
    return trimesh.util.concatenate([grasp_point, grasp_vector])


scene = trimesh.Scene()
for point in pointcloud:
    scene.add_geometry(trimesh.primitives.Sphere(radius=0.02, center=point))

for H_grasp in H_grasps:
    scene.add_geometry(create_gripper_marker().apply_transform(H_grasp))
    scene.add_geometry(visualize_grasp_axis(H_grasp, 1))

scene.show()