# Run a trained policy

This notebook will provide examples on how to run a trained policy and visualize the rollout.

In [1]:
import argparse
import json
import h5py
import imageio
import numpy as np
import os
from copy import deepcopy

import torch

import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.envs.env_base import EnvBase
from robomimic.algo import RolloutPolicy
import urllib.request
# packages = [robomimic, FileUtils, TorchUtils, TensorUtils, ObsUtils, robomimic.envs.env_base, robomimic.algo]
# import importlib
# for i in packages:
    # importlib.reload(i)




In [2]:
import tqdm
import torch
from torch.utils.data import DataLoader
import tensorflow as tf

from robomimic.utils.rlds_utils import droid_dataset_transform, robomimic_transform, TorchRLDSDataset, robomimic_dg_transform, dg_dataset_transform

from octo.data.dataset import make_dataset_from_rlds, make_interleaved_dataset, make_single_dataset
from octo.data.utils.data_utils import combine_dataset_statistics
from octo.utils.spec import ModuleSpec

tf.config.set_visible_devices([], "GPU")
from octo.utils.spec import ModuleSpec
import importlib

### Loading trained policy

We have a convenient function called `policy_from_checkpoint` that takes care of building the correct model from the checkpoint and load the trained weights. Of course you could also load the checkpoint manually.

In [3]:
# dg-noforce
dgnf = "C:\workspace\droid_policy_learning\logs\droid\im\diffusion_policy/10-06-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_dg_noforce_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241006181237\models\model_epoch_10.pth"
dgnf = dgnf.replace('\\', '/')

# dg
# dg = "C:\workspace\droid_policy_learning\logs\droid\im\diffusion_policy/10-06-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241006152541\models\model_epoch_45.pth"
# dg = "C:/workspace/droid_policy_learning/logs/droid/im/diffusion_policy/10-05-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241005230511/models/model_epoch_25.pth"
# dg = "C:\workspace\droid_policy_learning\logs\droid\im\diffusion_policy/10-07-None\deligrasp_grasponly_cams_workspace_wrist/20241007185334\models\model_epoch_48.pth"
# dg = "C:\workspace\droid_policy_learning\logs\droid\im\diffusion_policy/10-07-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_grasponly_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241007215755\models\model_epoch_30.pth"
# dg = "C:\workspace\droid_policy_learning\logs\droid\im\diffusion_policy/10-07-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_grasponly_noforce_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241007223737\models\model_epoch_30.pth"
dg = "C:\workspace\droid_policy_learning\logs\droid\im\diffusion_policy/10-07-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241007230909\models\deligrasp_model_epoch_30.pth"
# replace all '\' chars with '/'
dg = dg.replace('\\', '/')
print(dg)
device = TorchUtils.get_torch_device(try_to_use_cuda=True)

# restore policy
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=dg, device=device, verbose=True)

C:/workspace/droid_policy_learning/logs/droid/im/diffusion_policy/10-07-None/bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None/20241007230909/models/deligrasp_model_epoch_30.pth
{
    "algo_name": "diffusion_policy",
    "experiment": {
        "name": "bz_16_noise_samples_8_sample_weights_1_dataset_names_deligrasp_cams_workspace_wrist_goal_mode_None_truncated_geom_factor_0.3_ldkeys_proprio-lang_visenc_VisualCore_fuser_None",
        "validate": false,
        "logging": {
            "terminal_output_to_txt": true,
            "log_tb": true,
            "log_wandb": true,
            "wandb_proj_name": "jaf"
        },
        "mse": {
            "enabled": true,
            "every_n_epochs": 10,
            "on_save_ckpt": true,
            "num_samples": 6,
            "visualize": true
        },
        "save": {
            "enabled": true,
            "ev



DP obs_dim: 512
DP obs_dim shape: [512]
DP global_cond_dim: 1024
number of parameters: 7.990606e+07
DiffusionPolicyUNet (
  ModuleDict(
    (policy): ModuleDict(
      (obs_encoder): DataParallel(
        (module): ObservationGroupEncoder(
            group=obs
            ObservationEncoder(
                Key(
                    name=robot_state/gripper_position
                    shape=[1]
                    modality=low_dim
                    randomizer=ModuleList(
                      (0): None
                    )
                    net=None
                    sharing_from=None
                )
                Key(
                    name=camera/image/varied_camera_1_left_image
                    shape=[3, 128, 128]
                    modality=rgb
                    randomizer=ModuleList(
                      (0): ColorRandomizer(input_shape=[3, 128, 128], brightness=[0.7, 1.3], contrast=[0.7, 1.3], saturation=[0.7, 1.3], hue=[-0.3, 0.3], num_samples=1)
           

### Creating rollout envionment
The policy checkpoint also contains sufficient information to recreate the environment that it's trained with. Again, you may manually create the environment.

In [4]:
DATA_PATH = "C:/Users/willi/tensorflow_datasets"    # UPDATE WITH PATH TO RLDS DATASETS
DATASET_NAME = "deligrasp_dataset"
# DATASET_NAME = "deligrasp_dataset_grasponly"
EXP_LOG_PATH = "C:/workspace/deligrasp_policy_learning/logs" # UPDATE WITH PATH TO DESIRED LOGGING DIRECTORY
sample_weights = [1]

BASE_DATASET_KWARGS = {
    "name": DATASET_NAME,
    "data_dir": DATA_PATH,
    "image_obs_keys": {"primary": "image", "secondary": "wrist_image"},
    "state_obs_keys": ["cartesian_position", "gripper_position", "applied_force", "contact_force"],
    # "state_obs_keys": ["state"], # this makes ["observation"]['proprio'].shape len 16
    "language_key": "language_instruction",
    "norm_skip_keys":  ["proprio"],
    "action_proprio_normalization_type": "bounds",
    "absolute_action_mask": [False] * 11,                    # droid_dataset_transform uses absolute actions
    "action_normalization_mask": [False] * 11,      # don't normalize final (gripper) dimension
    "standardize_fn": dg_dataset_transform,
}

dataset = make_single_dataset(
    BASE_DATASET_KWARGS,
    train=True,
    traj_transform_kwargs=dict(
        window_size=2,
        future_action_window_size=15,
        subsample_length=50,
        skip_unlabeled=True,            # skip all trajectories without language annotation
    ),
    frame_transform_kwargs=dict(
        image_augment_kwargs=dict(
        ),
        resize_size=dict(
            primary=[128, 128],
            secondary=[128, 128],
        ),
        num_parallel_calls=200,
    )
)

dataset = dataset.map(robomimic_dg_transform, num_parallel_calls=48)


RLDS Utils: robomimic_dg_transform: ds_format: dg_rlds


In [5]:
episode = None
for i in dataset:
    episode = i
    break
obs = episode["obs"]
# obs
# # get first value of each key, that is the first observation step
# obs_iter = [{k: obs[k][i] for k in obs.keys()} for i in range(len(obs['robot_state/applied_force']))]
import numpy as np
act = episode["actions"]
obs_iter = []
for i in range(len(obs['robot_state/applied_force'])):
    dict = {}
    for k in obs.keys():
        skip_keys = ["raw"]
        if True in [sk in k for sk in skip_keys]:
            # print(obs[k][i])
            continue
        # o = obs[k][i][0] # tensorflow eagertensor with dimension [T, ...]
        o = obs[k][i] # tensorflow eagertensor with dimension [T, ...]
        if "image" in k:
            # o = o[k][i][0] # tensorflow eagertensor with dimension [T, ...]
            # o = o[0] # tensorflow eagertensor with dimension [T, ...]
            # currently in (H, W, C) format, need to change to (C, H, W)
            # o = np.transpose(o, (2, 0, 1))
            o = np.transpose(o, (0, 3, 1, 2))
        else:
            o = np.array(o)
        # if "image" not in k:
            # o = obs[k][i][0]
        # add dimension 1, so that it is [B, T, ...]
        # o = torch.from_numpy(np.expand_dims(o, axis=0))
        # print(o.ndim)
        print(o.shape)
        # dict[k] = obs[k][i]
        dict[k] = o
    obs_iter.append(dict)

(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 6)
(2, 3, 128, 128)
(2, 3, 128, 128)


In [6]:
policy.start_episode()
policy.goal_mode = None
policy.action_queue = None
policy.eval_mode = True

In [7]:
policy(obs_iter[0])

key: robot_state/gripper_position
ndim: 3, shape: torch.Size([1, 2, 1])
len shapes: 1 shape: [1]
key: camera/image/varied_camera_1_left_image
ndim: 5, shape: torch.Size([1, 2, 3, 128, 128])
len shapes: 3 shape: [3, 128, 128]
key: camera/image/varied_camera_2_left_image
ndim: 5, shape: torch.Size([1, 2, 3, 128, 128])
len shapes: 3 shape: [3, 128, 128]
key: robot_state/applied_force
ndim: 3, shape: torch.Size([1, 2, 1])
len shapes: 1 shape: [1]
key: robot_state/contact_force
ndim: 3, shape: torch.Size([1, 2, 1])
len shapes: 1 shape: [1]
key: robot_state/cartesian_position
ndim: 3, shape: torch.Size([1, 2, 6])
len shapes: 1 shape: [6]
TU: batch_size=1, seq_len=2
center_crop: im.shape: torch.Size([2, 128, 128, 3])
center_crop: im.shape: torch.Size([2, 128, 128, 3])
TU: outputs.shape=torch.Size([2, 512])
TU: outputs=tensor([[-0.2057, -1.7795, -0.1767,  ..., -0.6907, -1.9359, -2.3945],
        [-0.2057, -1.7795, -0.1767,  ..., -0.6907, -1.9359, -2.3945]],
       device='cuda:0')
TU: outputs.

array([-0.01507424, -0.00389477, -0.00696873, -0.04011911, -0.07106104,
        0.05214009,  0.01502148,  0.02163821])

In [8]:
timestep = 0
for i in range(len(obs_iter)):
    print(f"timestep {i}")
    action = policy(obs_iter[i])
    print(100*action)
    print(f"Ground Truth")
    if i < len(act):
        print(np.array(act[i][0]))
    timestep += 1

timestep 0
unnormalized action: {'action/rel_pos': array([ 0.04072944, -0.01496679,  0.00870464], dtype=float32), 'action/rel_rot_6d': array([ 0.12467919,  0.01749989, -0.0349833 ,  0.04255595,  0.13436852,
       -0.01664682], dtype=float32), 'action/gripper_position': array([0.0192178], dtype=float32), 'action/gripper_force': array([0.0150936], dtype=float32)}
normalized action: {'action/rel_pos': array([ 0.04072944, -0.01496679,  0.00870464]), 'action/rel_rot_6d': array([ 0.12467919,  0.01749989, -0.0349833 ,  0.04255595,  0.13436852,
       -0.01664682]), 'action/gripper_position': array([0.0192178]), 'action/gripper_force': array([0.0150936])}
{'action/rel_pos': array([ 0.04072944, -0.01496679,  0.00870464]), 'action/rel_rot_6d': array([ 0.03663751, -0.2686051 , -0.14413816], dtype=float32), 'action/gripper_position': array([0.0192178]), 'action/gripper_force': array([0.0150936])}
[  4.07294407  -1.49667906   0.87046437   3.66375074 -26.8605113
 -14.41381574   1.92177966   1.50936

In [9]:
policy

DiffusionPolicyUNet (
  ModuleDict(
    (policy): ModuleDict(
      (obs_encoder): DataParallel(
        (module): ObservationGroupEncoder(
            group=obs
            ObservationEncoder(
                Key(
                    name=robot_state/gripper_position
                    shape=[1]
                    modality=low_dim
                    randomizer=ModuleList(
                      (0): None
                    )
                    net=None
                    sharing_from=None
                )
                Key(
                    name=camera/image/varied_camera_1_left_image
                    shape=[3, 128, 128]
                    modality=rgb
                    randomizer=ModuleList(
                      (0): ColorRandomizer(input_shape=[3, 128, 128], brightness=[0.7, 1.3], contrast=[0.7, 1.3], saturation=[0.7, 1.3], hue=[-0.3, 0.3], num_samples=1)
                      (1): CropRandomizer(input_shape=[3, 128, 128], crop_size=[116, 116], num_crops=1)
       

### Define the rollout loop
Now let's define the main rollout loop. The loop runs the policy to a target `horizon` and optionally writes the rollout to a video.

In [10]:
def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 
    and returns the rollout trajectory.
    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
    """
    assert isinstance(env, EnvBase)
    assert isinstance(policy, RolloutPolicy)
    assert not (render and (video_writer is not None))

    policy.start_episode()
    obs = env.reset()
    state_dict = env.get_state()

    # hack that is necessary for robosuite tasks for deterministic action playback
    obs = env.reset_to(state_dict)

    results = {}
    video_count = 0  # video frame counter
    total_reward = 0.
    try:
        for step_i in range(horizon):

            # get action from policy
            act = policy(ob=obs)

            # play action
            next_obs, r, done, _ = env.step(act)

            # compute reward
            total_reward += r
            success = env.is_success()["task"]

            # visualization
            if render:
                env.render(mode="human", camera_name=camera_names[0])
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                    video_writer.append_data(video_img)
                video_count += 1

            # break if done or if success
            if done or success:
                break

            # update for next iter
            obs = deepcopy(next_obs)
            state_dict = env.get_state()

    except env.rollout_exceptions as e:
        print("WARNING: got rollout exception {}".format(e))

    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))

    return stats


### Run the policy
Now let's rollout the policy!

In [11]:
rollout_horizon = 400
np.random.seed(0)
torch.manual_seed(0)
video_path = "rollout.mp4"
video_writer = imageio.get_writer(video_path, fps=20)

In [12]:
stats = rollout(
    policy=policy, 
    env=env, 
    horizon=rollout_horizon, 
    render=False, 
    video_writer=video_writer, 
    video_skip=5, 
    camera_names=["agentview"]
)
print(stats)
video_writer.close()

NameError: name 'env' is not defined

### Visualize the rollout

In [None]:
from IPython.display import Video
Video(video_path)