# Dataset Visualization
This notebooks allows to load and visualize a GEECO dataset using the same data loader function as used during model training.

In [None]:
'''imports'''
import os
from timeit import default_timer as timer

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
import numpy as np
import tensorflow as tf

from data.geeco_gym import pickplace_input_fn_v4
from utils.plotting import create_image_grid

In [None]:
'''path setup'''
root_path = os.environ['GEECO_ROOT']
# dataset_dir = os.path.join(root_path, 'data', 'gym-push-pad2-cube2-v4')
dataset_dir = os.path.join(root_path, 'data', 'gym-pick-pad2-cube2-v4')
print(dataset_dir)

In [None]:
'''helper_fn'''
def _frame_readout(d, step_idx=0, batch_idx=0):
    step = d['step'][batch_idx][step_idx]
    # arm: joint velocities
    arm_jnt_qvel = [
            'joint_qvel-robot0:shoulder_pan_joint',
            'joint_qvel-robot0:shoulder_lift_joint',
            'joint_qvel-robot0:upperarm_roll_joint',
            'joint_qvel-robot0:elbow_flex_joint',
            'joint_qvel-robot0:forearm_roll_joint',
            'joint_qvel-robot0:wrist_flex_joint',
            'joint_qvel-robot0:wrist_roll_joint',
    ]
    arm_qvel = np.zeros((len(arm_jnt_qvel), ))
    for i, k in enumerate(arm_jnt_qvel):
        arm_qvel[i] = d[k][batch_idx][step_idx]
    # gripper: joint positions
    gripper_jnt_qpos = [
            'joint_qpos-robot0:r_gripper_finger_joint',
            'joint_qpos-robot0:l_gripper_finger_joint',
    ]
    gripper_qpos = np.zeros((len(gripper_jnt_qpos), ))
    for i, k in enumerate(gripper_jnt_qpos):
        gripper_qpos[i] = d[k][batch_idx][step_idx]
    # EE: qpos
    ee_jnt_qpos = 'mocap_qpos-robot0:mocap'
    ee_qpos = d[ee_jnt_qpos][batch_idx][step_idx][:3]  # only in xyz
    # object: qpos
    obj_jnt_qpos = 'object_qpos-object0:joint'
    obj_qpos = d[obj_jnt_qpos][batch_idx][step_idx][:3]  # only in xyz
    # data string
#     data_str = "#: %04d\nvel: %s\ngrp: %s\nee: %s\nobj: %s" % (step, arm_qvel, gripper_qpos, ee_qpos, obj_qpos)
    data_str = "#: %04d\ngrp: %s\nee: %s\nobj: %s" % (step, gripper_qpos, ee_qpos, obj_qpos)
    return data_str

def _frame_readout_v2(d, step_idx=0, batch_idx=0):
    step = d['step'][batch_idx][step_idx]
    # arm: joint velocities
    arm_qvel = d['vel_state'][batch_idx][step_idx]
    # gripper: joint positions
    gripper_qpos = d['grp_state'][batch_idx][step_idx]
    # EE: qpos
    ee_qpos = d['ee_state'][batch_idx][step_idx][:3]  # only in xyz
    # object: qpos
    obj_qpos = d['obj_state'][batch_idx][step_idx][:3]  # only in xyz
    # targets / commands
    grp_target = d['grp_target'][batch_idx][step_idx]
    ee_target = d['ee_target'][batch_idx][step_idx][:3]  # only in xyz
    # data string
    data_str = "#: %04d\nx_grp: %s\nx_ee: %s\nx_obj: %s\ny_grp: %s\ny_ee: %s" % (step, gripper_qpos, ee_qpos, obj_qpos, grp_target, ee_target)
    return data_str

def _feature_readout(feature, batch_idx=0, window_idx=0):
    step = feature['step'][batch_idx][window_idx]
    # state
    vel_state = feature['vel_state'][batch_idx][window_idx]
    grp_state = feature['grp_state'][batch_idx][window_idx]
    ee_state = feature['ee_state'][batch_idx][window_idx][:3]  # only in xyz
    obj_state = feature['obj_state'][batch_idx][window_idx][:3]  # only in xyz
    # string readout
    feature_str = "#: %04d\nx_grp: %s\nx_ee: %s\nx_obj: %s" % (step, grp_state, ee_state, obj_state)
    return feature_str

def _feature_readout_v4(feature, batch_idx=0, window_idx=0):
    step = feature['step'][batch_idx][window_idx]
    # state
    vel_state = feature['vel_state'][batch_idx][window_idx]
    grp_state = feature['grp_state'][batch_idx][window_idx]
    cmd = feature['cmd'][batch_idx][window_idx]
    ee_state = feature['ee_state'][batch_idx][window_idx][:3]  # only in xyz
    obj_state = feature['obj_state'][batch_idx][window_idx][:3]  # only in xyz
    # string readout
    feature_str = "#: %04d\ncmd: %s\nx_ee: %s\nx_obj: %s" % (step, cmd, ee_state, obj_state)
    return feature_str

def _label_readout(label, batch_idx=0):
    # targets / commands
    grp_target = label['grp_target'][batch_idx]
    ee_target = label['ee_target'][batch_idx][:3]  # only in xyz
    # string readout
    label_str = "y_grp: %s\ny_ee: %s" % (grp_target, ee_target)
    return label_str

def _label_readout_v4(label, batch_idx=0):
    # targets / commands
    grp_target = label['grp_target'][batch_idx]
    ee_target = label['ee_target'][batch_idx][:3]  # only in xyz
    cmd = label['cmd'][batch_idx]
    # string readout
    label_str = "cmd: %s\ny_grp: %s\ny_ee: %s" % (cmd, grp_target, ee_target)
    return label_str

In [None]:
'''input_fn'''
split_name = 'debug'
mode = 'eval'
window_size = 4
fetch_target = False
shuffle_buffer = 32
batch_size = 2
dataset = pickplace_input_fn_v4(
    dataset_dir=dataset_dir, split_name=split_name, mode=mode, window_size=window_size, fetch_target=fetch_target, shuffle_buffer=shuffle_buffer, batch_size=batch_size)
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()

In [None]:
'''session setup'''
try:
    sess.close()
except:
    pass
sess = tf.InteractiveSession()

In [None]:
'''load a data batch'''
np.set_printoptions(precision=2, suppress=True)
t_start_fetch = timer()
d = sess.run(data)
t_end_fetch = timer()
t_fetch = t_end_fetch - t_start_fetch
print(">>> Fetched one batch (size=%d, window=%d) in %.04f s!" % (batch_size, window_size, t_fetch))
feat, lbl = d
print(">>> Feature struct:")
for k, v in feat.items():
    print(k, v.shape)
print(">>> Label struct:")
for k, v in lbl.items():
    print(k, v.shape)

In [None]:
'''display a data batch'''
if fetch_target:
    grid_size = window_size + 1
else:
    grid_size = window_size

for b_idx in range(batch_size):
    
    grid_rgb = create_image_grid(num_examples=grid_size, max_cols_per_row=10, tile_pad=1.2)  # window + target frame
    for f_idx in range(window_size):  # display window
        grid_rgb[f_idx].imshow(feat['rgb'][b_idx][f_idx])
#         title = _feature_readout(feat, batch_idx=b_idx, window_idx=f_idx)
        title = _feature_readout_v4(feat, batch_idx=b_idx, window_idx=f_idx)
        grid_rgb[f_idx].set_title(title)
    if fetch_target:
        grid_rgb[window_size].imshow(feat['target_rgb'][b_idx])
        grid_rgb[window_size].set_title('target frame\n\n\n')
    plt.show()
    
    grid_depth = create_image_grid(num_examples=grid_size, max_cols_per_row=10, tile_pad=1.2)  # window + target frame
    for f_idx in range(window_size):
        grid_depth[f_idx].imshow(np.squeeze(feat['depth'][b_idx][f_idx]))
    if fetch_target:
        grid_depth[window_size].imshow(np.squeeze(feat['target_depth'][b_idx]))
    plt.show()
    
#     label_str = _label_readout(lbl, batch_idx=b_idx)
    label_str = _label_readout_v4(lbl, batch_idx=b_idx)
    print(label_str)