In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
import datetime
import time
import pathlib
project_path = pathlib.Path('.').absolute().parent
os.sys.path.insert(1, str(project_path))

In [None]:
import pybullet
import numpy as np
import pinocchio as pin
import trifinger_simulation
from dotmap import DotMap

In [None]:
from cto.mcts.pvmcts import PolicyValueMCTS
from cto.mcts.pvnet import PolicyValueNet, ValueClassifier
from cto.params import get_default_params, update_params
from cto.contact_modes import construct_contact_plan
from cto.envs.trifinger import TriFingerAndCube
from cto.policy import OpenLoopPolicy

In [None]:
object_urdf = str(trifinger_simulation.get_data_dir()/'cube_v2'/'cube_v2.urdf')
params = get_default_params(object_urdf)
# finger_type="trifinger_meta" # with arena
finger_type="trifingernyu" # without arena
visualization = True
max_goal_orn_diff = np.pi/2

num_episodes = 10
max_budget_mcts = 20
verbose = True
log_path_suffix = None #"data/"

In [None]:
init_pose_dict={"position": np.array([0.17930479, 0.06879323, 0.0325    ]),
                "orientation": np.array([0.        , 0.        , 0.24076294, 0.97058395])}
goal_pose_dict={"position": np.array([0.06534854, -0.01630616,  0.0325]),
                "orientation": np.array([0.        ,  0.        , -0.45309934,  0.89146003])}

In [None]:
env = TriFingerAndCube(params, visualization=visualization, 
                       init_difficulty=-1, finger_type=finger_type)
policy = OpenLoopPolicy(env.action_space, env.finger, time_step=0.001)
for episode in range(num_episodes):
    observation_list = []
#     create the env and store init/goal object pose for reset
#     obs = env.reset(init_pose_dict=init_pose_dict, 
#                     max_goal_orn_diff=max_goal_orn_diff)
#     obs = env.reset(max_goal_orn_diff=max_goal_orn_diff)
    obs = env.reset(init_pose_dict=init_pose_dict, 
                    goal_pose_dict=goal_pose_dict)
#     obs = env.reset(goal_pose_dict=goal_pose_dict)
    xyz, quat = env.get_cube_pose()
    init_pose_dict = {"position":xyz, "orientation":quat}
    goal_pose_dict = {"position":env.goal[:3], "orientation":env.goal[3:]}
    
    # set the mcts problem
    pose_init = env.get_cube_pose_as_SE3()
    pose_goal = pin.XYZQUATToSE3(env.goal)
    delta = pin.log6(pose_goal.act(pose_init.inverse()))
    desired_poses = [pose_init, pose_goal]
    params = update_params(params, desired_poses, repr="SE3")

    # plan with an untrained mcts
    mcts = PolicyValueMCTS(params, env)
    mcts.run(state=[[0, 0, 0]], budget=max_budget_mcts, verbose=False)
    state, sol = mcts.get_solution()

    if state is None:
        print("no solution found in episdoe {}".format(episode))
        continue
        
    # if a solution is found, executet the plan in gym
    observation = env.reset(goal_pose_dict=goal_pose_dict, init_pose_dict=init_pose_dict)
    policy.reset()
    policy.set_trajs(observation, state, sol, params)

    for i in range(policy.x_des.shape[0]):
        action = policy.predict(observation)
        observation, reward, episode_done, info = env.step(action)
        policy_observation = policy.get_observation()
        is_done = policy.done or episode_done
        full_observation = {**observation, **policy_observation}
        observation_list.append(full_observation)
        
    final_pos_err = observation_list[-1]["achieved_goal_position_error"]
    final_orn_err = observation_list[-1]["achieved_goal_orientation_error"]
    if verbose:
        print("Final object position error: ", final_pos_err)
        print("Final object orientation error: ", final_orn_err)
        
    if (log_path_suffix is not None
        and final_pos_err <=0.02):
        now = datetime.datetime.now()
        log_path = log_path_suffix + now.strftime("%m%d_%H%M%S")
        np.savez_compressed(log_path, data=observation_list)
        print("Saved episode {} to {}".format(episode, log_path))
    