In [1]:
import sys
import os
sys.path.append('../')

In [2]:
from pathlib import Path
from pprint import pprint

import imageio
import torch

import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print("List of available datasets:")
pprint(lerobot.available_datasets)

List of available datasets:
['lerobot/aloha_sim_insertion_human',
 'lerobot/aloha_sim_insertion_scripted',
 'lerobot/aloha_sim_transfer_cube_human',
 'lerobot/aloha_sim_transfer_cube_scripted',
 'lerobot/aloha_sim_insertion_human_image',
 'lerobot/aloha_sim_insertion_scripted_image',
 'lerobot/aloha_sim_transfer_cube_human_image',
 'lerobot/aloha_sim_transfer_cube_scripted_image',
 'lerobot/pusht',
 'lerobot/pusht_image',
 'lerobot/xarm_lift_medium',
 'lerobot/xarm_lift_medium_replay',
 'lerobot/xarm_push_medium',
 'lerobot/xarm_push_medium_replay',
 'lerobot/xarm_lift_medium_image',
 'lerobot/xarm_lift_medium_replay_image',
 'lerobot/xarm_push_medium_image',
 'lerobot/xarm_push_medium_replay_image',
 'lerobot/aloha_static_battery',
 'lerobot/aloha_static_candy',
 'lerobot/aloha_static_coffee',
 'lerobot/aloha_static_coffee_new',
 'lerobot/aloha_static_cups_open',
 'lerobot/aloha_static_fork_pick_up',
 'lerobot/aloha_static_pingpong_test',
 'lerobot/aloha_static_pro_pencil',
 'lerobot/

In [4]:
repo_id = 'lerobot/pusht_image'
dataset = LeRobotDataset(repo_id)

Downloading readme: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 733B/s]
Downloading data: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 32.0M/32.0M [00:01<00:00, 19.5MB/s]
Generating train split: 25650 examples [00:00, 298149.31 examples/s]


In [20]:
print(dataset)
print(dataset.hf_dataset)
print(dataset.image_transforms)

LeRobotDataset(
  Repository ID: 'lerobot/pusht_image',
  Split: 'train',
  Number of Samples: 25650,
  Number of Episodes: 206,
  Type: image (.png),
  Recorded Frames per Second: 10,
  Camera Keys: ['observation.image'],
  Video Frame Keys: N/A,
  Transformations: None,
  Codebase Version: v1.6,
)
Dataset({
    features: ['observation.image', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.reward', 'next.done', 'next.success', 'index'],
    num_rows: 25650
})
None


In [6]:
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")


average number of frames per episode: 124.515
frames per second used during data collection: dataset.fps=10
keys to access images from cameras: dataset.camera_keys=['observation.image']



In [7]:
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()

print(f"episode {episode_index} start from index {from_idx} to index {to_idx}")

episode 0 start from index 0 to index 161


In [15]:
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]

# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]

In [9]:
# Finally, we save the frames to a mp4 video for visualization.
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
video_path = f"outputs/examples/1_load_lerobot_dataset/episode_{episode_index}.mp4"
imageio.mimsave(video_path, frames, fps=dataset.fps)

In [10]:
from IPython.display import Video

# 비디오 표시
Video(video_path, embed=True, width=640, height=360)

In [32]:
record = dataset[1]

for k, v in record.items():
    print(f"{k}'s type: {type(dataset[k]), len(dataset[k])}, {k}'s record type: {type(v)}, {v}")

observation.image's type: (<class 'list'>, 25650), observation.image's record type: <class 'torch.Tensor'>, tensor([[[1.0000, 0.9725, 0.9725,  ..., 0.9725, 0.9725, 1.0000],
         [0.9725, 0.8706, 0.9137,  ..., 0.9137, 0.8706, 0.9725],
         [0.9686, 0.9137, 1.0000,  ..., 1.0000, 0.9137, 0.9686],
         ...,
         [0.9686, 0.9137, 1.0000,  ..., 1.0000, 0.9137, 0.9686],
         [0.9725, 0.8706, 0.9137,  ..., 0.9137, 0.8706, 0.9725],
         [1.0000, 0.9725, 0.9725,  ..., 0.9725, 0.9725, 1.0000]],

        [[1.0000, 0.9725, 0.9725,  ..., 0.9725, 0.9725, 1.0000],
         [0.9725, 0.8706, 0.9137,  ..., 0.9137, 0.8706, 0.9725],
         [0.9686, 0.9137, 1.0000,  ..., 1.0000, 0.9137, 0.9686],
         ...,
         [0.9686, 0.9137, 1.0000,  ..., 1.0000, 0.9137, 0.9686],
         [0.9725, 0.8706, 0.9137,  ..., 0.9137, 0.8706, 0.9725],
         [1.0000, 0.9725, 0.9725,  ..., 0.9725, 0.9725, 1.0000]],

        [[1.0000, 0.9725, 0.9725,  ..., 0.9725, 0.9725, 1.0000],
         [0.972

# Example Data Processing

For many machine learning applications we need to load the history of past observations or trajectories of
future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
differences with the current loaded frame. For instance:

In [14]:
delta_timestamps = {
    # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
    "observation.image": [-1, -0.5, -0.20, 0],
    # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
    "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
    # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
    "action": [t / dataset.fps for t in range(64)],
}

In [15]:
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0]['observation.image'].shape=}")  # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}")  # (8,c)
print(f"{dataset[0]['action'].shape=}\n")  # (64,c)

Fetching 806 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 806/806 [00:00<00:00, 22509.33it/s]


dataset[0]['observation.image'].shape=torch.Size([4, 3, 84, 84])
dataset[0]['observation.state'].shape=torch.Size([8, 4])
dataset[0]['action'].shape=torch.Size([64, 4])






In [16]:
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=32,
    shuffle=True,
)
for batch in dataloader:
    print(f"{batch['observation.image'].shape=}")  # (32,4,c,h,w)
    print(f"{batch['observation.state'].shape=}")  # (32,8,c)
    print(f"{batch['action'].shape=}")  # (32,64,c)
    break

batch['observation.image'].shape=torch.Size([32, 4, 3, 84, 84])
batch['observation.state'].shape=torch.Size([32, 8, 4])
batch['action'].shape=torch.Size([32, 64, 4])
