In [1]:
from pathlib import Path

import gym_pusht  # noqa: F401
import gymnasium as gym
import imageio
import numpy
import torch

In [2]:
# Select your device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
pretrained_policy_path = "IliaLarchenko/dot_pusht_keypoints"

In [14]:
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)

In [7]:
env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="environment_state_agent_pos",
    max_episode_steps=300,
)

In [15]:
print(policy.config.input_features)
print(env.observation_space)

{'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,)), 'observation.environment_state': PolicyFeature(type=<FeatureType.ENV: 'ENV'>, shape=(16,))}
Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'environment_state': Box(0.0, 512.0, (16,), float64))


In [16]:
print(policy.config.output_features)
print(env.action_space)

{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}
Box(0.0, 512.0, (2,), float32)


In [17]:
policy.reset()
numpy_observation, info = env.reset(seed=42)

In [18]:
# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []

# Render frame of the initial state
frames.append(env.render())

In [24]:
step = 0
done = False

while not done:
    # Prepare observation for the policy
    state = torch.from_numpy(numpy_observation["agent_pos"])  # Agent position
    env_state = torch.from_numpy(numpy_observation["environment_state"])  # Environment state

    # Convert to float32
    state = state.to(torch.float32)
    env_state = env_state.to(torch.float32)

    # Send data tensors from CPU to GPU
    state = state.to(device, non_blocking=True)
    env_state = env_state.to(device, non_blocking=True)

    # Add extra (empty) batch dimension, required to forward the policy
    state = state.unsqueeze(0)
    env_state = env_state.unsqueeze(0)

    # Create the policy input dictionary
    observation = {
        "observation.state": state,
        "observation.environment_state": env_state,  # Add environment_state here
    }

    # Predict the next action with respect to the current observation
    with torch.inference_mode():
        action = policy.select_action(observation)

    # Prepare the action for the environment
    numpy_action = action.squeeze(0).to("cpu").numpy()

    # Step through the environment and receive a new observation
    numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
    print(f"{step=} {reward=} {terminated=}")

    # Keep track of all the rewards and frames
    rewards.append(reward)
    frames.append(env.render())

    # The rollout is considered done when the success state is reached (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
    done = terminated or truncated or done
    step += 1


step=0 reward=np.float64(0.0) terminated=False
step=1 reward=np.float64(0.0) terminated=False
step=2 reward=np.float64(0.0) terminated=False
step=3 reward=np.float64(0.0) terminated=False
step=4 reward=np.float64(0.0) terminated=False
step=5 reward=np.float64(0.0) terminated=False
step=6 reward=np.float64(0.0) terminated=False
step=7 reward=np.float64(0.0) terminated=False
step=8 reward=np.float64(0.0) terminated=False
step=9 reward=np.float64(0.0) terminated=False
step=10 reward=np.float64(0.0) terminated=False
step=11 reward=np.float64(0.0) terminated=False
step=12 reward=np.float64(0.0) terminated=False
step=13 reward=np.float64(0.0) terminated=False
step=14 reward=np.float64(0.0) terminated=False
step=15 reward=np.float64(0.0) terminated=False
step=16 reward=np.float64(0.0) terminated=False
step=17 reward=np.float64(0.0) terminated=False
step=18 reward=np.float64(0.0) terminated=False
step=19 reward=np.float64(0.0) terminated=False
step=20 reward=np.float64(0.0) terminated=False
st

In [27]:
if terminated:
    print("Success!")
else:
    print("Failure!")

# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]

# Encode all frames into a mp4 video.
video_path = "/home/lerobot/output/rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)

print(f"Video of the evaluation is available in '{video_path}'.")

Failure!




Video of the evaluation is available in '/home/lerobot/output/rollout.mp4'.


In [15]:
#now on aloha
import imageio
import gymnasium as gym
import numpy as np
import gym_aloha
env = gym.make(
    "gym_aloha/AlohaInsertion-v0",
    obs_type="pixels",
    max_episode_steps=300,
)

In [9]:
from lerobot.common.policies.dot.modeling_dot import DOTPolicy
pretrained_policy_path = "IliaLarchenko/dot_bimanual_insert"
policy = DOTPolicy.from_pretrained(pretrained_policy_path, map_location=device)


  from .autonotebook import tqdm as notebook_tqdm
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 172MB/s]


In [10]:
# We can verify that the shapes of the features expected by the policy match the ones from the observations
# produced by the environment
print(policy.config.input_features)
print(env.observation_space)

{'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(14,))}
Dict('top': Box(0, 255, (480, 640, 3), uint8))


In [11]:
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
# environment
print(policy.config.output_features)
print(env.action_space)

{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(14,))}
Box(-1.0, 1.0, (14,), float32)


In [16]:
# Reset the policy and environments to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)

FatalError: gladLoadGL error