In [1]:
%load_ext autoreload
%autoreload 2

Convert a maniskill dataset to a set of npzs compatible with the pytorch DV3 environment.

In [2]:
import h5py
import numpy as np
import os
from tqdm import tqdm
import pathlib

In [3]:
input_h5_path = pathlib.Path("~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.rgb.pd_ee_delta_pos.physx_cpu.h5").expanduser()

In [93]:
## Pulled from mani_skill.examples.baselines.diffusion_policy.train_rgbd
import torch
from functools import partial
from h5py import Dataset, File, Group

TARGET_KEY_TO_SOURCE_KEY = {
    "states": "env_states",
    "observations": "obs",
    "success": "success",
    "next_observations": "obs",
    # 'dones': 'dones',
    # 'rewards': 'rewards',
    "actions": "actions",
}


def load_content_from_h5_file(file):
    if isinstance(file, (File, Group)):
        return {key: load_content_from_h5_file(file[key]) for key in list(file.keys())}
    elif isinstance(file, Dataset):
        return file[()]
    else:
        raise NotImplementedError(f"Unspported h5 file type: {type(file)}")


def load_hdf5(
    path,
):
    print("Loading HDF5 file", path)
    file = File(path, "r")
    ret = load_content_from_h5_file(file)
    file.close()
    print("Loaded")
    return ret


def load_traj_hdf5(path, num_traj=None):
    print("Loading HDF5 file", path)
    file = File(path, "r")
    keys = list(file.keys())
    if num_traj is not None:
        assert num_traj <= len(keys), f"num_traj: {num_traj} > len(keys): {len(keys)}"
        keys = sorted(keys, key=lambda x: int(x.split("_")[-1]))
        keys = keys[:num_traj]
    ret = {key: load_content_from_h5_file(file[key]) for key in keys}
    file.close()
    print("Loaded")
    return ret


def load_demo_dataset(
    path, keys=["observations", "actions", 'success'], num_traj=None, concat=True
):
    # assert num_traj is None
    raw_data = load_traj_hdf5(path, num_traj)
    # raw_data has keys like: ['traj_0', 'traj_1', ...]
    # raw_data['traj_0'] has keys like: ['actions', 'dones', 'env_states', 'infos', ...]
    _traj = raw_data["traj_0"]
    for key in keys:
        source_key = TARGET_KEY_TO_SOURCE_KEY[key]
        assert source_key in _traj, f"key: {source_key} not in traj_0: {_traj.keys()}"
    dataset = {}
    for target_key in keys:
        # if 'next' in target_key:
        #     raise NotImplementedError('Please carefully deal with the length of trajectory')
        source_key = TARGET_KEY_TO_SOURCE_KEY[target_key]
        dataset[target_key] = [raw_data[idx][source_key] for idx in raw_data]
        if isinstance(dataset[target_key][0], np.ndarray) and concat:
            if target_key in ["observations", "states"] and len(
                dataset[target_key][0]
            ) > len(raw_data["traj_0"]["actions"]):
                dataset[target_key] = np.concatenate(
                    [t[:-1] for t in dataset[target_key]], axis=0
                )
            elif target_key in ["next_observations", "next_states"] and len(
                dataset[target_key][0]
            ) > len(raw_data["traj_0"]["actions"]):
                dataset[target_key] = np.concatenate(
                    [t[1:] for t in dataset[target_key]], axis=0
                )
            else:
                dataset[target_key] = np.concatenate(dataset[target_key], axis=0)

            print("Load", target_key, dataset[target_key].shape)
        else:
            print(
                "Load",
                target_key,
                len(dataset[target_key]),
                type(dataset[target_key][0]),
            )
    return dataset

def convert_obs(obs, concat_fn, transpose_fn, state_obs_extractor, depth = True):
    img_dict = obs["sensor_data"]
    ls = ["rgb"]
    if depth:
        ls = ["rgb", "depth"]

    new_img_dict = {
        key: transpose_fn(
            concat_fn([v[key] for v in img_dict.values()])
        )  # (C, H, W) or (B, C, H, W)
        for key in ls
    }
    if "depth" in new_img_dict and isinstance(new_img_dict['depth'], torch.Tensor): # MS2 vec env uses float16, but gym AsyncVecEnv uses float32
        new_img_dict['depth'] = new_img_dict['depth'].to(torch.float16)

    # Unified version
    states_to_stack = state_obs_extractor(obs)
    for j in range(len(states_to_stack)):
        if states_to_stack[j].dtype == np.float64:
            states_to_stack[j] = states_to_stack[j].astype(np.float32)
    try:
        state = np.hstack(states_to_stack)
    except:  # dirty fix for concat trajectory of states
        state = np.column_stack(states_to_stack)
    if state.dtype == np.float64:
        for x in states_to_stack:
            print(x.shape, x.dtype)
        import pdb

        pdb.set_trace()

    out_dict = {
        "state": state,
        "rgb": new_img_dict["rgb"],
    }

    if "depth" in new_img_dict:
        out_dict["depth"] = new_img_dict["depth"]


    return out_dict

h5_dataset = load_demo_dataset(input_h5_path, concat=False)

obs_process_fn = partial(
    convert_obs,
    concat_fn=partial(np.concatenate, axis=-1),
    transpose_fn=partial(
        np.transpose, axes=(0, 3, 1, 2)
    ),  # (B, H, W, C) -> (B, C, H, W)
    state_obs_extractor=lambda obs: list(obs["agent"].values()) + list(obs["extra"].values()),
    depth = "rgbd" in str(input_h5_path)
)



Loading HDF5 file /home/jstaley_theaiinstitute_com/.maniskill/demos/PickCube-v1/motionplanning/trajectory.rgb.pd_ee_delta_pos.physx_cpu.h5
Loaded
Load observations 1000 <class 'dict'>
Load actions 1000 <class 'numpy.ndarray'>
Load success 1000 <class 'numpy.ndarray'>


In [94]:
obs_traj_dict_list = []
for obs_traj_dict in h5_dataset["observations"]:
    obs_traj_dict_list.append(obs_process_fn(obs_traj_dict))

obs_traj_dict_list[0]['state'].shape, obs_traj_dict_list[0]['rgb'].shape, h5_dataset['actions'][0].shape, h5_dataset['success'][0].shape

((75, 29), (75, 3, 128, 128), (74, 4), (74,))

In [89]:
traj_idx = 0
for observations, actions, success in zip(obs_traj_dict_list, h5_dataset['actions'], h5_dataset['success']):
    images, states = observations['rgb'][:-1], observations['state'][:-1]
    print(f'ep{traj_idx:04d}', end=' ')
    assert len(images) == len(states); assert len(states) == len(actions); assert len(actions) == len(success) 

    num_steps = len(images)

    # Create episodic flags
    is_first = np.zeros(num_steps, dtype=bool)
    is_last = np.zeros(num_steps, dtype=bool)
    is_first[0] = True
    is_last[-1] = True

    # is_terminal field
    # The episode is terminal if 'terminated' is True for the last step
    is_terminal = np.zeros(num_steps, dtype=bool)
    is_terminal[-1] = True

    # Discount (typically 1.0 until the end, then 0.0 if terminal)
    discount = np.ones(num_steps, dtype=np.float32)
    if is_terminal[-1]:
        discount[-1] = 0.0

    rewards = np.zeros(num_steps, dtype=np.float32)
    if success[-1]:
        rewards[-1] = 1.0
    print(f'{rewards[-1]:+1.1f}', end=' ')

    # Prepare data for NPZ
    npz_data = {
        'image': images,
        'state': states,
        'reward': rewards,
        'is_first': is_first,
        'is_last': is_last,
        'is_terminal': is_terminal,
        'discount': discount,
        'action': actions,
    }

    # Save to .npz file
    output_npz_path = os.path.join(output_dir, f'trajectory_{traj_idx:04d}-{num_steps}.npz')
    np.savez_compressed(output_npz_path, **npz_data)
    print(f"saving out to ", output_npz_path)
    traj_idx+=1

ep0000 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0000-74.npz
ep0001 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0001-74.npz
ep0002 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0002-79.npz
ep0003 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0003-68.npz
ep0004 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0004-77.npz
ep0005 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0005-65.npz
ep0006 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0006-61.npz
ep0007 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0007-68.npz
ep0008 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz_demos/trajectory_0008-82.npz
ep0009 +1.0 saving out to  /home/jstaley_theaiinstitute_com/converted_npz