In [None]:
import copy
import time
import torch
import numpy as np
import os
import argparse
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from structformer.data.tokenizer import Tokenizer
from structformer.evaluation.test_object_selection_network import ObjectSelectionInference
from structformer.evaluation.test_structformer import PriorInference
from structformer.utils.rearrangement import show_pcs_with_predictions, get_initial_scene_idxs, evaluate_target_object_predictions, save_img, show_pcs_with_labels, test_new_vis, show_pcs
from structformer.evaluation.inference import PointCloudRearrangement

In [None]:
# point cloud utils
from pc_utils import depth2pc

# tabletop environment
FILE_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(FILE_PATH, '../..', 'TabletopTidyingUp/pybullet_ur5_robotiq'))
from custom_env import TableTopTidyingUpEnv, get_contact_objects
from utilities import Camera, Camera_front_top
sys.path.append(os.path.join(FILE_PATH, '../..', 'TabletopTidyingUp'))
from collect_template_list import scene_list

In [None]:
def setupEnvironment(args=None):
    camera_top = Camera((0, 0, 1.45), 0.02, 2, (480, 360), 60)
    camera_front_top = Camera_front_top((0.5, 0, 1.3), 0.02, 2, (480, 360), 60)
    
    data_dir = '/ssd/disk' #args.data_dir
    objects_cfg = { 'paths': {
            'pybullet_object_path' : os.path.join(data_dir, 'pybullet-URDF-models/urdf_models/models'),
            'ycb_object_path' : os.path.join(data_dir, 'YCB_dataset'),
            'housecat_object_path' : os.path.join(data_dir, 'housecat6d/obj_models_small_size_final'),
        },
        'split' : 'inference' #args.object_split #'inference' #'train'
    }
    
    gui_on = not args.gui_off
    env = TableTopTidyingUpEnv(objects_cfg, camera_top, camera_front_top, vis=gui_on, gripper_type='85')
    p.resetDebugVisualizerCamera(2.0, -270., -60., (0., 0., 0.))
    p.configureDebugVisualizer(p.COV_ENABLE_SHADOWS, 1)  # Shadows on/off
    p.addUserDebugLine([0, -0.5, 0], [0, -0.5, 1.1], [0, 1, 0])

    env.reset()
    return env

In [None]:
env = setupEnvironment()

scenes = [s for s in sorted(list(scene_list.keys())) if s.startswith('D')]
selected_scene = random.choice(scenes)
print('Selected scene: %s' %selected_scene)

objects = scene_list[selected_scene]

sizes = []
for i in range(len(objects)):
    if 'small' in objects[i]:
        sizes.append('small')
        objects[i] = objects[i].replace('small_', '')
    elif 'large' in objects[i]:
        sizes.append('large')
        objects[i] = objects[i].replace('large_', '')
    else:
        sizes.append('medium')
objects = [[objects[i], sizes[i]] for i in range(len(objects))]
selected_objects = objects

env.spawn_objects(selected_objects)
env.arrange_objects(random=True)

In [None]:
obs = env.reset()
rgb = obs['top']['rgb']
depth = obs['top']['depth']

In [None]:
dataset_base_dir = ''
object_selection_model_dir = ''
pose_generation_model_dir = ''
dirs_config = ''

dirs_cfg = OmegaConf.load(dirs_config)
dirs_cfg.dataset_base_dir = dataset_base_dir
OmegaConf.resolve(dirs_cfg)

# load models
object_selection_inference = ObjectSelectionInference(object_selection_model_dir, dirs_cfg)
pose_generation_inference = PriorInference(pose_generation_model_dir, dirs_cfg)

test_dataset = object_selection_inference.dataset
initial_scene_idxs = get_initial_scene_idxs(test_dataset)

idx = 0
#for idx in range(len(test_dataset)):
#    if idx not in initial_scene_idxs:
#        continue
#    if idx == 4:
#        continue

filename, _ = test_dataset.get_data_index(idx)
scene_id = os.path.split(filename)[1][4:-3]
print("-"*50)
print("Scene No.{}".format(scene_id))

# retrieve data
init_datum = test_dataset.get_raw_data(idx)
goal_specification = init_datum["goal_specification"]
object_selection_structured_sentence = init_datum["sentence"][5:]
structure_specification_structured_sentence = init_datum["sentence"][:5]
object_selection_natural_sentence = object_selection_inference.tokenizer.convert_to_natural_sentence(
    object_selection_structured_sentence)
structure_specification_natural_sentence = object_selection_inference.tokenizer.convert_structure_params_to_natural_language(structure_specification_structured_sentence)

In [None]:
# object selection
predictions, gts = object_selection_inference.predict_target_objects(init_datum)

all_obj_xyzs = init_datum["xyzs"][:len(predictions)]
all_obj_rgbs = init_datum["rgbs"][:len(predictions)]
obj_idxs = [i for i, l in enumerate(predictions) if l == 1.0]
if len(obj_idxs) == 0:
    continue
other_obj_idxs = [i for i, l in enumerate(predictions) if l == 0.0]
obj_xyzs = [all_obj_xyzs[i] for i in obj_idxs]
obj_rgbs = [all_obj_rgbs[i] for i in obj_idxs]
other_obj_xyzs = [all_obj_xyzs[i] for i in other_obj_idxs]
other_obj_rgbs = [all_obj_rgbs[i] for i in other_obj_idxs]

print("\nSelect objects to rearrange...")
print("Instruction:", object_selection_natural_sentence)
print("Visualize groundtruth (dot color) and prediction (ring color)")
show_pcs_with_predictions(init_datum["xyzs"][:len(predictions)], init_datum["rgbs"][:len(predictions)],
                          gts, predictions, add_table=True, side_view=True)
print("Visualize object to rearrange")
show_pcs(obj_xyzs, obj_rgbs, side_view=True, add_table=True)

In [None]:
# pose generation
max_num_objects = pose_generation_inference.cfg.dataset.max_num_objects
max_num_other_objects = pose_generation_inference.cfg.dataset.max_num_other_objects
if len(obj_xyzs) > max_num_objects:
    print("WARNING: reducing the number of \"query\" objects because this model is trained with a maximum of {} \"query\" objects. Train a new model if a larger number is needed.".format(max_num_objects))
    obj_xyzs = obj_xyzs[:max_num_objects]
    obj_rgbs = obj_rgbs[:max_num_objects]
if len(other_obj_xyzs) > max_num_other_objects:
    print("WARNING: reducing the number of \"distractor\" objects because this model is trained with a maximum of {} \"distractor\" objects. Train a new model if a larger number is needed.".format(max_num_other_objects))
    other_obj_xyzs = other_obj_xyzs[:max_num_other_objects]
    other_obj_rgbs = other_obj_rgbs[:max_num_other_objects]

pose_generation_datum = pose_generation_inference.dataset.prepare_test_data(obj_xyzs, obj_rgbs,
                                                                            other_obj_xyzs, other_obj_rgbs,
                                                                            goal_specification["shape"])
beam_data = []
beam_pc_rearrangements = []
for b in range(beam_size):
    datum_copy = copy.deepcopy(pose_generation_datum)
    beam_data.append(datum_copy)
    beam_pc_rearrangements.append(PointCloudRearrangement(datum_copy))

# autoregressive decoding
num_target_objects = beam_pc_rearrangements[0].num_target_objects

# first predict structure pose
beam_goal_struct_pose, target_object_preds = pose_generation_inference.limited_batch_inference(beam_data)
for b in range(beam_size):
    datum = beam_data[b]
    datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]]
    datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]]
    datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]]
    datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]]

# then iteratively predict pose of each object
beam_goal_obj_poses = []
for obj_idx in range(num_target_objects):
    struct_preds, target_object_preds = pose_generation_inference.limited_batch_inference(beam_data)
    beam_goal_obj_poses.append(target_object_preds[:, obj_idx])
    for b in range(beam_size):
        datum = beam_data[b]
        datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0]
        datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1]
        datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2]
        datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:]
# concat in the object dim
beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0)
# swap axis
beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0)  # batch size, number of target objects, pose dim

# move pc
for bi in range(beam_size):
    beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi])
    beam_pc_rearrangements[bi].rearrange()

print("\nRearrange \"query\" objects...")
print("Instruction:", structure_specification_natural_sentence)
for pi, pc_rearrangement in enumerate(beam_pc_rearrangements):
    print("Visualize rearranged scene sample {}".format(pi))
    pc_rearrangement.visualize("goal", add_other_objects=True, add_table=True, side_view=True)