# Evaluate Pretrained Policy

## Introduction

This notebook is based on:
https://github.com/huggingface/lerobot/blob/main/examples/2_evaluate_pretrained_policy.py

## Preparation

In [1]:
from pathlib import Path

import gym_pusht  # noqa: F401
import gymnasium as gym
import imageio
import numpy
import torch
from IPython.display import Video

from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

In [2]:
# Create a directory to store the video of the evaluation
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# Select your device
device = "mps"

# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
pretrained_policy_path = "lerobot/diffusion_pusht"
# OR a path to a local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")

policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)

# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
# also automatically stops running after 300 interactions/steps.
env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    max_episode_steps=300,
)



In [3]:
# We can verify that the shapes of the features expected by the policy match the ones from the observations
# produced by the environment
print(policy.config.input_features)  # channel first
print(env.observation_space)  # channel last

{'observation.image': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 96, 96)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,))}
Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'pixels': Box(0, 255, (96, 96, 3), uint8))


In [4]:
# Similarly, we can check that the actions produced by the policy will match the actions expected by the
# environment
print(policy.config.output_features)
print(env.action_space)

{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}
Box(0.0, 512.0, (2,), float32)


In [5]:
# Reset the policy and environments to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)

## Evaluation

In [6]:
# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []

# Render frame of the initial state
frames.append(env.render())

step = 0
done = False
while not done:
    # Prepare observation for the policy running in Pytorch
    state = torch.from_numpy(numpy_observation["agent_pos"])
    image = torch.from_numpy(numpy_observation["pixels"])

    # Convert to float32 with image from channel first in [0,255]
    # to channel last in [0,1]
    state = state.to(torch.float32)
    image = image.to(torch.float32) / 255
    image = image.permute(2, 0, 1)

    # Send data tensors from CPU to GPU
    # Should be blocking to avoid race conditions. See:
    # https://github.com/huggingface/lerobot/issues/496#issuecomment-2543784955
    state = state.to(device, non_blocking=False)
    image = image.to(device, non_blocking=False)

    # Add extra (empty) batch dimension, required to forward the policy
    state = state.unsqueeze(0)
    image = image.unsqueeze(0)

    # Create the policy input dictionary
    observation = {
        "observation.state": state,
        "observation.image": image,
    }

    # Predict the next action with respect to the current observation
    with torch.inference_mode():
        action = policy.select_action(observation)

    # Prepare the action for the environment
    numpy_action = action.squeeze(0).to("cpu").numpy()

    # Step through the environment and receive a new observation
    numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
    print(f"{step=} {reward=} {terminated=}")

    # Keep track of all the rewards and frames
    rewards.append(reward)
    frames.append(env.render())

    # The rollout is considered done when the success state is reach (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
    done = terminated | truncated | done
    step += 1

if terminated:
    print("Success!")
else:
    print("Failure!")

step=0 reward=np.float64(0.0) terminated=False
step=1 reward=np.float64(0.0) terminated=False
step=2 reward=np.float64(0.0) terminated=False
step=3 reward=np.float64(0.0) terminated=False
step=4 reward=np.float64(0.0) terminated=False
step=5 reward=np.float64(0.0) terminated=False
step=6 reward=np.float64(0.0) terminated=False
step=7 reward=np.float64(0.0) terminated=False
step=8 reward=np.float64(0.0) terminated=False
step=9 reward=np.float64(0.0) terminated=False
step=10 reward=np.float64(0.0) terminated=False
step=11 reward=np.float64(0.0) terminated=False
step=12 reward=np.float64(0.0) terminated=False
step=13 reward=np.float64(0.0) terminated=False
step=14 reward=np.float64(0.0) terminated=False
step=15 reward=np.float64(0.0) terminated=False
step=16 reward=np.float64(0.0) terminated=False
step=17 reward=np.float64(0.0) terminated=False
step=18 reward=np.float64(0.0) terminated=False
step=19 reward=np.float64(0.0) terminated=False
step=20 reward=np.float64(0.0) terminated=False
st

## Visualization

In [7]:
# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]

# Encode all frames into a mp4 video.
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)



In [8]:
Video(video_path)