In [1]:
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
import torch
import numpy as np
device = torch.device("cuda")


In [2]:
# Set up the dataset.
delta_timestamps = {
    # Load the previous image and state at -0.1 seconds before current frame,
    # then load current image and state corresponding to 0.0 second.
    "observation.image": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    # Load the previous action (-0.1), the next action to be executed (0.0),
    # and 14 future actions with a 0.1 seconds spacing. All these actions will be
    # used to supervise the policy.
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 418 files:   0%|          | 0/418 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/206 [00:00<?, ?it/s]

In [3]:
cfg = DiffusionConfig()
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
policy.to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)


In [4]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device != torch.device("cpu"),
    drop_last=True,
)

In [6]:
training_steps = 5000

step = 0
done = False
while not done:
    for batch in dataloader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        output_dict = policy.forward(batch)
        loss = output_dict["loss"]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % 250 == 0:
            print(f"step: {step} loss: {loss.item():.3f}")
        step += 1
        if step >= training_steps:
            done = True
            break

  return F.conv2d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


step: 0 loss: 1.089
step: 250 loss: 0.056
step: 500 loss: 0.060
step: 750 loss: 0.046
step: 1000 loss: 0.084
step: 1250 loss: 0.044
step: 1500 loss: 0.059
step: 1750 loss: 0.031
step: 2000 loss: 0.053
step: 2250 loss: 0.029
step: 2500 loss: 0.064
step: 2750 loss: 0.043
step: 3000 loss: 0.033
step: 3250 loss: 0.031
step: 3500 loss: 0.041
step: 3750 loss: 0.051
step: 4000 loss: 0.040
step: 4250 loss: 0.045
step: 4500 loss: 0.040
step: 4750 loss: 0.042


In [7]:
import datetime
policy.save_pretrained(f"ckpts/pusht_diffusion{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")

In [None]:
policy = DiffusionPolicy.from_pretrained("ckpts/pusht_diffusion2025-01-09_22-37-03").to(device)


Loading weights from local directory


AttributeError: 'DiffusionPolicy' object has no attribute 'pin_memory'

In [4]:
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config
cfg = init_hydra_config(lerobot.__path__[0] + "/configs/env/pusht.yaml")
# env = make_env(cfg, n_envs=1)


In [5]:
cfg.env

{'name': 'pusht', 'task': 'PushT-v0', 'image_size': 96, 'state_dim': 2, 'action_dim': 2, 'fps': '${fps}', 'episode_length': 300, 'gym': {'obs_type': 'pixels_agent_pos', 'render_mode': 'rgb_array', 'visualization_width': 384, 'visualization_height': 384}}

In [6]:
import gymnasium as gym
import mediapy as media
import gym_pusht
# gym.envs.pprint_registry()
env = gym.make_vec('gym_pusht/PushT-v0', **dict(cfg.env.get("gym", {})))
obs, info = env.reset()
# obs['pixels'] = obs['pixels'][np.newaxis,...]


In [15]:
obs

OrderedDict([('agent_pos', array([[ 85., 209.]])),
             ('pixels',
              array([[[[255, 255, 255],
                       [248, 248, 248],
                       [248, 248, 248],
                       ...,
                       [248, 248, 248],
                       [248, 248, 248],
                       [255, 255, 255]],
              
                      [[248, 248, 248],
                       [222, 222, 222],
                       [233, 233, 233],
                       ...,
                       [233, 233, 233],
                       [222, 222, 222],
                       [248, 248, 248]],
              
                      [[247, 247, 247],
                       [233, 233, 233],
                       [255, 255, 255],
                       ...,
                       [255, 255, 255],
                       [233, 233, 233],
                       [247, 247, 247]],
              
                      ...,
              
                      [[247, 24

In [None]:
# Note: profiling needs to be run as root
import time

frames = []
obs, info = env.reset()
prep_obs = preprocess_observation(obs)
prep_obs = {key: prep_obs[key].to(device, non_blocking=True) for key in prep_obs}
frames.append(obs['pixels'][0])

terminated =  truncated = False
step = 0
while step < 16 and (not truncated or terminated):
    print(f"Step {step}")
    t0 = time.time()
    prep_obs = preprocess_observation(obs)
    prep_obs = {key: prep_obs[key].to(device, non_blocking=True) for key in prep_obs}
    t1 = time.time()
    print(f"Preprocessing: {t1-t0}s")
    # with record_function("model_inference"):
    with torch.inference_mode():
        action = policy.select_action(prep_obs )
    t2 = time.time()
    print(f"Inference: {t2-t1}s")

    obs, reward, terminated, truncated, info = env.step(action.cpu().numpy())
    t3 = time.time()
    print(f"Env step: {t3-t2}s")
    frames.append(obs['pixels'][0])
    step += 1
    # obs['pixels'] = obs['pixels'][np.newaxis,...]


Step 0
Preprocessing: 0.0018219947814941406s


STAGE:2025-01-14 17:35:12 571282:571282 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


Inference: 5.500082492828369s
Env step: 0.009531974792480469s
Step 1
Preprocessing: 0.0008671283721923828s
Inference: 0.0027136802673339844s
Env step: 0.012134075164794922s
Step 2
Preprocessing: 0.0010561943054199219s
Inference: 0.0034317970275878906s
Env step: 0.012519359588623047s
Step 3
Preprocessing: 0.0010619163513183594s
Inference: 0.003720521926879883s
Env step: 0.012822389602661133s
Step 4
Preprocessing: 0.0011029243469238281s
Inference: 0.003754138946533203s
Env step: 0.012717723846435547s
Step 5
Preprocessing: 0.0011451244354248047s
Inference: 0.0036971569061279297s
Env step: 0.012772083282470703s
Step 6
Preprocessing: 0.0010874271392822266s
Inference: 0.0036847591400146484s
Env step: 0.01273655891418457s
Step 7
Preprocessing: 0.0011467933654785156s
Inference: 0.0037162303924560547s
Env step: 0.012680530548095703s
Step 8
Preprocessing: 0.0010848045349121094s
Inference: 5.399359464645386s
Env step: 0.013236761093139648s
Step 9
Preprocessing: 0.0007200241088867188s
Inference: 0

STAGE:2025-01-14 17:35:24 571282:571282 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-01-14 17:35:24 571282:571282 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


In [37]:
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::lift_fresh         0.00%     161.000us         0.00%     161.000us       4.472us            36  
                        aten::permute         0.03%       3.134ms         0.04%       4.403ms      10.584us           416  
                     aten::as_strided         0.62%      68.105ms         0.62%      68.105ms       1.750us         38918  
                     aten::contiguous         0.01%       1.631ms         0.35%      37.968ms      90.833us           418  
                          aten::clone         0.04%       4.509ms         0.33%      36.403ms      87.089us           418  
        

In [25]:
next(policy.parameters()).device

device(type='cuda', index=0)

In [31]:
prof.key_averages()

[<FunctionEventAvg key=aten::lift_fresh self_cpu_time=143.000us cpu_time=3.972us  self_cuda_time=0.000us cuda_time=0.000us input_shapes= cpu_memory_usage=0 cuda_memory_usage=0>,
 <FunctionEventAvg key=aten::permute self_cpu_time=3.309ms cpu_time=10.945us  self_cuda_time=0.000us cuda_time=0.000us input_shapes= cpu_memory_usage=0 cuda_memory_usage=0>,
 <FunctionEventAvg key=aten::as_strided self_cpu_time=68.182ms cpu_time=1.752us  self_cuda_time=0.000us cuda_time=0.000us input_shapes= cpu_memory_usage=0 cuda_memory_usage=0>,
 <FunctionEventAvg key=aten::contiguous self_cpu_time=1.712ms cpu_time=78.462us  self_cuda_time=0.000us cuda_time=0.000us input_shapes= cpu_memory_usage=0 cuda_memory_usage=0>,
 <FunctionEventAvg key=aten::clone self_cpu_time=4.342ms cpu_time=74.835us  self_cuda_time=0.000us cuda_time=0.000us input_shapes= cpu_memory_usage=0 cuda_memory_usage=0>,
 <FunctionEventAvg key=aten::empty_like self_cpu_time=2.385ms cpu_time=19.447us  self_cuda_time=0.000us cuda_time=0.000us 

In [40]:
media.show_video(frames)

0
This browser does not support the video tag.


In [16]:
prep_obs = preprocess_observation(obs)

prep_obs_d = {key: prep_obs[key].to(device, non_blocking=True) for key in prep_obs}

In [52]:
from importlib import reload
import lerobot.common.envs.utils
reload(lerobot.common.envs.utils)
from lerobot.common.envs.utils import preprocess_observation

In [18]:
policy.select_action(prep_obs_d)

tensor([[394.0484, 339.3702]], device='cuda:0')

In [20]:
action

tensor([[393.5116, 339.3528]], device='cuda:0')

In [22]:
env.reset()
env.step(action)

AlreadyPendingCallError: Calling `reset_async` while waiting for a pending call to `step` to complete