In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from pprint import pprint

import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.factory import make_dataset


from hydra import compose, initialize
from omegaconf import OmegaConf

# context initialization
with initialize(version_base=None, config_path="../configs", job_name="test_app"):
    cfg = compose(config_name="default")
    print(OmegaConf.to_yaml(cfg))

  from .autonotebook import tqdm as notebook_tqdm


data dir None
resume: false
device: cuda
use_amp: false
seed: 100000
dataset_repo_id: lerobot/pusht
video_backend: pyav
training:
  offline_steps: 200000
  num_workers: 4
  batch_size: 64
  eval_freq: 10000
  log_freq: 200
  save_checkpoint: true
  save_freq: 100000
  online_steps: 0
  online_rollout_n_episodes: 1
  online_rollout_batch_size: 1
  online_steps_between_rollouts: 1
  online_sampling_ratio: 0.5
  online_env_seed: null
  online_buffer_capacity: null
  online_buffer_seed_size: 0
  do_online_rollout_async: false
  image_transforms:
    enable: false
    max_num_transforms: 3
    random_order: false
    brightness:
      weight: 1
      min_max:
      - 0.8
      - 1.2
    contrast:
      weight: 1
      min_max:
      - 0.8
      - 1.2
    saturation:
      weight: 1
      min_max:
      - 0.5
      - 1.5
    hue:
      weight: 1
      min_max:
      - -0.05
      - 0.05
    sharpness:
      weight: 1
      min_max:
      - 0.8
      - 1.2
  grad_clip_norm: 10
  lr: 0.0001
  

In [3]:
# get the path to the dataset
import pandas as pd
import numpy as np
from pathlib import Path
env_name = 'pusht' # 'pinpad' # 'robosuite'

# base_path = Path(f"~/workspace/lerobot/local/{env_name}/original").expanduser()
# base_path = Path(f"~/workspace/fastrl/logs/HD_pinpad_four_1/a").expanduser()
imi = 50
AI = False
tdmpc = True

def get_files(env_name, imi, AI=False, tdmpc=False, resize=False):
    if tdmpc:
        bp = f"~/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_{imi}_sparse/"
        od = f"~/workspace/lerobot/local/{env_name}/tdmpc{imi}"
        assert not AI
    else:    
        if AI:
            bp = f"~/workspace/fastrl/logs/AD_pusht_{imi}/"
            od = f"~/workspace/lerobot/local/{env_name}/A{imi}"
        else:
            bp = f"~/workspace/fastrl/logs/HD_pusht_{imi}/"
            od = f"~/workspace/lerobot/local/{env_name}/{imi}"

    if resize:
        od = od + "_96x96"

    base_path = Path(bp).expanduser()
    out_dir = Path(od).expanduser()

    # print(base_path)
# list all the files in the dataset
    folders = list(base_path.glob("*"))

    files = []
    for f in folders:
        files.extend((base_path / f).glob("*"))
    return files, out_dir

files, out_dir = get_files(env_name, imi, AI=AI, tdmpc=tdmpc)

print(files)

# print the keys
data = np.load(files[0])
# convert to a dictionary NOTE: this is necessary to make the arrays writeable for some reason
data = dict(data)
for k,v in data.items():
    print(k, v.shape)

# print("Setting last is_terminal to true")
# data["is_terminal"][-1] = True; data['is_last'][-1] = True

[PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/76.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/15.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/57.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/23.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/65.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/47.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/84.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/10.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/transfered/13.npz'), PosixPath('/home/j/workspace/fastrl/logs/demonstrations/TDMPC_pusht_HD_50_sparse/

In [4]:
import tqdm
import torch
import einops
import shutil
from PIL import Image as PILImage
import cv2

from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.scripts.push_dataset_to_hub import save_meta_data
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.common.datasets.utils import hf_transform_to_torch
from datasets import Dataset, Features, Image, Sequence, Value

def to_hf_dataset(data_dict, video):
    features = {}

    if video:
        features["observation.image"] = VideoFrame()
    else:
        features["observation.image"] = Image()

    features["observation.state"] = Sequence(
        length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
    )
    features["action"] = Sequence(
        length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
    )
    features["episode_index"] = Value(dtype="int64", id=None)
    features["frame_index"] = Value(dtype="int64", id=None)
    features["timestamp"] = Value(dtype="float32", id=None)
    features["next.reward"] = Value(dtype="float32", id=None)
    features["next.done"] = Value(dtype="bool", id=None)
    features["index"] = Value(dtype="int64", id=None)
    # TODO(rcadene): add success
    # features["next.success"] = Value(dtype='bool', id=None)

    hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
    hf_dataset.set_transform(hf_transform_to_torch)
    return hf_dataset

def files_to_data_dict(files):
    data_dicts = []
    for data_fn in files:
        print(f"Processing {data_fn}", end='...')
        data = np.load(data_fn)
        data = dict(data); 
        data["is_terminal"][-1] = True
        data_dicts.append(data)
    print()
    big_data_dict = {}
    for k in data_dicts[0].keys():
        big_data_dict[k] = np.concatenate([d[k] for d in data_dicts], axis=0)
        print(k, big_data_dict[k].shape)
        # if 'reward' in big_data_dict:
        #     for kk in ['reward', 'is_terminal', 'is_last']:
        #         print(f"\t{kk} {sum(big_data_dict[kk])}", end='  ')
    return big_data_dict

# big_data_dict = files_to_data_dict(files)

In [5]:
# img = big_data_dict['image'][100]
# from matplotlib import pyplot as plt
# # img = np.random.random((64, 64, 3))
# for img in big_data_dict['image'][:100]:
#     plt.imshow(img, interpolation='nearest')
#     plt.show()


In [6]:
def fastrl_to_hf(big_data_dict, out_dir):
    video = False; fps = 10; video_path = None; debug = False
    ep_dicts = []
    episode_data_index = {"from": [], "to": []}

    id_from = 0
    id_to = 0
    ep_idx = 0
    data = big_data_dict
    total_frames = data["action"].shape[0]
# for i in tqdm.tqdm(range(total_frames)):
    for i in range(total_frames):
        id_to += 1

        if not data["is_terminal"][i]:
            continue

    # print("found terminal step")

        num_frames = id_to - id_from

        image = torch.tensor(data["image"][id_from:id_to])
    # image = einops.rearrange(image, "b h w c -> b h w c")
    # image = einops.rearrange(image, "b c h w -> b h w c")
        state = torch.tensor(data["state"][id_from:id_to, :2]) if ("state" in data) else torch.zeros(num_frames, 1)
    # state = torch.tensor(data["vector_state"][id_from:id_to]) if ("vector_state" in data) else torch.zeros(num_frames, 1)
        action = (torch.tensor(data["action"][id_from:id_to]) + 1) * 256
    # action = torch.tensor(data["action"][id_from:id_to])
    # TODO(rcadene): we have a missing last frame which is the observation when the env is done
    # it is critical to have this frame for tdmpc to predict a "done observation/state"
    # next_image = torch.tensor(data["next_observations"]["rgb"][id_from:id_to])
    # next_state = torch.tensor(data["next_observations"]["state"][id_from:id_to])
        next_reward = torch.tensor(data["reward"][id_from:id_to])
        next_done = torch.tensor(data["is_terminal"][id_from:id_to])

        ep_dict = {}

        imgs_array = [x.numpy() for x in image]
        img_key = "observation.image"
        if video:
        # save png images in temporary directory
            tmp_imgs_dir = out_dir / "tmp_images"
            tmp_imgs_dir.mkdir(parents=True, exist_ok=True)

            for i in range(len(imgs_array)):
                img = PILImage.fromarray(imgs_array[i])
                img.save(str(tmp_imgs_dir / f"frame_{i:06d}.png"), quality=100)

        # encode images to a mp4 video
            fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
            video_path = out_dir / "videos" / fname
            encode_video_frames(tmp_imgs_dir, video_path, fps)

        # clean temporary images directory
            shutil.rmtree(tmp_imgs_dir)

        # store the reference to the video frame
            ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
        else:
        # ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
            ep_dict[img_key] = imgs_array

        ep_dict["observation.state"] = state
        ep_dict["action"] = action
        ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
        ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
        ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
    # ep_dict["next.observation.image"] = next_image
    # ep_dict["next.observation.state"] = next_state
        ep_dict["next.reward"] = next_reward
        ep_dict["next.done"] = next_done
        ep_dicts.append(ep_dict)

        episode_data_index["from"].append(id_from)
        episode_data_index["to"].append(id_from + num_frames)

        id_from = id_to
        ep_idx += 1

    # process first episode only
        if debug:
            break
    if len(ep_dicts) == 0:
        print("No terminal step found in the dataset")
    else:
        for k,v in ep_dicts[0].items():
            print(k, ep_dicts[0][k].shape if hasattr(ep_dicts[0][k], 'shape') else len(ep_dicts[0][k]), ep_dicts[-1][k].shape if hasattr(ep_dicts[-1][k], 'shape') else len(ep_dicts[-1][k]))

        # convert things to
        data_dict = concatenate_episodes(ep_dicts)
        data_dict, episode_data_index

        for k,v in data_dict.items():
            print(k, v.shape if hasattr(v, 'shape') else len(v), type(v))

        hf_dataset = to_hf_dataset(data_dict, video)

        info = {"fps": fps, "video": video}

        if video_path: 
            print(f"video path: {video_path}")
        lerobot_dataset = LeRobotDataset.from_preloaded(
            repo_id=env_name,
            hf_dataset=hf_dataset,
            episode_data_index=episode_data_index,
            info=info,
            videos_dir=video_path,
            )


        hf_dataset = hf_dataset.with_format(None)  # to remove transforms that cant be saved
        hf_dataset.save_to_disk(str(out_dir / "train"))
    # print(lerobot_dataset)
    stats = compute_stats(lerobot_dataset, batch_size=16, num_workers=1)
    save_meta_data(info, stats, episode_data_index, out_dir / "meta_data")
    return stats

# stats = fastrl_to_hf(big_data_dict, out_dir)


In [7]:
# Conversion from tdmpc to fastrl format does not work as of 10/6/24
# # imis = [4, 5, 6, 7, 9, 10] if AI else [3,4,5,6,7,8,9,10]
# imis = [50, 51, 52]
# for ai_tag in [False]:
#     for imi in imis:
#         files, out_dir = get_files(env_name, imi, AI=ai_tag, tdmpc=True)
#         if files:
#             big_data_dict = files_to_data_dict(files)
#             print(f"Attempting to write to {out_dir}")
#             stats = fastrl_to_hf(big_data_dict, out_dir)
#             # for k,v in stats.items():
#             #     print(k, v)
#         else: print(f"Could not find files for {imi} AI {ai_tag}")

In [8]:
import cv2
import numpy as np

def resize_images(bdd):
    for k in ['pixels', 'image']:
        if k in bdd:
            print(f"Original {k} shape:", bdd[k].shape)

            # Reshape if necessary (assuming the images are in NHWC format)
            if bdd[k].shape[-1] != 3:
                bdd[k] = np.transpose(bdd[k], (0, 2, 3, 1))
        
            # Get the original dimensions
            n, h, w, c = bdd[k].shape
        
            # Resize to 96x96
            resized = np.zeros((n, 96, 96, c), dtype=bdd[k].dtype)
            for i in range(n):
                resized[i] = cv2.resize(bdd[k][i], (96, 96), interpolation=cv2.INTER_CUBIC)
        
        # Update the dictionary with resized images
            bdd[k] = resized
            
            print(f"Resized {k} shape:", bdd[k].shape)
    else:
        print(f"Key '{k}' not found in big_data_dict")

In [9]:
# imis = [4, 5, 6, 7, 9, 10] if AI else [3,4,5,6,7,8,9,10]
RESIZE_TO_96x96 = True
imis = [11] #[11, 12, 13, 14]
for ai_tag in [True, False]:
    for imi in imis:
        files, out_dir = get_files(env_name, imi, AI=ai_tag, resize=RESIZE_TO_96x96)
        if files:
            big_data_dict = files_to_data_dict(files)
            if RESIZE_TO_96x96: resize_images(big_data_dict)
            print(f"Attempting to write to {out_dir}")
            stats = fastrl_to_hf(big_data_dict, out_dir)
            # for k,v in stats.items():
            #     print(k, v)
        else: print(f"Could not find files for {imi} AI {ai_tag}")

Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001356-3afa760b6d1a4f09aaf54436128e8e4d-34.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001352-3ca3766c9d134f9dabf263e4420ef287-52.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001308-72e18d71415f40be8591b0e93edb4b80-74.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001357-568d705091b546388984bd4c0ef4a055-89.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001343-0fa8ffee30f343518e16a5b4eb5f46e2-59.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001304-b850f91462d34e81b8f0357bc800698d-21.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001355-2f2e9fca9348430f8719234d9336cdd0-138.npz...Processing /home/j/workspace/fastrl/logs/AD_pusht_11/final_eps/20241007T001322-6d9c21de25c041cba56c330a84e07972-38.npz...Processing /home/j/workspace/fa

Saving the dataset (1/1 shards): 100%|██████████| 9393/9393 [00:00<00:00, 114990.42 examples/s]
Compute mean, min, max: 100%|█████████▉| 587/588 [00:07<00:00, 78.04it/s]
Compute std: 100%|█████████▉| 587/588 [00:07<00:00, 82.22it/s]


Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T210226-636687f528644c57adba5dfdab10b902-145.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T211354-77d1fc1a5b4a4021878931891df041e4-177.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T211142-e3d88728976447dc80c072b7a9f7e3a2-98.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T211028-5af2fe7e1a304ea2a12b119855a8d854-159.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T211304-03ebc8b59fb74d72bacd0c9d0d5f0f22-188.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T210151-ff22daa185a2426e92414677c560bc5f-108.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/20241004T211540-adcf5507f1b743caaedbf3a4e6a5ce4f-105.npz...Processing /home/j/workspace/fastrl/logs/HD_pusht_11/10-04_20-55-32_2024/2024

Saving the dataset (1/1 shards): 100%|██████████| 10662/10662 [00:00<00:00, 101156.73 examples/s]
Compute mean, min, max: 100%|█████████▉| 666/667 [00:08<00:00, 79.50it/s]
Compute std: 100%|█████████▉| 666/667 [00:08<00:00, 80.54it/s]


In [8]:
video = False; fps = 20; video_path = None; debug = False
ep_dicts = []
episode_data_index = {"from": [], "to": []}

id_from = 0
id_to = 0
ep_idx = 0
data = big_data_dict
total_frames = data["action"].shape[0]
# for i in tqdm.tqdm(range(total_frames)):
for i in range(total_frames):
    id_to += 1

    if not data["is_terminal"][i]:
        continue

# print("found terminal step")

    num_frames = id_to - id_from

    image = torch.tensor(data["image"][id_from:id_to])
# image = einops.rearrange(image, "b h w c -> b h w c")
# image = einops.rearrange(image, "b c h w -> b h w c")
    state = torch.tensor(data["state"][id_from:id_to, :2]) if ("state" in data) else torch.zeros(num_frames, 1)
# state = torch.tensor(data["vector_state"][id_from:id_to]) if ("vector_state" in data) else torch.zeros(num_frames, 1)
    action = (torch.tensor(data["action"][id_from:id_to]) + 1) * 256
# action = torch.tensor(data["action"][id_from:id_to])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(data["next_observations"]["rgb"][id_from:id_to])
# next_state = torch.tensor(data["next_observations"]["state"][id_from:id_to])
    next_reward = torch.tensor(data["reward"][id_from:id_to])
    next_done = torch.tensor(data["is_terminal"][id_from:id_to])

    ep_dict = {}

    imgs_array = [x.numpy() for x in image]
    img_key = "observation.image"
    if video:
    # save png images in temporary directory
        tmp_imgs_dir = out_dir / "tmp_images"
        tmp_imgs_dir.mkdir(parents=True, exist_ok=True)

        for i in range(len(imgs_array)):
            img = PILImage.fromarray(imgs_array[i])
            img.save(str(tmp_imgs_dir / f"frame_{i:06d}.png"), quality=100)

    # encode images to a mp4 video
        fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
        video_path = out_dir / "videos" / fname
        encode_video_frames(tmp_imgs_dir, video_path, fps)

    # clean temporary images directory
        shutil.rmtree(tmp_imgs_dir)

    # store the reference to the video frame
        ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
    else:
    # ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
        ep_dict[img_key] = imgs_array

    ep_dict["observation.state"] = state
    ep_dict["action"] = action
    ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
    ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
    ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = next_image
# ep_dict["next.observation.state"] = next_state
    ep_dict["next.reward"] = next_reward
    ep_dict["next.done"] = next_done
    ep_dicts.append(ep_dict)

    episode_data_index["from"].append(id_from)
    episode_data_index["to"].append(id_from + num_frames)

    id_from = id_to
    ep_idx += 1

# process first episode only
    if debug:
        break
if len(ep_dicts) == 0:
    print("No terminal step found in the dataset")
else:
    for k,v in ep_dicts[0].items():
        print(k, ep_dicts[0][k].shape if hasattr(ep_dicts[0][k], 'shape') else len(ep_dicts[0][k]), ep_dicts[-1][k].shape if hasattr(ep_dicts[-1][k], 'shape') else len(ep_dicts[-1][k]))

    # convert things to
    data_dict = concatenate_episodes(ep_dicts)

observation.image 301 246
observation.state torch.Size([301, 2]) torch.Size([246, 2])
action torch.Size([301, 2]) torch.Size([246, 2])
episode_index torch.Size([301]) torch.Size([246])
frame_index torch.Size([301]) torch.Size([246])
timestamp torch.Size([301]) torch.Size([246])
next.reward torch.Size([301]) torch.Size([246])
next.done torch.Size([301]) torch.Size([246])


In [9]:
ep_dicts[0]['observation.image'][0].shape


for f,t in zip(episode_data_index['from'], episode_data_index['to']):
    print(f, t, data_dict['action'][f:t].shape)

0 301 torch.Size([301, 2])
301 589 torch.Size([288, 2])
589 820 torch.Size([231, 2])
820 1007 torch.Size([187, 2])
1007 1130 torch.Size([123, 2])
1130 1412 torch.Size([282, 2])
1412 1713 torch.Size([301, 2])
1713 2014 torch.Size([301, 2])
2014 2315 torch.Size([301, 2])
2315 2616 torch.Size([301, 2])
2616 2764 torch.Size([148, 2])
2764 2949 torch.Size([185, 2])
2949 3088 torch.Size([139, 2])
3088 3242 torch.Size([154, 2])
3242 3444 torch.Size([202, 2])
3444 3549 torch.Size([105, 2])
3549 3850 torch.Size([301, 2])
3850 4113 torch.Size([263, 2])
4113 4261 torch.Size([148, 2])
4261 4537 torch.Size([276, 2])
4537 4790 torch.Size([253, 2])
4790 5002 torch.Size([212, 2])
5002 5303 torch.Size([301, 2])
5303 5387 torch.Size([84, 2])
5387 5688 torch.Size([301, 2])
5688 5895 torch.Size([207, 2])
5895 6026 torch.Size([131, 2])
6026 6205 torch.Size([179, 2])
6205 6506 torch.Size([301, 2])
6506 6727 torch.Size([221, 2])
6727 6984 torch.Size([257, 2])
6984 7125 torch.Size([141, 2])
7125 7251 torch.Si