# Train Policy

## Introduction

This notebook is based on:
https://github.com/huggingface/lerobot/blob/main/examples/3_train_policy.py

## Preparation

In [1]:
from pathlib import Path

import gym_pusht  # noqa: F401
import torch

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.configs.types import FeatureType

In [2]:
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# Select your device
device = torch.device("mps")

# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
log_freq = 1

# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
#   - input/output shapes: to properly size the policy
#   - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)

# We can now instantiate our policy with this config and the dataset stats.
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)



DiffusionPolicy(
  (normalize_inputs): Normalize(
    (buffer_observation_image): ParameterDict(
        (mean): Parameter containing: [torch.mps.FloatTensor of size 3x1x1]
        (std): Parameter containing: [torch.mps.FloatTensor of size 3x1x1]
    )
    (buffer_observation_state): ParameterDict(
        (max): Parameter containing: [torch.mps.FloatTensor of size 2]
        (min): Parameter containing: [torch.mps.FloatTensor of size 2]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.mps.FloatTensor of size 2]
        (min): Parameter containing: [torch.mps.FloatTensor of size 2]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.mps.FloatTensor of size 2]
        (min): Parameter containing: [torch.mps.FloatTensor of size 2]
    )
  )
  (diffusion): DiffusionModel(
    (rgb_encoder): DiffusionRgbEncoder(
      (center_crop): CenterC

In [3]:
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
# which can differ for inputs, outputs and rewards (if there are some).
delta_timestamps = {
    "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}

# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
assert delta_timestamps == {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.image": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}

# We can then instantiate the dataset with these delta_timestamps configuration.
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

Resolving data files:   0%|          | 0/206 [00:00<?, ?it/s]

In [4]:
# Then we create our optimizer and dataloader for offline training.
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device.type != "cpu",
    drop_last=True,
)

## Training

In [5]:
# Run training loop.
step = 0
done = False
while not done:
    for batch in dataloader:
        batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
        loss, _ = policy.forward(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % log_freq == 0:
            print(f"step: {step} loss: {loss.item():.3f}")
        step += 1
        if step >= training_steps:
            done = True
            break

# Save a policy checkpoint.
policy.save_pretrained(output_directory)

step: 0 loss: 1.070
step: 1 loss: 6.440
step: 2 loss: 1.135
step: 3 loss: 1.418
step: 4 loss: 1.061
step: 5 loss: 0.978
step: 6 loss: 1.109
step: 7 loss: 1.026
step: 8 loss: 0.976
step: 9 loss: 0.948
step: 10 loss: 0.974
step: 11 loss: 1.041
step: 12 loss: 0.991
step: 13 loss: 0.982
step: 14 loss: 0.932
step: 15 loss: 0.900
step: 16 loss: 0.906
step: 17 loss: 0.902
step: 18 loss: 0.939
step: 19 loss: 0.886
step: 20 loss: 0.838
step: 21 loss: 0.807
step: 22 loss: 0.759
step: 23 loss: 0.642
step: 24 loss: 0.557
step: 25 loss: 0.488
step: 26 loss: 0.473
step: 27 loss: 0.439
step: 28 loss: 0.414
step: 29 loss: 0.333
step: 30 loss: 0.390
step: 31 loss: 0.339
step: 32 loss: 0.304
step: 33 loss: 0.230
step: 34 loss: 0.272
step: 35 loss: 0.248
step: 36 loss: 0.214
step: 37 loss: 0.231
step: 38 loss: 0.198
step: 39 loss: 0.225
step: 40 loss: 0.233
step: 41 loss: 0.223
step: 42 loss: 0.218
step: 43 loss: 0.197
step: 44 loss: 0.198
step: 45 loss: 0.193
step: 46 loss: 0.248
step: 47 loss: 0.188
st