# Examples for MuJoCo and Bimanual Sim Usage

In [1]:
%load_ext autoreload
%autoreload 2

### Low-level MuJoCo Model Instantiation

In [None]:
import os
import mujoco
import mujoco.viewer

from robot_descriptions import aloha_mj_description

# instantiate the MuJoCo bimanual simulation model.
from xml.etree import ElementTree as ET

model = mujoco.MjModel.from_xml_path(aloha_mj_description.MJCF_PATH)
data = mujoco.MjData(model)

# launch an interactive viewer.
mujoco.viewer.launch(model, data)

In [None]:
data.qpos

array([-1.24866767e-02, -9.85510335e-02,  1.00000000e+00,  8.35067762e-02,
        1.49315150e-01, -7.94138101e-02,  3.74128540e-02,  4.02052328e-02,
       -8.73947428e-04,  1.06232035e-01, -1.28021856e-01, -1.63744567e-02,
       -4.58504852e-02, -8.15300697e-02,  3.46516773e-02,  3.86005880e-02])

In [None]:
data.ctrl

array([ 0.1571  , -0.591665,  0.039415,  0.34562 ,  0.819155, -0.12568 ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ])

In [42]:
from pathlib import Path
from kinematics import extract_kinematic_info, forward_kinematics
from sim import BimanualSim
from robot_descriptions import aloha_mj_description



sim = BimanualSim(merge_xml_files=[Path('block.xml'), Path('indicator.xml')])

k = extract_kinematic_info(aloha_mj_description.MJCF_PATH, 'right/base_link', 'right/gripper_base')
for j in k:
  print(j)
joint_pos, joint_quat = forward_kinematics(k, sim.data.qpos[8:14])
# mocap_id = mujoco.mj_name2id(sim.model, mujoco.mjtObj.mjOBJ_BODY, 'indicator')
print(sim.data.mocap_pos.shape, joint_pos[-1].shape)
sim.data.mocap_pos[0] = joint_pos[-1]
# sim.data.qpos[23:26] = joint_pos[-1]

sim.launch_viewer()

KinematicLink(joint_name='right/waist', joint_limits=(np.float64(-3.14158), np.float64(3.14158)), rotation_axis=(np.float64(0.0), np.float64(0.0), np.float64(1.0)), origin_pos=array([ 0.469, -0.019,  0.099]), quat=(np.float64(0.0), np.float64(0.0), np.float64(1.0), np.float64(0.0)))
KinematicLink(joint_name='right/shoulder', joint_limits=(np.float64(-1.85005), np.float64(1.25664)), rotation_axis=(np.float64(0.0), np.float64(1.0), np.float64(0.0)), origin_pos=array([0.     , 0.     , 0.04805]), quat=(np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(1.0)))
KinematicLink(joint_name='right/elbow', joint_limits=(np.float64(-1.76278), np.float64(1.6057)), rotation_axis=(np.float64(0.0), np.float64(1.0), np.float64(0.0)), origin_pos=array([0.05955, 0.     , 0.3    ]), quat=(np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(1.0)))
KinematicLink(joint_name='right/forearm_roll', joint_limits=(np.float64(-3.14158), np.float64(3.14158)), rotation_axis=(np.float64(1.0), np.

In [None]:
import mujoco
from robot_descriptions import aloha_mj_description
from typing import Any, Callable, Dict, List, Literal, Tuple, get_args
from pathlib import Path
import numpy as np
from scipy.spatial.transform import Rotation as scipyrotation
from sim import BimanualSim
from kinematics import KinematicLink, extract_kinematic_info, forward_kinematics


def quat_to_axes(quat: np.ndarray) -> np.ndarray:
  # print('quat', quat)
  rotation = scipyrotation.from_quat(quat)
  return rotation.apply(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]))

def wxyz_to_xyzw(quat: np.ndarray) -> np.ndarray:
  return np.concat((quat[1:], quat[0:1]))

def get_block_orientation(model: mujoco.MjModel, data: mujoco.MjData) -> Tuple[np.ndarray, np.ndarray]:
  block_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, 'block')
  # print('blockid', block_id)
  axes = quat_to_axes(wxyz_to_xyzw(data.xquat[block_id]))

  vertical_axis_index = np.abs(axes @ np.array([0, 0, 1]).T).argmax()
  vertical_axis = axes[vertical_axis_index]
  axes = np.delete(axes, vertical_axis_index, axis=0)

  base_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, 'left/base_link')
  base_to_block = data.xpos[block_id] - data.xpos[base_id]
  facing_axis_index = np.abs(axes @ base_to_block.T).argmax()
  facing_axis = axes[facing_axis_index]
  axes = np.delete(axes, facing_axis_index, axis=0)

  side_axis = axes[0]
  ordered_block_axes = np.stack((facing_axis, side_axis, vertical_axis))
  
  return data.xpos[block_id], ordered_block_axes

def augmented_kinematic_forward(
  kinematic_chain: List[KinematicLink],
  joint_q: np.ndarray,
  end_effector_displacement: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
  augmented_kinematic_chain = kinematic_chain + [
    KinematicLink(
      'end_effector',
      (0, 0),
      (0, 1, 0),
      tuple(end_effector_displacement),
      (0, 0, 0, 1)
    )
  ]
  joint_pos, joint_quat = forward_kinematics(augmented_kinematic_chain, np.concat((joint_q, np.array([0.0]))))
  return joint_pos[-1], joint_quat[-1]

def build_grasp_ik_objective(
  kinematic_chain: List[KinematicLink],
  end_effector_displacement: np.ndarray,
  target_pos: np.ndarray,
  grasp_axis: np.ndarray
):
  def objective(joint_q: np.ndarray):
    gripper_pos, gripper_quat = augmented_kinematic_forward(kinematic_chain, joint_q, end_effector_displacement)
    gripper_grasp_plane = scipyrotation.from_quat(gripper_quat).apply(np.array([[1, 0, 0], [0, 0, 1]]))
    gripper_direction = gripper_grasp_plane[0:1]
    # print('obj', gripper_pos, gripper_quat, gripper_direction)
    target_direction = target_pos - gripper_pos
    target_direction /= np.linalg.norm(target_direction)

    pos_error = gripper_pos - target_pos
    # print('grip', gripper_direction, grasp_axis)
    grasp_angle_error = np.abs(gripper_grasp_plane.dot(grasp_axis))
    # pointing_angle_error = np.abs(gripper_direction.dot(target_direction))
    # print(gripper_direction)
    # print('er', pos_error, rot_error)
    return np.concat([0.1 * pos_error, grasp_angle_error])
  return objective

import scipy
kinematic_chain = extract_kinematic_info(aloha_mj_description.MJCF_PATH, 'left/base_link', 'left/gripper_base')
from sim import BimanualAction, BimanualObs


class PrivilegedPolicy:
  Stage = Literal[
    'left-greet-block',
    'left-approach-block',
    'left-grasp-block',
    'left-raise-block',
    'right-greet-block',
    'right-approach-block',
    # 'right-angle-around-block',
    'right-grasp-block',
    'left-release-block',
    'right-retract-block',
    'done'
  ]
  def __init__(self, model: mujoco.MjModel, data: mujoco.MjData):
    self.model = model
    self.data = data
    self.policy_stage: PrivilegedPolicy.Stage = 'left-greet-block'
    self.subpolicies: Dict[PrivilegedPolicy.Stage, Callable[[BimanualObs, Dict], BimanualAction]] = {
      'left-greet-block': self._left_greet_block,
      'left-approach-block': self._left_approach_block,
      'left-grasp-block': self._left_grasp_block,
      'left-raise-block': self._left_raise_block,
      'right-greet-block': self._right_greet_block,
      'right-approach-block': self._right_approach_block,
      # 'right-angle-around-block': self._right_angle_around_block,
      'right-grasp-block': self._right_grasp_block,
      'left-release-block': self._left_release_block,
      'right-retract-block': self._right_retract_block,
      'done': self._done,
    }
    self.subpolicy_state: Dict[PrivilegedPolicy.Stage, Any] = {stage: {} for stage in get_args(PrivilegedPolicy.Stage)}
    self.left_kinematic_chain = extract_kinematic_info(aloha_mj_description.MJCF_PATH, 'left/base_link', 'left/gripper_base')
    self.right_kinematic_chain = extract_kinematic_info(aloha_mj_description.MJCF_PATH, 'right/base_link', 'right/gripper_base')
    self.previous_action = BimanualAction()

  def __call__(self, obs: BimanualObs) -> BimanualAction:
    action = self.subpolicies[self.policy_stage](obs, self.subpolicy_state[self.policy_stage])
    self.previous_action = action
    return action

  def _left_greet_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    target_pos, block_axes = get_block_orientation(self.model, self.data)
    action = self._inverse_kinematics_pass(
      'left',
      self.previous_action,
      obs,
      state,
      target_pos,
      block_axes[1],
      end_effector_displacement=np.array([0.2, 0.0, 0.0]),
      on_target_reached='left-approach-block'
    )
    action.left_gripper = 0.37

    action.right_shoulder = -1
    action.right_elbow = 1

    return action

  def _left_approach_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    target_pos, block_axes = get_block_orientation(self.model, self.data)
    return self._inverse_kinematics_pass(
      'left',
      self.previous_action,
      obs,
      state,
      target_pos,
      block_axes[1],
      end_effector_displacement=np.array([0.1, 0.0, 0.0]),
      on_target_reached='left-grasp-block'
    )
  
  def _left_grasp_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    self._maintain_stage(state, steps=10, then='left-raise-block')
    action = obs.qpos.to_approximate_action()
    action.left_gripper = 0
    return action
  
  def _left_raise_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    target_pos, gripper_axis = np.array([-0.05, 0.0, 0.4]), np.array([0.0, 1.0, 0.0])
    action = self._inverse_kinematics_pass(
      'left',
      self.previous_action,
      obs,
      state,
      target_pos,
      gripper_axis,
      end_effector_displacement=np.array([0.1, 0.0, 0.0]),
      on_target_reached='right-greet-block'
    )
    action.left_gripper = 0
    return action
    
  def _right_greet_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    base_action = self.previous_action

    target_pos, block_axes = get_block_orientation(self.model, self.data)
    action = self._inverse_kinematics_pass(
      'right',
      base_action,
      obs,
      state,
      target_pos,
      block_axes[2],
      end_effector_displacement=np.array([0.2, 0.0, 0.0]),
      on_target_reached='right-approach-block'
    )
    action.left_gripper = 0
    action.right_gripper = 0.37

    return action
  
  def _right_approach_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    base_action = self.previous_action

    target_pos, block_axes = get_block_orientation(self.model, self.data)
    action = self._inverse_kinematics_pass(
      'right',
      base_action,
      obs,
      state,
      target_pos,
      block_axes[2],
      end_effector_displacement=np.array([0.1, 0.0, 0.0]),
      on_target_reached='right-grasp-block'
    )
    action.left_gripper = 0
    action.right_gripper = 0.37

    return action
  
  # def _right_angle_around_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
  #   base_action = self.previous_action

  #   target_pos, block_axes = get_block_orientation(self.model, self.data)
  #   action = self._inverse_kinematics_pass(
  #     'right',
  #     base_action,
  #     obs,
  #     state,
  #     target_pos,
  #     block_axes[2],,
  #     target_tolerance_distance=0.01,
  #     end_effector_displacement=np.array([0.1, 0.0, 0.0])
  #     on_target_reached='right-grasp-block'
  #   )
  #   action.left_gripper = 0
  #   action.right_gripper = 0.37

  #   return action
  
  def _right_grasp_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    self._maintain_stage(state, steps=10, then='left-release-block')
    action = self.previous_action
    action.left_gripper = 0
    action.right_gripper = 0
    return action
  
  def _left_release_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    self._maintain_stage(state, steps=10, then='right-retract-block')
    action = self.previous_action
    action.left_gripper = 0.37
    action.right_gripper = 0
    return action
  
  def _right_retract_block(self, obs: BimanualObs, state: Dict) -> BimanualAction:
    target_pos, gripper_axis = np.array([0.2, 0.0, 0.4]), np.array([0.0, 1.0, 0.0])
    action = self._inverse_kinematics_pass(
      'right',
      self.previous_action,
      obs,
      state,
      target_pos,
      gripper_axis,
      end_effector_displacement=np.array([0.1, 0.0, 0.0]),
      on_target_reached='done'
    )
    action.right_gripper = 0
    return action
  
  def _done(self, _: BimanualObs, __: Dict) -> BimanualAction:
    return self.previous_action

  def _maintain_stage(self, state: Dict, steps: int, then: 'PrivilegedPolicy.Stage'):
    if 'steps' not in state:
      state['steps'] = 0
    state['steps'] += 1
    if state['steps'] > steps:
      self.policy_stage = then

  def _inverse_kinematics_pass(
    self,
    arm: Literal['left', 'right'],
    base_action: BimanualAction,
    obs: BimanualObs,
    state: Dict,
    target_pos: np.ndarray,
    grasp_axis: np.ndarray,
    on_target_reached: 'PrivilegedPolicy.Stage',
    target_tolerance_distance: float = 0.04,
    end_effector_displacement: np.ndarray | None = None
  ) -> BimanualAction:
    if 'settling-steps' not in state:
      state['settling-steps'] = 0

    if arm == 'left':
      arm_joints_obs, arm_joints_action, kinematic_chain = slice(0, 6), slice(0, 6), self.left_kinematic_chain
    else:
      arm_joints_obs, arm_joints_action, kinematic_chain = slice(8, 14), slice(7, 13), self.right_kinematic_chain
    if end_effector_displacement is None:
      end_effector_displacement = np.array([0.0, 0.0, 0.0])

    # perform forward kinematics to get gripper position and retrieve block grasping info
    joint_angles = obs.qpos.array[arm_joints_obs]  # exclude the fingers
    gripper_pos, gripper_quat = augmented_kinematic_forward(kinematic_chain, joint_angles, end_effector_displacement)

    # progress the policy stage if the block is within grasping distance (4cm) continuously for 10 timesteps
    if np.linalg.norm(gripper_pos - target_pos) <= target_tolerance_distance:
      state['settling-steps'] += 1
      if state['settling-steps'] > 10:
        self.policy_stage = on_target_reached
    else:
      state['settling-steps'] = 0

    # calculate the target position as at most 5cm in the direction of the block
    gripper_to_target = target_pos - gripper_pos
    target_distance = min(0.05, np.linalg.norm(gripper_to_target).item())
    target_dir = gripper_to_target / np.linalg.norm(gripper_to_target)
    constrained_target_pos = gripper_pos + target_dir * target_distance
    constrained_target_pos[2] = max(constrained_target_pos[2], 0.07)

    # calculate target joint angles using inverse kinematics
    lower_bound = np.array([l.joint_limits[0] for l in kinematic_chain])
    upper_bound = np.array([l.joint_limits[1] for l in kinematic_chain])
    fit_joint_angles = scipy.optimize.least_squares(
      build_grasp_ik_objective(kinematic_chain, end_effector_displacement, constrained_target_pos, grasp_axis),
      joint_angles.clip(min=lower_bound, max=upper_bound),
      bounds=(lower_bound, upper_bound)
    ).x

    # set action values
    base_action.array[arm_joints_action] = fit_joint_angles
    return base_action


def on_mujoco_init(model: mujoco.MjModel, data: mujoco.MjData):
  data.qpos[16:19] += np.array([0, 0.1, 0])
  return model, data
sim = BimanualSim(merge_xml_files=[Path('block.xml'), Path('indicator.xml')], on_mujoco_init=on_mujoco_init)
policy = PrivilegedPolicy(sim.model, sim.data)

import mujoco.viewer
prev_time = sim.data.time
with mujoco.viewer.launch_passive(sim.model, sim.data) as viewer:
  while viewer.is_running():
    if sim.data.time < prev_time:
      sim.reset()
    prev_time = sim.data.time
    sim.step(policy(sim.get_obs()))
    viewer.sync()

In [33]:
rk = extract_kinematic_info(aloha_mj_description.MJCF_PATH, 'right/base_link', 'right/gripper_base')
rk

[KinematicLink(joint_name='right/waist', joint_limits=(np.float64(-3.14158), np.float64(3.14158)), rotation_axis=(np.float64(0.0), np.float64(0.0), np.float64(1.0)), origin_pos=array([ 0.469, -0.019,  0.099]), quat=(np.float64(0.0), np.float64(0.0), np.float64(1.0), np.float64(0.0))),
 KinematicLink(joint_name='right/shoulder', joint_limits=(np.float64(-1.85005), np.float64(1.25664)), rotation_axis=(np.float64(0.0), np.float64(1.0), np.float64(0.0)), origin_pos=array([0.     , 0.     , 0.04805]), quat=(np.float64(0.0), np.float64(0.0), np.float64(1.0), np.float64(0.0))),
 KinematicLink(joint_name='right/elbow', joint_limits=(np.float64(-1.76278), np.float64(1.6057)), rotation_axis=(np.float64(0.0), np.float64(1.0), np.float64(0.0)), origin_pos=array([0.05955, 0.     , 0.3    ]), quat=(np.float64(0.0), np.float64(0.0), np.float64(1.0), np.float64(0.0))),
 KinematicLink(joint_name='right/forearm_roll', joint_limits=(np.float64(-3.14158), np.float64(3.14158)), rotation_axis=(np.float64(1.

In [15]:
extract_kinematic_info(aloha_mj_description.MJCF_PATH, 'left/base_link', 'left/gripper_base')

[KinematicLink(joint_name='left/waist', joint_limits=(np.float64(-3.14158), np.float64(3.14158)), rotation_axis=(np.float64(0.0), np.float64(0.0), np.float64(1.0)), origin_pos=array([-0.469, -0.019,  0.099]), quat=(np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(1.0))),
 KinematicLink(joint_name='left/shoulder', joint_limits=(np.float64(-1.85005), np.float64(1.25664)), rotation_axis=(np.float64(0.0), np.float64(1.0), np.float64(0.0)), origin_pos=array([0.     , 0.     , 0.04805]), quat=(np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(1.0))),
 KinematicLink(joint_name='left/elbow', joint_limits=(np.float64(-1.76278), np.float64(1.6057)), rotation_axis=(np.float64(0.0), np.float64(1.0), np.float64(0.0)), origin_pos=array([0.05955, 0.     , 0.3    ]), quat=(np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(1.0))),
 KinematicLink(joint_name='left/forearm_roll', joint_limits=(np.float64(-3.14158), np.float64(3.14158)), rotation_axis=(np.float64(1.0), 

In [21]:
forward_kinematics(rk, np.zeros(len(rk)))

(array([[ 0.469   , -0.019   ,  0.099   ],
        [ 0.469   , -0.019   ,  0.14705 ],
        [ 0.52855 , -0.019   ,  0.44705 ],
        [ 0.32855 , -0.019   ,  0.44705 ],
        [ 0.42855 , -0.019   ,  0.44705 ],
        [ 0.358806, -0.019   ,  0.44705 ]]),
 array([[ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0., -1.],
        [ 0.,  0., -1., -0.],
        [ 0.,  0.,  0.,  1.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0., -1.]]))

In [148]:
def on_mujoco_init(model: mujoco.MjModel, data: mujoco.MjData):
  data.qpos[16:19] += np.array([0, 0.1, 0])
  return model, data
sim = BimanualSim(merge_xml_files=[Path('block.xml'), Path('indicator.xml')], on_mujoco_init=on_mujoco_init)
policy = PrivilegedPolicy(sim.model, sim.data)
import time
start = time.perf_counter()
sim.rollout(policy, 400)
print(time.perf_counter() - start)

37.57714790001046


In [None]:
from robot.sim import BimanualAction, BimanualObs


def perfect_policy(model: mujoco.MjModel, data: mujoco.MjData, obs: BimanualObs):
  
  action = BimanualAction()

## BimanualSim Policy Deployment

In [3]:
import os
import numpy as np
import cv2
from IPython.display import Video
from sim import BimanualObs, BimanualSim

image_dims = (480, 640)  # h x w


def toy_policy(obs: BimanualObs) -> np.ndarray:
  # simply set the target to the current joint positions + 0.01
  return obs.qpos[list(range(7)) + list(range(8, 15))] + 0.01

# rollout simulation for 30 steps
sim = BimanualSim(camera_dims=image_dims, obs_camera_names=['wrist_cam_left'])
left_wrist_frames = []

obs = sim.get_obs()
for sim_step in range(30):
  left_wrist_frames.append(obs.visual[0])
  action = toy_policy(obs)
  obs = sim.step(action)

# save frames to video
os.makedirs('out', exist_ok=True)
video_path = 'out/example-bimanual-rollout.mp4'
video_writer = cv2.VideoWriter(video_path, cv2.VideoWriter.fourcc(*'H264'), 20, tuple(reversed(image_dims)))
for frame in left_wrist_frames:
  video_writer.write(cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR))
video_writer.release()
Video(video_path, width=image_dims[1], height=image_dims[0])