In [None]:
import isaacgym
import isaacgymenvs
from isaacgymenvs.utils.reformat import omegaconf_to_dict, print_dict
from isaacgymenvs.utils.utils import set_np_formatting, set_seed
from isaacgymenvs.utils.rlgames_utils import RLGPUEnv, RLGPUAlgoObserver, get_rlgames_env_creator

from rl_games.common import env_configurations, vecenv
from rl_games.torch_runner import Runner
from rl_games.algos_torch import model_builder

from omegaconf import DictConfig, OmegaConf

import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
cfg = OmegaConf.load("../isaacgymenvs/cfg/config.yaml")
cfg.task_name = "TrifingerNYU"
cfg.num_envs = 1
cfg.task = OmegaConf.load("../isaacgymenvs/cfg/task/TrifingerNYU.yaml")
cfg.headless = True

In [None]:
device = cfg.sim_device

In [None]:
def create_env_thunk(**kwargs):
    envs = isaacgymenvs.make(
            cfg.seed, 
            cfg.task_name, 
            cfg.task.env.numEnvs, 
            cfg.sim_device,
            cfg.rl_device,
            cfg.graphics_device_id,
            cfg.headless,
            cfg.multi_gpu,
            cfg.capture_video,
            cfg.force_render,
            cfg,
            **kwargs,
        )
    return envs

In [None]:
envs = create_env_thunk()

In [None]:
from isaacgym.torch_utils import *

def bmv(mat: torch.Tensor, vec: torch.Tensor):
    return torch.einsum('bij, bj -> bi', mat, vec)

def quat2mat(quat: torch.Tensor):
    def _quat2mat(x, y, z, w):
        x2, y2, z2 = x**2, y**2, z**2
        wx, wy, wz = w*x, w*y, w*z
        xy, xz, yz = x*y, x*z, y*z
        rotation_matrix = torch.stack([
            1-2*y2-2*z2, 2*(xy-wz), 2*(xz+wy),
            2*(xy+wz), 1-2*x2-2*z2, 2*(yz-wx),
            2*(xz-wy), 2*(yz+wx), 1-2*x2-2*y2]
        )
        return rotation_matrix.view(3, 3)
    
    x, y, z, w = torch.unbind(quat, dim=-1)
    
    return torch.vmap(_quat2mat)(x, y, z, w)

def local2world(
    local_frame_pose: torch.Tensor,
    position_local: torch.Tensor
):
    local_frame_pos = local_frame_pose[:, 0:3]
    local_frame_orn = local_frame_pose[:, 3:7]
    rot = quat2mat(local_frame_orn)
    
    position_world = local_frame_pos + bmv(rot, position_local)

    return position_world

def world2local(
    local_frame_pose: torch.Tensor,
    position_world: torch.Tensor
):
    local_frame_pos = local_frame_pose[:, 0:3]
    local_frame_orn = local_frame_pose[:, 3:7]
    rot = quat2mat(local_frame_orn)
    rot_inv = torch.transpose(rot, 1, 2)
    
    position_local = -bmv(rot_inv, local_frame_pos) + bmv(rot_inv, position_world)

    return position_local


In [None]:
# get fingertip states
N = 500
action_buffer = torch.zeros(N, 9).to(device)
ftip_pos_buffer = torch.zeros(N, 3, 3).to(device)
ftip_pos_local_buffer = torch.zeros(N, 3, 3).to(device)

In [None]:
for n in range(N):
    q = envs._dof_position
    dq = envs._dof_velocity
    
    fingertip_state = envs._rigid_body_state[:, envs._fingertip_indices]
    fingertip_position = fingertip_state[:, :, 0:3]
    ftip_pos_buffer[n] = fingertip_position[0]
    
    object_pose = envs._object_state_history[0][:, 0:7]
    object_pos = object_pose[:, 0:3]
    pos_diff = object_pos.repeat(1, 3) - fingertip_position.reshape(cfg.num_envs, 9)
    
    for i in range(3):
        ftip_pos_local_buffer[n, i] = world2local(object_pose, fingertip_position[:, i, :])
    
    max_abs_val = torch.max(torch.abs(pos_diff))
    normalized_vec = pos_diff / max_abs_val
    action = 2 * normalized_vec - 1

    action_buffer[n] = action[0]
    obs, rwds, resets, info = envs.step(action)
    

In [None]:
def get_cube_contact_normals(ftip_pos, threshold=0.0435):
    batch_size = len(ftip_pos)
    contact_normals = torch.zeros(batch_size, 3)
        
    _, max_indices = torch.max(torch.abs(ftip_pos), dim=1)
    max_values = torch.squeeze(torch.gather(ftip_pos, 1, max_indices.unsqueeze(1)))

    mask_pos = (torch.abs(max_values) <= threshold) * (max_values > 0)
    mask_neg = (torch.abs(max_values) <= threshold) * (max_values < 0)

    contact_normals[mask_pos, max_indices[mask_pos]] = 1.0
    contact_normals[mask_neg, max_indices[mask_neg]] = -1.0
    
    return contact_normals

In [None]:
ftip_pos = ftip_pos_local_buffer[:, 0, :]
contact_normals = get_cube_contact_normals(ftip_pos)