In [27]:
import os, sys
sys.path.append('/root/catkin_ws/src/primitives/')
import pickle
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import axes3d
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import trimesh
import networkx

from open3d import JVisualizer

import copy
import time
import argparse
import numpy as np
from multiprocessing import Process, Pipe, Queue
import pickle
import rospy
import copy
import signal
import open3d
from IPython import embed

from yacs.config import CfgNode as CN
from closed_loop_experiments import get_cfg_defaults

from airobot import Robot
from airobot.utils import pb_util
from airobot.sensor.camera.rgbdcam_pybullet import RGBDCameraPybullet
from airobot.utils import common
import pybullet as p

from helper import util
from macro_actions import ClosedLoopMacroActions, YumiGelslimPybulet
# from closed_loop_eval import SingleArmPrimitives, DualArmPrimitives

In [2]:
sys.path.append('/root/training/')

import os
import argparse
import time
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from data_loader import DataLoader
from model import VAE
from util import to_var, save_state, load_net_state, load_seed, load_opt_state

In [None]:
with open ('/root/catkin_ws/src/primitives/data/grasp/face_ind_test_0/2.pkl', 'rb') as f:
    grasp_data = pickle.load(f)
    
print(grasp_data.keys())

In [3]:
with open ('/root/catkin_ws/src/primitives/data/pull/face_ind_large_0_fixed/1003.pkl', 'rb') as f:
    pull_data = pickle.load(f)
    
print(pull_data.keys())

['goal', 'contact_obj_frame', 'contact_world_frame', 'transformation', 'start', 'result', 'keypoints_start', 'contact_pcd', 'keypoints_goal', 'obs']


# Training data: 
## Inputs: 
Initial observation, different representations:
- ```'start'```: pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```
- ```'keypoints_start'```: 3D location of box corners at start pose
- ```'obs':'pcd_pts'```: Point cloud of box from 3 different viewpoints, with same global coordinate sys. Can be fused with np.concatenate (see below)

Goal, different representations:
- ```'goal'```: pose
- ```'keypoints_goal'```: 3D location of box corners at goal pose

In [4]:
# start observation
start = pull_data['start']
keypoints_start = pull_data['keypoints_start']
pcd_pts = pull_data['obs']['pcd_pts']
pcd_pts_start = np.concatenate(pcd_pts, axis=0)

# goal
goal = pull_data['goal']
keypoints_goal = pull_data['keypoints_goal']

# Outputs:
## Pulling/Pushing (single arm)
Robot palm pose in the object frame, for active arm -- active arm currently based on which side of the table the object starts on. TODO perhaps includes predicting which arm is active, once we move to more diverse data. For now, everything for pulling happens with the right arm
- ```'contact_obj_frame'```: pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```, specified with respect to the coordinate system located at the object center of mass at the start pose

## Grasping/Pivoting (dual arm)

Right and left robot palm pose in the object frame
- ```'contact_obj_frame':'right'```: right palm pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```, specified with respect to the coordinate system located at the object center of mass at the start pose
- ```'contact_obj_frame':'left'```: left palm pose ```[x_pos, y_pos, z_pos, x_ori, y_ori, z_ori, w_ori]```, specified with respect to the coordinate system located at the object center of mass at the start pose

In [5]:
# right palm contact for pulling
contact_r = pull_data['contact_obj_frame']

# both palms contact for grasping
contact_r = grasp_data['contact_obj_frame']['right']
contact_l = grasp_data['contact_obj_frame']['left']

NameError: name 'grasp_data' is not defined

# Training

In [6]:
## helpers

# load minibatch
# def load_minibatch(M, datatype='pose')


## setup

In [7]:
# load model architecture (encoder and decoder)
# setup optimizer

# setup loss function

## train loop

# for epoch in num_epochs:
# for minibatch in minibatch_size / data_size:
# forward pass, compute loss, backprop, optimizer.step
# look at loss

# Testing

In [None]:
## setup

# distribution to sample from for start/goal (same as in training)

# load trained model
vae = VAE(14, 7, 3, 0.0003)
load_net_state(vae, '/root/training/saved_models/pose_init_small_batch_0/pose_init_small_batch_0_epoch_10.pt')

## eval

In [11]:
# sample from the latent space and use the decoder/generator to produce contacts, visualize below
x = torch.from_numpy(np.asarray(start+goal, dtype=np.float32))
output = vae.forward(x)

In [13]:
output[1]

tensor([ 0.0358, -0.0496,  0.0493,  2.2993,  2.8404,  2.3054,  2.3719],
       grad_fn=<AddBackward0>)

In [17]:
pos = output[1][:3].data.cpu().numpy()
ori = output[1][3:].data.cpu().numpy()

In [18]:
ori = ori/np.linalg.norm(ori)

In [19]:
print(pos)
print(ori)

[ 0.03581019 -0.04964388  0.04928012]
[0.46648335 0.57625514 0.46771353 0.4812144 ]


In [20]:
pull_data_eval = copy.deepcopy(pull_data)
pull_data_eval['contact_obj_frame'] = pos.tolist() + ori.tolist()

# Visualization

In [24]:
with open('/root/catkin_ws/src/primitives/data/pull/face_ind_large_0/metadata.pkl', 'rb') as mf:
    metadata = pickle.load(mf)
# with open('/root/catkin_ws/src/primitives/data/grasp/face_ind_test_0/metadata.pkl', 'rb') as mf:
#     metadata = pickle.load(mf)


print('Metadata keys: ')
dynamics_info = metadata['dynamics']
mesh_file = metadata['mesh_file']
palm_mesh_file = '/root/catkin_ws/src/config/descriptions/meshes/mpalm/mpalms_all_coarse.stl'
table_mesh_file = '/root/catkin_ws/src/config/descriptions/meshes/table/table_top.stl'
cfg = metadata['cfg']

Metadata keys: 


# Visualize contact on object

In [28]:
def vis_palms(data, name='pull'):
    obj_mesh = trimesh.load_mesh(mesh_file)
    r_palm_mesh = trimesh.load_mesh(palm_mesh_file)
    l_palm_mesh = trimesh.load_mesh(palm_mesh_file)
    table_mesh = trimesh.load_mesh(table_mesh_file)
    
    obj_pos_world = data['start'][:3]
    obj_ori_world = data['start'][3:]
    obj_ori_mat = common.quat2rot(obj_ori_world)
    h_trans = np.zeros((4, 4))
    h_trans[:3, :3] = obj_ori_mat
    h_trans[:-1, -1] = obj_pos_world
    h_trans[-1, -1] = 1

    obj_mesh.apply_transform(h_trans)
    if name == 'pull':
        tip_contact_r_obj = util.list2pose_stamped(data['contact_obj_frame'])
        tip_contact_r = util.convert_reference_frame(
            pose_source=tip_contact_r_obj,
            pose_frame_target=util.unit_pose(),
            pose_frame_source=util.list2pose_stamped(data['start']))

        wrist_contact_r = util.convert_reference_frame(
            pose_source=util.list2pose_stamped(cfg.TIP_TO_WRIST_TF),
            pose_frame_target=util.unit_pose(),
            pose_frame_source=tip_contact_r)

        wrist_contact_r_list = util.pose_stamped2list(wrist_contact_r)
        
        palm_pos_world_r = wrist_contact_r_list[:3]
        palm_ori_world_r = wrist_contact_r_list[3:]
        palm_ori_mat = common.quat2rot(palm_ori_world_r)
        h_trans = np.zeros((4, 4))
        h_trans[:3, :3] = palm_ori_mat
        h_trans[:-1, -1] = palm_pos_world_r
        h_trans[-1, -1] = 1

        r_palm_mesh.apply_transform(h_trans)      
        
        scene = trimesh.Scene([obj_mesh, r_palm_mesh, table_mesh])        
    else:
        tip_contact_r_obj = util.list2pose_stamped(data['contact_obj_frame']['right'])
        tip_contact_l_obj = util.list2pose_stamped(data['contact_obj_frame']['left'])

        tip_contact_r = util.convert_reference_frame(
            pose_source=tip_contact_r_obj,
            pose_frame_target=util.unit_pose(),
            pose_frame_source=util.list2pose_stamped(data['start']))
            
        tip_contact_l = util.convert_reference_frame(
            pose_source=tip_contact_l_obj,
            pose_frame_target=util.unit_pose(),
            pose_frame_source=util.list2pose_stamped(data['start']))            
            
        wrist_contact_r = util.convert_reference_frame(
            pose_source=util.list2pose_stamped(cfg.TIP_TO_WRIST_TF),
            pose_frame_target=util.unit_pose(),
            pose_frame_source=tip_contact_r)

        wrist_contact_l = util.convert_reference_frame(
            pose_source=util.list2pose_stamped(cfg.TIP_TO_WRIST_TF),
            pose_frame_target=util.unit_pose(),
            pose_frame_source=tip_contact_l)

        wrist_contact_r_list = util.pose_stamped2list(wrist_contact_r)
        wrist_contact_l_list = util.pose_stamped2list(wrist_contact_l)
        
        palm_pos_world_r = wrist_contact_r_list[:3]
        palm_ori_world_r = wrist_contact_r_list[3:]
        palm_ori_mat = common.quat2rot(palm_ori_world_r)
        h_trans = np.zeros((4, 4))
        h_trans[:3, :3] = palm_ori_mat
        h_trans[:-1, -1] = palm_pos_world_r
        h_trans[-1, -1] = 1

        r_palm_mesh.apply_transform(h_trans)
        
        palm_pos_world_l = wrist_contact_l_list[:3]
        palm_ori_world_l = wrist_contact_l_list[3:]
        palm_ori_mat = common.quat2rot(palm_ori_world_l)
        h_trans = np.zeros((4, 4))
        h_trans[:3, :3] = palm_ori_mat
        h_trans[:-1, -1] = palm_pos_world_l
        h_trans[-1, -1] = 1

        l_palm_mesh.apply_transform(h_trans)        
        
        scene = trimesh.Scene([obj_mesh, r_palm_mesh, l_palm_mesh, table_mesh])
    return scene

In [29]:
scene = vis_palms(pull_data_eval, name='pull')
# scene = vis_palms(grasp_data, name='grasp')

scene.show(viewer='gl')

SceneViewer(width=1800, height=1350)

In [31]:
print(pull_data['contact_obj_frame'])
print(pull_data_eval['contact_obj_frame'])

[0.04103590457255263, -0.03249568473897843, 0.026030000299215317, 0.37224725603208564, 0.6011921326635795, 0.6011921326635795, 0.3722472560320857]
[0.0358101949095726, -0.04964388161897659, 0.04928012192249298, 0.46648335456848145, 0.5762551426887512, 0.46771353483200073, 0.48121440410614014]
