# Inference on dp on both agentview and handview

In [None]:
import os
import sys
import torch
import dill
import numpy as np
import collections
import tqdm
import imageio

# Import workspace and utilities
from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace import TrainDiffusionUnetHybridWorkspace
from diffusion_policy.env_runner.robomimic_image_runner import RobomimicImageRunner
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper

import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils

# ----- Helper Classes and Functions -----

# Frame stacker for temporal context.
class FrameStackForTrans:
    def __init__(self, num_frames):
        self.num_frames = num_frames
        self.obs_history = {}
    
    def reset(self, init_obs):
        self.obs_history = {}
        for k in init_obs:
            self.obs_history[k] = collections.deque([init_obs[k][None] for _ in range(self.num_frames)], maxlen=self.num_frames)
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs
    
    def add_new_obs(self, new_obs):
        for k in new_obs:
            if 'timesteps' in k or 'actions' in k:
                continue
            self.obs_history[k].append(new_obs[k][None])
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs

# Environment wrapper to inject a dummy 'robot0_eye_in_hand_image'
class DummyObsWrapper:
    def __init__(self, env):
        self.env = env
        self.required_key = 'robot0_eye_in_hand_image'
        self.image_shape = (3, 84, 84)  # Adjust if needed
    
    def reset(self):
        obs = self.env.reset()
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs
    
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs, reward, done, info
    
    def render(self, *args, **kwargs):
        return self.env.render(*args, **kwargs)
    
    def __getattr__(self, name):
        return getattr(self.env, name)

# Create environment from metadata and shape configuration.
def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
    
    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False,
        render_offscreen=enable_render,
        use_image_obs=enable_render,
    )
    return env

# Undo the transformation on actions.
def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

# Rollout inference function.
def rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=False):
    keys_select = ['robot0_eye_in_hand_image', 'agentview_image', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
    imgs = []
    imgs_eye = []
    # np.random.seed(40)
    # torch.manual_seed(40)
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = env.reset()
    # print("Rollout initial observation keys:", list(obs.keys()))
    policy.reset()
    obs = framestacker.reset(obs)
    done = False
    success = False
    step = 0
    while not done:
        np_obs_dict = {key: obs[key][None, :] for key in keys_select if key in obs}
        obs_dict = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
        with torch.no_grad():
            action_dict = policy.predict_action(obs_dict)
        np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
        env_action = np_action_dict['action']
        env_action = undo_transform_action(env_action, rotation_transformer)
        env_action = env_action.squeeze()
        # print(env_action)
        for act in env_action:
            act = act
            if return_imgs:
                img = env.render(mode="rgb_array", height=512, width=512, camera_name="agentview")
                img_eye = env.render(mode="rgb_array", height=512, width=512, camera_name="robot0_eye_in_hand")
                imgs.append(img)
                imgs_eye.append(img_eye)
            next_obs, reward, done, info = env.step(act)
            success = env.is_success()["task"]
            step += 1
            if step == max_steps:
                done = True
                break
            obs = framestacker.add_new_obs(next_obs)
            if done or success:
                done = True
                break
        if done:
            break
    return success, imgs, imgs_eye

# ----- Main Execution -----

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Set checkpoint and dataset paths (update these paths)
checkpoint_path = "/Riad/diffusion_policy/data/outputs/Riad_sim_lift_ph_full_2025_03_16_16_54_12/checkpoints/after_train_200_epochs.ckpt"
dataset_path = "/Riad/diffusion_policy/full_image_low_lift_ph.hdf5"  # Used only for env metadata.

# Load checkpoint payload
with open(checkpoint_path, 'rb') as f:
    payload = torch.load(f, pickle_module=dill)
cfg = payload['cfg']

# Build workspace and load the payload (model weights, etc.)
workspace = TrainDiffusionUnetHybridWorkspace(cfg, output_dir=None)
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# Select policy from workspace (use EMA model if enabled)
policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.to(device)
policy.eval()
print("Policy loaded and set to eval mode.")

# Get environment metadata from dataset.
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
env_meta['env_kwargs']['use_object_obs'] = False  # disable object state observation

# Set absolute action mode if needed and initialize the rotation transformer.
abs_action = True
rotation_transformer = None
if abs_action:
    env_meta['env_kwargs']['controller_configs']['control_delta'] = True
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

# Define shape metadata (include expected modalities)
shape_meta = {
    'obs': {
        'robot0_eye_in_hand_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        'agentview_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        'robot0_eef_pos': {'shape': [3]},
        'robot0_eef_quat': {'shape': [4]},
        'robot0_gripper_qpos': {'shape': [2]},
        'object': {'shape': [1]}  # adjust if needed
    },
    'action': {
        'shape': [10]
    }
}

# Create the environment.
raw_env = create_env(env_meta=env_meta, shape_meta=shape_meta, enable_render=True)
print("Created environment with name:", env_meta.get("name", "Unknown"))
print("Action size is", raw_env.action_space.shape[0] if hasattr(raw_env, "action_space") else "Unknown")
print("Original env observation keys:", list(raw_env.reset().keys()))

# Wrap the environment to inject dummy 'robot0_eye_in_hand_image' if missing.
env = DummyObsWrapper(raw_env)
# print("Wrapped env observation keys:", list(env.reset().keys()))

# Inference parameters
n_obs_steps = cfg.dataset_obs_steps if hasattr(cfg, "dataset_obs_steps") else 2
n_action_steps = cfg.n_action_steps if hasattr(cfg, "n_action_steps") else 8
max_steps = 400  # maximum steps per rollout
n_trials = 5    # number of inference trials
fps = 20         # frames per second for the output video

# Run trials and save video for each trial.
trial_success = []
for i in range(n_trials):
    print(f"Trial {i+1}/{n_trials}...")
    # Set return_imgs=True to record frames.
    success, imgs, imgs_eye= rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=True)
    trial_success.append(success)
    print(f"Trial {i+1} success: {success}")
    
    # Save video if images were recorded.
    if imgs:
        video_filename = f"trial_{i+1}_output.mp4"
        imageio.mimwrite(video_filename, imgs, fps=fps, quality=8)
        print(f"Saved video: {video_filename}")
    if imgs_eye:
        video_filename_eye = f"trial_{i+1}_output_eye.mp4"
        imageio.mimwrite(video_filename_eye, imgs_eye, fps=fps, quality=8)
        print(f"Saved video: {video_filename_eye}")

mean_success = np.mean(trial_success)
print("Mean success over trials:", mean_success)


# Policy performance with noise

In [None]:
import os
import sys
import torch
import dill
import numpy as np
import collections
import tqdm
import imageio

# Import workspace and utilities
from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace import TrainDiffusionUnetHybridWorkspace
from diffusion_policy.env_runner.robomimic_image_runner import RobomimicImageRunner
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper

import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils

# ----- Helper Classes and Functions -----

# Frame stacker for temporal context.
class FrameStackForTrans:
    def __init__(self, num_frames):
        self.num_frames = num_frames
        self.obs_history = {}
    
    def reset(self, init_obs):
        self.obs_history = {}
        for k in init_obs:
            self.obs_history[k] = collections.deque([init_obs[k][None] for _ in range(self.num_frames)], maxlen=self.num_frames)
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs
    
    def add_new_obs(self, new_obs):
        for k in new_obs:
            if 'timesteps' in k or 'actions' in k:
                continue
            self.obs_history[k].append(new_obs[k][None])
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs

# Environment wrapper to inject a dummy 'robot0_eye_in_hand_image'
class DummyObsWrapper:
    def __init__(self, env):
        self.env = env
        self.required_key = 'robot0_eye_in_hand_image'
        self.image_shape = (3, 84, 84)  # Adjust if needed
    
    def reset(self):
        obs = self.env.reset()
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs
    
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs, reward, done, info
    
    def render(self, *args, **kwargs):
        return self.env.render(*args, **kwargs)
    
    def __getattr__(self, name):
        return getattr(self.env, name)

# Create environment from metadata and shape configuration.
def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
    
    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False,
        render_offscreen=enable_render,
        use_image_obs=enable_render,
    )
    return env

# Undo the transformation on actions.
def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

# Rollout inference function.
def rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=False):
    keys_select = ['robot0_eye_in_hand_image', 'agentview_image', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
    imgs = []
    imgs_eye = []
    # np.random.seed(40)
    # torch.manual_seed(40)
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = env.reset()
    # print("Rollout initial observation keys:", list(obs.keys()))
    policy.reset()
    obs = framestacker.reset(obs)
    done = False
    success = False
    step = 0
    while not done:
        np_obs_dict = {key: obs[key][None, :] for key in keys_select if key in obs}
        obs_dict = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
        with torch.no_grad():
            action_dict = policy.predict_action(obs_dict)
        np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
        env_action = np_action_dict['action']
        env_action = undo_transform_action(env_action, rotation_transformer)
        env_action = env_action.squeeze()
        # print(env_action)
        for act in env_action:
            act = act
            if return_imgs:
                img = env.render(mode="rgb_array", height=512, width=512, camera_name="agentview")
                img_eye = env.render(mode="rgb_array", height=512, width=512, camera_name="robot0_eye_in_hand")
                imgs.append(img)
                imgs_eye.append(img_eye)
            noise = np.random.normal(loc=0.0, scale=0.4, size=act.shape)
            act_noisy = act + noise
            next_obs, reward, done, info = env.step(act_noisy)
            success = env.is_success()["task"]
            step += 1
            if step == max_steps:
                done = True
                break
            obs = framestacker.add_new_obs(next_obs)
            if done or success:
                done = True
                break
        if done:
            break
    return success, imgs, imgs_eye

# ----- Main Execution -----

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Set checkpoint and dataset paths (update these paths)
checkpoint_path = "/Riad/diffusion_policy/data/outputs/Riad_sim_lift_ph_full_2025_03_16_16_54_12/checkpoints/after_train_200_epochs.ckpt"
dataset_path = "/Riad/diffusion_policy/full_image_low_lift_ph.hdf5"  # Used only for env metadata.

# Load checkpoint payload
with open(checkpoint_path, 'rb') as f:
    payload = torch.load(f, pickle_module=dill)
cfg = payload['cfg']

# Build workspace and load the payload (model weights, etc.)
workspace = TrainDiffusionUnetHybridWorkspace(cfg, output_dir=None)
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# Select policy from workspace (use EMA model if enabled)
policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.to(device)
policy.eval()
print("Policy loaded and set to eval mode.")

# Get environment metadata from dataset.
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
env_meta['env_kwargs']['use_object_obs'] = False  # disable object state observation

# Set absolute action mode if needed and initialize the rotation transformer.
abs_action = True
rotation_transformer = None
if abs_action:
    env_meta['env_kwargs']['controller_configs']['control_delta'] = True
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

# Define shape metadata (include expected modalities)
shape_meta = {
    'obs': {
        'robot0_eye_in_hand_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        'agentview_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        'robot0_eef_pos': {'shape': [3]},
        'robot0_eef_quat': {'shape': [4]},
        'robot0_gripper_qpos': {'shape': [2]},
        'object': {'shape': [1]}  # adjust if needed
    },
    'action': {
        'shape': [10]
    }
}

# Create the environment.
raw_env = create_env(env_meta=env_meta, shape_meta=shape_meta, enable_render=True)
print("Created environment with name:", env_meta.get("name", "Unknown"))
print("Action size is", raw_env.action_space.shape[0] if hasattr(raw_env, "action_space") else "Unknown")
print("Original env observation keys:", list(raw_env.reset().keys()))

# Wrap the environment to inject dummy 'robot0_eye_in_hand_image' if missing.
env = DummyObsWrapper(raw_env)
# print("Wrapped env observation keys:", list(env.reset().keys()))

# Inference parameters
n_obs_steps = cfg.dataset_obs_steps if hasattr(cfg, "dataset_obs_steps") else 2
n_action_steps = cfg.n_action_steps if hasattr(cfg, "n_action_steps") else 8
max_steps = 400  # maximum steps per rollout
n_trials = 5    # number of inference trials
fps = 20         # frames per second for the output video

# Run trials and save video for each trial.
trial_success = []
for i in range(n_trials):
    print(f"Trial {i+1}/{n_trials}...")
    # Set return_imgs=True to record frames.
    success, imgs, imgs_eye= rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=True)
    trial_success.append(success)
    print(f"Trial {i+1} success: {success}")
    
    # Save video if images were recorded.
    if imgs:
        video_filename = f"trial_{i+1}_output.mp4"
        imageio.mimwrite(video_filename, imgs, fps=fps, quality=8)
        print(f"Saved video: {video_filename}")
    if imgs_eye:
        video_filename_eye = f"trial_{i+1}_output_eye.mp4"
        imageio.mimwrite(video_filename_eye, imgs_eye, fps=fps, quality=8)
        print(f"Saved video: {video_filename_eye}")

mean_success = np.mean(trial_success)
print("Mean success over trials:", mean_success)


# Inference on only agaentview or hand view

In [1]:


import os
import sys
import torch
import dill
import numpy as np
import collections
import tqdm
import imageio

# Import workspace and utilities
from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace import TrainDiffusionUnetHybridWorkspace
from diffusion_policy.env_runner.robomimic_image_runner import RobomimicImageRunner
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper

import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils

# ----- Helper Classes and Functions -----

# Frame stacker for temporal context.
class FrameStackForTrans:
    def __init__(self, num_frames):
        self.num_frames = num_frames
        self.obs_history = {}
    
    def reset(self, init_obs):
        self.obs_history = {}
        for k in init_obs:
            self.obs_history[k] = collections.deque([init_obs[k][None] for _ in range(self.num_frames)], maxlen=self.num_frames)
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs
    
    def add_new_obs(self, new_obs):
        for k in new_obs:
            if 'timesteps' in k or 'actions' in k:
                continue
            self.obs_history[k].append(new_obs[k][None])
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs

# Environment wrapper to inject a dummy 'robot0_eye_in_hand_image'
class DummyObsWrapper:
    def __init__(self, env):
        self.env = env
        self.required_key = 'robot0_eye_in_hand_image'
        self.image_shape = (3, 84, 84)  # Adjust if needed
    
    def reset(self):
        obs = self.env.reset()
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs
    
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs, reward, done, info
    
    def render(self, *args, **kwargs):
        return self.env.render(*args, **kwargs)
    
    def __getattr__(self, name):
        return getattr(self.env, name)

# Create environment from metadata and shape configuration.
def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
    
    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False,
        render_offscreen=enable_render,
        use_image_obs=enable_render,
    )
    return env

# Undo the transformation on actions.
def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

# Rollout inference function.
def rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=False):
    keys_select = [ 'robot0_eye_in_hand_image', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
    imgs = []
    imgs_eye = []
    # np.random.seed(40)
    # torch.manual_seed(40)
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = env.reset()
    # print("Rollout initial observation keys:", list(obs.keys()))
    policy.reset()
    obs = framestacker.reset(obs)
    done = False
    success = False
    step = 0
    while not done:
        np_obs_dict = {key: obs[key][None, :] for key in keys_select if key in obs}
        obs_dict = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
        with torch.no_grad():
            action_dict = policy.predict_action(obs_dict)
        np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
        env_action = np_action_dict['action']
        env_action = undo_transform_action(env_action, rotation_transformer)
        env_action = env_action.squeeze()
        # print(env_action)
        for act in env_action:
            act = act
            if return_imgs:
                img = env.render(mode="rgb_array", height=512, width=512, camera_name="agentview")
                img_eye = env.render(mode="rgb_array", height=512, width=512, camera_name="robot0_eye_in_hand")
                imgs.append(img)
                imgs_eye.append(img_eye)
            # noise = np.random.normal(loc=0.0, scale=0.4, size=act.shape)
            # act_noisy = act + noise
            next_obs, reward, done, info = env.step(act)
            success = env.is_success()["task"]
            step += 1
            if step == max_steps:
                done = True
                break
            obs = framestacker.add_new_obs(next_obs)
            if done or success:
                done = True
                break
        if done:
            break
    return success, imgs, imgs_eye

# ----- Main Execution -----

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Set checkpoint and dataset paths (update these paths)
checkpoint_path = "/home/carl_lab/diffusion_policy/data/outputs/Riad_sim_lift_mh_vision_emp_hand_2025_03_22_00_50_32/checkpoints/after_train_200_epochs.ckpt"
dataset_path = "/home/carl_lab/Riad/data/simulation/full_image_low_lift_ph.hdf5"  # Used only for env metadata.

# Load checkpoint payload
with open(checkpoint_path, 'rb') as f:
    payload = torch.load(f, pickle_module=dill)
cfg = payload['cfg']

# Build workspace and load the payload (model weights, etc.)
workspace = TrainDiffusionUnetHybridWorkspace(cfg, output_dir=None)
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# Select policy from workspace (use EMA model if enabled)
policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.to(device)
policy.eval()
print("Policy loaded and set to eval mode.")

# Get environment metadata from dataset.
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
env_meta['env_kwargs']['use_object_obs'] = False  # disable object state observation

# Set absolute action mode if needed and initialize the rotation transformer.
abs_action = True
rotation_transformer = None
if abs_action:
    env_meta['env_kwargs']['controller_configs']['control_delta'] = True
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

# Define shape metadata (include expected modalities)
shape_meta = {
    'obs': {
        'robot0_eye_in_hand_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        # 'agentview_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        'robot0_eef_pos': {'shape': [3]},
        'robot0_eef_quat': {'shape': [4]},
        'robot0_gripper_qpos': {'shape': [2]},
        'object': {'shape': [1]}  # adjust if needed
    },
    'action': {
        'shape': [10]
    }
}

# Create the environment.
raw_env = create_env(env_meta=env_meta, shape_meta=shape_meta, enable_render=True)
print("Created environment with name:", env_meta.get("name", "Unknown"))
print("Action size is", raw_env.action_space.shape[0] if hasattr(raw_env, "action_space") else "Unknown")
print("Original env observation keys:", list(raw_env.reset().keys()))

# Wrap the environment to inject dummy 'robot0_eye_in_hand_image' if missing.
env = DummyObsWrapper(raw_env)
# print("Wrapped env observation keys:", list(env.reset().keys()))

# Inference parameters
n_obs_steps = cfg.dataset_obs_steps if hasattr(cfg, "dataset_obs_steps") else 2
n_action_steps = cfg.n_action_steps if hasattr(cfg, "n_action_steps") else 8
max_steps = 400  # maximum steps per rollout
n_trials = 10   # number of inference trials
fps = 20         # frames per second for the output video

# Run trials and save video for each trial.
trial_success = []
for i in range(n_trials):
    print(f"Trial {i+1}/{n_trials}...")
    # Set return_imgs=True to record frames.
    success, imgs, imgs_eye= rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=True)
    trial_success.append(success)
    print(f"Trial {i+1} success: {success}")
    
    # Save video if images were recorded.
    if imgs:
        video_filename = f"trial_{i+1}_mh_hand_output.mp4"
        imageio.mimwrite(video_filename, imgs, fps=fps, quality=8)
        print(f"Saved video: {video_filename}")
    if imgs_eye:
        video_filename_eye = f"trial_{i+1}_mh_hand_output_eye.mp4"
        imageio.mimwrite(video_filename_eye, imgs_eye, fps=fps, quality=8)
        print(f"Saved video: {video_filename_eye}")

mean_success = np.mean(trial_success)
print("Mean success over trials:", mean_success)


Using device: cuda


using obs modality: low_dim with keys: ['robot0_eef_quat', 'robot0_gripper_qpos', 'robot0_eef_pos']
using obs modality: rgb with keys: ['robot0_eye_in_hand_image']
using obs modality: depth with keys: []
using obs modality: scan with keys: []




Diffusion params: 1.737294e+07
Vision params: 1.119709e+07
Policy loaded and set to eval mode.
Found 3 GPUs for rendering. Using device 0.
Created environment with name Lift
Action size is 7
Created environment with name: Unknown
Action size is Unknown
Original env observation keys: ['robot0_eye_in_hand_image', 'object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
Trial 1/10...
Trial 1 success: True
Saved video: trial_1_mh_hand_output.mp4
Saved video: trial_1_mh_hand_output_eye.mp4
Trial 2/10...
Trial 2 success: True
Saved video: trial_2_mh_hand_output.mp4
Saved video: trial_2_mh_hand_output_eye.mp4
Trial 3/10...
Trial 3 success: True
Saved video: trial_3_mh_hand_output.mp4
Saved video: trial_3_mh_hand_output_eye.mp4
Trial 4/10...
Trial 4 success: True
Saved video: trial_4_mh_hand_output.mp4
Saved video: trial_4_mh_hand_output_eye.mp4
Trial 5/10...
Trial 5 success: True
Saved video: trial_5_mh_hand_output.mp4
Saved video: trial_5_mh_hand_output_eye.mp4
Trial 6/10...
Tri

# Inference on only agaentview or hand view (noise)

In [3]:


import os
import sys
import torch
import dill
import numpy as np
import collections
import tqdm
import imageio

# Import workspace and utilities
from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace import TrainDiffusionUnetHybridWorkspace
from diffusion_policy.env_runner.robomimic_image_runner import RobomimicImageRunner
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper

import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils

# ----- Helper Classes and Functions -----

# Frame stacker for temporal context.
class FrameStackForTrans:
    def __init__(self, num_frames):
        self.num_frames = num_frames
        self.obs_history = {}
    
    def reset(self, init_obs):
        self.obs_history = {}
        for k in init_obs:
            self.obs_history[k] = collections.deque([init_obs[k][None] for _ in range(self.num_frames)], maxlen=self.num_frames)
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs
    
    def add_new_obs(self, new_obs):
        for k in new_obs:
            if 'timesteps' in k or 'actions' in k:
                continue
            self.obs_history[k].append(new_obs[k][None])
        obs = {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
        return obs

# Environment wrapper to inject a dummy 'robot0_eye_in_hand_image'
class DummyObsWrapper:
    def __init__(self, env):
        self.env = env
        self.required_key = 'robot0_eye_in_hand_image'
        self.image_shape = (3, 84, 84)  # Adjust if needed
    
    def reset(self):
        obs = self.env.reset()
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs
    
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if self.required_key not in obs:
            obs[self.required_key] = np.zeros(self.image_shape, dtype=np.float32)
        return obs, reward, done, info
    
    def render(self, *args, **kwargs):
        return self.env.render(*args, **kwargs)
    
    def __getattr__(self, name):
        return getattr(self.env, name)

# Create environment from metadata and shape configuration.
def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)
    
    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False,
        render_offscreen=enable_render,
        use_image_obs=enable_render,
    )
    return env

# Undo the transformation on actions.
def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

# Rollout inference function.
def rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=False):
    keys_select = [ 'robot0_eye_in_hand_image', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
    imgs = []
    imgs_eye = []
    # np.random.seed(40)
    # torch.manual_seed(40)
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = env.reset()
    # print("Rollout initial observation keys:", list(obs.keys()))
    policy.reset()
    obs = framestacker.reset(obs)
    done = False
    success = False
    step = 0
    while not done:
        np_obs_dict = {key: obs[key][None, :] for key in keys_select if key in obs}
        obs_dict = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
        with torch.no_grad():
            action_dict = policy.predict_action(obs_dict)
        np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
        env_action = np_action_dict['action']
        env_action = undo_transform_action(env_action, rotation_transformer)
        env_action = env_action.squeeze()
        # print(env_action)
        for act in env_action:
            act = act
            if return_imgs:
                img = env.render(mode="rgb_array", height=512, width=512, camera_name="agentview")
                img_eye = env.render(mode="rgb_array", height=512, width=512, camera_name="robot0_eye_in_hand")
                imgs.append(img)
                imgs_eye.append(img_eye)
            noise = np.random.normal(loc=0.0, scale=0.4, size=act.shape)
            act_noisy = act + noise
            next_obs, reward, done, info = env.step(act_noisy)
            success = env.is_success()["task"]
            step += 1
            if step == max_steps:
                done = True
                break
            obs = framestacker.add_new_obs(next_obs)
            if done or success:
                done = True
                break
        if done:
            break
    return success, imgs, imgs_eye

# ----- Main Execution -----

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Set checkpoint and dataset paths (update these paths)
checkpoint_path = "/home/carl_lab/diffusion_policy/data/outputs/Riad_sim_lift_mh_vision_emp_hand_2025_03_22_00_50_32/checkpoints/after_train_200_epochs.ckpt"
dataset_path = "/home/carl_lab/Riad/data/simulation/full_image_low_lift_ph.hdf5"  # Used only for env metadata.

# Load checkpoint payload
with open(checkpoint_path, 'rb') as f:
    payload = torch.load(f, pickle_module=dill)
cfg = payload['cfg']

# Build workspace and load the payload (model weights, etc.)
workspace = TrainDiffusionUnetHybridWorkspace(cfg, output_dir=None)
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# Select policy from workspace (use EMA model if enabled)
policy = workspace.model
if cfg.training.use_ema:
    policy = workspace.ema_model
policy.to(device)
policy.eval()
print("Policy loaded and set to eval mode.")

# Get environment metadata from dataset.
env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path)
env_meta['env_kwargs']['use_object_obs'] = False  # disable object state observation

# Set absolute action mode if needed and initialize the rotation transformer.
abs_action = True
rotation_transformer = None
if abs_action:
    env_meta['env_kwargs']['controller_configs']['control_delta'] = True
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

# Define shape metadata (include expected modalities)
shape_meta = {
    'obs': {
        'robot0_eye_in_hand_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        # 'agentview_image': {'shape': [3, 84, 84], 'type': 'rgb'},
        'robot0_eef_pos': {'shape': [3]},
        'robot0_eef_quat': {'shape': [4]},
        'robot0_gripper_qpos': {'shape': [2]},
        'object': {'shape': [1]}  # adjust if needed
    },
    'action': {
        'shape': [10]
    }
}

# Create the environment.
raw_env = create_env(env_meta=env_meta, shape_meta=shape_meta, enable_render=True)
print("Created environment with name:", env_meta.get("name", "Unknown"))
print("Action size is", raw_env.action_space.shape[0] if hasattr(raw_env, "action_space") else "Unknown")
print("Original env observation keys:", list(raw_env.reset().keys()))

# Wrap the environment to inject dummy 'robot0_eye_in_hand_image' if missing.
env = DummyObsWrapper(raw_env)
# print("Wrapped env observation keys:", list(env.reset().keys()))

# Inference parameters
n_obs_steps = cfg.dataset_obs_steps if hasattr(cfg, "dataset_obs_steps") else 2
n_action_steps = cfg.n_action_steps if hasattr(cfg, "n_action_steps") else 8
max_steps = 700  # maximum steps per rollout
n_trials = 5   # number of inference trials
fps = 20         # frames per second for the output video

# Run trials and save video for each trial.
trial_success = []
for i in range(n_trials):
    print(f"Trial {i+1}/{n_trials}...")
    # Set return_imgs=True to record frames.
    success, imgs, imgs_eye= rollout_diffusion(env, policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=True)
    trial_success.append(success)
    print(f"Trial {i+1} success: {success}")
    
    # Save video if images were recorded.
    if imgs:
        video_filename = f"results/48_my_trial_{i+1}_mh_hand_output.mp4"
        imageio.mimwrite(video_filename, imgs, fps=fps, quality=8)
        print(f"Saved video: {video_filename}")
    if imgs_eye:
        video_filename_eye = f"results/48_my_trial_{i+1}_mh_hand_output_eye.mp4"
        imageio.mimwrite(video_filename_eye, imgs_eye, fps=fps, quality=8)
        print(f"Saved video: {video_filename_eye}")

mean_success = np.mean(trial_success)
print("Mean success over trials:", mean_success)


Using device: cuda


using obs modality: low_dim with keys: ['robot0_eef_quat', 'robot0_gripper_qpos', 'robot0_eef_pos']
using obs modality: rgb with keys: ['robot0_eye_in_hand_image']
using obs modality: depth with keys: []
using obs modality: scan with keys: []
Diffusion params: 1.737294e+07
Vision params: 1.119709e+07
Policy loaded and set to eval mode.
Created environment with name Lift
Action size is 7
Created environment with name: Unknown
Action size is Unknown
Original env observation keys: ['robot0_eye_in_hand_image', 'object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
Trial 1/5...
Trial 1 success: False
Saved video: results/48_my_trial_1_mh_hand_output.mp4
Saved video: results/48_my_trial_1_mh_hand_output_eye.mp4
Trial 2/5...
Trial 2 success: False
Saved video: results/48_my_trial_2_mh_hand_output.mp4
Saved video: results/48_my_trial_2_mh_hand_output_eye.mp4
Trial 3/5...
