In [1]:
from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
output_directory = Path("outputs/train/example_pusht_diffusion2")
output_directory.mkdir(parents=True, exist_ok=True)

training_steps = 5000
device = torch.device("cuda")
log_freq = 250

In [3]:
# Set up the dataset.
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.image": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

Fetching 212 files: 100%|██████████| 212/212 [00:00<00:00, 16077.69it/s]


In [4]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device != torch.device("cpu"),
    drop_last=True,
)

In [14]:
batch = next(iter(dataloader))
for key, value in batch.items():
    print(f"{key}: {value.shape}")

observation.image: torch.Size([64, 2, 3, 96, 96])
observation.state: torch.Size([64, 2, 2])
action: torch.Size([64, 16, 2])
episode_index: torch.Size([64])
frame_index: torch.Size([64])
timestamp: torch.Size([64])
next.reward: torch.Size([64])
next.done: torch.Size([64])
next.success: torch.Size([64])
index: torch.Size([64])
observation.image_is_pad: torch.Size([64, 2])
observation.state_is_pad: torch.Size([64, 2])
action_is_pad: torch.Size([64, 16])


In [15]:
# Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
policy.train()
policy.to(device) 

DiffusionPolicy(
  (normalize_inputs): Normalize(
    (buffer_observation_image): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
    )
    (buffer_observation_state): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 2 (cuda:0)]
    )
  )
  (diffusion): Diff

In [16]:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

In [17]:
output_dict = policy.forward(batch)
loss = output_dict["loss"]

In [18]:
batch = policy.normalize_inputs(batch)
if len(policy.expected_image_keys) > 0:
    batch = dict(batch)  # shallow copy so that adding a key doesn't modify the original
    batch["observation.images"] = torch.stack([batch[k] for k in policy.expected_image_keys], dim=-4)

In [20]:
batch = policy.normalize_targets(batch)

In [21]:
batch.keys()

dict_keys(['observation.image', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.reward', 'next.done', 'next.success', 'index', 'observation.image_is_pad', 'observation.state_is_pad', 'action_is_pad', 'observation.images'])

In [22]:
loss = policy.diffusion.compute_loss(batch)
loss 

tensor(1.2612, device='cuda:0', grad_fn=<MeanBackward0>)