In [None]:
from copy import deepcopy

import trimesh
import numpy as np
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout

import robosuite.mcts.transforms as tf

from matplotlib import pyplot as plt

%matplotlib inline

# Load Mesh Models

In [None]:
def get_meshes(_mesh_types=None, _mesh_files=None, _mesh_units=None, _area_ths=0.003, _tbl_area_ths=0.003, _rotation_types=4):

    if _mesh_types is None:
        _mesh_types = ['arch_box',
                       'rect_box',
                       'square_box',
                       'half_cylinder_box',
                       'triangle_box',
                       'twin_tower_goal',
                       'tower_goal',
                       'box_goal',
                       'custom_table']
    if _mesh_files is None:
        _mesh_files = ['./robosuite/models/assets/objects/meshes/arch_box.stl',
                       './robosuite/models/assets/objects/meshes/rect_box.stl',
                       './robosuite/models/assets/objects/meshes/square_box.stl',
                       './robosuite/models/assets/objects/meshes/half_cylinder_box.stl',
                       './robosuite/models/assets/objects/meshes/triangle_box.stl',
                       './robosuite/models/assets/objects/meshes/twin_tower_goal.stl',
                       './robosuite/models/assets/objects/meshes/tower_goal.stl',
                       './robosuite/models/assets/objects/meshes/box_goal.stl',
                       './robosuite/models/assets/objects/meshes/custom_table.stl']
    if _mesh_units is None:
        _mesh_units = [0.001, 0.001, 0.001, 0.001, 0.001, 0.0011, 0.001, 0.0011, 0.01]

    _meshes = []
    _contact_points = []
    _contact_faces = []

    for mesh_type, mesh_file, unit in zip(_mesh_types, _mesh_files, _mesh_units):
        mesh = trimesh.load(mesh_file)

        mesh.apply_scale(unit)
        mesh.apply_translation(-mesh.center_mass)
        while True:
            indices, = np.where(mesh.area_faces > _area_ths)
            if len(indices) > 0:
                mesh = mesh.subdivide(indices)
            else:
                break

        if 'table' in mesh_type:
            while True:
                indices, = np.where(mesh.area_faces > _tbl_area_ths)
                if len(indices) > 0:
                    mesh = mesh.subdivide(indices)
                else:
                    break

            points = np.mean(mesh.vertices[mesh.faces], axis=1)
            # find top surfaces of the mesh
            faces, = np.where(points[:, 2] == np.max(points, axis=0)[2])
            points = points[faces]
        elif 'goal' in mesh_type:
            points = np.mean(mesh.vertices[mesh.faces], axis=1)
            # find bottom surfaces of the mesh
            faces, = np.where(
                np.logical_and(mesh.face_normals[:, 2] < - 0.99, points[:, 2] == np.min(points, axis=0)[2]))
            points = points[faces]
        else:
            faces = np.arange(mesh.faces.shape[0])
            points = np.mean(mesh.vertices[mesh.faces], axis=1)

        _meshes.append(mesh)
        _contact_faces.append(faces)
        _contact_points.append(points)

    gripper_mesh_files = ['/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.stl']

    gripper_mesh_relative_poses = [{'position': [0, 0.01 + 0.010833, 0.0444 - 0.12255859], 'quat': [0, 0, -1, 0]},
                                   {'position': [0, 0.01 + 0.01725 + 0.010833, 0.0444 + 0.075 - 0.12255859], 'quat': [0, 0, -1, 0]},
                                   {'position': [0, -0.01 - 0.010833, 0.0444 - 0.12255859], 'quat': [0, 0, 0, 1]},
                                   {'position': [0, -0.01 - 0.01725 - 0.010833, 0.0444 + 0.075 - 0.12255859], 'quat': [0, 0, 0, 1]},
                                   {'position': [0, 0, 0.0018 - 0.12255859], 'quat': [0, 0, 0, 1]},
                                   {'position': [0, 0, 0.0194 - 0.12255859], 'quat': [0, 0, 0, 1]}]

    _gripper_mesh = None
    for mesh_file, rel_pose in zip(gripper_mesh_files, gripper_mesh_relative_poses):
        _gripper_component_mesh = trimesh.load(mesh_file)
        pose = tf.quaternion_matrix(rel_pose['quat'])
        pose[:3, 3] = rel_pose['position']
        _gripper_component_mesh.apply_transform(pose)
        if _gripper_mesh is None:
            _gripper_mesh = _gripper_component_mesh
        else:
            _gripper_mesh = trimesh.util.concatenate(_gripper_mesh, _gripper_component_mesh)

    _mesh_types.append('left_gripper')
    _meshes.append(_gripper_mesh)

    return _mesh_types, _mesh_files, _mesh_units, _meshes, _rotation_types, _contact_faces, _contact_points


class Object(object):
    def __init__(self, _name, _mesh_idx, _pose, _logical_state):
        self.name = _name
        self.mesh_idx = _mesh_idx
        self.pose = _pose
        self.logical_state = _logical_state
        self.color = [np.random.uniform(), np.random.uniform(), np.random.uniform(), 0.3]


def get_obj_idx_by_name(_object_list, _name):
    for _obj_idx, _obj in enumerate(_object_list):
        if _obj.name == _name:
            return _obj_idx
    return None


def get_held_object(_object_list):
    for _obj_idx, _obj in enumerate(_object_list):
        if "held" in _obj.logical_state: return _obj_idx
    return None


def update_logical_state(_object_list):
    for _obj in _object_list:
        if "on" in _obj.logical_state:
            for _support_obj_name in _obj.logical_state["on"]:
                _support_obj_idx = get_obj_idx_by_name(_object_list, _support_obj_name)
                if "support" in _object_list[_support_obj_idx].logical_state:
                    if _obj.name in _object_list[_support_obj_idx].logical_state["support"]:
                        continue
                    else:
                        _object_list[_support_obj_idx].logical_state["support"].append(_obj.name)
                else:
                    _object_list[_support_obj_idx].logical_state["support"] = [_obj.name]

        if "support" in _obj.logical_state:
            for _on_obj_name in _obj.logical_state["support"]:
                _on_obj_idx = get_obj_idx_by_name(_object_list, _on_obj_name)
                if "on" in _object_list[_on_obj_idx].logical_state:
                    if _obj.name in _object_list[_on_obj_idx].logical_state["on"]:
                        continue
                    else:
                        _object_list[_on_obj_idx].logical_state["on"].append(_obj.name)
                else:
                    _object_list[_on_obj_idx].logical_state["on"] = [_obj.name]


def rotation_matrix_from_z_x(z_axis):
    x_axis = [1, 0, 0]
    y_axis = np.cross(z_axis, x_axis)
    if np.sqrt(np.sum(y_axis ** 2)) > 0.:
        y_axis = y_axis / np.sqrt(np.sum(y_axis ** 2))
        x_axis = np.cross(y_axis, z_axis)
        x_axis = x_axis / np.sqrt(np.sum(x_axis ** 2))
    else:
        x_axis = [0, 1, 0]
        y_axis = np.cross(z_axis, x_axis)
        if np.sqrt(np.sum(y_axis ** 2)) > 0.:
            y_axis = y_axis / np.sqrt(np.sum(y_axis ** 2))
            x_axis = np.cross(y_axis, z_axis)
            x_axis = x_axis / np.sqrt(np.sum(x_axis ** 2))
        else:
            x_axis = [0, 0, 1]
            y_axis = np.cross(z_axis, x_axis)
            y_axis = y_axis / np.sqrt(np.sum(y_axis ** 2))
            x_axis = np.cross(y_axis, z_axis)
            x_axis = x_axis / np.sqrt(np.sum(x_axis ** 2))

    T = np.eye(4)
    T[:3, 0] = x_axis
    T[:3, 1] = y_axis
    T[:3, 2] = z_axis

    return T


def rotation_matrix_from_y_x(y_axis):
    x_axis = [1, 0, 0]
    z_axis = np.cross(x_axis, y_axis)
    if np.sqrt(np.sum(z_axis ** 2)) > 0.:
        z_axis = z_axis / np.sqrt(np.sum(z_axis ** 2))
        x_axis = np.cross(y_axis, z_axis)
        x_axis = x_axis / np.sqrt(np.sum(x_axis ** 2))
    else:
        x_axis = [0, 1, 0]
        z_axis = np.cross(x_axis, y_axis)
        if np.sqrt(np.sum(z_axis ** 2)) > 0.:
            z_axis = z_axis / np.sqrt(np.sum(z_axis ** 2))
            x_axis = np.cross(y_axis, z_axis)
            x_axis = x_axis / np.sqrt(np.sum(x_axis ** 2))
        else:
            x_axis = [0, 0, 1]
            z_axis = np.cross(x_axis, y_axis)
            z_axis = z_axis / np.sqrt(np.sum(z_axis ** 2))
            x_axis = np.cross(y_axis, z_axis)
            x_axis = x_axis / np.sqrt(np.sum(x_axis ** 2))

    rot_mtx = np.eye(4)
    rot_mtx[:3, 0] = x_axis
    rot_mtx[:3, 1] = y_axis
    rot_mtx[:3, 2] = z_axis

    return rot_mtx


def transform_matrix2pose(pose_mtx):
    q = tf.quaternion_from_matrix(pose_mtx)
    orientation = Quaternion(x=q[0], y=q[1], z=q[2], w=q[3])
    position = Point(x=pose_mtx[0, 3], y=pose_mtx[1, 3], z=pose_mtx[2, 3])
    pose = Pose(position, orientation)

    return pose


def pose2transform_matrix(pose):
    orientation = pose.orientation
    position = pose.position

    q = [orientation.x, orientation.y, orientation.z, orientation.w]
    pose_mtx = np.eye(4)
    pose_mtx[:3, :3] = tf.quaternion_matrix(q)[:3,:3]
    pose_mtx[0, 3] = position.x
    pose_mtx[1, 3] = position.y
    pose_mtx[2, 3] = position.z

    return pose_mtx


def get_grasp_pose(_mesh_idx1, _pnt1, _normal1, _rotation1, _pose1, _meshes, _coll_mngr):
    mesh1 = _meshes[_mesh_idx1]
    locations, index_ray, index_tri = mesh1.ray.intersects_location(
        ray_origins=[_pnt1 - 1e-5 * _normal1],
        ray_directions=[-_normal1])

    _pnt2 = locations[0]
    _pnt_center = (_pnt1 + _pnt2) / 2.

    y_axis = _pnt2 - _pnt1
    gripper_width = np.sqrt(np.sum(y_axis ** 2))
    y_axis = y_axis / np.sqrt(np.sum(y_axis ** 2))

    t_grasp = rotation_matrix_from_y_x(y_axis).dot(tf.rotation_matrix(_rotation1, [0, 1, 0]))
    t_grasp[:3, 3] = _pnt_center

    t_retreat = deepcopy(t_grasp)
    approaching_dir = t_retreat[:3, 2]
    t_retreat[:3, 3] = _pnt_center - 15e-2 * approaching_dir
    hand_t_grasp = _pose1.dot(t_grasp)
    hand_t_retreat = _pose1.dot(t_retreat)
    _coll_mngr.set_transform('left_gripper', hand_t_retreat)
    if not _coll_mngr.in_collision_internal():
        _coll_mngr.set_transform('left_gripper', hand_t_grasp)
        if not _coll_mngr.in_collision_internal():
            return hand_t_grasp, hand_t_retreat, gripper_width

    return None, None, None


def get_on_pose(_name1, _pnt1, _normal1, _rotation1, _pnt2, _normal2, _pose2, _coll_mngr, _rel_gripper=None):
    target_pnt = _pnt1 + 1e-6 * _normal1
    T_target = rotation_matrix_from_z_x(-_normal1).dot(tf.rotation_matrix(_rotation1, [0, 0, 1]))
    T_target[:3, 3] = target_pnt

    T_source = rotation_matrix_from_z_x(_normal2)
    T_source[:3, 3] = _pnt2

    _T1 = _pose2.dot(T_source.dot(np.linalg.inv(T_target)))
    _coll_mngr.set_transform(_name1, _T1)
    if _rel_gripper is not None:
        _new_gripper = _T1.dot(_rel_gripper)
        _coll_mngr.set_transform('left_gripper', _new_gripper)

    if not _coll_mngr.in_collision_internal():
        return _T1
    else:
        return None


def configuration_initializer(_mesh_types, _meshes, _mesh_units, _rotation_types, _contact_faces, _contact_points,
                              goal_name='tower_goal'):
    _object_list = []
    _coll_mngr = trimesh.collision.CollisionManager()

    if 'tower_goal' is goal_name:
        n_obj_per_mesh_types = [0, 1, 1, 1, 0, 0, 0, 0, 0, 0]
    elif 'twin_tower_goal' is goal_name:
        n_obj_per_mesh_types = [0, 2, 2, 0, 2, 0, 0, 0, 0, 0]
    elif 'box_goal' is goal_name:
        n_obj_per_mesh_types = [0, 2, 2, 0, 0, 0, 0, 0, 0, 0]
    elif 'stack_easy' is goal_name:
        n_obj_per_mesh_types = [0, 2, 2, 0, 0, 0, 0, 0, 0, 0]
    elif 'stack_hard' is goal_name:
        n_obj_per_mesh_types = [0, 1, 1, 1, 1, 0, 0, 0, 0, 0]
    elif 'regular_shapes' is goal_name:
        n_obj_per_mesh_types = [0, 2, 2, 0, 0, 0, 0, 0, 0, 0]
    elif 'round_shapes' is goal_name:
        n_obj_per_mesh_types = [2, 0, 0, 2, 2, 0, 0, 0, 0, 0]
    elif 'debug_config' is goal_name:
        n_obj_per_mesh_types = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0]
    else:
        assert Exception('goal name is wrong!!!')

    if 'goal' in goal_name:
        table_spawn_position = [0.6, 0.2]
        table_spawn_bnd_size = 0.1
    else:
        table_spawn_position = [0.6, 0.2]
        table_spawn_bnd_size = 0.13

    table_name = 'custom_table'
    table_pose = np.array([0.6, 0.0, 0.573077])
    table_mesh_idx = _mesh_types.index(table_name)
    table_T = np.eye(4)
    table_T[:3, 3] = table_pose
    table_obj = Object(table_name, table_mesh_idx, table_T, {"static": []})

    _object_list.append(table_obj)
    _coll_mngr.add_object(table_name, _meshes[table_mesh_idx])
    _coll_mngr.set_transform(table_name, table_T)

    table_contact_points = _contact_points[table_mesh_idx]
    table_contact_normals = _meshes[table_mesh_idx].face_normals[_contact_faces[table_mesh_idx]]

    table_contact_points_world = table_T[:3, :3].dot(table_contact_points.T).T + table_T[:3, 3]
    table_contact_normals_world = table_T[:3, :3].dot(table_contact_normals.T).T
    table_contact_indices, = np.where(
        np.logical_and(
            np.logical_and(
                table_contact_normals_world[:, 2] > 0.,
                np.abs(table_spawn_position[0] - table_contact_points_world[:, 0]) < table_spawn_bnd_size
            ), np.abs(table_spawn_position[1] - table_contact_points_world[:, 1]) < table_spawn_bnd_size
        )
    )

    _contact_points[table_mesh_idx] = _contact_points[table_mesh_idx][table_contact_indices]
    _contact_faces[table_mesh_idx] = _contact_faces[table_mesh_idx][table_contact_indices]

    if 'goal' in goal_name:
        goal_mesh_idx = _mesh_types.index(goal_name)
        goal_pose = np.array([0.4, 0.0, -0.019*_mesh_units[goal_mesh_idx]*1000 + _meshes[table_mesh_idx].bounds[1][2] - _meshes[table_mesh_idx].bounds[0][2]
                              - _meshes[goal_mesh_idx].bounds[0][2]])
        goal_T = np.eye(4)
        goal_T[:3, 3] = goal_pose
        _goal_obj = Object(goal_name, goal_mesh_idx, goal_T, {"goal": []})
    else:
        _goal_obj = None

    for mesh_idx, n_obj in enumerate(n_obj_per_mesh_types):
        obj_mesh_idx = mesh_idx
        for obj_idx in range(n_obj):
            obj_name = _mesh_types[mesh_idx] + str(obj_idx)

            # Find stable pose
            if "custom_table" not in _mesh_types[mesh_idx]:
                stable_poses, probs = _meshes[mesh_idx].compute_stable_poses(n_samples=100)
                stable_pose_idx = np.argmax(probs)
                stable_pose = stable_poses[stable_pose_idx]

                if "half_cylinder_box" in _mesh_types[mesh_idx] or "triangle_box" in _mesh_types[mesh_idx]:
                    stable_pose = tf.rotation_matrix(-np.pi / 2., [1., 0., 0.]).dot(stable_pose)
                elif "arch_box" in _mesh_types[mesh_idx]:
                    stable_pose = tf.rotation_matrix(np.pi / 2., [1., 0., 0.]).dot(stable_pose)

                stable_pose = stable_pose.dot(tf.rotation_matrix(np.pi / 2., [0., 1., 0.]))
            else:
                stable_pose = np.eye(4)

            new_obj = Object(obj_name, obj_mesh_idx, stable_pose, {"on": [table_name]})
            _coll_mngr.add_object(obj_name, _meshes[obj_mesh_idx])

            new_obj_contact_points = _contact_points[obj_mesh_idx]
            new_obj_contact_normals = _meshes[obj_mesh_idx].face_normals[_contact_faces[obj_mesh_idx]]

            new_obj_contact_normals_world = stable_pose[:3, :3].dot(new_obj_contact_normals.T).T
            new_obj_contact_indices, = np.where(new_obj_contact_normals_world[:, 2] < -0.99)

            table_contact_points = _contact_points[table_mesh_idx]
            table_contact_normals = _meshes[table_mesh_idx].face_normals[_contact_faces[table_mesh_idx]]

            table_contact_normals_world = table_T[:3, :3].dot(table_contact_normals.T).T
            table_contact_indices, = np.where(table_contact_normals_world[:, 2] > 0.)

            pose1_list = []
            for i in new_obj_contact_indices:
                for j in range(_rotation_types):
                    for k in table_contact_indices:
                        pnt1 = new_obj_contact_points[i]
                        normal1 = new_obj_contact_normals[i]

                        pnt2 = table_contact_points[k]
                        normal2 = table_contact_normals[k]

                        pose1 = get_on_pose(new_obj.name, pnt1, normal1, 2.*np.pi*j/_rotation_types, pnt2, normal2,
                                            table_T, _coll_mngr)
                        if pose1 is not None:
                            min_dist1 = np.min([np.sqrt(np.sum(np.square(tmp_obj.pose[:3, 3] - pose1[:3, 3]))) for tmp_obj in _object_list if tmp_obj.name not in new_obj.name])
                            if min_dist1 > 0.12:
                                pose1_list.append(pose1)
            # print(new_obj.name, len(pose1_list))
            if len(pose1_list) > 0:
                init_pose_idx = np.random.choice(len(pose1_list), 1)[0]
                new_obj.pose = pose1_list[init_pose_idx]
                _coll_mngr.set_transform(new_obj.name, new_obj.pose)
            _object_list.append(new_obj)
            update_logical_state(_object_list)

    if _goal_obj is not None:
        transform_g_t = np.linalg.inv(table_obj.pose).dot(_goal_obj.pose)
        _contact_points[table_mesh_idx] = transform_g_t[:3, :3].dot(_contact_points[goal_mesh_idx].T).T \
                                         + transform_g_t[:3, 3]
        _contact_faces[table_mesh_idx] = _contact_faces[table_mesh_idx][:len(_contact_faces[goal_mesh_idx])]
        # this is the most tricky part...

    gripper_mesh_idx = _mesh_types.index('left_gripper')

    _coll_mngr.add_object('left_gripper', _meshes[gripper_mesh_idx])
    gripper_pose = deepcopy(table_T)
    gripper_pose[2,3] += 0.5
    _coll_mngr.set_transform('left_gripper', gripper_pose)

    _gripper_obj = Object('left_gripper', gripper_mesh_idx, gripper_pose, {'left_gripper': []})
    _object_list.append(_gripper_obj)
    return _object_list, _goal_obj, _contact_points, _contact_faces, _coll_mngr, table_spawn_position, \
           table_spawn_bnd_size, n_obj_per_mesh_types


def visualize(_object_list, _meshes, _goal_obj=None):
    mesh_scene = trimesh.Scene()
    for _obj in _object_list:
        mesh_scene.add_geometry(_meshes[_obj.mesh_idx], node_name=_obj.name, transform=_obj.pose)
    if _goal_obj is not None:
        mesh_scene.add_geometry(_meshes[_goal_obj.mesh_idx], node_name=_goal_obj.name, transform=_goal_obj.pose)
    mesh_scene.show(viewer='gl')

# Initialization of models

In [None]:
np.random.seed(1)
# task_name = 'tower_goal'
task_name = 'stack_hard'
mesh_types, mesh_files, mesh_units, meshes, rotation_types, contact_faces, contact_points = get_meshes(_area_ths=1.)
initial_object_list, goal_obj, contact_points, contact_faces, coll_mngr, _, _, n_obj_per_mesh_types = \
    configuration_initializer(mesh_types, meshes, mesh_units, rotation_types, contact_faces, contact_points,
                              goal_name=task_name)

visualize(initial_object_list,meshes,goal_obj)

# Module Tests

## Sample place pose

In [None]:
def get_possible_actions(_object_list, _meshes, _coll_mngr, _contact_points, _contact_faces, _rotation_types, side_place_flag=False, goal_obj=None):
    _action_list = []
    for obj in _object_list:
        _coll_mngr.set_transform(obj.name, obj.pose)

    # Check pick
    if all(["held" not in obj.logical_state for obj in _object_list]):
        for obj1 in _object_list:
            if "support" not in obj1.logical_state and "static" not in obj1.logical_state and \
                    "done" not in obj1.logical_state and "goal" not in obj1.name and 'gripper' not in obj1.name:

                obj1_mesh_idx = obj1.mesh_idx
                obj1_pose = obj1.pose
                obj1_contact_points = _contact_points[obj1_mesh_idx]
                obj1_contact_normals = _meshes[obj1_mesh_idx].face_normals[_contact_faces[obj1_mesh_idx]]

                obj1_contact_normals_world = obj1.pose[:3, :3].dot(obj1_contact_normals.T).T
                obj1_contact_indices, = np.where(np.abs(obj1_contact_normals_world[:, 2]) < 1e-10)

                grasp_poses = []
                retreat_poses = []
                gripper_widths = []

                for i in obj1_contact_indices:
                    for j in range(_rotation_types):
                        pnt1 = obj1_contact_points[i]
                        normal1 = obj1_contact_normals[i]
                        hand_t_grasp, hand_t_retreat, gripper_width = get_grasp_pose(obj1_mesh_idx, pnt1, normal1, 2. * np.pi * j / _rotation_types, obj1_pose, _meshes, _coll_mngr)
                        if hand_t_grasp is not None:
                            grasp_poses.append(hand_t_grasp)
                            retreat_poses.append(hand_t_retreat)
                            gripper_widths.append(gripper_width)
                
                if len(grasp_poses) > 0:
                    _action_list.append({"type": "pick", "param": obj1.name, "grasp_poses": grasp_poses,
                                         "retreat_poses": retreat_poses, "gripper_widths": gripper_widths})
                            
    # Check place
    held_obj_idx = get_held_object(_object_list)
    if held_obj_idx is not None:
        held_obj = _object_list[held_obj_idx]
        obj2 = held_obj
        obj2_mesh_idx = held_obj.mesh_idx
        obj2_contact_points = _contact_points[obj2_mesh_idx]
        obj2_contact_normals = _meshes[obj2_mesh_idx].face_normals[_contact_faces[obj2_mesh_idx]]

        obj2_contact_normals_world = obj2.pose[:3, :3].dot(obj2_contact_normals.T).T
        if side_place_flag and 'rect_box' in obj2.name:
            obj2_contact_indices, = np.where(obj2_contact_normals_world[:, 2] < 1e-3)
        else:
            obj2_contact_indices, = np.where(obj2_contact_normals_world[:, 2] < 0.)

        for obj1 in _object_list:
            if "held" not in obj1.logical_state and 'gripper' not in obj1.name:
                # print(obj1.name)
                obj1_mesh_idx = obj1.mesh_idx
                obj1_contact_points = _contact_points[obj1_mesh_idx]
                obj1_contact_normals = _meshes[obj1_mesh_idx].face_normals[_contact_faces[obj1_mesh_idx]]

                obj1_contact_normals_world = obj1.pose[:3, :3].dot(obj1_contact_normals.T).T
                obj1_contact_indices, = np.where(obj1_contact_normals_world[:, 2] > 0.99)

                _rel_gripper = np.linalg.inv(held_obj.pose).dot(_object_list[-1].pose)
                
                placing_poses = []                
                for i in obj1_contact_indices:
                    for j in range(_rotation_types):
                        for k in obj2_contact_indices:
                            pnt1 = obj1_contact_points[i]
                            normal1 = obj1_contact_normals[i]

                            pnt2 = obj2_contact_points[k]
                            normal2 = obj2_contact_normals[k]

                            pose = get_on_pose(obj2.name, pnt2, normal2, 2. * np.pi * j / _rotation_types,
                                               pnt1, normal1, obj1.pose, _coll_mngr, _rel_gripper=_rel_gripper)

                            if pose is not None:
                                placing_poses.append(pose)
                if len(placing_poses) > 0:
                    _action_list.append({"type": "place", "param": obj1.name, "placing_poses": placing_poses})
    return _action_list

def get_possible_transitions(_object_list, _action, _physical_checker=None):
    if _action["type"] is "pick":
        pick_obj_idx = get_obj_idx_by_name(_object_list, _action['param'])
        gripper_obj_idx = get_obj_idx_by_name(_object_list, 'left_gripper')
        new_hand_poses = _action["grasp_poses"]
        _set_of_next_object_list = []
        for new_hand_pose in new_hand_poses:
            _next_object_list = deepcopy(_object_list)
            _next_object_list[gripper_obj_idx].pose = new_hand_pose
            support_obj_idx = get_obj_idx_by_name(_next_object_list, _next_object_list[pick_obj_idx].logical_state["on"][0])
            _next_object_list[support_obj_idx].logical_state["support"].remove(_next_object_list[pick_obj_idx].name)

            _next_object_list[pick_obj_idx].logical_state.clear()
            _next_object_list[pick_obj_idx].logical_state["held"] = []
            update_logical_state(_next_object_list)
            
            _set_of_next_object_list.append(_next_object_list)

    elif _action["type"] is "place":
        place_obj_idx = get_obj_idx_by_name(_object_list, _action['param'])
        held_obj_idx = get_held_object(_object_list)
        gripper_obj_idx = get_obj_idx_by_name(_object_list, 'left_gripper')
        new_poses = _action["placing_poses"]
        _set_of_next_object_list = []
        for new_pose in new_poses:
            _next_object_list = deepcopy(_object_list)
            rel_gripper = np.linalg.inv(_next_object_list[held_obj_idx].pose).dot(_next_object_list[gripper_obj_idx].pose)
            _next_object_list[gripper_obj_idx].pose = new_pose.dot(rel_gripper)
            _next_object_list[held_obj_idx].pose = new_pose
            _next_object_list[held_obj_idx].logical_state.clear()
            _next_object_list[held_obj_idx].logical_state["on"] = [_next_object_list[place_obj_idx].name]
            _next_object_list[held_obj_idx].logical_state["done"] = []
            update_logical_state(_next_object_list)

            if _physical_checker is not None:
                if _physical_checker(_object_list, _action, _next_object_list):
                    _set_of_next_object_list.append(_next_object_list)
            else:
                _set_of_next_object_list.append(_next_object_list)
    return _set_of_next_object_list

def get_reward(_obj_list, _action, _goal_obj, _next_obj_list, _meshes):
    if _next_obj_list is None:
        return -np.inf
    elif "pick" in _action["type"]:
        return 0.0
    elif "place" in _action["type"]:
        if _goal_obj is None:
            obj_height_list = []
            for obj in _obj_list:
                if 'gripper' not in obj.name and "held" not in obj.logical_state:
                    obj_height_list.append(obj.pose[2, 3])
            curr_height = np.max(obj_height_list)

            next_obj_height_list = []
            for next_obj in _next_obj_list:
                if 'gripper' not in next_obj.name:
                    next_obj_height_list.append(next_obj.pose[2, 3])
            next_height = np.max(next_obj_height_list)
            return next_height - curr_height
        else:
            goal_mesh_copied = deepcopy(_meshes[_goal_obj.mesh_idx])
            goal_mesh_copied.apply_transform(_goal_obj.pose)

            rew = 0.
            for obj in _obj_list:
                if "table" not in obj.name and 'gripper' not in obj.name:
                    obj_mesh_copied = deepcopy(_meshes[obj.mesh_idx])
                    obj_mesh_copied.apply_transform(obj.pose)

                    signed_distance = trimesh.proximity.signed_distance(goal_mesh_copied, obj_mesh_copied.vertices)
                    rew -= 1 / (1 + 1e3 * np.max(np.abs(signed_distance)))
                    
            for obj in _next_obj_list:
                if "table" not in obj.name and 'gripper' not in obj.name:
                    obj_mesh_copied = deepcopy(_meshes[obj.mesh_idx])
                    obj_mesh_copied.apply_transform(obj.pose)

                    signed_distance = trimesh.proximity.signed_distance(goal_mesh_copied, obj_mesh_copied.vertices)
                    rew += 1 / (1 + 1e3 * np.max(np.abs(signed_distance)))

            return rew
        
        
def check_stability(obj_idx, object_list, meshes, com1, volume1, margin=0.015):
    if "on" in object_list[obj_idx].logical_state:
        obj_name = object_list[obj_idx].logical_state["on"][0]
        support_obj_idx = get_obj_idx_by_name(object_list, obj_name)
        mesh2 = deepcopy(meshes[object_list[support_obj_idx].mesh_idx])
        mesh2.apply_transform(object_list[support_obj_idx].pose)

        closest_points, dists, surface_idx = trimesh.proximity.closest_point(mesh2, [com1])
        project_point2 = closest_points[0]
        safe_com = mesh2.face_normals[surface_idx[0]][2] > 0.999 and (project_point2[0] > mesh2.bounds[0][0] + margin) and (project_point2[0] < mesh2.bounds[1][0] - margin) and (project_point2[1] > mesh2.bounds[0][1] + margin) and (project_point2[1] < mesh2.bounds[1][1] - margin)

        if safe_com:
            com1 = (com1*volume1 + mesh2.center_mass*mesh2.volume)/(volume1+mesh2.volume)
            volume1 += mesh2.volume
            return check_stability(support_obj_idx, object_list, meshes, com1, volume1)
        else:
            return False
    else:
        return True


def geometry_based_physical_checker(obj_list, action, next_obj_list, meshes, network=None):
    if action['type'] is 'pick':
        return True
    if action['type'] is 'place':
        pick_idx = get_held_object(obj_list)
        mesh1 = deepcopy(meshes[next_obj_list[pick_idx].mesh_idx])
        mesh1.apply_transform(next_obj_list[pick_idx].pose)
        com1 = deepcopy(mesh1.center_mass)
        volume1 = deepcopy(mesh1.volume)
        flag_ = check_stability(pick_idx, next_obj_list, meshes, com1, volume1)
        return flag_

In [None]:
def check_stability(obj_idx, object_list, meshes, com1, volume1, margin=0.005):
    if "on" in object_list[obj_idx].logical_state:
        obj_name = object_list[obj_idx].logical_state["on"][0]
        support_obj_idx = get_obj_idx_by_name(object_list, obj_name)
        mesh2 = deepcopy(meshes[object_list[support_obj_idx].mesh_idx])
        mesh2.apply_transform(object_list[support_obj_idx].pose)

        closest_points, dists, surface_idx = trimesh.proximity.closest_point(mesh2, [com1])
        project_point2 = closest_points[0]
        safe_com = mesh2.face_normals[surface_idx[0]][2] > 0.999 and (project_point2[0] > mesh2.bounds[0][0] + margin) and (project_point2[0] < mesh2.bounds[1][0] - margin) and (project_point2[1] > mesh2.bounds[0][1] + margin) and (project_point2[1] < mesh2.bounds[1][1] - margin)

        if safe_com:
            com1 = (com1*volume1 + mesh2.center_mass*mesh2.volume)/(volume1+mesh2.volume)
            volume1 += mesh2.volume
            return check_stability(support_obj_idx, object_list, meshes, com1, volume1)
        else:
            return False
    else:
        return True


def geometry_based_physical_checker(obj_list, action, next_obj_list, meshes, network=None):
    if action['type'] is 'pick':
        return True
    if action['type'] is 'place':
        pick_idx = get_held_object(obj_list)
        mesh1 = deepcopy(meshes[next_obj_list[pick_idx].mesh_idx])
        mesh1.apply_transform(next_obj_list[pick_idx].pose)
        com1 = deepcopy(mesh1.center_mass)
        volume1 = deepcopy(mesh1.volume)
        flag_ = check_stability(pick_idx, next_obj_list, meshes, com1, volume1)
        return flag_


def sampler(_exploration_method, _action_values, _visits, _depth, _indices=None, eps=0.):
    if _indices is not None:
        selected_action_values = [_action_values[_index] for _index in _indices]
        selected_visits = [_visits[_index] for _index in _indices]
    else:
        selected_action_values = deepcopy(_action_values)
        selected_visits = _visits

    selected_action_values = np.asarray(selected_action_values)
    selected_action_values[np.isinf(selected_action_values)] = 0.

    if eps > np.random.uniform() or _exploration_method['method'] is 'random':
        selected_idx = np.random.choice(len(selected_action_values), size=1)[0]
    elif _exploration_method['method'] is 'ucb':
        c = _exploration_method['param'] / np.maximum(_depth, 1)
        upper_confidence_bounds = selected_action_values + c * np.sqrt(1. / np.maximum(1., selected_visits))
        selected_idx = np.argmax(upper_confidence_bounds)
    elif _exploration_method['method'] is 'bai_ucb':
        if len(selected_visits) == 1:
            selected_idx = 0
        else:
            c = _exploration_method['param'] / np.maximum(_depth, 1)
            upper_bounds = selected_action_values + c * np.sqrt(1. / np.maximum(1., selected_visits))
            lower_bounds = selected_action_values - c * np.sqrt(1. / np.maximum(1., selected_visits))
            B_k = [np.max([upper_bounds[i] - lower_bounds[k] for i in range(len(selected_action_values)) if i is not k]) for k in range(len(selected_action_values))]
            b = np.argmin(B_k)
            u = np.argmax(upper_bounds)
            if selected_visits[b] > selected_visits[u]:
                selected_idx = u
            else:
                selected_idx = b
    elif _exploration_method['method'] is 'bai_perturb':
        if len(selected_visits) == 1:
            selected_idx = 0
        else:
            c = _exploration_method['param'] / np.maximum(_depth, 1)
            g = np.random.normal(size=(len(selected_visits)))
            upper_bounds = selected_action_values + c * np.sqrt(1. / np.maximum(1., selected_visits)) * g
            lower_bounds = selected_action_values - c * np.sqrt(1. / np.maximum(1., selected_visits)) * g
            B_k = [np.max([upper_bounds[i] - lower_bounds[k] for i in range(len(selected_action_values)) if i is not k]) for k in range(len(selected_action_values))]
            b = np.argmin(B_k)
            u = np.argmax(upper_bounds)
            if selected_visits[b] > selected_visits[u]:
                selected_idx = u
            else:
                selected_idx = b
    elif _exploration_method['method'] is 'greedy':
        selected_idx = np.argmax(selected_action_values)
    if _indices is not None:
        return _indices[selected_idx]
    else:
        return selected_idx


class Tree(object):
    def __init__(self,
                 _init_obj_list,
                 _max_depth,
                 _coll_mngr,
                 _meshes,
                 _contact_points,
                 _contact_faces,
                 _rotation_types,
                 _min_visit=1,
                 _goal_obj=None,
                 _physcial_constraint_checker=geometry_based_physical_checker,
                 _exploration=None):


        self.Tree = nx.DiGraph()
        self.max_depth = _max_depth
        self.min_visit = _min_visit
        self.Tree.add_node(0)
        self.Tree.update(nodes=[(0, {'depth': 0,
                                     'state': _init_obj_list,
                                     'reward': 0,
                                     'value': -np.inf,
                                     'visit': 0})])

        self.coll_mngr = _coll_mngr
        self.meshes = _meshes
        self.contact_points = _contact_points
        self.contact_faces = _contact_faces
        self.rotation_types = _rotation_types

        self.goal_obj = _goal_obj
        if _goal_obj is None:
            self.side_place_flag = True
        else:
            self.side_place_flag = False
        self.physcial_constraint_checker = _physcial_constraint_checker
        self.network = None

        if _exploration is None:
            _exploration = {'method': 'random'}
        self.exploration_method = _exploration

    def exploration(self, state_node):
        depth = self.Tree.nodes[state_node]['depth']
        visit = self.Tree.nodes[state_node]['visit']
        self.Tree.update(nodes=[(state_node, {'visit': visit + 1})])

        if depth < self.max_depth:
            obj_list = self.Tree.nodes[state_node]['state']
            action_nodes = [action_node for action_node in self.Tree.neighbors(state_node) if self.Tree.nodes[action_node]['reward'] == 0.]
            if obj_list is None:
                return 0.0
            elif len(action_nodes) == 0:
                action_list = get_possible_actions(obj_list, self.meshes, coll_mngr, contact_points, contact_faces, rotation_types, 
                                                   side_place_flag=self.side_place_flag, 
                                                   goal_obj=self.goal_obj)
                if len(action_list) == 0:
                    return 0.0
                else:
                    for action in action_list:
                        set_of_next_obj_list = get_possible_transitions(obj_list, action, _physical_checker=lambda x,y,z:self.physcial_constraint_checker(x,y,z,self.meshes,network=self.network))
                        if len(set_of_next_obj_list) > 0:
                            reward_list = []
                            for next_obj_list in set_of_next_obj_list:
                                reward = get_reward(obj_list, action, self.goal_obj, next_obj_list, self.meshes)
                                reward_list.append(reward)

                            sort_indices = np.argsort(reward_list)
                            reward_list = [reward_list[i] for i in sort_indices]
                            set_of_next_obj_list = [set_of_next_obj_list[i] for i in sort_indices]

                            child_action_node = self.Tree.number_of_nodes()
                            self.Tree.add_node(child_action_node)
                            self.Tree.update(nodes=[(child_action_node,
                                                     {'depth': depth,
                                                      'state': obj_list,
                                                      'action': action,
                                                      'reward': 0.,
                                                      'value': -np.inf,
                                                      'done': False,
                                                      'visit': 0,
                                                      'next_states': set_of_next_obj_list,
                                                      'next_rewards':reward_list})])
                            self.Tree.add_edge(state_node, child_action_node)
                    action_nodes = [action_node for action_node in self.Tree.neighbors(state_node)]

            if len(action_nodes) > 0:
                action_values = [self.Tree.nodes[action_node]['value'] for action_node in action_nodes]
                action_visits = [self.Tree.nodes[action_node]['visit'] for action_node in action_nodes]
                action_list = [self.Tree.nodes[action_node]['action'] for action_node in action_nodes]

                eps = np.maximum(np.minimum(1., 1 / np.maximum(visit, 1)), 0.01)
                if np.any(['place' in action['type'] for action in action_list]):
                    if self.goal_obj is not None:
                        table_place_indices = [action_idx for action_idx, action in enumerate(action_list) if
                                               'table' in action['param']]
                        if len(table_place_indices) > 0:
                            selected_idx = sampler(self.exploration_method, action_values, action_visits, depth, _indices=table_place_indices, eps=eps)
                        else:
                            selected_idx = sampler(self.exploration_method, action_values, action_visits, depth, eps=eps)
                    else:
                        non_table_place_indices = [action_idx for action_idx, action in enumerate(action_list) if
                                                   'table' not in action['param']]
                        if len(non_table_place_indices) > 0:
                            selected_idx = sampler(self.exploration_method, action_values, action_visits, depth, _indices=non_table_place_indices, eps=eps)
                        else:
                            selected_idx = sampler(self.exploration_method, action_values, action_visits, depth, eps=eps)
                else:
                    selected_idx = sampler(self.exploration_method, action_values, action_visits, depth, eps=eps)
                selected_action_node = action_nodes[selected_idx]
                selected_action_value = action_values[selected_idx]
                selected_action_value_new = self.action_exploration(selected_action_node)

                if selected_action_value < selected_action_value_new:
                    action_values[selected_idx] = selected_action_value_new
                    self.Tree.update(nodes=[(state_node, {'value': np.max(action_values)})])
                return np.max(action_values)
            else:
                return 0.0
        else:
            return 0.0

    def action_exploration(self, action_node):
        obj_list = self.Tree.nodes[action_node]['state']
        action = self.Tree.nodes[action_node]['action']
        action_value = self.Tree.nodes[action_node]['value']
        depth = self.Tree.nodes[action_node]['depth']
        visit = self.Tree.nodes[action_node]['visit']
        
        self.Tree.update(nodes=[(action_node, {'visit': visit + 1})])

        next_state_nodes = [node for node in self.Tree.neighbors(action_node)]
        next_states = self.Tree.nodes[action_node]['next_states']
        next_rewards = self.Tree.nodes[action_node]['next_rewards']
        if len(next_state_nodes) == 0:
            child_node = self.Tree.number_of_nodes()
            self.Tree.add_node(child_node)
            self.Tree.update(nodes=[(child_node,
                                     {'depth': depth + 1,
                                      'state': next_states[0],
                                      'reward': next_rewards[0],
                                      'value': -np.inf,
                                      'visit': 0})])
            self.Tree.add_edge(action_node, child_node)
                                
        next_state_nodes = [next_state_node for next_state_node in self.Tree.neighbors(action_node)]
        next_state_visits = [self.Tree.nodes[next_state_node]['visit'] for next_state_node in next_state_nodes]
        
        minimum_visits = 1
        if len(next_state_nodes) < len(next_states):
            if next_state_visits[-1] > minimum_visits:
                next_state_idx = len(next_state_nodes)
                child_node = self.Tree.number_of_nodes()
                self.Tree.add_node(child_node)
                self.Tree.update(nodes=[(child_node,
                                         {'depth': depth + 1,
                                          'state': next_states[next_state_idx],
                                          'reward': next_rewards[next_state_idx],
                                          'value': -np.inf,
                                          'visit': 0})])
                self.Tree.add_edge(action_node, child_node)
            next_state_nodes = [next_state_node for next_state_node in self.Tree.neighbors(action_node)]
            next_state_node = next_state_nodes[-1]
        else:
            eps = np.maximum(np.minimum(1., 1 / np.maximum(visit, 1)), 0.01)
            next_state_values =  [self.Tree.nodes[next_state_node]['value'] for next_state_node in next_state_nodes]
            selected_idx = sampler(self.exploration_method, next_state_values, next_state_visits, depth, eps=eps)
            next_state_node = next_state_nodes[selected_idx]
        
        reward = self.Tree.nodes[next_state_node]['reward']
        action_value_new = reward + self.exploration(next_state_node)
        if action_value < action_value_new:
            action_value = action_value_new
            self.Tree.update(nodes=[(action_node, {'value': action_value})])
        
        return action_value

    def exhaustive_search(self, state_node):
        depth = self.Tree.nodes[state_node]['depth']
        visit = self.Tree.nodes[state_node]['visit']
        self.Tree.update(nodes=[(state_node, {'visit': visit + 1})])

        if depth < self.max_depth:
            obj_list = self.Tree.nodes[state_node]['state']
            action_nodes = [action_node for action_node in self.Tree.neighbors(state_node) if self.Tree.nodes[action_node]['reward'] == 0.]
            if obj_list is None:
                return 0.0
            elif len(action_nodes) == 0:
                action_list = get_possible_actions(obj_list, self.meshes, coll_mngr, contact_points, contact_faces, rotation_types, 
                                                   side_place_flag=self.side_place_flag, 
                                                   goal_obj=self.goal_obj)
                if len(action_list) == 0:
                    return 0.0
                else:
                    for action in action_list:
                        set_of_next_obj_list = get_possible_transitions(obj_list, action, _physical_checker=lambda x,y,z:self.physcial_constraint_checker(x,y,z,self.meshes,network=self.network))
                        if len(set_of_next_obj_list) > 0:
                            reward_list = []
                            for next_obj_list in set_of_next_obj_list:
                                reward = get_reward(obj_list, action, self.goal_obj, next_obj_list, self.meshes)
                                reward_list.append(reward)

                            sort_indices = np.argsort(reward_list)
                            reward_list = [reward_list[i] for i in sort_indices]
                            set_of_next_obj_list = [set_of_next_obj_list[i] for i in sort_indices]

                            child_action_node = self.Tree.number_of_nodes()
                            self.Tree.add_node(child_action_node)
                            self.Tree.update(nodes=[(child_action_node,
                                                     {'depth': depth,
                                                      'state': obj_list,
                                                      'action': action,
                                                      'reward': 0.,
                                                      'value': -np.inf,
                                                      'done': False,
                                                      'visit': 0,
                                                      'next_states': set_of_next_obj_list,
                                                      'next_rewards':reward_list})])
                            self.Tree.add_edge(state_node, child_action_node)
                    action_nodes = [action_node for action_node in self.Tree.neighbors(state_node)]

            if len(action_nodes) > 0:
                for action_node in action_nodes:
                    self.action_exhaustive_search(action_node)

                action_values = [self.Tree.nodes[action_node]['value'] for action_node in action_nodes]
                self.Tree.update(nodes=[(state_node, {'value': np.max(action_values)})])
                return np.max(action_values)
            else:
                return 0.0
        else:
            return 0.0
    
    def action_exhaustive_search(self, action_node):
        obj_list = self.Tree.nodes[action_node]['state']
        action = self.Tree.nodes[action_node]['action']
        action_value = self.Tree.nodes[action_node]['value']
        depth = self.Tree.nodes[action_node]['depth']
        visit = self.Tree.nodes[action_node]['visit']
        
        self.Tree.update(nodes=[(action_node, {'visit': visit + 1})])

        next_state_nodes = [node for node in self.Tree.neighbors(action_node)]
        next_states = self.Tree.nodes[action_node]['next_states']
        next_rewards = self.Tree.nodes[action_node]['next_rewards']
        
        for next_state, next_reward in zip(next_states, next_rewards):
            child_node = self.Tree.number_of_nodes()
            self.Tree.add_node(child_node)
            self.Tree.update(nodes=[(child_node,
                                     {'depth': depth + 1,
                                      'state': next_state,
                                      'reward': next_reward,
                                      'value': -np.inf,
                                      'visit': 0})])
            self.Tree.add_edge(action_node, child_node)
                                
        next_state_nodes = [next_state_node for next_state_node in self.Tree.neighbors(action_node)]
        value_list = []
        for next_state_node in next_state_nodes:
            reward = self.Tree.nodes[next_state_node]['reward']
            value_list.append(reward + self.exhaustive_search(next_state_node))
            
        self.Tree.update(nodes=[(action_node, {'value': np.max(value_list)})])
        print(depth, " is finished")
        return np.max(value_list)
    
    def get_best_path(self, start_node=0):
        next_nodes = [next_node for next_node in self.Tree.neighbors(start_node)]
        if len(next_nodes) == 0:
            return [start_node]
        else:
            best_idx = np.argmax([self.Tree.nodes[next_node]['value'] for next_node in next_nodes])
            next_node = next_nodes[best_idx]
            return [start_node, ] + self.get_best_path(next_node)

    def visualize(self):
        visited_nodes = [n for n in self.Tree.nodes if self.Tree.nodes[n]['visit'] > 0]
        visited_tree = self.Tree.subgraph(visited_nodes)
        labels = {
            n: 'depth:{:d}\nvisit:{:d}\nreward:{:.4f}\nvalue:{:.4f}'.format(visited_tree.nodes[n]['depth'], visited_tree.nodes[n]['visit'], visited_tree.nodes[n]['reward'], visited_tree.nodes[n]['value'])
            for n in visited_tree.nodes}

        # nx.nx_agraph.write_dot(self.Tree, 'test.dot')

        # same layout using matplotlib with no labels
        plt.figure(figsize=(32, 64))
        plt.figure()
        plt.title('')
        # nx.nx_agraph.write_dot(self.Tree, 'test.dot')

        pos = graphviz_layout(visited_tree, prog='dot')
        nx.draw(visited_tree, pos, labels=labels, node_shape="s", node_color="none",
                bbox=dict(facecolor="skyblue", edgecolor='black', boxstyle='round,pad=0.2'))
        plt.show()

In [None]:
mcts = Tree(initial_object_list, np.sum(n_obj_per_mesh_types) * 2, coll_mngr, meshes, contact_points,
            contact_faces, rotation_types, _goal_obj=goal_obj,
            _exploration={'method': 'random', 'param': 1.})

In [None]:
mcts.exhaustive_search(state_node=0)

In [None]:
mcts.visualize()

In [None]:
best_path_indices = mcts.get_best_path(0)

In [None]:
object_list = mcts.Tree.nodes[best_path_indices[-1]]['state']

In [None]:
visualize(object_list,meshes,_goal_obj=goal_obj)