In [16]:
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
import torch

In [17]:
# Set up the dataset.
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.image": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 418 files:   0%|          | 0/418 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/206 [00:00<?, ?it/s]

In [18]:
device = torch.device("mps")
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)


In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    # num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device != torch.device("cpu"),
    drop_last=True,
)

In [9]:
policy

DiffusionPolicy(
  (normalize_inputs): Normalize(
    (buffer_observation_image): ParameterDict(
        (mean): Parameter containing: [torch.mps.FloatTensor of size 3x1x1]
        (std): Parameter containing: [torch.mps.FloatTensor of size 3x1x1]
    )
    (buffer_observation_state): ParameterDict(
        (max): Parameter containing: [torch.mps.FloatTensor of size 2]
        (min): Parameter containing: [torch.mps.FloatTensor of size 2]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.mps.FloatTensor of size 2]
        (min): Parameter containing: [torch.mps.FloatTensor of size 2]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.mps.FloatTensor of size 2]
        (min): Parameter containing: [torch.mps.FloatTensor of size 2]
    )
  )
  (diffusion): DiffusionModel(
    (rgb_encoder): DiffusionRgbEncoder(
      (center_crop): CenterC

In [12]:
dataset[0]

{'observation.image': tensor([[[[0.9922, 0.9647, 0.9647,  ..., 0.9647, 0.9647, 0.9922],
           [0.9647, 0.9059, 0.9059,  ..., 0.9059, 0.9059, 0.9647],
           [0.9647, 0.9059, 0.9647,  ..., 0.9922, 0.9059, 0.9647],
           ...,
           [0.9647, 0.9059, 0.9922,  ..., 0.9922, 0.9059, 0.9647],
           [0.9647, 0.9059, 0.9059,  ..., 0.9059, 0.9059, 0.9647],
           [0.9922, 0.9647, 0.9647,  ..., 0.9647, 0.9647, 0.9922]],
 
          [[0.9922, 0.9647, 0.9647,  ..., 0.9647, 0.9647, 0.9922],
           [0.9647, 0.9059, 0.9059,  ..., 0.9059, 0.9059, 0.9647],
           [0.9647, 0.9059, 0.9647,  ..., 0.9922, 0.9059, 0.9647],
           ...,
           [0.9647, 0.9059, 0.9922,  ..., 0.9922, 0.9059, 0.9647],
           [0.9647, 0.9059, 0.9059,  ..., 0.9059, 0.9059, 0.9647],
           [0.9922, 0.9647, 0.9647,  ..., 0.9647, 0.9647, 0.9922]],
 
          [[0.9922, 0.9647, 0.9647,  ..., 0.9647, 0.9647, 0.9922],
           [0.9647, 0.9059, 0.9059,  ..., 0.9059, 0.9059, 0.9647],
   

In [20]:
training_steps = 5000

step = 0
done = False
while not done:
    for batch in dataloader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        output_dict = policy.forward(batch)
        loss = output_dict["loss"]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % 250 == 0:
            print(f"step: {step} loss: {loss.item():.3f}")
        step += 1
        if step >= training_steps:
            done = True
            break

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [14]:
loss

tensor(1.2831, device='mps:0', grad_fn=<MeanBackward0>)

In [15]:
step

0

In [21]:
policy.tags

AttributeError: 'DiffusionPolicy' object has no attribute 'tags'