In [2]:
from pathlib import Path

import yaml
import trimesh
import h5py

from se3dif.visualization import create_gripper_marker

from grasping_benchmarks.algorithms.se3dif import Se3DifGraspPlanner
from grasping_benchmarks.base import CameraData

ModuleNotFoundError: No module named 'grasping_benchmarks.algorthims'

In [None]:
ACRONYM_DATASET_PATH = Path("/home/data/")

In [None]:
def load_acronym_mesh(
    dataset_path: Path, object_class: str, grasp_uuid: str
) -> trimesh.Trimesh:
    """Loads the mesh associated with the given grasp from the ACRONYM dataset.
    Centers the mesh at the origin and scales the mesh according to the scale stored in the dataset.

    Args:
        dataset_path (Path): Path to the dataset
        object_class (str): Object class
        grasp_uuid (str): UUID of the grasp
    """
    assert object_class in [
        p.name for p in (dataset_path / "grasps").iterdir()
    ], f"Object class {object_class} not found in dataset path {dataset_path}/grasps"

    grasp_file_path = list(
        (dataset_path / "grasps" / object_class).glob(
            f"{object_class}_{grasp_uuid}*.h5"
        )
    )

    assert len(grasp_file_path) != 0, f"Grasp file not found: {grasp_file_path}"
    assert len(grasp_file_path) < 2, f"Multiple grasp files found: {grasp_file_path}"
    grasp_file_path = grasp_file_path[0]

    grasp_data = h5py.File(grasp_file_path, "r")
    mesh_scale = grasp_data["object"]["scale"][()]
    mesh_file_path = (
        dataset_path
        / "meshes"
        / grasp_data["object"]["file"][()].decode("utf-8")[len("meshes") + 1 :]
    )

    mesh = trimesh.load_mesh(mesh_file_path)
    if type(mesh) == trimesh.scene.scene.Scene:
        mesh = trimesh.util.concatenate(mesh.dump())
    mesh = mesh.apply_translation(-mesh.centroid)
    mesh = mesh.apply_scale(mesh_scale)

    return mesh

mesh = load_acronym_mesh(ACRONYM_DATASET_PATH, "ScrewDriver", "28d")
mesh.apply_scale(8.0)

In [None]:
with open(Path.cwd().parent / "cfg" / "base_config.yaml", "r") as f:
    cfg = yaml.load(f)

planner = Se3DifGraspPlanner(cfg)

In [None]:
camera_data = CameraData(pointcloud=mesh.sample(1000))

In [None]:
grasps = planner.plan_grasp(camera_data)

In [None]:
def visualize_grasp_axis(H_grasp, gripper_height=None):
    gripper_height = gripper_height or 0.02

    grasp_point = trimesh.primitives.Sphere(radius=0.004, center=H_grasp[:3, 3])
    
    # assumption: grasp axis is along z-axis of gripper coordinate system
    grasp_vector = trimesh.creation.cylinder(
        radius=0.002, segment=(H_grasp[:3, 3], H_grasp[:3, 3] + gripper_height * H_grasp[:3, 2])
    )
    
    return trimesh.util.concatenate([grasp_point, grasp_vector])


scene = trimesh.Scene()
scene.add_geometry(mesh)

for H_grasp in grasps:
    scene.add_geometry(create_gripper_marker().apply_transform(H_grasp))
    scene.add_geometry(visualize_grasp_axis(H_grasp,1))

scene.show()
