# AgiBot World Diffusion Policy Training Demo

This notebook demonstrates how to use **AgiBotWorldDataset** to run an offline training workflow.
Make sure you have installed all necessary packages before running.


In [1]:
# =============================================
# 1. Imports and Parameter Settings
# =============================================
import torch
import numpy as np

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

# Parameters
FPS = 30
TASK_ID = 327
training_steps = 5000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
dataset_path = "/home/ubuntu/lerobot_data/"
# output_path = "/path/to/save/your/checkpoint"

In [None]:
# =============================================
# 2. Dataset Setup
# =============================================
observation_idx = np.array([-1, 0])
action_idx = np.arange(-1, 15)
repo_id = f"agibotworld/task_{TASK_ID}"

delta_timestamps = {
    "observation.images.top_head": (observation_idx / FPS).tolist(),
    "observation.state": (observation_idx / FPS).tolist(),
    "action": (action_idx / FPS).tolist(),
}

dataset = LeRobotDataset(
    repo_id=repo_id,
    root=f"{dataset_path}/{repo_id}",
    delta_timestamps=delta_timestamps,
    local_files_only=True
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,
    batch_size=64,
    shuffle=True,
    pin_memory=(device.type == "cuda"),
    drop_last=True,
)

If you want to train one robot policy model to master multiple distinct skills, you can use ’MultiLeRobotDataset‘ to load datasets for various tasks into a unified training process.

In [None]:
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
repo_ids = [f"agibotworld/{path.name}" for path in Path(dataset_path).glob("agibotworld/task_*")]
multi_dataset = MultiLeRobotDataset(
    repo_ids=repo_ids,
    root=dataset_path,
    delta_timestamps=delta_timestamps,
    local_files_only=True
)

Let's kick off a simple training with Diffusion Policy:

In [None]:
# =============================================
# 3. Policy Configuration and Initialization
# =============================================
cfg = DiffusionConfig()
cfg.input_shapes = {
    "observation.images.top_head": [3, 480, 640],
    "observation.state": [20],
}
cfg.input_normalization_modes = {
    "observation.images.top_head": "mean_std",
    "observation.state": "min_max",
}
cfg.output_shapes = {
    "action": [22],
}

policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
#policy = DiffusionPolicy(cfg, dataset_stats=multi_dataset.stats)
policy.train()
policy.to(device)

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

In [None]:
# =============================================
# 4. Training Loop
# =============================================
step = 0
done = False

while not done:
    for batch in dataloader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        output_dict = policy.forward(batch)
        loss = output_dict["loss"]
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        print(f"Step {step}, Loss: {loss.item():.3f}")
        step += 1
        
        if step >= training_steps:
            done = True
            break


In [None]:
# =============================================
# 5. Save Policy Checkpoint
# =============================================
policy.save_pretrained(output_path)
print(f"Model saved to {output_path}")


Congrats! Now please feel free to explore the AgiBot World!