# AgiBot World Diffusion Policy Training Demo

This notebook demonstrates how to use **AgiBotWorldDataset** to run an offline training workflow.
Make sure you have installed all necessary packages before running.


In [1]:
from scripts.image import Image as ImageFeature
from datasets.features.features import register_feature
from datasets.features import Image as DeprecatedImageFeature

# replace image feature with new one
register_feature(ImageFeature, DeprecatedImageFeature.__name__)

Overwriting feature type 'Image' (Image -> Image)


In [2]:
from datasets import load_dataset
from pathlib import Path

TASK_ID = 362
repo_id = f"agibotworld/task_{TASK_ID}"
tgt_path = "/home/ubuntu/lerobot_data/"
dataset_root = f"{tgt_path}/{repo_id}"

dataset = load_dataset("parquet", data_dir=Path(dataset_root) / "data", split="train", streaming=True)

Resolving data files:   0%|          | 0/714 [00:00<?, ?it/s]

In [3]:
next(iter(dataset))

{'observation.images.cam_top_depth': <PIL.TiffImagePlugin.TiffImageFile image mode=F size=640x480 at 0x7883CAF8C3D0>,
 'observation.state': [-1.7821335792541504,
  0.4922644793987274,
  2.12530517578125,
  -1.1757111549377441,
  2.940516948699951,
  -0.885674774646759,
  -1.6433712244033813,
  1.3742921352386475,
  -0.8007281422615051,
  -1.2186731100082397,
  0.3783508837223053,
  -0.48504018783569336,
  -1.4048820734024048,
  -1.4678765535354614,
  34.93333435058594,
  35.022220611572266,
  0.0,
  0.5060499906539917,
  0.663100004196167,
  0.2749898433685303],
 'action': [-1.7821335792541504,
  0.4922644793987274,
  2.12530517578125,
  -1.1757111549377441,
  2.940516948699951,
  -0.885674774646759,
  -1.6433712244033813,
  1.3742921352386475,
  -0.8007281422615051,
  -1.2186731100082397,
  0.3783508837223053,
  -0.48504018783569336,
  -1.4048820734024048,
  -1.4678765535354614,
  0.0,
  0.0,
  0.0,
  0.5060499906539917,
  0.663100004196167,
  0.2749898433685303,
  0.0,
  0.0],
 'epis

In [7]:
dataset.features

{'observation.images.cam_top_depth': Image(mode=None, decode=True, id=None),
 'observation.state': Sequence(feature=Value(dtype='float32', id=None), length=20, id=None),
 'action': Sequence(feature=Value(dtype='float32', id=None), length=22, id=None),
 'episode_index': Value(dtype='int64', id=None),
 'frame_index': Value(dtype='int64', id=None),
 'index': Value(dtype='int64', id=None),
 'task_index': Value(dtype='int64', id=None),
 'timestamp': Value(dtype='float32', id=None)}

In [4]:
dataset

IterableDataset({
    features: ['observation.images.cam_top_depth', 'observation.state', 'action', 'episode_index', 'frame_index', 'index', 'task_index', 'timestamp'],
    num_shards: 714
})

In [5]:

from lerobot.common.datasets.utils import check_timestamps_sync
from lerobot.common.datasets.utils import get_episode_data_index
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata

meta = LeRobotDatasetMetadata(repo_id, dataset_root, True)
episode_data_index = get_episode_data_index(meta.episodes, None)
check_timestamps_sync(dataset, episode_data_index, 30, 1e-4)

Returning existing local_dir `/home/ubuntu/lerobot_data/agibotworld/task_362` as remote repo cannot be accessed in `snapshot_download` (None).


NotImplementedError: Subclasses of Dataset should implement __getitem__.

In [1]:
%load_ext autoreload
%autoreload 2

from scripts.convert_to_lerobot import AgiBotDataset

# Check if the dataset directory exists and has parquet files
parquet_files = list(Path(dataset_root).rglob("*.parquet"))
resume = 0

processed_episodes = len(parquet_files)
print(f"Found {processed_episodes} processed episodes. Resuming from episode {processed_episodes}.")
resume = processed_episodes

# Create or load the dataset
dataset = AgiBotDataset(
    repo_id=repo_id,
    root=dataset_root,  # Use dataset_root instead of tgt_path
    local_files_only=True,
    download_videos=False
)


NameError: name 'Path' is not defined

In [9]:
# Try a more controlled approach to access the dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

# Create a modified dataset that only loads specific features and frames
try:
    # Define delta timestamps that only load current frames (no history or future)
    delta_timestamps = {
        "observation.state": [0],
        "action": [0]
    }
    
    # Create the dataset with limited features
    modified_dataset = LeRobotDataset(
        repo_id=repo_id,
        root=dataset_root,
        delta_timestamps=delta_timestamps,  # Only load current frames
        download_videos=False,  # Don't try to redownload videos
        local_files_only=True
    )
    
    # Try to access the first item
    first_item = modified_dataset[0]
    print("Successfully accessed first item")
    print(f"Available keys: {list(first_item.keys())}")
    
    # Now try to load one video frame at a time
    for video_key in modified_dataset.meta.video_keys[:1]:  # Start with just the first video
        try:
            # Use the built-in methods to get a specific video frame
            frame = modified_dataset.get_frames(0, video_key, [0])[0]  # Get frame at timestamp 0
            print(f"Successfully loaded frame from {video_key}, shape: {frame.shape}")
        except Exception as e:
            print(f"Error loading {video_key}: {e}")
            import traceback
            traceback.print_exc()
    
except Exception as e:
    print(f"Error with modified dataset: {e}")
    import traceback
    traceback.print_exc()

Returning existing local_dir `/home/ubuntu/lerobot_data/agibotworld/task_327` as remote repo cannot be accessed in `snapshot_download` (None).
Returning existing local_dir `/home/ubuntu/lerobot_data/agibotworld/task_327` as remote repo cannot be accessed in `snapshot_download` (None).


Resolving data files:   0%|          | 0/69 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/169 [00:00<?, ?it/s]

Error with modified dataset: 'NoneType' object has no attribute 'seek'


Traceback (most recent call last):
  File "/tmp/ipykernel_7679/2651213137.py", line 22, in <module>
    first_item = modified_dataset[0]
  File "/home/ubuntu/.local/lib/python3.10/site-packages/lerobot/common/datasets/lerobot_dataset.py", line 645, in __getitem__
    item = self.hf_dataset[idx]
  File "/home/ubuntu/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2861, in __getitem__
    return self._getitem(key)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2846, in _getitem
    formatted_output = format_table(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 633, in format_table
    return formatter(pa_table, query_type=query_type)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/datasets/formatting/formatting.py", line 397, in __call__
    return self.format_row(pa_table)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/datasets/formatting/formatting.py", 

In [6]:
# =============================================
# 1. Imports and Parameter Settings
# =============================================
import torch
import numpy as np

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

# Parameters
FPS = 30
TASK_ID = 362
training_steps = 5000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
dataset_path = "/home/ubuntu/lerobot_data/"
# output_path = "/path/to/save/your/checkpoint"

In [None]:
# =============================================
# 2. Dataset Setup
# =============================================
observation_idx = np.array([-1, 0])
action_idx = np.arange(-1, 15)
repo_id = f"agibotworld/task_{TASK_ID}"

delta_timestamps = {
    "observation.images.top_head": (observation_idx / FPS).tolist(),
    "observation.state": (observation_idx / FPS).tolist(),
    "action": (action_idx / FPS).tolist(),
}

dataset = LeRobotDataset(
    repo_id=repo_id,
    root=f"{dataset_path}/{repo_id}",
    delta_timestamps=delta_timestamps,
    local_files_only=True
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=64,
    shuffle=True,
    pin_memory=(device.type == "cuda"),
    drop_last=True,
)

Returning existing local_dir `/home/ubuntu/lerobot_data/agibotworld/task_362` as remote repo cannot be accessed in `snapshot_download` (None).
Returning existing local_dir `/home/ubuntu/lerobot_data/agibotworld/task_362` as remote repo cannot be accessed in `snapshot_download` (None).


Resolving data files:   0%|          | 0/714 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/714 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

If you want to train one robot policy model to master multiple distinct skills, you can use ’MultiLeRobotDataset‘ to load datasets for various tasks into a unified training process.

In [None]:
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
repo_ids = [f"agibotworld/{path.name}" for path in Path(dataset_path).glob("agibotworld/task_*")]
multi_dataset = MultiLeRobotDataset(
    repo_ids=repo_ids,
    root=dataset_path,
    delta_timestamps=delta_timestamps,
    local_files_only=True
)

Let's kick off a simple training with Diffusion Policy:

In [None]:
# =============================================
# 3. Policy Configuration and Initialization
# =============================================
cfg = DiffusionConfig()
cfg.input_shapes = {
    "observation.images.top_head": [3, 480, 640],
    "observation.state": [20],
}
cfg.input_normalization_modes = {
    "observation.images.top_head": "mean_std",
    "observation.state": "min_max",
}
cfg.output_shapes = {
    "action": [22],
}

policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
#policy = DiffusionPolicy(cfg, dataset_stats=multi_dataset.stats)
policy.train()
policy.to(device)

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

In [None]:
# =============================================
# 4. Training Loop
# =============================================
step = 0
done = False

while not done:
    for batch in dataloader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        output_dict = policy.forward(batch)
        loss = output_dict["loss"]
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        print(f"Step {step}, Loss: {loss.item():.3f}")
        step += 1
        
        if step >= training_steps:
            done = True
            break


In [None]:
# =============================================
# 5. Save Policy Checkpoint
# =============================================
policy.save_pretrained(output_path)
print(f"Model saved to {output_path}")


Congrats! Now please feel free to explore the AgiBot World!