In [1]:
from pathlib import Path

import gym_pusht  # noqa: F401
import gymnasium as gym
import imageio
import numpy
import torch
from huggingface_hub import snapshot_download
import einops

from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.utils import (
    get_device_from_parameters,
    get_dtype_from_parameters,
    populate_queues,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Download the diffusion policy for pusht environment
# pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device 

device(type='cuda')

In [4]:
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
policy.eval()
policy.to(device)
pass 

Loading weights from local directory


In [5]:
# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
# also automatically stops running after 300 interactions/steps.
env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    max_episode_steps=300,
)

In [47]:
# Reset the policy and environmens to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)

In [48]:
for key in numpy_observation.keys():
    print(key, numpy_observation[key].shape)

pixels (96, 96, 3)
agent_pos (2,)


In [49]:
state = torch.from_numpy(numpy_observation["agent_pos"])
image = torch.from_numpy(numpy_observation["pixels"])

# Convert to float32 with image from channel first in [0,255]
# to channel last in [0,1]
state = state.to(torch.float32)
image = image.to(torch.float32) / 255
image = image.permute(2, 0, 1)

# Send data tensors from CPU to GPU
state = state.to(device, non_blocking=True)
image = image.to(device, non_blocking=True)

# Add extra (empty) batch dimension, required to forward the policy
state = state.unsqueeze(0)
image = image.unsqueeze(0)

# Create the policy input dictionary
observation = {
    "observation.state": state,
    "observation.image": image,
}

observation['observation.state'].shape, observation['observation.image'].shape

(torch.Size([1, 2]), torch.Size([1, 3, 96, 96]))

In [53]:
# Predict the next action with respect to the current observation
with torch.inference_mode():
    action = policy.select_action(observation)

action.shape

torch.Size([1, 2])

In [55]:
# Prepare the action for the environment
numpy_action = action.squeeze(0).to("cpu").numpy()
numpy_action

array([175.30476, 421.1422 ], dtype=float32)

In [10]:
"""Select a single action given environment observations.

This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
    - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
    copied `n_obs_steps` times to fill the cache).
    - The diffusion model generates `horizon` steps worth of actions.
    - `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
    ----------------------------------------------------------------------------------------------
    (legend: o = n_obs_steps, h = horizon, a = n_action_steps)
    |timestep            | n-o+1 | n-o+2 | ..... | n     | ..... | n+a-1 | n+a   | ..... | n-o+h |
    |observation is used | YES   | YES   | YES   | YES   | NO    | NO    | NO    | NO    | NO    |
    |action is generated | YES   | YES   | YES   | YES   | YES   | YES   | YES   | YES   | YES   |
    |action is used      | NO    | NO    | NO    | YES   | YES   | YES   | NO    | NO    | NO    |
    ----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
pass 

In [11]:
batch=observation

In [12]:
batch = policy.normalize_inputs(batch)
if len(policy.expected_image_keys) > 0:
    batch = dict(batch)  # shallow copy so that adding a key doesn't modify the original
    batch["observation.images"] = torch.stack([batch[k] for k in policy.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
policy._queues = populate_queues(policy._queues, batch)

In [13]:
# policy._queues

In [14]:
policy._queues.keys(), len( policy._queues['observation.images'] )

(dict_keys(['observation.state', 'action', 'observation.images']), 2)

In [15]:
if len(policy._queues["action"]) == 0:
    # stack n latest observations from the queue
    batch = {k: torch.stack(list(policy._queues[k]), dim=1) for k in batch if k in policy._queues}
    actions = policy.diffusion.generate_actions(batch)

    # TODO(rcadene): make above methods return output dictionary?
    actions = policy.unnormalize_outputs({"action": actions})["action"]

    policy._queues["action"].extend(actions.transpose(0, 1))

In [16]:
policy._queues["action"]

deque([tensor([[203.8272, 424.9017]], device='cuda:0'),
       tensor([[226.5591, 426.3619]], device='cuda:0'),
       tensor([[243.5239, 421.2686]], device='cuda:0'),
       tensor([[259.5940, 418.5714]], device='cuda:0'),
       tensor([[277.0004, 414.8833]], device='cuda:0'),
       tensor([[293.9783, 412.5151]], device='cuda:0'),
       tensor([[313.9379, 407.7516]], device='cuda:0')],
      maxlen=8)

In [17]:
action = policy._queues["action"].popleft()
action

tensor([[203.8272, 424.9017]], device='cuda:0')

In [18]:
batch = {k: torch.stack(list(policy._queues[k]), dim=1) for k in batch if k in policy._queues}
for key in batch.keys():
    print(key, batch[key].shape)

observation.state torch.Size([1, 2, 2])
observation.images torch.Size([1, 2, 1, 3, 96, 96])


In [19]:
actions = policy.diffusion.generate_actions(batch)
actions.shape

torch.Size([1, 8, 2])

In [20]:
"""
This function expects `batch` to have:
{
    "observation.state": (B, n_obs_steps, state_dim)

    "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
        AND/OR
    "observation.environment_state": (B, environment_dim)
}
"""
pass

In [21]:
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == policy.diffusion.config.n_obs_steps

batch_size, n_obs_steps

(1, 2)

In [22]:
# Encode image features and concatenate them all together along with the state vector.
global_cond = policy.diffusion._prepare_global_conditioning(batch)  # (B, global_cond_dim)
global_cond.shape

torch.Size([1, 132])

In [23]:
policy.diffusion.config.use_separate_rgb_encoder_per_camera

False

In [24]:
policy.diffusion._use_env_state

False

In [25]:
# global_cond = policy.diffusion._prepare_global_conditioning(batch)
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
global_cond_feats = [batch["observation.state"]] 

# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = policy.diffusion.rgb_encoder( einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ..."))
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange( img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps)
global_cond_feats.append(img_features)

# Concatenate features then flatten to (B, global_cond_dim).
gc=torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)

img_features.shape, gc.shape, batch["observation.state"].shape

(torch.Size([1, 2, 64]), torch.Size([1, 132]), torch.Size([1, 2, 2]))

In [26]:
policy.diffusion.rgb_encoder

DiffusionRgbEncoder(
  (center_crop): CenterCrop(size=[84, 84])
  (maybe_random_crop): RandomCrop(size=(84, 84), padding=None)
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): GroupNorm(4, 64, eps=1e-05, affine=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): GroupNorm(4, 64, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): GroupNorm(4, 64, eps=1e-05, affine=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): GroupNorm(4, 64, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
 

In [27]:
# run sampling
actions = policy.diffusion.conditional_sample(batch_size, global_cond=global_cond)
actions.shape

torch.Size([1, 16, 2])

In [28]:
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + policy.diffusion.config.n_action_steps
actions = actions[:, start:end]
actions.shape

torch.Size([1, 8, 2])

In [29]:
n_obs_steps, policy.diffusion.config.n_action_steps

(2, 8)

In [30]:
actions = policy.diffusion.conditional_sample(batch_size, global_cond=global_cond)
actions.shape

torch.Size([1, 16, 2])

In [33]:
generator =None
generator

In [35]:
device = get_device_from_parameters(policy.diffusion)
dtype = get_dtype_from_parameters(policy.diffusion)

# Sample prior.
sample = torch.randn(
    size=(batch_size, policy.diffusion.config.horizon, policy.diffusion.config.output_shapes["action"][0]),
    dtype=dtype,
    device=device,
    generator=generator,
)

device, dtype, sample.shape

(device(type='cuda', index=0), torch.float32, torch.Size([1, 16, 2]))

In [36]:
policy.diffusion.num_inference_steps

100

In [37]:
policy.diffusion.noise_scheduler.set_timesteps(policy.diffusion.num_inference_steps)

In [38]:
policy.diffusion.noise_scheduler

DDPMScheduler {
  "_class_name": "DDPMScheduler",
  "_diffusers_version": "0.31.0",
  "beta_end": 0.02,
  "beta_schedule": "squaredcos_cap_v2",
  "beta_start": 0.0001,
  "clip_sample": true,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 100,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null,
  "variance_type": "fixed_small"
}

In [39]:
for t in policy.diffusion.noise_scheduler.timesteps:
    # Predict model output.
    model_output = policy.diffusion.unet(
        sample,
        torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
        global_cond=global_cond,
    )
    # Compute previous image: x_t -> x_t-1
    sample = policy.diffusion.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample


In [43]:
# [1, 16, 2], [1, 132] -> [1, 16, 2]

In [44]:
# policy.diffusion.unet