In [1]:
import copy
from typing import Tuple
from pathlib import Path
import h5py

import trimesh

import scipy.spatial.transform
import numpy as np
from se3dif.models.loader import load_model
from se3dif.samplers import ApproximatedGrasp_AnnealedLD, Grasp_AnnealedLD
from se3dif.utils import to_numpy, to_torch
from se3dif.visualization import create_gripper_marker

ModuleNotFoundError: No module named 'se3dif.models'

In [None]:
ACRONYM_DATASET_PATH = Path("/home/data/")

In [None]:
def load_acronym_mesh(
    dataset_path: Path, object_class: str, grasp_uuid: str
) -> trimesh.Trimesh:
    assert object_class in [
        p.name for p in (dataset_path / "grasps").iterdir()
    ], f"Object class {object_class} not found in dataset path {dataset_path}/grasps"

    grasp_file_path = list(
        (dataset_path / "grasps" / object_class).glob(
            f"{object_class}_{grasp_uuid}*.h5"
        )
    )

    assert len(grasp_file_path) != 0, f"Grasp file not found: {grasp_file_path}"
    assert len(grasp_file_path) < 2, f"Multiple grasp files found: {grasp_file_path}"
    grasp_file_path = grasp_file_path[0]

    grasp_data = h5py.File(grasp_file_path, "r")
    mesh_scale = grasp_data["object"]["scale"][()]
    mesh_file_path = (
        dataset_path
        / "meshes"
        / grasp_data["object"]["file"][()].decode("utf-8")[len("meshes") + 1 :]
    )

    mesh = trimesh.load_mesh(mesh_file_path)
    if type(mesh) == trimesh.scene.scene.Scene:
        mesh = trimesh.util.concatenate(mesh.dump())
    mesh = mesh.apply_translation(-mesh.centroid)
    mesh = mesh.apply_scale(mesh_scale)

    return mesh



In [None]:
mesh = load_acronym_mesh(ACRONYM_DATASET_PATH, "ScrewDriver", "28d")

# scene = trimesh.Scene(mesh)
# scene.add_geometry(trimesh.creation.axis(origin_size=0.005, axis_radius=0.001, axis_length=0.1))
# scene.show()

In [None]:
def get_pointcloud_for_inference(
    mesh, random_rotation: bool = False, scaling_factor: float = 8.0, n_points: int = 1000
):
    mesh = copy.deepcopy(mesh)

    H_rot = np.eye(4)
    if random_rotation:
        H_rot[:3, :3] = scipy.spatial.transform.Rotation.random().as_matrix()
    mesh.apply_transform(H_rot)

    mesh.apply_scale(scaling_factor)

    pointcloud = mesh.sample(n_points)

    return pointcloud, mesh, H_rot

In [None]:
pointcloud, mesh_transformed, H_rot = get_pointcloud_for_inference(
    mesh, random_rotation=False, scaling_factor=8.0, n_points=1000
)

# scene = trimesh.Scene(mesh_transformed)
# scene.add_geometry(
#     [trimesh.primitives.Sphere(radius=0.01, center=p) for p in pointcloud]
# )
# scene.add_geometry(trimesh.creation.axis(origin_size=0.05, axis_radius=0.01, axis_length=1))
# scene.show()

In [None]:
def get_fitted_grasp_generator(
    pointcloud, model, device, batch, T: int = 70, T_fit: int = 50, k_steps: int = 2
):
    model.set_latent(to_torch(pointcloud[None, ...], device), batch=batch)
    generator = Grasp_AnnealedLD(
        model, batch=batch, T=70, T_fit=50, k_steps=2, device=device
    )
    return generator

In [None]:
DEVICE = "cuda:0"
MODEL = "grasp_dif_multi"
BATCH = 10

In [None]:
model = load_model({"device": DEVICE, "pretrained_model": MODEL})
generator = get_fitted_grasp_generator(pointcloud, model, DEVICE, batch=BATCH)

H_grasps = to_numpy(generator.sample())
H_grasps_rescaled = H_grasps.copy()
H_grasps_rescaled[:, :3, 3] /= 8.0

In [None]:
scene = trimesh.Scene()
scene.add_geometry(mesh)

for H_grasp in H_grasps_rescaled:
    scene.add_geometry(create_gripper_marker().apply_transform(H_grasp))

scene.show()