Restrucutred version of the same named script in the `scripts` directory. Used to understand the code better and analysze whats going on.

In [None]:
import copy
from typing import Tuple

# isaac gym has to be imported before torch
from isaac_evaluation.grasp_quality_evaluation import GraspSuccessEvaluator
import trimesh

import scipy.spatial.transform
import numpy as np
from se3dif.datasets import AcronymGraspsDirectory
from se3dif.models.loader import load_model
from se3dif.samplers import ApproximatedGrasp_AnnealedLD, Grasp_AnnealedLD
from se3dif.utils import to_numpy, to_torch
from se3dif.visualization import grasp_visualization

import torch

In [None]:
N_GRASPS = 10               # argument values from README: 10,10,10. default in argparse: 200
OBJ_ID = 0                  # argument values from README: 0,10,12.  default in argparse: 0
OBJ_CLASS = "ScrewDriver"   # argument values from README: 'ScrewDriver','grasp_dif_mugs','Mug'. default in argparse: "Laptop"
N_ENVS = 30                 # hardcoded in the code with default value 30 in __main__
DEVICE = "cuda:0"           # argumant values from README: n.a., n.a., n.a. default in argparse: "cuda:0"
EVAL_SIM = False            # argument values from README: n.a., n.a., n.a. default in argparse: False
MODEL = "grasp_dif_multi"   # argument values from README: n.a., 'grasp_dif_mugs', n.a. default in argparse: 'grasp_dif_multi'
BATCH = 10                  # hardcoded in code with default value 10 in get_approximated_grasp_diffusion_field

In [None]:
def get_object_data(
    obj_id: int, obj_class: str
) -> Tuple["NpArray[3, N]", trimesh.Trimesh, "NpArray[4, 4]", "NpArray[4, 4]"]:
    """Load the mesh of the specified object and sample a pointcloud from it.
    The pointcloud and the mesh is rotated randomly, scaled by 8 and made zero mean.

    Args:
        obj_id (int): The id of the object to load.
        obj_class (str): The class of the object to load.

    Returns:
        Tuple["NpArray[3, N]", trimesh.Trimesh, "NpArray[4, 4]", "NpArray[4]"]: The pointcloud, 
            the mesh, the homogenaous transformation matrix of the rotation and the homogeneous
            transformation matrix of the translation.
    """
    
    # get mesh of object and sample pointcloud
    acronym_grasps = AcronymGraspsDirectory(data_type=obj_class)
    mesh = acronym_grasps.avail_obj[obj_id].load_mesh()
    pointcloud = mesh.sample(1000)

    # get a random rotation
    H_rot = np.eye(4)
    H_rot[:3, :3] = scipy.spatial.transform.Rotation.random().as_matrix()

    # apply the rotation to the pointcloud and mesh
    pointcloud = np.einsum("mn,bn->bm", H_rot[:3,:3], pointcloud)
    mesh.apply_transform(H_rot)

    # scale the pointcloud and mesh
    pointcloud *= 8.0
    mesh.apply_scale(8.0)
    
    # get the translation to make the pointcloud and mesh zero mean
    H_trans = np.eye(4)
    H_trans[:3, -1] = -np.mean(pointcloud, 0)
    pointcloud += H_trans[:3, -1]
    mesh.apply_transform(H_trans)

    return pointcloud, mesh, H_rot, H_trans


In [None]:
def get_fitted_grasp_generator(pointcloud, model):
    model.set_latent(to_torch(pointcloud[None, ...], DEVICE), batch=BATCH)
    generator = Grasp_AnnealedLD(
        model, batch=BATCH, T=70, T_fit=50, k_steps=2, device=DEVICE
    )
    return generator

In [None]:
def visualize_grasp(H_grasp,pointcloud,mesh,H_trans):
    # counteract the translational shift of the pointcloud (as the spawned model in simulation will still have it)
    H_grasp[:, :3, -1] = (H_grasp[:, :3, -1] - torch.as_tensor(H_trans[:3,-1],device=DEVICE)).float()
    
    H_grasp[..., :3, -1] *=1/8.

    # Visualize results
    pointcloud *=1/8
    mesh = mesh.apply_scale(1/8)
    scene = grasp_visualization.visualize_grasps(Hs=to_numpy(H_grasp), p_cloud=pointcloud, mesh=mesh, show=False)

    return scene

In [None]:
pointcloud, mesh, H_rot, H_trans = get_object_data(OBJ_ID, OBJ_CLASS)
model = load_model({"device": DEVICE, "pretrained_model": MODEL})
generator = get_fitted_grasp_generator(pointcloud, model)
H_grasp = generator.sample() # torch.Size([10, 4, 4])
# scene = visualize_grasp(H_grasp,pointcloud,mesh, H_trans)
# scene.show()

In [None]:
color = np.zeros(H_grasp.shape[0])

## Grips
grips = []
for k in range(H_grasp.shape[0]):
    H = H_grasp[k,...]

    c = color[k]
    c_vis = [0, 0, int(c*254)]

    grips.append(
        grasp_visualization.create_gripper_marker(color=c_vis, scale=1).apply_transform(H)
    )

## Visualize grips and the object
if mesh is not None:
    scene = trimesh.Scene([mesh]+ grips)

scene.show()

In [None]:
# if (EVAL_SIM):
#     ## Evaluate Grasps in Simulation##
#     num_eval_envs = 10
#     evaluator = GraspSuccessEvaluator(OBJ_CLASS, n_envs=num_eval_envs, idxs=[args.obj_id] * num_eval_envs, viewer=True, device=DEVICE, \
#                                         rotations=[rot_quad]*num_eval_envs, enable_rel_trafo=False)
#     succes_rate = evaluator.eval_set_of_grasps(H_grasp)
#     print('Success cases : {}'.format(succes_rate))