In [None]:
import yaml
import logging
import pathlib as pl
import sys
from ros_tf2_wrapper import get_populated_tf2_wrapper
from rosbag_to_rlbench import (_setup_rosbag_dataset, 
                               extract_list_from_rosbag_dataset, 
                               add_tfs_to_episodes, 
                               keypoints_from_frame,
                               save_keypoint)

CONFIG_DIR = './configs/bag_to_rlbench.yaml'

In [None]:
from rlbench.backend.observation import Observation
import numpy as np

def get_rlbench_obs(eps, misc):
     # pose format -> x, y, z, qx, qy, qz, qw
    dummy_pose = np.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0 ,1.0])

    obs = Observation(
        left_shoulder_rgb=None,
        left_shoulder_depth=None,
        left_shoulder_point_cloud=None,
        right_shoulder_rgb=None,
        right_shoulder_depth=None,
        right_shoulder_point_cloud=None,
        overhead_rgb=None,
        overhead_depth=None,
        overhead_point_cloud=None,
        wrist_rgb=None,
        wrist_depth=None,
        wrist_point_cloud=None,
        front_rgb=None,
        front_depth=None,
        front_point_cloud=None,
        left_shoulder_mask=None,
        right_shoulder_mask=None,
        overhead_mask=None,
        wrist_mask=None,
        front_mask=None,
        joint_velocities            = dummy_pose,
        joint_positions             = dummy_pose,
        joint_forces                = dummy_pose,
        gripper_open                = 0.9,
        gripper_pose                = dummy_pose,
        gripper_matrix              = dummy_pose,
        gripper_touch_forces        = None,
        gripper_joint_positions     = np.ones(2),
        task_low_dim_state          =None,
        ignore_collisions           =True, # TODO: fix
        misc=misc,
    )

    return obs

In [None]:
'''
A bag contains topics which are mapped to names corresponding to their RLBench
    counterparts. A topic/name is also associated with a frame name. 
This frame name is used to get the extrinsics of the datapoint. The reference 
    frame is the frame of reference for all the extrinsics. 
By default the transforms are derived from the /tf and /tf_static topics in the bag.
'''

logging.basicConfig(level=logging.INFO, format="%(message)s",
                    handlers=[logging.StreamHandler(sys.stdout)])
cfg = yaml.load(CONFIG_DIR)
data_dir = pl.Path(cfg.data_dir)
out_dir = pl.Path(cfg.save_path)

# For every task folder and its description
for task_folder, description in zip(cfg.task_folders, cfg.task_description):
    task_dir = data_dir/task_folder
    # For every rosbag/demo in folder
    for demo_idx, demo_path in enumerate(task_dir.glob('*.bag')):
        logging.info('.'*3 + f'processing: {demo_path}' + '.'*3)
        logging.info(f'depth scale: {cfg.depth_scale}')
        
        tf_data = get_populated_tf2_wrapper(demo_path)
        tf_data.register_getter_hook(lambda x: cfg.reference_frame_T @ x, cfg.reference_frame)
        dataset = _setup_rosbag_dataset(demo_path, 
                                        cfg.topics_and_names,
                                        time_slack=cfg.topic_time_slack,
                                        reference_topic=cfg.reference_topic)
        episodes = extract_list_from_rosbag_dataset(dataset)
        add_tfs_to_episodes(tf_data, 
                            episodes, 
                            cfg.names_and_frames, 
                            cfg.reference_frame)
        # keypoints_from_frame(episodes, 
        #                     tf_data, 
        #                     cfg.reference_frame)
        save_keypoint(episodes, 
                      out_dir, 
                      description, 
                      demo_idx, 
                      depth_scale=cfg.depth_scale,
                      cameras_used=cfg.cameras_used)
        logging.info('.'*5 + f'Completed demo {demo_idx}' + '.'*5)
    logging.info('Done with task ')

