In [44]:
import os
import sys
sys.path.append(os.path.realpath("."))
from utils.isaac_validator import IsaacValidator, ValidationType
from tap import Tap
import torch
import numpy as np
from utils.hand_model import HandModel
from utils.object_model import ObjectModel
from utils.hand_model_type import (
   HandModelType,
   handmodeltype_to_joint_names,
)
from utils.qpos_pose_conversion import (
   qpos_to_pose,
   qpos_to_translation_quaternion_jointangles,
   pose_to_qpos,
)
from typing import List, Optional, Dict
import math
import random
from utils.seed import set_seed
from utils.joint_angle_targets import (
   compute_optimized_joint_angle_targets,
   OptimizationMethod,
)
from utils.energy import _cal_hand_object_penetration

## PARAMS

In [36]:
mesh_path = "/scr-ssd/ksrini/DexGraspNet/meshdata/"
data_path = "/juno/downloads/dexgraspnet_dataset/2023-07-01_dataset_DESIRED_DIST_TOWARDS_OBJECT_SURFACE_MULTIPLE_STEPS_v2/"
hand_model_type = HandModelType.ALLEGRO_HAND
seed = 102
joint_angle_targets_optimization_method = (
    OptimizationMethod.DESIRED_DIST_TOWARDS_OBJECT_SURFACE_MULTIPLE_STEPS
)
should_canonicalize_hand_pose = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Set Seed

In [4]:
set_seed(seed)

Setting seed: 102


102

## Grasp codes

In [39]:
grasp_code_list = []
mesh_files = set(map(lambda x: os.path.splitext(x)[0], os.listdir(mesh_path)))

for filename in os.listdir(data_path):
    code = filename.split(".")[0]
    assert code in mesh_files
    if code.startswith('sem-'):
        grasp_code_list.append(code)

## Sample and read in data

In [43]:
grasp_code = random.choice(grasp_code_list)
grasp_data_list = np.load(os.path.join(data_path, grasp_code + ".npy"), allow_pickle=True)
print(f"Randomly sampled grasp_code = {grasp_code}")

index = random.randint(0, len(grasp_data_list) - 1)
qpos = grasp_data_list[index]["qpos"]
scale = grasp_data_list[index]["scale"]
print(f"Randomly sampled index = {index}")
print(f"scale = {scale}")

Randomly sampled grasp_code = sem-Book-4dc6344e8e357867a0a7b8ecc42036d5
Randomly sampled index = 346
scale = 0.05999999865889549


## Object model

In [None]:
object_model = ObjectModel(
    data_root_path=mesh_path,
    batch_size_each=1,
    device=device,
)
object_model.initialize([grasp_code])
object_model.object_scale_tensor = torch.tensor(
    scale, dtype=torch.float, device=device
).reshape(object_model.object_scale_tensor.shape)

## Hand model

In [None]:
joint_names = handmodeltype_to_joint_names[hand_model_type]
hand_model = HandModel(hand_model_type, device=device)

hand_pose = qpos_to_pose(
    qpos=qpos, joint_names=joint_names, unsqueeze_batch_dim=True
).to(device)
hand_model.set_parameters(hand_pose)

batch_idx = 0
hand_mesh = hand_model.get_trimesh_data(batch_idx)
object_mesh = object_model.object_mesh_list[batch_idx].copy().apply_scale(scale)

## Visualize hand and object

In [None]:
(hand_mesh + object_mesh).show()

## Visualize hand and object plotly

In [None]:
fig_title = f"Grasp Code: {grasp_code}, Index: {index}"
idx_to_visualize = batch_idx

fig = go.Figure(
    layout=go.Layout(
        scene=dict(
            xaxis=dict(title="X"),
            yaxis=dict(title="Y"),
            zaxis=dict(title="Z"),
            aspectmode="data",
        ),
        showlegend=True,
        title=fig_title,
        autosize=False,
        width=800,
        height=800,
    )
)
plots = [
    *hand_model.get_plotly_data(
        i=idx_to_visualize, opacity=1.0, with_contact_candidates=True
    ),
    *object_model.get_plotly_data(i=idx_to_visualize, opacity=0.5),
]
for plot in plots:
    fig.add_trace(plot)
fig.show()

## Compute optimized canonicalized grasp

In [None]:
if should_canonicalize_hand_pose:
    (
        canonicalized_hand_pose,
        canonicalized_losses,
        canonicalized_debug_infos,
    ) = compute_optimized_canonicalized_hand_pose(
        hand_model=hand_model,
        object_model=object_model,
        device=device,
    )
    hand_model.set_parameters(canonicalized_hand_pose)

    fig_title = f"Canonicalized Grasp Code: {grasp_code}, Index: {index}"
    idx_to_visualize = batch_idx

    fig = go.Figure(
        layout=go.Layout(
            scene=dict(
                xaxis=dict(title="X"),
                yaxis=dict(title="Y"),
                zaxis=dict(title="Z"),
                aspectmode="data",
            ),
            showlegend=True,
            title=fig_title,
            autosize=False,
            width=800,
            height=800,
        )
    )

    canonicalized_target_points = canonicalized_debug_infos[-1]["target_points"]
    canonicalized_contact_points_hand = canonicalized_debug_infos[-1]["contact_points_hand"]
    canonicalized_closest_points_object = (
        canonicalized_debug_infos[-1]['contact_points_hand']
        - canonicalized_debug_infos[-1]['contact_normals'] * canonicalized_debug_infos[-1]['contact_distances'][..., None]
    )

    plots = [
        *hand_model.get_plotly_data(
            i=idx_to_visualize, opacity=1.0, with_contact_candidates=True
        ),
        *object_model.get_plotly_data(i=idx_to_visualize, opacity=0.5),
        go.Scatter3d(
            x=canonicalized_target_points[batch_idx, :, 0].detach().cpu().numpy(),
            y=canonicalized_target_points[batch_idx, :, 1].detach().cpu().numpy(),
            z=canonicalized_target_points[batch_idx, :, 2].detach().cpu().numpy(),
            mode="markers",
            marker=dict(size=10, color="red"),
            name="target_points",
        ),
        go.Scatter3d(
            x=canonicalized_contact_points_hand[batch_idx, :, 0].detach().cpu().numpy(),
            y=canonicalized_contact_points_hand[batch_idx, :, 1].detach().cpu().numpy(),
            z=canonicalized_contact_points_hand[batch_idx, :, 2].detach().cpu().numpy(),
            mode="markers",
            marker=dict(size=10, color="green"),
            name="contact_points_hand",
        ),
        # Draw blue line between closest points and contact points
        *[
            go.Scatter3d(
                x=[
                    canonicalized_closest_points_object[batch_idx, i, 0].detach().cpu().numpy(),
                    canonicalized_contact_points_hand[batch_idx, i, 0].detach().cpu().numpy(),
                ],
                y=[
                    canonicalized_closest_points_object[batch_idx, i, 1].detach().cpu().numpy(),
                    canonicalized_contact_points_hand[batch_idx, i, 1].detach().cpu().numpy(),
                ],
                z=[
                    canonicalized_closest_points_object[batch_idx, i, 2].detach().cpu().numpy(),
                    canonicalized_contact_points_hand[batch_idx, i, 2].detach().cpu().numpy(),
                ],
                mode="lines",
                line=dict(color="blue", width=5),
                name="contact_point_to_closest_point",
            )
            for i in range(canonicalized_closest_points_object.shape[1])
        ],
    ]
    for plot in plots:
        fig.add_trace(plot)
    fig.show()

In [None]:
if should_canonicalize_hand_pose:
    fig = px.line(y=canonicalized_losses)
    fig.update_layout(
        title="Canonicalized Loss vs. Iterations", xaxis_title="Iterations", yaxis_title="Loss"
    )
    fig.show()



## Compute optimized joint angle targets

In [None]:
original_hand_pose = hand_model.hand_pose.detach().clone()
print(f"original_hand_pose[:, 9:] = {original_hand_pose[:, 9:]}")

In [None]:
(
    joint_angle_targets_to_optimize,
    losses,
    debug_infos,
) = compute_optimized_joint_angle_targets(
    optimization_method=joint_angle_targets_optimization_method,
    hand_model=hand_model,
    object_model=object_model,
    device=device,
)
old_debug_info = debug_infos[0]
debug_info = debug_infos[-1]

In [None]:
fig = px.line(y=losses)
fig.update_layout(
    title=f"{joint_angle_targets_optimization_method} Loss vs. Iterations", xaxis_title="Iterations", yaxis_title="Loss"
)
fig.show()

In [None]:
print(f"joint_angle_targets_to_optimize = {joint_angle_targets_to_optimize}")

## Visualize hand pose before and after optimization

In [None]:
# Plotly fig
hand_model.set_parameters(original_hand_pose)
old_hand_model_plotly = hand_model.get_plotly_data(
    i=idx_to_visualize, opacity=1.0, with_contact_candidates=True
)

fig = make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
    subplot_titles=("Original", "Optimized"),
)
old_target_points = old_debug_info["target_points"]
old_contact_points_hand = old_debug_info["contact_points_hand"]

plots = [
    *old_hand_model_plotly,
    *object_model.get_plotly_data(i=idx_to_visualize, opacity=0.5),
    go.Scatter3d(
        x=old_target_points[batch_idx, :, 0].detach().cpu().numpy(),
        y=old_target_points[batch_idx, :, 1].detach().cpu().numpy(),
        z=old_target_points[batch_idx, :, 2].detach().cpu().numpy(),
        mode="markers",
        marker=dict(size=10, color="red"),
        name="target_points",
    ),
    go.Scatter3d(
        x=old_contact_points_hand[batch_idx, :, 0].detach().cpu().numpy(),
        y=old_contact_points_hand[batch_idx, :, 1].detach().cpu().numpy(),
        z=old_contact_points_hand[batch_idx, :, 2].detach().cpu().numpy(),
        mode="markers",
        marker=dict(size=10, color="green"),
        name="contact_points_hand",
    ),
]

for plot in plots:
    fig.append_trace(plot, row=1, col=1)

In [None]:

new_hand_pose = original_hand_pose.detach().clone()
new_hand_pose[:, 9:] = joint_angle_targets_to_optimize
hand_model.set_parameters(new_hand_pose)
new_hand_model_plotly = hand_model.get_plotly_data(
    i=idx_to_visualize, opacity=1.0, with_contact_candidates=True
)

new_target_points = debug_info["target_points"]
new_contact_points_hand = debug_info["contact_points_hand"]

plots = [
    *new_hand_model_plotly,
    *object_model.get_plotly_data(i=idx_to_visualize, opacity=0.5),
    go.Scatter3d(
        x=new_target_points[batch_idx, :, 0].detach().cpu().numpy(),
        y=new_target_points[batch_idx, :, 1].detach().cpu().numpy(),
        z=new_target_points[batch_idx, :, 2].detach().cpu().numpy(),
        mode="markers",
        marker=dict(size=10, color="red"),
        name="new_target_points",
    ),
    go.Scatter3d(
        x=new_contact_points_hand[batch_idx, :, 0].detach().cpu().numpy(),
        y=new_contact_points_hand[batch_idx, :, 1].detach().cpu().numpy(),
        z=new_contact_points_hand[batch_idx, :, 2].detach().cpu().numpy(),
        mode="markers",
        marker=dict(size=10, color="green"),
        name="contact_points_hand",
    ),
]

for plot in plots:
    fig.append_trace(plot, row=1, col=2)

fig.update_layout(
    autosize=False,
    width=1600,
    height=800,
    title_text=f"Optimization Method: {joint_angle_targets_optimization_method.name}",
)
fig.show()