# Run a trained policy

This notebook will provide examples on how to run a trained policy and visualize the rollout.

In [10]:
import argparse
import json
import h5py
import imageio
import numpy as np
import os
from copy import deepcopy

import torch

import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.envs.env_base import EnvBase
from robomimic.algo import RolloutPolicy

import urllib.request


### Download policy checkpoint
First, let's try downloading a pretrained model from our model zoo.

In [11]:
# Get pretrained checkpooint from the model zoo

# ckpt_path = "lift_ph_low_dim_epoch_1000_succ_100.pth"
# # Lift (Proficient Human)
# urllib.request.urlretrieve(
#     "http://downloads.cs.stanford.edu/downloads/rt_benchmark/model_zoo/lift/bc_rnn/lift_ph_low_dim_epoch_1000_succ_100.pth",
#     filename=ckpt_path
# )

ckpt_path = "Pretrained/lift_ph_image_epoch_500_succ_100.pth"
# urllib.request.urlretrieve(
#     "http://downloads.cs.stanford.edu/downloads/rt_benchmark/model_zoo/lift/bc_rnn/lift_ph_image_epoch_500_succ_100.pth",
#     filename=ckpt_path
# )

# ckpt_path = "tool_hang_ph_image_epoch_440_succ_74.pth"
# urllib.request.urlretrieve(
#     "http://downloads.cs.stanford.edu/downloads/rt_benchmark/model_zoo/lift/bc_rnn/tool_hang_ph_image_epoch_440_succ_74.pth",
#     filename=ckpt_path
# )
assert os.path.exists(ckpt_path)

### Loading trained policy
We have a convenient function called `policy_from_checkpoint` that takes care of building the correct model from the checkpoint and load the trained weights. Of course you could also load the checkpoint manually.

In [12]:
device = TorchUtils.get_torch_device(try_to_use_cuda=False)
# restore policy
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=True)

{
    "algo_name": "bc",
    "experiment": {
        "name": "core_bc_rnn_lift_ph_image",
        "validate": true,
        "logging": {
            "terminal_output_to_txt": true,
            "log_tb": true
        },
        "save": {
            "enabled": true,
            "every_n_seconds": null,
            "every_n_epochs": 20,
            "epochs": [],
            "on_best_validation": false,
            "on_best_rollout_return": false,
            "on_best_rollout_success_rate": true
        },
        "epoch_every_n_steps": 500,
        "validation_epoch_every_n_steps": 50,
        "env": null,
        "additional_envs": null,
        "render": false,
        "render_video": true,
        "keep_all_videos": false,
        "video_skip": 5,
        "rollout": {
            "enabled": true,
            "n": 50,
            "horizon": 400,
            "rate": 20,
            "warmstart": 0,
            "terminate_on_success": true
        }
    },
    "train": {
        "data": "



ObservationKeyToModalityDict: mean not found, adding mean to mapping with assumed low_dim modality!
ObservationKeyToModalityDict: scale not found, adding scale to mapping with assumed low_dim modality!
ObservationKeyToModalityDict: logits not found, adding logits to mapping with assumed low_dim modality!
BC_RNN_GMM (
  ModuleDict(
    (policy): RNNGMMActorNetwork(
        action_dim=7, std_activation=softplus, low_noise_eval=True, num_nodes=5, min_std=0.0001
  
        encoder=ObservationGroupEncoder(
            group=obs
            ObservationEncoder(
                Key(
                    name=agentview_image
                    shape=(3, 84, 84)
                    modality=rgb
                    randomizer=CropRandomizer(input_shape=(3, 84, 84), crop_size=[76, 76], num_crops=1)
                    net=VisualCore(
                      input_shape=[3, 76, 76]
                      output_shape=[64]
                      backbone_net=ResNet18Conv(input_channel=3, input_coord_con

### Creating rollout envionment
The policy checkpoint also contains sufficient information to recreate the environment that it's trained with. Again, you may manually create the environment.

In [13]:
# create environment from saved checkpoint
env, _ = FileUtils.env_from_checkpoint(
    ckpt_dict=ckpt_dict, 
    render=False, # we won't do on-screen rendering in the notebook
    render_offscreen=True, # render to RGB images for video
    verbose=True,
)

Created environment with name Lift
Action size is 7
    No environment version found in dataset!
    Cannot verify if dataset and installed environment versions match
)[0m
Lift
{
    "camera_depths": false,
    "camera_heights": 84,
    "camera_names": [
        "agentview",
        "robot0_eye_in_hand"
    ],
    "camera_widths": 84,
    "control_freq": 20,
    "controller_configs": {
        "control_delta": true,
        "damping": 1,
        "damping_limits": [
            0,
            10
        ],
        "impedance_mode": "fixed",
        "input_max": 1,
        "input_min": -1,
        "interpolation": null,
        "kp": 150,
        "kp_limits": [
            0,
            300
        ],
        "orientation_limits": null,
        "output_max": [
            0.05,
            0.05,
            0.05,
            0.5,
            0.5,
            0.5
        ],
        "output_min": [
            -0.05,
            -0.05,
            -0.05,
            -0.5,
            -0.

### Define the rollout loop
Now let's define the main rollout loop. The loop runs the policy to a target `horizon` and optionally writes the rollout to a video.

In [14]:
def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 
    and returns the rollout trajectory.
    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
    """
    assert isinstance(env, EnvBase)
    assert isinstance(policy, RolloutPolicy)
    assert not (render and (video_writer is not None))

    policy.start_episode()
    obs = env.reset()
    state_dict = env.get_state()

    # hack that is necessary for robosuite tasks for deterministic action playback
    obs = env.reset_to(state_dict)

    results = {}
    video_count = 0  # video frame counter
    total_reward = 0.
    try:
        for step_i in range(horizon):

            # get action from policy
            act = policy(ob=obs)

            # play action
            next_obs, r, done, _ = env.step(act)

            # compute reward
            total_reward += r
            success = env.is_success()["task"]

            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                    video_writer.append_data(video_img)
                    print(video_count)
                video_count += 1

            # break if done or if success
            if done or success:
                break

            # update for next iter
            obs = deepcopy(next_obs)
            state_dict = env.get_state()

    except env.rollout_exceptions as e:
        print("WARNING: got rollout exception {}".format(e))

    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))

    return stats


### Run the policy
Now let's rollout the policy!

In [15]:
type(ckpt_path)

str

In [16]:
rollout_horizon = 400
np.random.seed(0)
torch.manual_seed(0)
video_path = ckpt_path + "_rollout.mp4"
video_writer = imageio.get_writer(video_path, fps=20)

In [17]:
stats = rollout(
    policy=policy, 
    env=env, 
    horizon=rollout_horizon, 
    render=False, 
    video_writer=video_writer, 
    video_skip=5, 
    camera_names=["agentview"]
)
print(stats)
video_writer.close()

0
5
10
15
20
25
30
35
40
{'Return': 1.0, 'Horizon': 43, 'Success_Rate': 1.0}


### Visualize the rollout

In [18]:
from IPython.display import Video
Video(video_path)