In [2]:
from copy import deepcopy

import trimesh
import numpy as np
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout

import robosuite.mcts.transforms as tf

from matplotlib import pyplot as plt

%matplotlib inline

# Load Mesh Models

In [3]:
def get_meshes():
    _mesh_types = ['arch_box',
                  'rect_box',
                  'square_box',
                  'half_cylinder_box',
                  'triangle_box',
                  'twin_tower_goal',
                  'tower_goal',
                  'box_goal',
                  'custom_table']
    _mesh_files = ['./robosuite/models/assets/objects/meshes/arch_box.stl',
                  './robosuite/models/assets/objects/meshes/rect_box.stl',
                  './robosuite/models/assets/objects/meshes/square_box.stl',
                  './robosuite/models/assets/objects/meshes/half_cylinder_box.stl',
                  './robosuite/models/assets/objects/meshes/triangle_box.stl',
                  './robosuite/models/assets/objects/meshes/twin_tower_goal.stl',
                  './robosuite/models/assets/objects/meshes/tower_goal.stl',
                  './robosuite/models/assets/objects/meshes/box_goal.stl',
                  './robosuite/models/assets/objects/meshes/custom_table.stl']
    _mesh_units = [0.001, 0.001, 0.001, 0.001, 0.001, 0.0011, 0.001, 0.0011, 0.01]
    area_ths = 0.003

    _meshes = []

    for mesh_type, mesh_file, unit in zip(_mesh_types, _mesh_files, _mesh_units):
        mesh = trimesh.load(mesh_file)

        mesh.apply_scale(unit)
        mesh.apply_translation(-mesh.center_mass)
        while True:
            indices, = np.where(mesh.area_faces > area_ths)
            if len(indices) > 0:
                mesh = mesh.subdivide(indices)
            else:
                break

        # Find stable pose
        if "custom_table" not in mesh_type:
            stable_poses, probs = mesh.compute_stable_poses(n_samples=100)
            stable_pose_idx = np.argmax(probs)
            stable_pose = stable_poses[stable_pose_idx]

            if "half_cylinder_box" in mesh_type or "triangle_box" in mesh_type:
                stable_pose = tf.rotation_matrix(-np.pi / 2., [1., 0., 0.]).dot(stable_pose)
            elif "arch_box" in mesh_type:
                stable_pose = tf.rotation_matrix(np.pi / 2., [1., 0., 0.]).dot(stable_pose)

            stable_pose = stable_pose.dot(tf.rotation_matrix(np.pi / 2., [0., 1., 0.]))
        else:
            stable_pose = np.eye(4)
        mesh.apply_transform(stable_pose)
        _meshes.append(mesh)

    # Read gripper meshes
    gripper_mesh_files = ['/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/standard_narrow.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/half_round_tip.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/connector_plate.stl',
                          '/home/kj/robosuite/robosuite/models/assets/grippers/meshes/rethink_gripper/electric_gripper_base.stl']

    gripper_mesh_relative_poses = [{'position': [0, 0.01 + 0.010833, 0.0444 - 0.12255859], 'quat': [0, 0, -1, 0]},
                                   {'position': [0, 0.01 + 0.01725 + 0.010833, 0.0444 + 0.075 - 0.12255859], 'quat': [0, 0, -1, 0]},
                                   {'position': [0, -0.01 - 0.010833, 0.0444 - 0.12255859], 'quat': [0, 0, 0, 1]},
                                   {'position': [0, -0.01 - 0.01725 - 0.010833, 0.0444 + 0.075 - 0.12255859], 'quat': [0, 0, 0, 1]},
                                   {'position': [0, 0, 0.0018 - 0.12255859], 'quat': [0, 0, 0, 1]},
                                   {'position': [0, 0, 0.0194 - 0.12255859], 'quat': [0, 0, 0, 1]}]

    gripper_mesh = None
    for mesh_file, rel_pose in zip(gripper_mesh_files, gripper_mesh_relative_poses):
        gripper_component_mesh = trimesh.load(mesh_file)
        pose = tf.quaternion_matrix(rel_pose['quat'])
        pose[:3, 3] = rel_pose['position']
        gripper_component_mesh.apply_transform(pose)
        if gripper_mesh is None:
            gripper_mesh = gripper_component_mesh
        else:
            gripper_mesh = trimesh.util.concatenate(gripper_mesh, gripper_component_mesh)

    _mesh_types.append('left_gripper')
    _meshes.append(gripper_mesh)
    return _meshes, _mesh_types, _mesh_units, _mesh_files

if __name__=='__main__':
    meshes, mesh_types, mesh_units, mesh_files = get_meshes()

In [4]:
def sample_place_pose(_obj1, _obj2, _meshes, _coll_mngr, n_sampling_trial=100):
    mesh1 = _meshes[_obj1.mesh_idx]
    T1_before = deepcopy(_obj1.pose)
    normals1_world = T1_before[:3,:3].dot(mesh1.face_normals.T)
    surface1_indices = np.where(normals1_world[2,:] < -0.99)[0]
    surface1_probs = mesh1.area_faces[surface1_indices]
    surface1_probs = surface1_probs/np.sum(surface1_probs)
    sampled1_indices = surface1_indices[np.random.choice(len(surface1_probs), n_sampling_trial, p=surface1_probs)]

    weights = np.random.uniform(size=(1,3,n_sampling_trial))
    weights = weights/np.sum(weights, keepdims=True, axis=1)
    pnts1 = np.sum(weights*mesh1.vertices[mesh1.faces[sampled1_indices]].T,axis=1)
    normals1 = mesh1.face_normals[sampled1_indices].T
    
    mesh2 = _meshes[_obj2.mesh_idx]
    T2 = deepcopy(_obj2.pose)
    normals2_world = T2[:3,:3].dot(mesh2.face_normals.T)
    if _obj2.name is "custom_table":
        surface2_indices = np.where(normals2_world[2,:] > 0.999)[0]
    else:
        surface2_indices = np.where(normals2_world[2,:] > 0.999)[0]
    surface2_probs = mesh2.area_faces[surface2_indices]
    surface2_probs = surface2_probs/np.sum(surface2_probs)
    sampled2_indices = surface2_indices[np.random.choice(len(surface2_probs), n_sampling_trial, p=surface2_probs)]

    weights = np.random.uniform(size=(1,3,n_sampling_trial))
    weights = weights/np.sum(weights, keepdims=True, axis=1)
    pnts2 = np.sum(weights*mesh2.vertices[mesh2.faces[sampled2_indices]].T,axis=1)
    normals2 = mesh2.face_normals[sampled2_indices].T

    def compute_T(pnt1, normal1, pnt2, normal2):
        target_pnt = pnt2 + 1e-5*normal2
        target_normal = -normal2
        T_target = create_random_rotation_mtx_from_z(target_normal)
        T_target[:3,3] = target_pnt
        T_source = create_random_rotation_mtx_from_z(normal1)
        T_source[:3,3] = pnt1
        
        T1 = T2.dot(T_target.dot(np.linalg.inv(T_source)))
        _coll_mngr.set_transform(_obj1.name, T1)
        if not _coll_mngr.in_collision_internal():
            return T1
        else:
            return None
    
    return list(map(compute_T, pnts1.T, normals1.T, pnts2.T, normals2.T))

# Initializing Configuration

In [20]:
class Object(object):
    def __init__(self, _name, _mesh_idx, _pose, _logical_state):
        self.name = _name
        self.mesh_idx = _mesh_idx
        self.pose = _pose
        self.logical_state = _logical_state
        self.color = [np.random.uniform(), np.random.uniform(), np.random.uniform(), 0.3]

def get_obj_idx_by_name(_object_list, _name):
    for _obj_idx, _obj in enumerate(_object_list):
        if _obj.name == _name:
            return _obj_idx
    return None

def get_held_object(_object_list):
    for _obj_idx, _obj in enumerate(_object_list):
        if "held" in _obj.logical_state: return _obj_idx
    return None

def update_logical_state(_object_list):
    for _obj in _object_list:
        if "on" in _obj.logical_state:
            for _support_obj_name in _obj.logical_state["on"]:
                _support_obj_idx = get_obj_idx_by_name(_object_list, _support_obj_name)
                if "support" in _object_list[_support_obj_idx].logical_state:
                    if _obj.name in _object_list[_support_obj_idx].logical_state["support"]:
                        continue
                    else:
                        _object_list[_support_obj_idx].logical_state["support"].append(_obj.name)
                else:
                    _object_list[_support_obj_idx].logical_state["support"] = [_obj.name]

        if "support" in _obj.logical_state:
            for _on_obj_name in _obj.logical_state["support"]:
                _on_obj_idx = get_obj_idx_by_name(_object_list, _on_obj_name)
                if "on" in _object_list[_on_obj_idx].logical_state:
                    if _obj.name in _object_list[_on_obj_idx].logical_state["on"]:
                        continue
                    else:
                        _object_list[_on_obj_idx].logical_state["on"].append(_obj.name)
                else:
                    _object_list[_on_obj_idx].logical_state["on"] = [_obj.name]

def create_random_rotation_mtx_from_x(x_axis):
    rnd_axis = np.random.uniform(low=-1.0,high=1.0,size=(3,))
    rnd_axis = rnd_axis / np.sqrt(np.sum(rnd_axis**2))
    
    y_axis = np.cross(rnd_axis,x_axis)
    y_axis = y_axis / np.sqrt(np.sum(y_axis**2))
    z_axis = np.cross(x_axis,y_axis)
    z_axis = z_axis / np.sqrt(np.sum(z_axis**2))
    
    T = np.eye(4)
    T[:3,0] = x_axis
    T[:3,1] = y_axis
    T[:3,2] = z_axis
    
    return T
    
def create_random_rotation_mtx_from_y(y_axis):
    rnd_axis = np.random.uniform(low=-1.0,high=1.0,size=(3,))
    rnd_axis = rnd_axis / np.sqrt(np.sum(rnd_axis**2))
    
    z_axis = np.cross(rnd_axis,y_axis)
    z_axis = z_axis / np.sqrt(np.sum(z_axis**2))
    x_axis = np.cross(y_axis,z_axis)
    x_axis = x_axis / np.sqrt(np.sum(x_axis**2))
    
    T = np.eye(4)
    T[:3,0] = x_axis
    T[:3,1] = y_axis
    T[:3,2] = z_axis
    
    return T

def create_random_rotation_mtx_from_z(z_axis):
    rnd_axis = np.random.uniform(low=-1.0,high=1.0,size=(3,))
    rnd_axis = rnd_axis / np.sqrt(np.sum(rnd_axis**2))
    
    x_axis = np.cross(rnd_axis,z_axis)
    x_axis = x_axis / np.sqrt(np.sum(x_axis**2))
    y_axis = np.cross(z_axis,x_axis)
    y_axis = y_axis / np.sqrt(np.sum(y_axis**2))
    
    T = np.eye(4)
    T[:3,0] = x_axis
    T[:3,1] = y_axis
    T[:3,2] = z_axis
    
    return T
                    
def configuration_initializer(_meshes, _mesh_types, _mesh_units, _goal_name='tower_goal'):
    _object_list = []
    _coll_mngr = trimesh.collision.CollisionManager()
    
    # goal_name = 'stack_hard'

    # Set how many objects are spawned
    if 'tower_goal' is _goal_name:
        n_obj_per_mesh_types = [0, 1, 1, 1, 0, 0, 0, 0, 0, 0]
    elif 'stack_hard' is _goal_name:
        n_obj_per_mesh_types = [0, 1, 1, 1, 1, 0, 0, 0, 0, 0]

    
    table_spawn_position = [0.6, 0.2]
    
    # Create Table object
    table_name = 'custom_table'
    table_pose = np.array([0.6, 0.0, 0.573077])
    table_mesh_idx =_mesh_types.index(table_name)
    table_T = np.eye(4)
    table_T[:3, 3] = table_pose
    table_obj = Object(table_name, table_mesh_idx, table_T, {"static": []})
    table_height = _meshes[table_mesh_idx].bounds[1][2] + table_pose[2]

    _object_list.append(table_obj)
    _coll_mngr.add_object(table_name, _meshes[table_mesh_idx])
    _coll_mngr.set_transform(table_name, table_T)

    if 'goal' in _goal_name:
        goal_mesh_idx = _mesh_types.index(_goal_name)
        goal_pose = np.array([0.4, 0.0, -0.019*_mesh_units[goal_mesh_idx]*1000 + _meshes[table_mesh_idx].bounds[1][2] - _meshes[table_mesh_idx].bounds[0][2]
                              - _meshes[goal_mesh_idx].bounds[0][2]])
        goal_T = np.eye(4)
        goal_T[:3, 3] = goal_pose
        _goal_obj = Object(_goal_name, goal_mesh_idx, goal_T, {"goal": []})
    else:
        _goal_obj = None

    for mesh_idx, n_obj in enumerate(n_obj_per_mesh_types):
        obj_mesh_idx = mesh_idx
        for obj_idx in range(n_obj):
            obj_name =_mesh_types[mesh_idx] + str(obj_idx)
            new_obj = Object(obj_name, obj_mesh_idx, np.eye(4), {"on": [table_name]})
            new_obj.pose[:3,3] = [0.2 + mesh_idx*0.08, table_spawn_position[1]+obj_idx*0.08, - _meshes[mesh_idx].bounds[0][2] + table_height]
            _coll_mngr.add_object(obj_name, _meshes[obj_mesh_idx])
            _coll_mngr.set_transform(obj_name, new_obj.pose)
            _object_list.append(new_obj)

    gripper_mesh_idx =_mesh_types.index('left_gripper')

    _coll_mngr.add_object('left_gripper', _meshes[gripper_mesh_idx])
    gripper_pose = deepcopy(table_T)
    gripper_pose[2,3] += 0.5
    _coll_mngr.set_transform('left_gripper', gripper_pose)

    gripper_obj = Object('left_gripper', gripper_mesh_idx, gripper_pose, {'left_gripper': []})
    _object_list.append(gripper_obj)
    return _object_list, _goal_obj, _coll_mngr
    
if __name__=='__main__':
    object_list, goal_obj, coll_mngr = configuration_initializer(meshes, mesh_types, mesh_units, 'tower_goal')

In [21]:
def visualize(_object_list, _meshes, _goal_obj=None):
    mesh_scene = trimesh.Scene()
    for _obj in _object_list:
        mesh_scene.add_geometry(_meshes[_obj.mesh_idx], node_name=_obj.name, transform=_obj.pose)
    if _goal_obj is not None:
        mesh_scene.add_geometry(_meshes[_goal_obj.mesh_idx], node_name=_goal_obj.name, transform=_goal_obj.pose)
    mesh_scene.show(viewer='gl')

visualize(object_list, meshes, goal_obj)

# Compute Possible Logical Actions

In [22]:
def get_possible_actions(_object_list, _meshes, _coll_mngr, _rotation_types=4, side_place_flag=False, goal_obj=None):
    _action_list = []
    for obj in _object_list:
        _coll_mngr.set_transform(obj.name, obj.pose)

    # Check pick
    if all(["held" not in obj.logical_state for obj in _object_list]):
        for obj1 in _object_list:
            if "support" not in obj1.logical_state and "static" not in obj1.logical_state and \
                    "done" not in obj1.logical_state and "goal" not in obj1.name and 'gripper' not in obj1.name:

                obj1_mesh_idx = obj1.mesh_idx
                obj1_pose = obj1.pose
                                
                obj1_contact_points = np.mean(_meshes[obj1_mesh_idx].vertices[_meshes[obj1_mesh_idx].faces], axis=1)
                obj1_contact_normals = _meshes[obj1_mesh_idx].face_normals
                obj1_contact_normals_world = obj1.pose[:3, :3].dot(obj1_contact_normals.T).T
                obj1_contact_indices, = np.where(np.abs(obj1_contact_normals_world[:, 2]) < 1e-10)

                grasp_poses = []
                retreat_poses = []
                gripper_widths = []

                for i in obj1_contact_indices:
                    for j in range(_rotation_types):
                        pnt1 = obj1_contact_points[i]
                        normal1 = obj1_contact_normals[i]
                        hand_t_grasp, hand_t_retreat, gripper_width = get_grasp_pose(obj1_mesh_idx, pnt1, normal1, 2. * np.pi * j / _rotation_types, obj1_pose, _meshes, _coll_mngr)
                        if hand_t_grasp is not None:
                            grasp_poses.append(hand_t_grasp)
                            retreat_poses.append(hand_t_retreat)
                            gripper_widths.append(gripper_width)

                if len(grasp_poses) > 0:
                    _action_list.append({"type": "pick", "param": obj1.name, "grasp_poses": grasp_poses,
                                         "retreat_poses": retreat_poses, "gripper_widths": gripper_widths})

    # Check place
    held_obj_idx = get_held_object(_object_list)
    if held_obj_idx is not None:
        held_obj = _object_list[held_obj_idx]
        obj2 = held_obj
        obj2_mesh_idx = held_obj.mesh_idx
        obj2_contact_points = _contact_points[obj2_mesh_idx]
        obj2_contact_normals = _meshes[obj2_mesh_idx].face_normals[_contact_faces[obj2_mesh_idx]]

        obj2_contact_normals_world = obj2.pose[:3, :3].dot(obj2_contact_normals.T).T
        if side_place_flag and 'rect_box' in obj2.name:
            obj2_contact_indices, = np.where(obj2_contact_normals_world[:, 2] < 1e-3)
        else:
            obj2_contact_indices, = np.where(obj2_contact_normals_world[:, 2] < 0.)

        for obj1 in _object_list:
            if "held" not in obj1.logical_state and 'gripper' not in obj1.name:
                # print(obj1.name)
                obj1_mesh_idx = obj1.mesh_idx
                obj1_contact_points = _contact_points[obj1_mesh_idx]
                obj1_contact_normals = _meshes[obj1_mesh_idx].face_normals[_contact_faces[obj1_mesh_idx]]

                obj1_contact_normals_world = obj1.pose[:3, :3].dot(obj1_contact_normals.T).T
                obj1_contact_indices, = np.where(obj1_contact_normals_world[:, 2] > 0.99)

                _rel_gripper = np.linalg.inv(held_obj.pose).dot(_object_list[-1].pose)

                placing_poses = []
                for i in obj1_contact_indices:
                    for j in range(_rotation_types):
                        for k in obj2_contact_indices:
                            pnt1 = obj1_contact_points[i]
                            normal1 = obj1_contact_normals[i]

                            pnt2 = obj2_contact_points[k]
                            normal2 = obj2_contact_normals[k]

                            pose = get_on_pose(obj2.name, pnt2, normal2, 2. * np.pi * j / _rotation_types,
                                               pnt1, normal1, obj1.pose, _coll_mngr, _rel_gripper=_rel_gripper)

                            if pose is not None:
                                placing_poses.append(pose)
                if len(placing_poses) > 0:
                    _action_list.append({"type": "place", "param": obj1.name, "placing_poses": placing_poses})
    return _action_list

[{'type': 'pick', 'param': 'rect_box0'}, {'type': 'pick', 'param': 'square_box0'}, {'type': 'pick', 'param': 'half_cylinder_box0'}]
