In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
project_path = pathlib.Path('.').absolute().parent
python_path = project_path/'src'
os.sys.path.insert(1, str(python_path))

In [None]:
from dotmap import DotMap
from tqdm.notebook import trange
import torch
    
import numpy as np
import pinocchio as pin
import pybullet

In [None]:
from contact_mcts.objects import Cube
from contact_mcts.envs.fingers import FingerDoubleAndBox
from contact_mcts.pvmcts import PolicyValueMCTS
from contact_mcts.trajectory import generate_random_poses
from contact_mcts.params import get_default_params, update_params
from robot_properties_nyu_finger.config import NYUFingerDoubleConfig0, NYUFingerDoubleConfig1

In [None]:
object_urdf = str(python_path/'contact_mcts'/'envs'/'resources'/'box.urdf')
robot_config = [NYUFingerDoubleConfig0(), NYUFingerDoubleConfig1()]
params = get_default_params(object_urdf, robot_config)

In [None]:
# generate random poses
n_desired_poses = 3
params.num_contact_modes = 6 * (n_desired_poses - 1)

lb = np.array([-0.1, -0.1, 0.1, 0., 0., -np.pi])
ub = np.array([ 0.1,  0.1, 0.1, 0., 0.,  np.pi])

diff_lb = np.array([-0.05, -0.05, 0, 0., 0., -np.pi / 4])
diff_ub = np.array([ 0.05,  0.05, 0, 0., 0.,  np.pi / 4])

In [None]:
states = []
values = []
action_probs = []
goals = []
eps = 1e-3
failed_tasks = []

ntasks = 300
max_budget = 200

for _ in trange(ntasks):
    desired_poses = generate_random_poses(n_desired_poses, lb, ub, diff_lb, diff_ub)
    params = update_params(params, desired_poses)
    pose_init = pin.SE3ToXYZQUAT(params.desired_poses[0])
    box_pos = pose_init[:3]
    box_orn = pose_init[3:]
    env = FingerDoubleAndBox(params, box_pos, box_orn, server=pybullet.DIRECT)
    
    mcts = PolicyValueMCTS(params, env)
    state = [[0, 0]]
    mcts.train(state, budget=max_budget, verbose=False)
    best_state, _  = mcts.get_solution()
    
    if best_state is None:
        print('failed')
        failed_tasks.append(desired_poses)
    else:
        states += mcts.get_data()[0]
        values += mcts.get_data()[1]
        action_probs += mcts.get_data()[2]
        goals += mcts.get_data()[3]
    env.close()

In [None]:
from contact_mcts.pvmcts import MCTSDataset
data = MCTSDataset(states, values, action_probs, goals)
torch.save(data, '../data/'+'data1.pt')