In [1]:
#!/usr/bin/env python
import os
import sys
import time
import torch
import dill
import numpy as np
import collections
import imageio
import cv2

# -------------------------------
# Deoxys and Franka Interface Imports
# -------------------------------
sys.path.append("/home/franka_deoxys/deoxys_control/deoxys")
from deoxys import config_root
from deoxys.franka_interface import FrankaInterface
from deoxys.utils import YamlConfig
from deoxys.utils.config_utils import robot_config_parse_args
from deoxys.utils.input_utils import input2action
from deoxys.utils.io_devices import SpaceMouse
from deoxys.utils.log_utils import get_deoxys_example_logger
from deoxys.experimental.motion_utils import follow_joint_traj, reset_joints_to

sys.path.append("/home/franka_deoxys/deoxys_vision")
from deoxys_vision.networking.camera_redis_interface import CameraRedisSubInterface
from deoxys_vision.utils.camera_utils import assert_camera_ref_convention, get_camera_info

# -------------------------------
# Diffusion Policy Imports
# -------------------------------
sys.path.append("/home/franka_deoxys/diffusion_policy")
from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace import TrainDiffusionUnetHybridWorkspace
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.common.pytorch_util import dict_apply
from util_eval import RobotStateRawObsDictGenerator, FrameStackForTrans

# -------------------------------
# Set device
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

  from .autonotebook import tqdm as notebook_tqdm
pybullet build time: Nov 28 2023 23:52:03


Using device: cuda


In [2]:

# -------------------------------
# Load Robot and Controller Configurations
# -------------------------------
args = robot_config_parse_args()
robot_interface = FrankaInterface(os.path.join(config_root, args.interface_cfg))
controller_cfg = YamlConfig(os.path.join(config_root, args.controller_cfg)).as_easydict()
controller_type = args.controller_type

# Initialize SpaceMouse (if used)
# # For wireless or different devices, adjust vendor/product ids.
# spacemouse = SpaceMouse(vendor_id=9583, product_id=50770)
# spacemouse.start_control()

# (Assume that RobotStateRawObsDictGenerator is available in your environment)
# For example:
# from deoxys.raw_obs_generator import RobotStateRawObsDictGenerator
raw_obs_dict_generator = RobotStateRawObsDictGenerator()

def set_gripper(open=True):
    d = -1.0 if open else 1.0
    action_close = np.array([0., 0., 0., 0., 0., 0., d])
    robot_interface.control(
        controller_type=controller_type,
        action=action_close,
        controller_cfg=controller_cfg,
    )


In [3]:


# -------------------------------
# Setup Camera Interfaces
# -------------------------------
camera_ids = [0, 1]
cr_interfaces = {}
use_depth = False

for camera_id in camera_ids:
    camera_ref = f"rs_{camera_id}"
    assert_camera_ref_convention(camera_ref)
    camera_info = get_camera_info(camera_ref)
    print("Camera Info for {}:".format(camera_ref), camera_info)
    cr_interface = CameraRedisSubInterface(camera_info=camera_info, use_depth=use_depth, redis_host='127.0.0.1')
    cr_interface.start()
    cr_interfaces[camera_id] = cr_interface

def get_imgs(use_depth=False):
    data = {}
    for camera_id in camera_ids:
        img_info = cr_interfaces[camera_id].get_img_info()
        data[f"camera_{camera_id}"] = img_info

        imgs = cr_interfaces[camera_id].get_img()
        color_img = imgs["color"][..., ::-1]  # Convert from BGR to RGB if needed.
        color_img = cv2.resize(color_img, (320, 240))
        data[f"camera_{camera_id}_color"] = color_img

        if use_depth:
            depth_img = imgs["depth"]
            depth_img = cv2.resize(depth_img, (224, 224))
            data[f"camera_{camera_id}_depth"] = depth_img
    return data

def get_current_obs():
    """
    Gather the latest robot state and images to form an observation dictionary.
    The raw state is obtained from the deoxys robot interface, while images come from the camera interfaces.
    """
    last_state = robot_interface._state_buffer[-1]
    last_gripper_state = robot_interface._gripper_state_buffer[-1]
    obs_dict = raw_obs_dict_generator.get_raw_obs_dict({
        "last_state": last_state,
        "last_gripper_state": last_gripper_state
    })
    data = get_imgs(use_depth=False)
    # Map your camera images to the keys expected by the diffusion policy.
    # Here, we assume:
    #   agentview_image  <- wrist camera image from camera_0_color
    #   robot0_eye_in_hand_image  <- front camera image from camera_1_color
    agentview_img = data['camera_0_color']
    eye_in_hand_img = data['camera_1_color']
    # Transpose images to channel-first if required.
    obs_dict['agentview_rgb'] = agentview_img.transpose(2, 0, 1).astype(np.float32) / 255.0
    obs_dict['eye_in_hand_rgb'] = eye_in_hand_img.transpose(2, 0, 1).astype(np.float32) / 255.0
    return obs_dict

# -------------------------------
# Helper Classes
# -------------------------------
class FrameStackForTrans:
    def __init__(self, num_frames):
        self.num_frames = num_frames
        self.obs_history = {}

    def reset(self, init_obs):
        self.obs_history = {}
        for k in init_obs:
            self.obs_history[k] = collections.deque([init_obs[k][None] for _ in range(self.num_frames)], maxlen=self.num_frames)
        return {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}

    def add_new_obs(self, new_obs):
        for k in new_obs:
            if 'timesteps' in k or 'actions' in k:
                continue
            self.obs_history[k].append(new_obs[k][None])
        return {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}

def undo_transform_action(action, rotation_transformer):
    # This mirrors your simulation code logic.
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

Camera Info for rs_0: {'camera_id': 0, 'camera_type': 'rs', 'camera_name': 'camera_rs_0'}
CameraRedisSubInterface:: {'camera_id': 0, 'camera_type': 'rs', 'camera_name': 'camera_rs_0'} True False
Camera Info for rs_1: {'camera_id': 1, 'camera_type': 'rs', 'camera_name': 'camera_rs_1'}
CameraRedisSubInterface:: {'camera_id': 1, 'camera_type': 'rs', 'camera_name': 'camera_rs_1'} True False


In [4]:


# -------------------------------
# Load Diffusion Policy Checkpoint
# -------------------------------

checkpoint = "/home/franka_deoxys/data_franka/epoch_700_drawer_bellpepper_bed.ckpt"
with open(checkpoint, 'rb') as f:
    payload = torch.load(f, pickle_module=dill)
cfg = payload['cfg']

workspace = TrainDiffusionUnetHybridWorkspace(cfg, output_dir=None)
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# Select the policy model (use EMA model if enabled)
policy = workspace.model
if getattr(cfg.training, "use_ema", False):
    policy = workspace.ema_model
policy.to(device)
policy.eval()
print("Diffusion policy loaded and set to eval mode.")

# -------------------------------
# Optional: Initialize Rotation Transformer (if using absolute actions)
# -------------------------------
abs_action = True
rotation_transformer = None
if abs_action:
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')



using obs modality: low_dim with keys: ['ee_states', 'gripper_states', 'joint_states']
using obs modality: rgb with keys: ['agentview_rgb', 'eye_in_hand_rgb']
using obs modality: depth with keys: []
using obs modality: scan with keys: []




Diffusion params: 2.564722e+08
Vision params: 2.239418e+07
Diffusion policy loaded and set to eval mode.


In [5]:


# -------------------------------
# Inference / Rollout Function
# -------------------------------
def rollout_diffusion(policy, rotation_transformer, n_obs_steps, n_action_steps, max_steps, return_imgs=False):
    """
    This function repeatedly obtains observations from the robot,
    uses the diffusion policy to predict an action, and then sends the action
    to the Franka via the deoxys interface.
    """
    # Keys expected by the policy (images and robot state).
    # keys_select = ['joint_states', 'ee_states', 'eye_in_hand_rgb', 'gripper_states']
    keys_select = ['agentview_rgb', 'joint_states', 'ee_states', 'eye_in_hand_rgb', 'gripper_states']
    imgs = []
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = get_current_obs()
    # (Mapping: ensure that any additional keys required by your model are present.)
    obs = framestacker.reset(obs)
    done = False
    success = False
    step = 0

    while not done and step < max_steps:
        # Prepare the observation dictionary for the policy.
        np_obs_dict = {key: obs[key][None, :] for key in keys_select if key in obs}
        obs_tensor = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
        with torch.no_grad():
            action_dict = policy.predict_action(obs_tensor)
        np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
        env_action = np_action_dict['action']
        env_action = undo_transform_action(env_action, rotation_transformer)
        env_action = env_action.squeeze()
        obs_img = get_current_obs()
        # print(obs_img['eye_in_hand_rgb'].shape)
        # For safety, you might use only a subset of the predicted action.
        for delta_action in env_action[:4]:  # adjust slicing as required by your controller
        # print(delta_action)

            # Send control command to the robot.
            robot_interface.control(
                controller_type=controller_type,
                action=delta_action,  # Convert numpy array to list of floats
                controller_cfg=controller_cfg,
            )


        # Optionally, capture images after applying the action.
        if return_imgs:
            obs_img = get_current_obs()
            # For visualization, record one of the camera images.
            imgs.append(obs_img['eye_in_hand_rgb'].transpose(1, 2, 0))  # convert to HxWxC

        step += 1
        # A simple termination condition: you can add your own task success logic.
        if step >= max_steps:
            done = True
        else:
            # Update observation with new sensor data.
            new_obs = get_current_obs()
            obs = framestacker.add_new_obs(new_obs)
            # Optionally, set a success flag if a condition is met.
            # success = check_success_condition(new_obs)

    return success, imgs

In [6]:
# block normal
# reset_joint_positions_d = [-0.03,0.244,-0.03,-1.673,0.134,1.905,0.665]
# block agent plus eye
# reset_joint_positions_b= [0.034,0.109,-0.012,-1.63,0.005,1.776,0.696]
reset_joint_positions = [
    0.09162008114028396,
    -0.19826458111314524,
    -0.01990020486871322,
    -2.4732269941140346,
    -0.01307073642274261,
    2.30396583422025,
    0.8480939705504309,
]
# candy
# reset_joint_positions_c = [-0.03,0.244,-0.03,
#                                -1.673,0.134,1.905,0.665]


In [7]:
# # reset robot
# set_gripper(open=True)
# reset_joints_to(robot_interface, reset_joint_positions)
# set_gripper(open=True)

In [8]:

# policy.reset()
# -------------------------------
# Main Inference Loop
# -------------------------------

n_obs_steps = getattr(cfg, "dataset_obs_steps", 2)
n_action_steps = getattr(cfg, "n_action_steps", 4)
max_steps = 400   # maximum steps per trial
n_trials = 5      # number of inference trials
fps = 20          # frames per second for output video

trial_success = []
# for trial in range(n_trials):
print("Starting trial ")
success, imgs = rollout_diffusion(policy, rotation_transformer, n_obs_steps,
                                    n_action_steps, max_steps, return_imgs=True)
trial_success.append(success)
# print("Trial {} success: {}".format(trial+1, success))
# if imgs:
#     video_filename = f"trial_{trial+1}_output.mp4"
#     imageio.mimwrite(video_filename, imgs, fps=fps, quality=8)
#     print("Saved video:", video_filename)
# Pause briefly between trials.
time.sleep(2.0)

mean_success = np.mean(trial_success)
print("Mean success over trials:", mean_success)

# Close robot interface safely.
robot_interface.close()


Starting trial 


KeyboardInterrupt: 