In [None]:
import os, sys
sys.path.append('/root/catkin_ws/src/primitives/')
import pickle
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import axes3d
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import trimesh
import networkx

from open3d import JVisualizer

import copy
import time
import argparse
import numpy as np
from multiprocessing import Process, Pipe, Queue
import pickle
import rospy
import copy
import signal
import open3d
from IPython import embed

from airobot import Robot
from airobot.utils import pb_util
from airobot.sensor.camera.rgbdcam_pybullet import RGBDCameraPybullet
from airobot.utils import common
import pybullet as p

from helper import util
from macro_actions import ClosedLoopMacroActions, YumiGelslimPybulet
# from closed_loop_eval import SingleArmPrimitives, DualArmPrimitives

from yacs.config import CfgNode as CN
from closed_loop_experiments import get_cfg_defaults

In [None]:
with open ('/root/catkin_ws/src/primitives/data/grasp/face_ind_test_0_fixed/2022.pkl', 'rb') as f:
    grasp_data = pickle.load(f)
# with open ('/root/catkin_ws/src/primitives/data/grasp/face_ind_test_0/1721.pkl', 'rb') as f:
#     grasp_data = pickle.load(f)
    
print(grasp_data.keys())

In [None]:
with open ('/root/catkin_ws/src/primitives/data/pull/face_ind_large_0_fixed/3.pkl', 'rb') as f:
    pull_data = pickle.load(f)

# with open ('/root/catkin_ws/src/primitives/data/pull/face_ind_large_0/3.pkl', 'rb') as f:
#     pull_data = pickle.load(f)

print(pull_data.keys())

# Training data: 
## Inputs: 
Initial observation, different representations:
- ```'start'```: pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```
- ```'keypoints_start'```: 3D location of box corners at start pose
- ```'obs':'pcd_pts'```: Point cloud of box from 3 different viewpoints, with same global coordinate sys. Can be fused with np.concatenate (see below)

Goal, different representations:
- ```'goal'```: pose
- ```'keypoints_goal'```: 3D location of box corners at goal pose

In [None]:
# start observation
start = pull_data['start']
keypoints_start = pull_data['keypoints_start']
pcd_pts = pull_data['obs']['pcd_pts']
pcd_pts_start = np.concatenate(pcd_pts, axis=0)

# goal
goal = pull_data['goal']
keypoints_goal = pull_data['keypoints_goal']

# Outputs:
## Pulling/Pushing (single arm)
Robot palm pose in the object frame, for active arm -- active arm currently based on which side of the table the object starts on. TODO perhaps includes predicting which arm is active, once we move to more diverse data. For now, everything for pulling happens with the right arm
- ```'contact_obj_frame'```: pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```, specified with respect to the coordinate system located at the object center of mass at the start pose

## Grasping/Pivoting (dual arm)

Right and left robot palm pose in the object frame
- ```'contact_obj_frame':'right'```: right palm pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```, specified with respect to the coordinate system located at the object center of mass at the start pose
- ```'contact_obj_frame':'left'```: left palm pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```, specified with respect to the coordinate system located at the object center of mass at the start pose

In [None]:
# right palm contact for pulling
contact_r = pull_data['contact_obj_frame']

# both palms contact for grasping
contact_r = grasp_data['contact_obj_frame']['right']
contact_l = grasp_data['contact_obj_frame']['left']

# Visualization

In [None]:
# with open('/root/catkin_ws/src/primitives/data/pull/face_ind_large_0/metadata.pkl', 'rb') as mf:
#     metadata = pickle.load(mf)
with open('/root/catkin_ws/src/primitives/data/grasp/face_ind_test_0/metadata.pkl', 'rb') as mf:
    metadata = pickle.load(mf)


print('Metadata keys: ')
dynamics_info = metadata['dynamics']
mesh_file = metadata['mesh_file']
palm_mesh_file = '/root/catkin_ws/src/config/descriptions/meshes/mpalm/mpalms_all_coarse.stl'
table_mesh_file = '/root/catkin_ws/src/config/descriptions/meshes/table/table_top.stl'
cfg = metadata['cfg']

# Visualize contact on object

In [None]:
def vis_palms(data, name='pull'):
    obj_mesh = trimesh.load_mesh(mesh_file)
    r_palm_mesh = trimesh.load_mesh(palm_mesh_file)
    l_palm_mesh = trimesh.load_mesh(palm_mesh_file)
    table_mesh = trimesh.load_mesh(table_mesh_file)
    
    obj_pos_world = data['start'][:3]
    obj_ori_world = data['start'][3:]
    obj_ori_mat = common.quat2rot(obj_ori_world)
    h_trans = np.zeros((4, 4))
    h_trans[:3, :3] = obj_ori_mat
    h_trans[:-1, -1] = obj_pos_world
    h_trans[-1, -1] = 1

    obj_mesh.apply_transform(h_trans)
    if name == 'pull':
        tip_contact_r_obj = util.list2pose_stamped(data['contact_obj_frame'])
        tip_contact_r = util.convert_reference_frame(
            pose_source=tip_contact_r_obj,
            pose_frame_target=util.unit_pose(),
            pose_frame_source=util.list2pose_stamped(data['start']))
        
        tip_contact_r = util.list2pose_stamped(data['contact_world_frame'])

        wrist_contact_r = util.convert_reference_frame(
            pose_source=util.list2pose_stamped(cfg.TIP_TO_WRIST_TF),
            pose_frame_target=util.unit_pose(),
            pose_frame_source=tip_contact_r)

        wrist_contact_r_list = util.pose_stamped2list(wrist_contact_r)
        
        palm_pos_world_r = wrist_contact_r_list[:3]
        palm_ori_world_r = wrist_contact_r_list[3:]
        palm_ori_mat = common.quat2rot(palm_ori_world_r)
        h_trans = np.zeros((4, 4))
        h_trans[:3, :3] = palm_ori_mat
        h_trans[:-1, -1] = palm_pos_world_r
        h_trans[-1, -1] = 1

        r_palm_mesh.apply_transform(h_trans)      
        
        scene = trimesh.Scene([obj_mesh, r_palm_mesh, table_mesh])        
    else:
        tip_contact_r_obj = util.list2pose_stamped(data['contact_obj_frame']['right'])
        tip_contact_l_obj = util.list2pose_stamped(data['contact_obj_frame']['left'])

        tip_contact_r = util.convert_reference_frame(
            pose_source=tip_contact_r_obj,
            pose_frame_target=util.unit_pose(),
            pose_frame_source=util.list2pose_stamped(data['start']))
            
        tip_contact_l = util.convert_reference_frame(
            pose_source=tip_contact_l_obj,
            pose_frame_target=util.unit_pose(),
            pose_frame_source=util.list2pose_stamped(data['start']))
        
        tip_contact_r = util.list2pose_stamped(data['contact_world_frame']['right'])
        tip_contact_l = util.list2pose_stamped(data['contact_world_frame']['left'])
            
        wrist_contact_r = util.convert_reference_frame(
            pose_source=util.list2pose_stamped(cfg.TIP_TO_WRIST_TF),
            pose_frame_target=util.unit_pose(),
            pose_frame_source=tip_contact_r)

        wrist_contact_l = util.convert_reference_frame(
            pose_source=util.list2pose_stamped(cfg.TIP_TO_WRIST_TF),
            pose_frame_target=util.unit_pose(),
            pose_frame_source=tip_contact_l)

        wrist_contact_r_list = util.pose_stamped2list(wrist_contact_r)
        wrist_contact_l_list = util.pose_stamped2list(wrist_contact_l)
        
        palm_pos_world_r = wrist_contact_r_list[:3]
        palm_ori_world_r = wrist_contact_r_list[3:]
        palm_ori_mat = common.quat2rot(palm_ori_world_r)
        h_trans = np.zeros((4, 4))
        h_trans[:3, :3] = palm_ori_mat
        h_trans[:-1, -1] = palm_pos_world_r
        h_trans[-1, -1] = 1

        r_palm_mesh.apply_transform(h_trans)
        
        palm_pos_world_l = wrist_contact_l_list[:3]
        palm_ori_world_l = wrist_contact_l_list[3:]
        palm_ori_mat = common.quat2rot(palm_ori_world_l)
        h_trans = np.zeros((4, 4))
        h_trans[:3, :3] = palm_ori_mat
        h_trans[:-1, -1] = palm_pos_world_l
        h_trans[-1, -1] = 1

        l_palm_mesh.apply_transform(h_trans)        
        
        scene = trimesh.Scene([obj_mesh, r_palm_mesh, l_palm_mesh, table_mesh])
    return scene

In [None]:
# scene = vis_palms(pull_data, name='pull')
scene = vis_palms(grasp_data, name='grasp')

scene.show(viewer='gl')

# Visualize point cloud

In [None]:
pcd = pull_data['obs']['pcd_full']
pcd_pts = pull_data['obs']['pcd_pts']
obj_pcd_pts_cat = np.concatenate(pcd_pts, axis=0)

In [None]:
obj_pcd_pts_cat.shape

In [None]:
fig = go.Figure(data=go.Scatter3d(
    x=obj_pcd_pts_cat[:, 0],
    y=obj_pcd_pts_cat[:, 1],
    z=obj_pcd_pts_cat[:, 2]))

fig.update_xaxes(range=[-0.25, 0.7])
fig.update_yaxes(range=[-0.75, 0.75])

fig.show()

fig.update_xaxes(range=[-0.25, 0.7])
fig.update_yaxes(range=[-0.75, 0.75])

In [None]:
# visualizer = JVisualizer()
# visualizer.add_geometry([pcd])
# visualizer.show()

# open3d.visualization.draw_geometries([pcd])

# Visualize image observations

In [None]:
im_fig = plt.figure()

ax1 = im_fig.add_subplot(311)
ax1.imshow(pull_data['obs']['rgb'][0])

ax2 = im_fig.add_subplot(312)
ax2.imshow(pull_data['obs']['rgb'][1])

ax3 = im_fig.add_subplot(313)
ax3.imshow(pull_data['obs']['rgb'][2])

plt.show()

# Visualize keypoints

In [None]:
kp_fig_0 = go.Figure(data=go.Scatter3d(
        x=pull_data['keypoints_start'][:, 0],
        y=pull_data['keypoints_start'][:, 1],
        z=pull_data['keypoints_start'][:, 2],
        mode='markers'))

kp_fig_0.show()

In [None]:
kp_fig = make_subplots(rows=1, cols=2,  specs=[[{"type": "scene"}, {"type": "scene"}]])

kp_fig.add_trace(go.Scatter3d(
        x=pull_data['keypoints_start'][:, 0],
        y=pull_data['keypoints_start'][:, 1],
        z=pull_data['keypoints_start'][:, 2],
        mode='markers'), 
    row=1, col=1)

kp_fig.add_trace(go.Scatter3d(
        x=pull_data['keypoints_goal'][:, 0],
        y=pull_data['keypoints_goal'][:, 1],
        z=pull_data['keypoints_goal'][:, 2],
        mode='markers'),
    row=1, col=2)

kp_fig.show()

In [None]:
# TODO try visulizing the meshes with the table for showing the poses

# Visualize trajectory in PyBullet

In [None]:
data = copy.deepcopy(grasp_data)

In [None]:
# # with open ('/root/catkin_ws/src/primitives/data/grasp/face_ind_test_0/2.pkl', 'rb') as f:
# #     data = pickle.load(f)

# with open ('/root/catkin_ws/src/primitives/data/pull/face_ind_large_0_fixed/105.pkl', 'rb') as f:
#     data = pickle.load(f)
    
# print(data.keys())

In [None]:
rospy.init_node("test")

In [None]:
yumi = Robot('yumi_palms', arm_cfg={'render': True, 'self_collision': False, 'rt_simulation': True})

In [None]:
yumi.arm.go_home()

In [None]:
yumi.arm.set_jpos(cfg.RIGHT_INIT + cfg.LEFT_INIT)

In [None]:
gel_id = 12

p.changeDynamics(
    yumi.arm.robot_id,
    gel_id,
    restitution=dynamics_info['restitution'],
    contactStiffness=dynamics_info['contactStiffness'],
    contactDamping=dynamics_info['contactDamping'],
    rollingFriction=dynamics_info['rollingFriction']
)

In [None]:
yumi_gs = YumiGelslimPybulet(
    yumi,
    cfg)

In [None]:
box_id = pb_util.load_urdf(
    '/root/catkin_ws/src/config/descriptions/urdf/realsense_box.urdf',
    cfg.OBJECT_POSE_3[0:3],
    cfg.OBJECT_POSE_3[3:]
)

In [None]:
new_args = {}
new_args['object_pose1_world'] = util.list2pose_stamped(data['start'])
new_args['object_pose2_world'] = util.list2pose_stamped(data['goal'])
new_args['primitive_name'] = 'grasp'
# new_args['palm_pose_r_object'] = util.list2pose_stamped(data['contact_obj_frame'])
# new_args['palm_pose_l_object'] = util.list2pose_stamped(cfg.PALM_LEFT)
new_args['palm_pose_r_object'] = util.list2pose_stamped(data['contact_obj_frame']['right'])
new_args['palm_pose_l_object'] = util.list2pose_stamped(data['contact_obj_frame']['left'])
new_args['object'] = None
new_args['init'] = True
new_args['N'] = 60
new_args['table_face'] = 0

In [None]:
planner_args = new_args

In [None]:
config_pkg_path = '/root/catkin_ws/src/config/'

action_planner = ClosedLoopMacroActions(
    cfg,
    yumi_gs,
    box_id,
    pb_util.PB_CLIENT,
    config_pkg_path,
    object_mesh_file=mesh_file,
    replan=False
)

# planner_args = data['planner_args']
primitive_name = planner_args['primitive_name']
object_start_pose_list = data['start']
object_goal_pose_list = data['goal']

pb_util.reset_body(
    body_id=box_id, 
    base_pos=object_start_pose_list[:3],
    base_quat=object_start_pose_list[3:])

In [None]:
result = action_planner.execute(primitive_name=primitive_name, execute_args=planner_args)

In [None]:
print(result)

In [None]:
pb_util.remove_body(box_id)

In [None]:
yumi_gs.update_joints(cfg.RIGHT_INIT + cfg.LEFT_INIT)