In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import isaacgym
import isaacgymenvs
from isaacgymenvs.utils.reformat import omegaconf_to_dict, print_dict
from isaacgymenvs.utils.utils import set_np_formatting, set_seed
from isaacgymenvs.utils.rlgames_utils import RLGPUEnv, RLGPUAlgoObserver, get_rlgames_env_creator

from rl_games.common import env_configurations, vecenv
from rl_games.torch_runner import Runner
from rl_games.algos_torch import model_builder

from omegaconf import DictConfig, OmegaConf

import torch
import numpy as np
import matplotlib.pyplot as plt

from vecrobotics import *
from fista import QP, FISTA, ForceQP

In [None]:
cfg = OmegaConf.load("../isaacgymenvs/cfg/config.yaml")
cfg.task_name = "TrifingerNYU"
cfg.num_envs = 1
cfg.task = OmegaConf.load("../isaacgymenvs/cfg/task/TrifingerNYU.yaml")
cfg.task.env.command_mode = "fingertip_diff_force"
cfg.headless = True

In [None]:
def create_env_thunk(**kwargs):
    envs = isaacgymenvs.make(
            cfg.seed, 
            cfg.task_name, 
            cfg.task.env.numEnvs, 
            cfg.sim_device,
            cfg.rl_device,
            cfg.graphics_device_id,
            cfg.headless,
            cfg.multi_gpu,
            cfg.capture_video,
            cfg.force_render,
            cfg,
            **kwargs,
        )
    return envs

In [None]:
device = cfg.sim_device
envs = create_env_thunk()

In [None]:
lifting_data = np.load("data/lifting.npz", allow_pickle=True)["data"]

In [None]:
# get ftip states
envs.reset_idx(torch.arange(cfg.num_envs))
N = 570
action_buffer = torch.zeros(N, 18).to(device)
ftip_pos_buffer = torch.zeros(N, 3, 3).to(device)
ftip_pos_local_buffer = torch.zeros(N, 3, 3).to(device)
object_pose_buffer = torch.zeros(N, 7).to(device)
contact_normals_buffer = []

q_buffer = torch.zeros(N, 9).to(device)
dq_buffer = torch.zeros(N, 9).to(device)
obs, rwds, resets, info = envs.step(torch.zeros(cfg.num_envs, 18))

In [None]:
def get_cube_contact_normals(ftip_pos, threshold=0.0435):
    batch_size = len(ftip_pos)
    contact_normals = torch.zeros(batch_size, 3).to(ftip_pos.device)
        
    _, max_indices = torch.max(torch.abs(ftip_pos), dim=1)
    max_values = torch.squeeze(torch.gather(ftip_pos, 1, max_indices.unsqueeze(1)))

    mask_pos = (torch.abs(max_values) <= threshold) * (max_values > 0)
    mask_neg = (torch.abs(max_values) <= threshold) * (max_values < 0)

    # contact normal points to the same direction as the contact force, hence into the object
    contact_normals[mask_pos, max_indices[mask_pos]] = -1.0
    contact_normals[mask_neg, max_indices[mask_neg]] = 1.0
    
    return contact_normals

def get_contact_frame_orn(contact_normals: torch.Tensor):
    # get the orientation of the contact frames expressed in the object frame
    z_axis = contact_normals
    zero_indices = torch.argmax(torch.eq(z_axis, 0).int(), dim=1)
    y_axis = torch.eye(3).to(z_axis.device)[zero_indices]
    x_axis = torch.cross(y_axis, z_axis, dim=1)
    y_axis = torch.cross(z_axis, x_axis, dim=1) # this makes sure if z is all zero, then orn is a zero matrix
    orn = torch.stack((x_axis, y_axis, z_axis), dim=2)
    return orn

def hstacked_SO3_transform(mat, vec):
    batch_size, d1, d2 = mat.shape
    stack_size = d2 // 3
    transformed = torch.zeros_like(vec)
    for i in range(stack_size):
        selection = torch.zeros_like(vec)
        selection[:, 3*i:3*(i+1)] = 1
        transformed[:, 3*i:3*(i+1)] = bmv(mat, vec * selection)
    return transformed

def get_force_qp_data(ftip_pos, object_pose, mg, weights=[1, 200, 1e-4]):
    # get ftip positin in the object frame
    batch_size, num_ftip, _ = ftip_pos.shape
    p = world2local(object_pose.repeat_interleave(3, dim=0), 
                    ftip_pos.view(-1, 3))
    contact_normals = get_cube_contact_normals(p)
    R = get_contact_frame_orn(contact_normals)
    R_vstacked = R.transpose(1, 2).reshape(-1, 3 * num_ftip, 3)
    Q1 = R_vstacked @ R_vstacked.transpose(1, 2)
    
    pxR = vec2skew_sym_mat(p) @ R
    pxR_vstacked = pxR.transpose(1, 2).reshape(-1, 3 * num_ftip, 3)
    Q2 = pxR_vstacked @ pxR_vstacked.transpose(1, 2)
    
    w1, w2, w3 = weights
    Q = w1 * Q1 + w2 * Q2 + w3 * torch.eye(3 * num_ftip).repeat(batch_size, 1, 1).to(Q1.device)

    # for Q == 0, hence R1, R2, R3 == 0, fill the diagnoal of Q with ones. This produces f == 0
    reshaped_tensor = Q.view(num_batches, -1)
    diagonal_elements_zero = torch.all(reshaped_tensor[:, ::num_vars+1] == 0, dim=1)
    mask = diagonal_elements_zero[:, None].repeat(1, num_vars)
    diag_idx = torch.arange(num_vars)
    Q[:, diag_idx, diag_idx] = Q[:, diag_idx, diag_idx].masked_fill_(mask, 1)
    
    object_orn = quat2mat(object_pose[:, 3:])
    mg_local = bmv(object_orn.transpose(1,2), mg)
    q = -2 * bmv(R_vstacked, mg_local)
    
    return Q, q, R_vstacked, pxR_vstacked, contact_normals

In [None]:
# construct force QP
num_batches = cfg.num_envs
num_vars = 9
lb = -10 * torch.ones(num_batches, num_vars)
ub = 10 * torch.ones(num_batches, num_vars)
mg = torch.tensor([0, 0, 9.81]).repeat(num_batches, 1).to(device)

prob = ForceQP(num_batches, num_vars, friction_coeff=1.0, device=device)
solver = FISTA(prob, device=device)
max_it = 50

for n in range(N):
    q = envs._dof_position
    dq = envs._dof_velocity
    q_buffer[n] = q[0]
    dq_buffer[n] = dq[0]
    
    ftip_state = envs._rigid_body_state[:, envs._fingertip_indices]
    ftip_pos = ftip_state[:, :, 0:3]
    ftip_pos_buffer[n] = ftip_pos[0]
    
    object_pose = envs._object_state_history[0][:, 0:7]
    object_orn = quat2mat(object_pose[:, 3:])
    object_pose_buffer[n] = object_pose[0]

    for i in range(3):
        ftip_pos_local_buffer[n, i] = world2local(object_pose, ftip_pos[:, i, :])
        
    # set up force qp
    Q, q, R_vstacked, pxR_vstacked, contact_normals = get_force_qp_data(ftip_pos, object_pose, mg)    
    prob.set_data(Q, q, lb, ub)
    solver.reset()
    for i in range(max_it):
        solver.step()
    ftip_force_contact_frame = solver.prob.yk.clone()
    contact_normals_buffer.append(contact_normals)
    
    # convert force to the world frame
    R = R_vstacked.reshape(-1, 3, 3).transpose(1, 2)
    ftip_force_object_frame = stacked_bmv(R, ftip_force_contact_frame)
    ftip_force_des = stacked_bmv(object_orn.repeat(3, 1, 1), ftip_force_object_frame)
    ftip_pos_des = torch.tensor(lifting_data[20 * n]['policy']['controller']['ft_pos_des'], dtype=torch.float32).to(device)
    
    action = torch.zeros(cfg.num_envs, 18)
    action[:, :9] = ftip_pos_des.view(cfg.num_envs, 9) - ftip_pos.reshape(cfg.num_envs, 9)
    if len(contact_normals.nonzero()) == 3:
        action[:, 9:] = ftip_force_des
     
    action_buffer[n] = action[0]
    obs, rwds, resets, info = envs.step(action)