In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
from pathlib import Path
sys.path.append(str(Path(os.getcwd()).parent.absolute()))

In [2]:
import mujoco
import mujoco.viewer
from pathlib import Path

import numpy as np
from robot import kinematics
from robot.sim import LEFT_KINEMATIC_CHAIN, RIGHT_KINEMATIC_CHAIN, BimanualSim, randomize_block_position

from policy.privileged_policy import PrivilegedPolicy

with BimanualSim(merge_xml_files=[Path('block.xml')], on_mujoco_init=randomize_block_position) as sim:
  policy = PrivilegedPolicy(sim.model, sim.data)
  
  prev_time = sim.data.time
  obs = sim.get_obs()
  recent_min_gripper_deviation = 0.0
  recent_min_gripper_deviation_r = 0.0
  gripper_deviation_stable_for = 0
  gripper_deviation_stable_for_r = 0
  with mujoco.viewer.launch_passive(sim.model, sim.data) as viewer:
    while viewer.is_running():
      if sim.data.time < prev_time:
        sim.reset()
      prev_time = sim.data.time

      # ubiquitous obs+action sim step loop
      action = policy(obs)
      obs = sim.step(action)

      block_pos = sim.data.xpos[mujoco.mj_name2id(sim.model, mujoco.mjtObj.mjOBJ_BODY, 'block')]
      gripper_pos, _ = kinematics.augmented_forward(RIGHT_KINEMATIC_CHAIN, obs.qpos.right_arm[:-2], np.array([0.1, 0.0, 0.0]))
      gripper_pos_l, _ = kinematics.augmented_forward(LEFT_KINEMATIC_CHAIN, obs.qpos.left_arm[:-2], np.array([0.1, 0.0, 0.0]))
      block_close_to_gripper = np.linalg.norm(block_pos - gripper_pos).item() < 0.05
      block_far_from_left = np.linalg.norm(block_pos - gripper_pos_l).item() > 0.2

      current_gripper_obs_r = obs.qpos.to_approximate_action().right_gripper
      gripper_deviation_r = obs.qpos.to_approximate_action().right_gripper - action.right_gripper

      current_gripper_obs = obs.qpos.to_approximate_action().left_gripper
      gripper_deviation = obs.qpos.to_approximate_action().left_gripper - action.left_gripper
      grip = False
      if not recent_min_gripper_deviation - 0.01 < gripper_deviation < recent_min_gripper_deviation + 0.01:
        # print('reset', gripper_deviation)
        recent_min_gripper_deviation = gripper_deviation
        gripper_deviation_stable_for = 0
      else:
        if gripper_deviation_stable_for > 10 and gripper_deviation > 0.2:
          grip = True
        gripper_deviation_stable_for += 1
      grip_r = False
      success = False
      if not recent_min_gripper_deviation_r - 0.01 < gripper_deviation_r < recent_min_gripper_deviation_r + 0.01:
        # print('reset', gripper_deviation)
        recent_min_gripper_deviation_r = gripper_deviation_r
        gripper_deviation_stable_for_r = 0
      else:
        if gripper_deviation_stable_for_r > 10 and gripper_deviation_r > 0.2:
          grip_r = True
          if block_close_to_gripper and block_far_from_left:
            success = True
        gripper_deviation_stable_for_r += 1
      print(f'l: {grip}, r: {grip_r}, success: {success} {sim.data.time:.2f}')
      viewer.sync()
      if success:
        break

l: False, r: False, success: False 0.01
l: False, r: False, success: False 0.02
l: False, r: False, success: False 0.03
l: False, r: False, success: False 0.04
l: False, r: False, success: False 0.05
l: False, r: False, success: False 0.06
l: False, r: False, success: False 0.07
l: False, r: False, success: False 0.08
l: False, r: False, success: False 0.09
l: False, r: False, success: False 0.10
l: False, r: False, success: False 0.11
l: False, r: False, success: False 0.12
l: False, r: False, success: False 0.13
l: False, r: False, success: False 0.14
l: False, r: False, success: False 0.15
l: False, r: False, success: False 0.16
l: False, r: False, success: False 0.17
l: False, r: False, success: False 0.18
l: False, r: False, success: False 0.19
l: False, r: False, success: False 0.20
l: False, r: False, success: False 0.21
l: False, r: False, success: False 0.22
l: False, r: False, success: False 0.23
l: False, r: False, success: False 0.24
l: False, r: False, success: False 0.25


In [None]:
from typing import Callable, Literal

from tqdm import tqdm

from robot.sim import BimanualAction, BimanualObs

class GripperTracker:
  def __init__(self, arm: Literal['left', 'right']) -> None:
    if arm == 'left':
      self.kinematic_chain = LEFT_KINEMATIC_CHAIN
      self.get_joint_pos = lambda obs: obs.qpos.left_arm[:-2]
      self.get_gripper_error = lambda action, obs: obs.qpos.to_approximate_action().left_gripper - action.left_gripper
    else:
      self.kinematic_chain = RIGHT_KINEMATIC_CHAIN
      self.get_joint_pos = lambda obs: obs.qpos.right_arm[:-2]
      self.get_gripper_error = lambda action, obs: obs.qpos.to_approximate_action().right_gripper - action.right_gripper
    self.last_obs: BimanualObs | None = None
    self.last_stable_error = 0.0
    self.stability_duration = 0
  
  def update(self, action: BimanualAction, obs: BimanualObs):
    stability_threshold = 0.01
    self.last_obs = obs
    gripper_error = self.get_gripper_error(action, obs)
    if self.last_stable_error - stability_threshold < gripper_error < self.last_stable_error + stability_threshold:\
      self.stability_duration += 1
    else:
      self.last_stable_error = gripper_error
      self.stability_duration = 0

  def pos(self) -> np.ndarray:
    return kinematics.augmented_forward(self.kinematic_chain, self.get_joint_pos(self.last_obs), np.array([0.1, 0.0, 0.0]))[0]
  
  def is_stable(self) -> bool:
    return self.stability_duration > 10
  
  def is_gripping(self) -> bool:
    return self.last_stable_error > 0.2

def evaluate_rollout(create_sim: Callable[[], BimanualSim], policy: Callable[[BimanualObs], BimanualAction], max_steps_per_rollout: int) -> bool:
  left_gripper_tracker = GripperTracker('left')
  right_gripper_tracker = GripperTracker('right')

  with create_sim() as sim:
    obs = sim.get_obs()
    for _ in tqdm(range(max_steps_per_rollout)):
      action = policy(obs)
      obs = sim.step(action)

      left_gripper_tracker.update(action, obs)
      right_gripper_tracker.update(action, obs)

      block_pos = sim.data.xpos[mujoco.mj_name2id(sim.model, mujoco.mjtObj.mjOBJ_BODY, 'block')]
      right_gripper_distance = np.linalg.norm(block_pos - right_gripper_tracker.pos()).item()
      left_gripper_distance = np.linalg.norm(block_pos - left_gripper_tracker.pos()).item()
      if right_gripper_tracker.is_gripping() and right_gripper_distance < 0.05 and left_gripper_distance > 0.2:
        return True
  return False

sim = BimanualSim(merge_xml_files=[Path('block.xml')], on_mujoco_init=randomize_block_position)
policy = PrivilegedPolicy(sim.model, sim.data)
evaluate_rollout(lambda: sim, policy, max_steps_per_rollout=600)

 66%|██████▌   | 397/600 [01:44<00:53,  3.79it/s]


True