In [34]:
import sys
sys.path.append('../')

%load_ext autoreload
%autoreload 2

from custom_datasets.rollout_push_any import RolloutPushAnyDataset
from lerobot.rollout_datasets.episode_stores import EpisodeVideoStore

import re
from pathlib import Path
import zarr
import numpy as np
import torch
import torch.nn.functional as F
import imageio
from IPython.display import Video

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
data_path = "/home/sm/Datasets/lerobot/pusht"
output_directory = Path(data_path)
output_directory.mkdir(parents=True, exist_ok=True)

video_directory = output_directory / 'videos'
video_directory.mkdir(parents=True, exist_ok=True)

In [9]:
mode = 'r'
root_group = zarr.open(data_path, mode=mode)
data_group = root_group['data']

In [10]:
store = EpisodeVideoStore.create_from_path(output_directory)

Connected to the exising zarr


In [11]:
dataset = store.convert_to_lerobot_dataset()

In [62]:
def get_from_to_idx(dataset, episode_index):
    from_idx = dataset.episode_data_index["from"][episode_index].item()
    to_idx = dataset.episode_data_index["to"][episode_index].item()
    return from_idx, to_idx

def read_frames(dataset, episode_index):
    from_idx, to_idx = get_from_to_idx(dataset, episode_index)
    
    print(f"episode {episode_index} start from index {from_idx} to index {to_idx}")
    
    frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
    
    return torch.stack(frames, dim=0) # list of tensor

def read_frames_as_image_numpy(dataset, episode_index):
    frames = read_frames(dataset, episode_index)
    frames_numpy = torch_frames_to_video(frames)
    frame_list = [frame for frame in frames_numpy]
    return frame_list
    
def torch_frames_to_video(frames):
    # Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
    # them, we convert to uint8 in range [0,255]
    frames = frames.permute(0, 2, 3, 1)
    return (frames * 255).type(torch.uint8).numpy()

def resize_image_tensor(obs):
    assert len(obs.size()) == 4  # T, C, H, W
    obs = F.interpolate(obs, size=[224, 224], mode='bilinear')
    return obs

# Create Zarr group

In [43]:
def extract_episode_code(path):
    match = re.search(r'episode_(\w+)\.mp4$', path)
    return match.group(1)

In [45]:
def get_image_with_path_and_episode_index(dataset, episode_idx):
    from_idx, to_idx = get_from_to_idx(dataset, episode_idx)
    path = dataset.hf_dataset[from_idx]['observation.image']['path']
    episode_code = extract_episode_code(path)
    frames = read_frames(dataset, episode_idx)
    return {'episode_code': episode_code, 'frames': frames}

In [46]:
episode_idx = 0
output = get_image_with_path_and_episode_index(dataset, episode_idx)
print(output['episode_code'])

episode 0 start from index 0 to index 271
tvox3kNM
[tensor([[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 0.9843, 0.8510,  ..., 0.8353, 1.0000, 1.0000],
         [1.0000, 0.8510, 0.8118,  ..., 0.8196, 0.8471, 1.0000],
         ...,
         [1.0000, 0.8431, 0.8196,  ..., 0.8196, 0.8549, 1.0000],
         [1.0000, 1.0000, 0.8549,  ..., 0.8549, 1.0000, 1.0000],
         [1.0000, 0.9882, 1.0000,  ..., 0.9922, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 0.9843, 0.8510,  ..., 0.8353, 1.0000, 1.0000],
         [1.0000, 0.8510, 0.8118,  ..., 0.8196, 0.8471, 1.0000],
         ...,
         [1.0000, 0.8431, 0.8196,  ..., 0.8196, 0.8549, 1.0000],
         [1.0000, 1.0000, 0.8549,  ..., 0.8549, 1.0000, 1.0000],
         [1.0000, 0.9882, 1.0000,  ..., 0.9922, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 0.9843, 0.8510,  ..., 0.8353, 1.0000, 1.0000],
      

In [24]:
new_data_path = '/home/sm/Datasets/lerobot/pusht_tmp'
data_directory = Path(new_data_path)
data_directory.mkdir(parents=True, exist_ok=True)

In [26]:
image_data_root = zarr.open(new_data_path, mode='w')
frame_group = image_data_root.create_group('frames')

In [52]:
dataset[0]

{'action': tensor([274.7782, 247.4380]),
 'episode_index': tensor(0),
 'frame_index': tensor(0),
 'next.done': tensor(False),
 'next.reward': tensor(0.2579),
 'next.success': tensor(False),
 'observation.image': tensor([[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9843, 0.8510,  ..., 0.8353, 1.0000, 1.0000],
          [1.0000, 0.8510, 0.8118,  ..., 0.8196, 0.8471, 1.0000],
          ...,
          [1.0000, 0.8431, 0.8196,  ..., 0.8196, 0.8549, 1.0000],
          [1.0000, 1.0000, 0.8549,  ..., 0.8549, 1.0000, 1.0000],
          [1.0000, 0.9882, 1.0000,  ..., 0.9922, 1.0000, 1.0000]],
 
         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9843, 0.8510,  ..., 0.8353, 1.0000, 1.0000],
          [1.0000, 0.8510, 0.8118,  ..., 0.8196, 0.8471, 1.0000],
          ...,
          [1.0000, 0.8431, 0.8196,  ..., 0.8196, 0.8549, 1.0000],
          [1.0000, 1.0000, 0.8549,  ..., 0.8549, 1.0000, 1.0000],
          [1.0000, 0.9882, 1.0

In [None]:
for ep_idx in range(dataset.num_episodes):
    print(ep_idx)
    frames = read_frames_as_image_numpy(dataset, ep_idx)
    episode_group = frame_group.create_group(f"episode_{ep_idx}")
    episode_group.create_dataset('frames', data=frames, chunks=(1, 3, 224, 224), dtype='uint8', overwrite=True)

0
episode 0 start from index 0 to index 271
1
episode 1 start from index 271 to index 465
2
episode 2 start from index 465 to index 728
3
episode 3 start from index 728 to index 936
4
episode 4 start from index 936 to index 1236
5
episode 5 start from index 1236 to index 1536
6
episode 6 start from index 1536 to index 1691
7
episode 7 start from index 1691 to index 1876
8
episode 8 start from index 1876 to index 2176
9
episode 9 start from index 2176 to index 2476
10
episode 10 start from index 2476 to index 2726
11
episode 11 start from index 2726 to index 3026
12
episode 12 start from index 3026 to index 3154
13
episode 13 start from index 3154 to index 3306
14
episode 14 start from index 3306 to index 3606
15
episode 15 start from index 3606 to index 3906
16
episode 16 start from index 3906 to index 4033
17
episode 17 start from index 4033 to index 4249
18
episode 18 start from index 4249 to index 4394
19
episode 19 start from index 4394 to index 4528
20
episode 20 start from index 