In [1]:
import time
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys
from dataclasses import dataclass


sys.path.insert(0, "/root/work")
import bench_buddy
print("Loaded bench_buddy from:", bench_buddy.__file__)

from pydrake.multibody.plant import CoulombFriction
from pydrake.multibody.tree import JointActuatorIndex, FixedOffsetFrame
from pydrake.geometry import Box
from pydrake.systems.framework import LeafSystem, BasicVector
from pydrake.systems.primitives import Adder, LogVectorOutput
from pydrake.systems.sensors import CameraConfig, ApplyCameraConfig
from bar_rgbd_sensor import MeshcatPointCloudPublisher
from pydrake.all import InverseKinematics, Solve

from __future__ import annotations

import math
from typing import Optional, Tuple

import numpy as np

from pydrake.math import RigidTransform, RotationMatrix
from pydrake.perception import BaseField, Fields, PointCloud
from pydrake.common.value import AbstractValue
from pydrake.systems.framework import BasicVector, LeafSystem

from pydrake.systems.sensors import CameraInfo, ImageDepth32F

try:
    from scipy.spatial import cKDTree  # type: ignore

    _HAS_SCIPY = True
except ImportError:  # pragma: no cover - scipy should exist, but guard anyway
    cKDTree = None
    _HAS_SCIPY = False

# Make /root/work importable
sys.path.insert(0, "/root/work")

import bench_buddy
print("Loaded bench_buddy from:", bench_buddy.__file__)

# Drake manipulation helper
from manipulation import ConfigureParser

from pydrake.all import (
    DiagramBuilder,
    AddMultibodyPlantSceneGraph,
    Simulator,
    StartMeshcat,
    MeshcatVisualizer,
    MeshcatVisualizerParams,
    Parser,
    RigidTransform,
    RotationMatrix,
)

# Multibody tree pieces
from pydrake.multibody.tree import (
    FixedOffsetFrame,
    PrismaticJoint,
    SpatialInertia,
    UnitInertia,
    JointActuatorIndex,
)

import numpy as np
from typing import Tuple
from pydrake.all import (
    BasicVector,
    JointActuatorIndex,
    CameraInfo,
    ImageDepth32F,
    RigidTransform,
)
from pydrake.perception import BaseField, Fields, PointCloud


from pydrake.multibody.plant import CoulombFriction

# Geometry + cameras
from pydrake.geometry import (
    Box,  
    ClippingRange,
    ColorRenderCamera,
    DepthRenderCamera,
    DepthRange,
    MakeRenderEngineVtk,
    RenderCameraCore,
    RenderEngineVtkParams,
    Rgba,
)

from pydrake.systems.sensors import CameraInfo, RgbdSensor

from pydrake.all import Meshcat

# Systems framework / utilities
from pydrake.systems.framework import LeafSystem, BasicVector
from pydrake.systems.primitives import Adder, LogVectorOutput

# Global constants used by the builder
SUPPORT_SIZE = [0.08, 0.10, 0.08]  # x, y, z (meters)
SUPPORT_MASS = 5.0
BAR_RADIUS = 0.014
GRAVITY = [0.0, 0.0, -9.81]
PR2_BASE_DEFAULT_OFFSET_X = 1.0  # relative x-offset from the bar center (m)

# Resolve asset paths relative to common launch locations.
if "__file__" in globals():
    _root = Path(__file__).resolve().parent
else:
    _root = Path.cwd()
_asset_roots = [
    _root,
    _root / "Bench-Buddy",
    _root.parent,
    _root.parent / "Bench-Buddy",
]
ASSETS_DIR = None
for candidate_root in _asset_roots:
    candidate_assets = candidate_root / "assets"
    if candidate_assets.exists():
        ASSETS_DIR = candidate_assets
        break
if ASSETS_DIR is None:
    raise FileNotFoundError(
        f"Could not locate Bench-Buddy/assets starting from {_root}."
    )
RACK_SDF_PATH = ASSETS_DIR / "rack.sdf"
BAR_SDF_PATH = ASSETS_DIR / "bar.sdf"
LEFT_HAND_SDF_PATH = ASSETS_DIR / "left_hand.sdf"
RIGHT_HAND_SDF_PATH = ASSETS_DIR / "right_hand.sdf"


QL_PARK  = [0.0,  0.9, 0.0, -1.3, 0.0, 0.5,  np.pi/2.0]
# QL_GRASP = [0.0, 0.3, 0.35, -0.9, 0.0, -0.1, np.pi/2]

QR_PARK  = [0.0,  0.9, 0.0, -1.3, 0.0, 0.5, -np.pi/2.0]
# QR_GRASP = [0.0, 0.3, -0.35, -0.9, 0.0, -0.1, -np.pi/2]


PR2_LEFT_ARM_JOINTS = [
    "l_shoulder_pan_joint",
    "l_shoulder_lift_joint",
    "l_upper_arm_roll_joint",
    "l_elbow_flex_joint",
    "l_forearm_roll_joint",
    "l_wrist_flex_joint",
    "l_wrist_roll_joint",
]


PR2_RIGHT_ARM_JOINTS = [
    "r_shoulder_pan_joint",
    "r_shoulder_lift_joint",    
    "r_upper_arm_roll_joint",
    "r_elbow_flex_joint",
    "r_forearm_roll_joint",
    "r_wrist_flex_joint",
    "r_wrist_roll_joint",
]

# goal pose for the PR2 at the end: [-0.15, 0.0, 0.75]

# # Your assets directory
# _ASSETS_DIR = Path(bench_buddy.__file__).resolve().parent.parent / "assets"

# # Global constants used by the builder
# SUPPORT_SIZE = [0.08, 0.10, 0.08]  # x, y, z (meters)
# SUPPORT_MASS = 5.0
# BAR_RADIUS = 0.014
# GRAVITY = [0.0, 0.0, -9.81]
# PR2_BASE_DEFAULT_OFFSET_X = 1.0  # relative x-offset from the bar center (m)

# RACK_SDF_PATH = str(_ASSETS_DIR / "rack.sdf")
# BAR_SDF_PATH = str(_ASSETS_DIR / "bar.sdf")
# LEFT_HAND_SDF_PATH = str(_ASSETS_DIR / "left_hand.sdf")
# RIGHT_HAND_SDF_PATH = str(_ASSETS_DIR / "right_hand.sdf")

# QL_PARK  = [0.0,  0.9, 0.0, -1.3, 0.0, 0.5,  np.pi/2.0]
# # QL_GRASP = [0.0, 0.3, 0.35, -0.9, 0.0, -0.1, np.pi/2]

# QR_PARK  = [0.0,  0.9, 0.0, -1.3, 0.0, 0.5, -np.pi/2.0]
# # QR_GRASP = [0.0, 0.3, -0.35, -0.9, 0.0, -0.1, -np.pi/2]


# # goal pose for the PR2 at the end: [-0.15, 0.0, 0.75]



Loaded bench_buddy from: /workspaces/6.4210/Bench-Buddy/bench_buddy/__init__.py


SyntaxError: invalid syntax (bar_rgbd_sensor.py, line 697)

# Bar sensor & PD controllers

- BarHeightSensor is my little tap into the plant for the bar’s true z and ż.
- BarTrackingPDController chews on that z measurement and pushes the two support actuators along the scripted lift profile.
- Pr2ArmPDController handles all the arm/base joints with a park→grasp→lift joint-space PD routine (optionally nudged by the sensed bar pose).
- HeadCameraPointCloud turns the PR2 head RGB‑D depth image plus camera pose into a world-frame point cloud that represents what the camera “sees” of the bar.
- BarPoseFromICP aligns a template bar point set with that observed cloud via ICP to recover the bar’s estimated pose in world coordinates.
- BarRgbdSensor is the LeafSystem wrapper that chains the head camera point cloud builder with the ICP estimator to publish a perception-based bar pose that can replace the ground-truth BarHeightSensor output in the PD controllers.

In [None]:
class BarHeightSensor(LeafSystem):
    """
    Tiny tap into the plant so I can just read the bar's true z and ż.
    """
    def __init__(self, plant, bar_body):
        super().__init__()
        self._plant = plant
        self._bar_body = bar_body
        self._plant_context = plant.CreateDefaultContext()

        self._nq = plant.num_positions()
        self._nv = plant.num_velocities()

        # I only need the full state to recompute the pose/velocity on demand.
        self.DeclareVectorInputPort("x", BasicVector(self._nq + self._nv))
        self.DeclareVectorOutputPort("z_and_zdot",
                                     BasicVector(2),
                                     self._calc_output)

    def _calc_output(self, context, output):
        x = self.get_input_port(0).Eval(context)
        q = x[:self._nq]
        v = x[self._nq:]

        self._plant.SetPositions(self._plant_context, q)
        self._plant.SetVelocities(self._plant_context, v)

        X_WB = self._plant.EvalBodyPoseInWorld(self._plant_context, self._bar_body)
        z_bar = X_WB.translation()[2]

        V_WB = self._plant.EvalBodySpatialVelocityInWorld(self._plant_context, self._bar_body)
        zdot_bar = V_WB.translational()[2]

        output.SetFromVector([z_bar, zdot_bar])

class BarTrackingPDController(LeafSystem):
    """
    Straightforward PD loop to keep the two support actuators following z_ref(t).
    """
    def __init__(self, plant, left_joint, right_joint, z_ref_fn,
                 kp=4000.0, kd=800.0):
        super().__init__()
        self._plant = plant
        self._left_joint = left_joint
        self._right_joint = right_joint
        self._z_ref_fn = z_ref_fn
        self._kp = kp
        self._kd = kd

        self._na = plant.num_actuators()

        # Map joint -> actuator index robustly by matching actuator.joint().name().
        self._left_actuator_index = None
        self._right_actuator_index = None
        for i in range(self._na):
            a = plant.get_joint_actuator(JointActuatorIndex(i))
            jname = a.joint().name()
            if jname == self._left_joint.name():
                self._left_actuator_index = i
            elif jname == self._right_joint.name():
                self._right_actuator_index = i

        assert self._left_actuator_index is not None, "Left support actuator not found"
        assert self._right_actuator_index is not None, "Right support actuator not found"

        # Input is [z_bar, zdot_bar]; output is the plant-sized torque vector.
        self.DeclareVectorInputPort("bar_z", BasicVector(2))
        self.DeclareVectorOutputPort("u",
                                     BasicVector(self._na),
                                     self._calc_output)

    def _calc_output(self, context, output):
        t = context.get_time()
        z_ref, zdot_ref = self._z_ref_fn(t)

        z_bar, zdot_bar = self.get_input_port(0).Eval(context)

        e = z_ref - z_bar
        edot = zdot_ref - zdot_bar

        F_des = self._kp * e + self._kd * edot

        # Split desired vertical force between left/right supports.
        tauL = 0.5 * F_des
        tauR = 0.5 * F_des

        u = np.zeros(self._na)
        u[self._left_actuator_index] = tauL
        u[self._right_actuator_index] = tauR

        output.SetFromVector(u)

class Pr2ArmPDController(LeafSystem):
    """
    Joint-space PD routine for a named list of PR2 joints (arm, base, whatever).
    """

    def __init__(
        self,
        plant,
        joint_names,
        t_grab,
        q_park,
        q_pre,
        q_grasp,
        kp=50.0,
        kd=10.0,
        z_start=None,
        clearance=0.05,
        follow_joint_name=None,
        k_z_to_joint=-3.5,
        q_lift=None,
        lift_duration=2.0,
        gripper_joint_name=None,
        gripper_open_angle=3.0,
        gripper_closed_angle=0.02,
        gripper_close_time=None,
        gripper_kp=150.0,
        gripper_kd=5.0,
        approach_duration=2.0,
        grasp_hold_duration=2.0,
    ):
        super().__init__()
        self._plant = plant
        self._nq = plant.num_positions()
        self._nv = plant.num_velocities()
        self._na = plant.num_actuators()

        self._q_park = np.array(q_park, dtype=float)
        self._q_pre = np.array(q_pre, dtype=float)
        self._q_grasp = np.array(q_grasp, dtype=float)
        q_lift_vec = q_lift if q_lift is not None else q_grasp
        self._q_lift = np.array(q_lift_vec, dtype=float)

        self._joint_q_indices = []
        self._joint_v_indices = []
        self._actuator_indices = []
        for name in joint_names:
            joint = plant.GetJointByName(name)
            self._joint_q_indices.append(joint.position_start())
            self._joint_v_indices.append(joint.velocity_start())
            actuator_index = None
            for i in range(self._na):
                actuator = plant.get_joint_actuator(JointActuatorIndex(i))
                if actuator.joint().name() == name:
                    actuator_index = i
                    break
            assert actuator_index is not None, f"No actuator for joint {name}"
            self._actuator_indices.append(actuator_index)

        self._t_grab = t_grab
        self._kp = kp
        self._kd = kd
        self._approach_duration = max(0.0, approach_duration)
        self._grasp_hold_duration = max(0.0, grasp_hold_duration)
        self._lift_duration = max(0.0, lift_duration)
        self._t_pick = max(0.0, self._t_grab - self._approach_duration)
        self._t_reach = self._t_pick + self._approach_duration
        self._t_hold_grasp = self._t_reach + self._grasp_hold_duration
        self._t_lift_done = (
            self._t_hold_grasp + self._lift_duration
            if self._lift_duration > 0.0
            else self._t_hold_grasp
        )

        self.DeclareVectorInputPort("x", BasicVector(self._nq + self._nv))

        self._z_start = z_start
        self._clearance = clearance
        self._k_z_to_joint = k_z_to_joint

        self._follow_joint_name = follow_joint_name
        if (
            follow_joint_name is not None
            and follow_joint_name in joint_names
            and z_start is not None
        ):
            self._follow_idx = joint_names.index(follow_joint_name)
            self._use_bar_follow = True
        else:
            self._follow_idx = None
            self._use_bar_follow = False

        if self._use_bar_follow:
            self.DeclareVectorInputPort("bar_xyz", BasicVector(3))

        self._gripper_open_angle = gripper_open_angle
        self._gripper_closed_angle = gripper_closed_angle
        self._gripper_close_time = (
            gripper_close_time if gripper_close_time is not None else self._t_grab
        )
        self._gripper_kp = gripper_kp
        self._gripper_kd = gripper_kd
        self._gripper_q_index = None
        self._gripper_v_index = None
        self._gripper_u_index = None
        if gripper_joint_name is not None:
            joint = plant.GetJointByName(gripper_joint_name)
            self._gripper_q_index = joint.position_start()
            self._gripper_v_index = joint.velocity_start()
            for i in range(self._na):
                actuator = plant.get_joint_actuator(JointActuatorIndex(i))
                if actuator.joint().name() == gripper_joint_name:
                    self._gripper_u_index = i
                    break
            assert self._gripper_u_index is not None, (
                f"No actuator for joint {gripper_joint_name}"
            )

        self.DeclareVectorOutputPort(
            "u", BasicVector(self._na), self._calc_output
        )

    def compute_phase_joint_targets(self, t):
        if t < self._t_pick:
            q_des = np.array(self._q_park, dtype=float)
        elif t < self._t_reach:
            alpha = (t - self._t_pick) / (self._t_reach - self._t_pick)
            q_des = (1.0 - alpha) * self._q_park + alpha * self._q_grasp
        elif t < self._t_hold_grasp:
            q_des = np.array(self._q_grasp, dtype=float)
        elif self._lift_duration > 0.0 and t < self._t_lift_done:
            beta = (t - self._t_hold_grasp) / self._lift_duration
            beta = np.clip(beta, 0.0, 1.0)
            q_des = (1.0 - beta) * self._q_grasp + beta * self._q_lift
        else:
            q_des = np.array(self._q_lift, dtype=float)

        qd_des = np.zeros_like(q_des)
        return q_des, qd_des

    def _calc_output(self, context, output):
        t = context.get_time()
        x = self.get_input_port(0).Eval(context)
        q = x[: self._nq]
        v = x[self._nq :]
        u = np.zeros(self._na)
        for k, (q_idx, v_idx, u_idx) in enumerate(
                zip(self._joint_q_indices, self._joint_v_indices, self._actuator_indices)):
            e = q_des[k] - q[q_idx]
            edot = qd_des[k] - v[v_idx]
            u[u_idx] = self._kp * e + self._kd * edot


        q_des, qd_des = self.compute_phase_joint_targets(t)

        if self._use_bar_follow:
            bar_vec = np.asarray(self.get_input_port(1).Eval(context)).flatten()
            if bar_vec.size >= 3:
                z_bar = float(bar_vec[2])
            elif bar_vec.size >= 1:
                z_bar = float(bar_vec[0])
            else:
                z_bar = self._z_start
            dz = z_bar - self._z_start - self._clearance
            base_angle = self._q_park[self._follow_idx]
            q_des[self._follow_idx] = base_angle + self._k_z_to_joint * dz
            e = q_des[k] - q[q_idx]
            edot = qd_des[k] - v[v_idx]
            u[u_idx] = self._kp * e + self._kd * edot

        if self._gripper_u_index is not None:
            qg = q[self._gripper_q_index]
            vg = v[self._gripper_v_index]
            target = (
                self._gripper_closed_angle
                if t >= self._gripper_close_time
                else self._gripper_open_angle
            )
            tau = self._gripper_kp * (target - qg) - self._gripper_kd * vg
            u[self._gripper_u_index] = tau

        output.SetFromVector(u)

    def _desired_joints(self, t):
        """
        Piecewise trajectory:
          - t < t1:     park
          - t1..t2:     park -> pre-grasp
          - t2..t3:     pre-grasp -> grasp
          - t >= t3:    hold grasp
        """
        q_park = self._q_park   # assume you stored q_park in __init__
        q_des = np.array(q_park, dtype=float)
        qd_des = np.zeros_like(q_des)
        return q_des, qd_des

    def _calc_output(self, context, output):
        t = context.get_time()
        x = self.get_input_port(0).Eval(context)  # full state
        q = x[:self._nq]
        v = x[self._nq:]

        q_des, qd_des = self._desired_joints(t)

        if self._use_bar_follow:
            bar_vec = np.asarray(self.get_input_port(1).Eval(context)).flatten()
            if bar_vec.size >= 3:
                z_bar = float(bar_vec[2])
            elif bar_vec.size >= 1:
                z_bar = float(bar_vec[0])
            else:
                z_bar = self._z_start
            dz = z_bar - self._z_start - self._clearance
            base_angle = self._q_park[self._follow_idx]
            q_des[self._follow_idx] = base_angle + self._k_z_to_joint * dz

        u = np.zeros(self._na)
        for k, (q_idx, v_idx, u_idx) in enumerate(
            zip(self._joint_q_indices, self._joint_v_indices, self._actuator_indices)
        ):
            e    = q_des[k] - q[q_idx]
            edot = qd_des[k] - v[v_idx]
            u[u_idx] = self._kp * e + self._kd * edot

        output.SetFromVector(u)

class HeadCameraPointCloud:
    """
    Projects the head depth image into a cleaned-up world point cloud.
    """

    def __init__(
        self,
        camera_info: CameraInfo,
        depth_limits: Tuple[float, float] = (0.2, 3.5),
        fields: Fields = Fields(BaseField.kXYZs),
    ) -> None:
        self._info = camera_info
        self._min_depth, self._max_depth = depth_limits
        if self._min_depth <= 0.0 or self._max_depth <= self._min_depth:
            raise ValueError("Invalid depth limits for HeadCameraPointCloud.")
        self._fields = fields

        width = int(camera_info.width())
        height = int(camera_info.height())
        u_coords = np.arange(width, dtype=np.float32)
        v_coords = np.arange(height, dtype=np.float32)
        self._uu, self._vv = np.meshgrid(u_coords, v_coords)
        self._fx = float(camera_info.focal_x())
        self._fy = float(camera_info.focal_y())
        self._cx = float(camera_info.center_x())
        self._cy = float(camera_info.center_y())

    def build_cloud(
        self, depth_image: ImageDepth32F, X_WC: RigidTransform
    ) -> PointCloud:
        width = int(depth_image.width())
        height = int(depth_image.height())
        if width != self._uu.shape[1] or height != self._uu.shape[0]:
            raise ValueError("Depth image size does not match CameraInfo.")

        depth = np.array(depth_image.data[:, :, 0], copy=False, dtype=np.float32)
        finite_mask = np.isfinite(depth)
        in_range = (depth >= self._min_depth) & (depth <= self._max_depth)
        mask = finite_mask & in_range
        if not np.any(mask):
            return PointCloud(0, self._fields)

        z = depth[mask]
        u = self._uu[mask]
        v = self._vv[mask]
        x = (u - self._cx) / self._fx * z
        y = (v - self._cy) / self._fy * z

        pts_C = np.vstack((x, y, z))
        R_WC = X_WC.rotation().matrix()
        t_WC = X_WC.translation().reshape(3, 1)
        pts_W = (R_WC @ pts_C) + t_WC

        cloud = PointCloud(pts_W.shape[1], self._fields)
        cloud.mutable_xyzs()[:] = pts_W
        return cloud

# Helper functions
- sample_bar_template_points creates a single point cloud template for the bar.
- SimulationParameters is a dataclass to help execute the sims

In [None]:

class BarRgbdSensor(LeafSystem):
    """Runs the depth -> point cloud -> ICP pipeline and spits out the bar pose."""

    def __init__(
        self,
        camera_info: CameraInfo,
        template_points_B: np.ndarray,
        depth_limits: Tuple[float, float] = (0.2, 3.5),
    ):
        super().__init__()
        self._cloud_builder = HeadCameraPointCloud(
            camera_info=camera_info,
            depth_limits=depth_limits,
        )
        self._icp = BarPoseFromICP(template_points_B=template_points_B)

        width = int(camera_info.width())
        height = int(camera_info.height())
        self.DeclareAbstractInputPort(
            "depth_image",
            AbstractValue.Make(ImageDepth32F(width, height)),
        )
        self.DeclareAbstractInputPort(
            "camera_pose_WC", AbstractValue.Make(RigidTransform())
        )

        self.DeclareVectorOutputPort("bar_xyz", BasicVector(3), self._calc_bar_xyz)
        self.DeclareAbstractOutputPort(
            "cloud_W",
            lambda: AbstractValue.Make(
                PointCloud(0, Fields(BaseField.kXYZs))
            ),
            self._calc_cloud_output,
        )

    def _build_cloud(self, context) -> PointCloud:
        depth_image = self.get_input_port(0).Eval(context)
        X_WC = self.get_input_port(1).Eval(context)
        return self._cloud_builder.build_cloud(depth_image, X_WC)

    def _calc_bar_xyz(self, context, output):
        cloud = self._build_cloud(context)
        if cloud.size() == 0:
            output.SetFromVector([math.nan, math.nan, math.nan])
            return
        try:
            X_WB = self._icp.estimate_pose(cloud)
            output.SetFromVector(X_WB.translation())
        except RuntimeError:
            output.SetFromVector([math.nan, math.nan, math.nan])

    def _calc_cloud_output(self, context, output):
        output.set_value(self._build_cloud(context))


class MeshcatPointCloudPublisher(LeafSystem):
    """Publishes a Drake PointCloud to Meshcat for easy visualization."""

    def __init__(
        self,
        meshcat,
        path: str = "perception/bar_cloud",
        publish_period: float = 1.0 / 30.0,
        color: Rgba = Rgba(1.0, 0.2, 0.7, 1.0),
    ):
        super().__init__()
        self._meshcat = meshcat
        self._path = path
        self._color = color

        self.DeclareAbstractInputPort(
            "cloud_W",
            AbstractValue.Make(PointCloud(0, Fields(BaseField.kXYZs))),
        )
        self.DeclarePeriodicPublishEvent(
            period_sec=publish_period,
            offset_sec=0.0,
            publish=self._do_publish,
        )

    def _do_publish(self, context):
        cloud = self.get_input_port(0).Eval(context)
        if cloud.size() == 0:
            self._meshcat.Delete(self._path)
            return
        self._meshcat.SetObject(self._path, cloud, point_size=0.01, rgba=self._color)


class Pr2PickAndLiftTrajectory(LeafSystem):
    """High-level pick-and-lift trajectory generator using an ICP pose."""

    def __init__(
        self,
        plant: MultibodyPlant,
        pr2_model,
        approach_offset: np.ndarray = np.array([0.0, 0.2, 0.1]),
        lift_height: float = 0.35,
        t_grab: float = 5.0,
        approach_duration: float = 2.0,
        lift_duration: float = 2.0,
        hold_duration: float = 2.0,
    ):
        super().__init__()
        self._plant = plant
        self._pr2_model = pr2_model
        self._approach_offset = np.asarray(approach_offset, dtype=float)
        self._lift_height = lift_height
        self._t_grab = t_grab
        self._approach_duration = approach_duration
        self._lift_duration = lift_duration
        self._hold_duration = hold_duration

        self.DeclareAbstractInputPort("body_poses", AbstractValue.Make([]))
        self.DeclareAbstractInputPort("X_WB", AbstractValue.Make(RigidTransform()))

        self._traj_index = self.DeclareAbstractState(
            AbstractValue.Make({"X_WG_traj": None, "wsg_traj": None})
        )

        self.DeclareAbstractOutputPort(
            "X_WG",
            lambda: AbstractValue.Make(RigidTransform()),
            self._calc_X_WG_output,
        )
        self.DeclareVectorOutputPort(
            "wsg_position",
            BasicVector(1),
            self._calc_wsg_output,
        )

    def _plan_if_needed(self, context):
        traj_state = context.get_abstract_state(self._traj_index).get_value()
        if traj_state["X_WG_traj"] is not None:
            return

        X_WB = self.get_input_port(1).Eval(context)
        if not isinstance(X_WB, RigidTransform):
            return

        X_pre = X_WB @ RigidTransform(self._approach_offset)
        X_pick = X_WB
        X_lift = X_pick @ RigidTransform([0.0, 0.0, self._lift_height])

        t0 = max(0.0, self._t_grab - self._approach_duration)
        t1 = self._t_grab
        t2 = t1 + self._lift_duration
        t3 = t2 + self._hold_duration
        times = [0.0, t0, t1, t2, t3]
        poses = [X_pre, X_pre, X_pick, X_lift, X_lift]

        traj_pose = PiecewisePose.MakeLinear(times, poses)
        wsg = PiecewisePolynomial.FirstOrderHold(
            [0.0, t1 - 0.2, t1 + 0.1, t3],
            np.array([[0.09, 0.09, 0.02, 0.02]]),
        )

        traj_state["X_WG_traj"] = traj_pose
        traj_state["wsg_traj"] = wsg

    def _calc_X_WG_output(self, context, output):
        self._plan_if_needed(context)
        traj_state = context.get_abstract_state(self._traj_index).get_value()
        traj = traj_state["X_WG_traj"]
        if traj is None:
            output.set_value(RigidTransform())
            return
        t = context.get_time()
        output.set_value(traj.GetPose(np.clip(t, traj.start_time(), traj.end_time())))

    def _calc_wsg_output(self, context, output):
        self._plan_if_needed(context)
        traj_state = context.get_abstract_state(self._traj_index).get_value()
        wsg = traj_state["wsg_traj"]
        if wsg is None:
            output.SetAtIndex(0, 0.0)
            return
        t = context.get_time()
        output.SetAtIndex(0, float(wsg.value(np.clip(t, wsg.start_time(), wsg.end_time()))))


def add_bar_and_supports(
    plant,
    parser,
    params: SimulationParameters,
    left_xy=(0.0, 0.20),
    right_xy=(0.0, -0.20),
    support_size=(0.08, 0.10, 0.08),
    bar_radius=0.014,
):
    """Add the rack, bar, support lifts, and welded hands."""
    left_xy = np.asarray(left_xy, dtype=float)
    right_xy = np.asarray(right_xy, dtype=float)
    half_height = support_size[2] / 2.0
    support_center_z = params.z_start - bar_radius - half_height

    left_parent_pose = RigidTransform(RotationMatrix.Identity(), [left_xy[0], left_xy[1], 0.0])
    right_parent_pose = RigidTransform(RotationMatrix.Identity(), [right_xy[0], right_xy[1], 0.0])
    left_parent_frame = plant.AddFrame(
        FixedOffsetFrame("left_support_parent", plant.world_frame(), left_parent_pose)
    )
    right_parent_frame = plant.AddFrame(
        FixedOffsetFrame("right_support_parent", plant.world_frame(), right_parent_pose)
    )

    bar_model = parser.AddModels(params.bar_sdf_path)[0]
    bar_body = plant.GetBodyByName("bar", bar_model)
    bar_start_pos = np.array([0.0, 0.0, params.z_start])
    plant.SetDefaultFreeBodyPose(bar_body, RigidTransform(RotationMatrix.Identity(), bar_start_pos))

    support_mass = 5.0
    G_SP = UnitInertia.SolidBox(*support_size)
    M_SP = SpatialInertia(support_mass, [0.0, 0.0, 0.0], G_SP)
    left_model = plant.AddModelInstance("left_support_model")
    right_model = plant.AddModelInstance("right_support_model")
    left_body = plant.AddRigidBody("left_support", left_model, M_SP)
    right_body = plant.AddRigidBody("right_support", right_model, M_SP)

    box = Box(*support_size)
    friction = CoulombFriction(0.9, 0.8)
    plant.RegisterVisualGeometry(left_body, RigidTransform(), box, "left_support_vis", [0.2, 0.2, 1.0, 1.0])
    plant.RegisterCollisionGeometry(left_body, RigidTransform(), box, "left_support_col", friction)
    plant.RegisterVisualGeometry(right_body, RigidTransform(), box, "right_support_vis", [0.2, 1.0, 0.2, 1.0])
    plant.RegisterCollisionGeometry(right_body, RigidTransform(), box, "right_support_col", friction)

    left_joint = plant.AddJoint(
        PrismaticJoint(
            "left_support_z",
            left_parent_frame,
            left_body.body_frame(),
            np.array([0.0, 0.0, 1.0]),
        )
    )
    left_joint.set_default_translation(support_center_z)

    right_joint = plant.AddJoint(
        PrismaticJoint(
            "right_support_z",
            right_parent_frame,
            right_body.body_frame(),
            np.array([0.0, 0.0, 1.0]),
        )
    )
    right_joint.set_default_translation(support_center_z)

    plant.AddJointActuator("left_support_actuator", left_joint)
    plant.AddJointActuator("right_support_actuator", right_joint)

    handles = {
        "left_body": left_body,
        "right_body": right_body,
        "left_joint": left_joint,
        "right_joint": right_joint,
        "left_parent_frame": left_parent_frame,
        "right_parent_frame": right_parent_frame,
        "bar_start_pos": bar_start_pos,
    }
    return bar_body, left_joint, right_joint, handles


def add_pr2_model(plant, parser, bar_start_pos: np.ndarray) -> tuple[int, dict]:
    """Add the PR2 model, park the base, and open both grippers."""
    pr2_url = "package://drake_models/pr2_description/urdf/pr2_simplified.urdf"
    pr2_model = parser.AddModelsFromUrl(pr2_url)[0]

    base_x = plant.GetJointByName("x", pr2_model)
    base_y = plant.GetJointByName("y", pr2_model)
    base_theta = plant.GetJointByName("theta", pr2_model)
    base_x.set_default_translation(bar_start_pos[0] + PR2_BASE_DEFAULT_OFFSET_X)
    base_y.set_default_translation(bar_start_pos[1])
    base_theta.set_default_angle(np.pi)

    finger_open = 3.0
    l_gripper = plant.GetJointByName("l_gripper_l_finger_joint", pr2_model)
    r_gripper = plant.GetJointByName("r_gripper_l_finger_joint", pr2_model)
    l_gripper.set_default_angle(finger_open)
    r_gripper.set_default_angle(finger_open)
    l_wrist_roll = plant.GetJointByName("l_wrist_roll_joint", pr2_model)
    r_wrist_roll = plant.GetJointByName("r_wrist_roll_joint", pr2_model)
    l_wrist_roll.set_default_angle(np.pi / 2.0)
    r_wrist_roll.set_default_angle(-np.pi / 2.0)

    handles = {
        "base_x": base_x,
        "base_y": base_y,
        "base_theta": base_theta,
        "l_gripper_joint": l_gripper,
        "r_gripper_joint": r_gripper,
        "l_wrist_roll_joint": l_wrist_roll,
        "r_wrist_roll_joint": r_wrist_roll,
        "l_gripper_frame": plant.GetFrameByName("l_gripper_tool_frame", pr2_model),
        "r_gripper_frame": plant.GetFrameByName("r_gripper_tool_frame", pr2_model),
    }
    return pr2_model, handles


def sample_bar_template_points(
    radius: float,
    half_length: float,
    axial_samples: int = 50,
    angular_samples: int = 32,
) -> np.ndarray:
    """Generates a cylindrical point cloud template for the bar (used by ICP)."""
    z_vals = np.linspace(-half_length, half_length, axial_samples)
    thetas = np.linspace(0.0, 2.0 * np.pi, angular_samples, endpoint=False)
    pts = []
    for z in z_vals:
        x = radius * np.cos(thetas)
        y = radius * np.sin(thetas)
        z_col = np.full_like(x, z)
        pts.append(np.stack([x, y, z_col], axis=1))
    return np.vstack(pts)


def compute_symmetric_bar_grasp(
    plant,
    p_center_W,
    half_width: float = 0.20,
    allow_base_motion: bool = True,
    left_gripper_frame: str = "l_gripper_tool_frame",
    right_gripper_frame: str = "r_gripper_tool_frame",
    position_tol: float = 2e-3,
):
    """Solve IK for symmetric left/right gripper targets around the bar midpoint."""
    p_center_W = np.array(p_center_W, dtype=float).reshape(3)
    p_left = p_center_W + np.array([0.0, half_width, 0.0])
    p_right = p_center_W + np.array([0.0, -half_width, 0.0])

    q_nominal = plant.GetPositions(plant.CreateDefaultContext())
    base_joint_names = ("x", "y", "theta")

    def _solve_arm(target_point, frame_name):
        context = plant.CreateDefaultContext()
        ik = InverseKinematics(plant, context)
        q = ik.q()
        prog = ik.prog()
        prog.AddQuadraticErrorCost(np.eye(plant.num_positions()), q_nominal, q)

        for name in base_joint_names:
            joint = plant.GetJointByName(name)
            idx = joint.position_start()
            q0 = q_nominal[idx]
            if allow_base_motion:
                slack = 0.75 if name == "theta" else 1.0
                prog.AddBoundingBoxConstraint(q0 - slack, q0 + slack, q[idx])
            else:
                prog.AddBoundingBoxConstraint(q0, q0, q[idx])

        W = plant.world_frame()
        G = plant.GetFrameByName(frame_name)
        lower = target_point - position_tol
        upper = target_point + position_tol
        ik.AddPositionConstraint(
            frameA=W,
            frameB=G,
            p_BQ=np.zeros(3),
            p_AQ_lower=lower,
            p_AQ_upper=upper,
        )

        result = Solve(prog)
        if not result.is_success():
            return None
        return result.GetSolution(q)

    q_left_full = _solve_arm(p_left, left_gripper_frame)
    q_right_full = _solve_arm(p_right, right_gripper_frame)
    if q_left_full is None or q_right_full is None:
        return None, None, None

    def _extract(q_full, joint_names):
        return [float(q_full[plant.GetJointByName(name).position_start()]) for name in joint_names]

    qL = _extract(q_left_full, PR2_LEFT_ARM_JOINTS)
    qR = _extract(q_right_full, PR2_RIGHT_ARM_JOINTS)

    base_pose = {}
    for name in base_joint_names:
        idx = plant.GetJointByName(name).position_start()
        base_pose[name] = float(q_left_full[idx])

    return qL, qR, base_pose


# Bar Catch Simulation

In [None]:
class BarCatchSimulation:
    """
    Builds the full diagram (plant + sensors + controllers) and provides a
    helper `run` method for executing and plotting the result.
    """

    # TO DO what does support size mean?
    SUPPORT_SIZE = np.array([0.08, 0.10, 0.08])
    BAR_RADIUS = 0.014

    def __init__(
        self,
        z_start=0.35,
        z_target=0.85,
        t_hold_bottom=1.0,
        t_move=4.0,
        t_hold_top=1.0,
        bar_sdf_path=str(BAR_SDF_PATH),
        use_rgbd_sensor: bool = True,
    ):
        self.params = SimulationParameters(
            z_start=z_start,
            z_target=z_target,
            t_hold_bottom=t_hold_bottom,
            t_move=t_move,
            t_hold_top=t_hold_top,
            bar_sdf_path=bar_sdf_path,
        )
        self.meshcat = StartMeshcat()
        self.builder = DiagramBuilder()
        self.plant, self.scene_graph = AddMultibodyPlantSceneGraph(
            self.builder, time_step=0.001
        )

        self.scene_graph.AddRenderer(
            "renderer", MakeRenderEngineVtk(RenderEngineVtkParams())
        )
        self.parser = Parser(self.plant)

        self.bar_body, self.left_joint, self.right_joint, support_handles = add_bar_and_supports(
            self.plant, self.parser, self.params
        )
        self.bar_start_pos = support_handles["bar_start_pos"]

        self.pr2_model, self.pr2_handles = add_pr2_model(
            self.plant, self.parser, self.bar_start_pos
        )

        # TO DO figure out the pentrtation allowance meaning
        self.plant.set_penetration_allowance(1e-3)
        self.plant.mutable_gravity_field().set_gravity_vector([0.0, 0.0, -9.81])
        self._configure_reference_trajectory()
        self.plant.Finalize()
        self._add_ground_truth_control()
        if use_rgbd_sensor:
            self._add_rgbd_pipeline()
        else:
            self.bar_pose_logger = None
        self._add_logging_and_visualization()
        self._finalize_simulator()

    # ------------------------------------------------------------------
    # Geometry setup
    # ------------------------------------------------------------------
    # ------------------------------------------------------------------
    # Trajectory + controller wiring
    # ------------------------------------------------------------------

    # ------------------------------------------------------------------
    def _configure_reference_trajectory(self):
        p = self.params
        self.t1 = p.t_hold_bottom
        self.t2 = self.t1 + p.t_move
        self.t3 = self.t2 + p.t_hold_top
        self.t4 = self.t3 + p.t_move

        def z_ref_fn(t: float) -> Tuple[float, float]:
            if t < self.t1:
                z = p.z_start
            elif t < self.t2:
                alpha = (t - self.t1) / (p.t_move)
                z = p.z_start + alpha * (p.z_target - p.z_start)
            elif t < self.t3:
                z = p.z_target
            elif t < self.t4:
                alpha = (t - self.t3) / (p.t_move)
                z = p.z_target + (p.z_start - p.z_target) * alpha
            else:
                z = p.z_start
            return z, 0.0

        self.z_ref_fn = z_ref_fn
        self.t_grab = 0.5 * (self.t3 + self.t4)

    def _add_ground_truth_control(self):
        self.bar_sensor = self.builder.AddSystem(BarHeightSensor(self.plant, self.bar_body))
        self.builder.Connect(
            self.plant.get_state_output_port(), self.bar_sensor.get_input_port(0)
        )

        self.bar_controller = self.builder.AddSystem(
            BarTrackingPDController(
                plant=self.plant,
                left_joint=self.left_joint,
                right_joint=self.right_joint,
                z_ref_fn=self.z_ref_fn,
                kp=4000.0,
                kd=800.0,
            )
        )
        self.builder.Connect(
            self.bar_sensor.get_output_port(0), self.bar_controller.get_input_port(0)
        )

        self.actuation_adder = self.builder.AddSystem(
            Adder(3, self.plant.num_actuators())
        )
        self.builder.Connect(
            self.bar_controller.get_output_port(0),
            self.actuation_adder.get_input_port(0),
        )
        self.builder.Connect(
            self.actuation_adder.get_output_port(0),
            self.plant.get_actuation_input_port(),
        )

    # ------------------------------------------------------------------
    # RGB-D perception stack
    # ------------------------------------------------------------------
    
    def _add_rgbd_pipeline(self):
        cam_info = CameraInfo(width=320, height=240, fov_y=np.deg2rad(60.0))
        camera_pose = RigidTransform(
            RotationMatrix.MakeXRotation(np.pi),  # forward -> world -Z
            [0.5, 0.0, 1.2],
        )

        clipping = ClippingRange(0.1, 5.0)
        core = RenderCameraCore(
            renderer_name="renderer",
            intrinsics=cam_info,
            clipping=clipping,
            X_BS=RigidTransform(),
        )
        color_camera = ColorRenderCamera(core, show_window=False)
        depth_camera = DepthRenderCamera(core, DepthRange(0.1, 5.0))

        self.rgbd_sensor = self.builder.AddSystem(
            RgbdSensor(
                parent_id=self.scene_graph.world_frame_id(),
                X_PB=camera_pose,
                color_camera=color_camera,
                depth_camera=depth_camera,
            )
        )
        self.rgbd_sensor.set_name("overhead_rgbd")
        self.builder.Connect(
            self.scene_graph.get_query_output_port(),
            self.rgbd_sensor.query_object_input_port(),
        )

        template_points = sample_bar_template_points(
            radius=self.BAR_RADIUS, half_length=0.45
        )
        self.bar_rgbd_sensor = self.builder.AddSystem(
            BarRgbdSensor(camera_info=cam_info, template_points_B=template_points)
        )

        cloud_viz = self.builder.AddSystem(
            MeshcatPointCloudPublisher(
                meshcat=self.meshcat,
                path="perception/bar_cloud",
                publish_period=1.0 / 30.0,
                color=Rgba(1.0, 0.2, 0.7, 1.0),
            )
        )
        self.builder.Connect(
            self.bar_rgbd_sensor.GetOutputPort("cloud_W"),
            cloud_viz.get_input_port(0),
        )

        bar_midpoint_port = self.bar_rgbd_sensor.get_output_port(0)

        grasp_qL, grasp_qR, base_pose = compute_symmetric_bar_grasp(
            plant=self.plant,
            p_center_W=self.bar_start_pos,
            half_width=0.02,
            allow_base_motion=True,
        )

        left_ctrl = self.builder.AddSystem(
            Pr2ArmPDController(
                plant=self.plant,
                joint_names=PR2_LEFT_ARM_JOINTS,
                t_grab=self.t_grab,
                q_park=QL_PARK,
                q_pre=QL_PARK,
                q_grasp=grasp_qL,
                q_lift=grasp_qL,
                z_start=self.params.z_start,
                follow_joint_name="l_shoulder_lift_joint",
                gripper_joint_name="l_gripper_l_finger_joint",
                gripper_close_time=self.t_grab,
            )
        )

        right_ctrl = self.builder.AddSystem(
            Pr2ArmPDController(
                plant=self.plant,
                joint_names=PR2_RIGHT_ARM_JOINTS,
                t_grab=self.t_grab,
                q_park=QR_PARK,
                q_pre=QR_PARK,
                q_grasp=grasp_qR,
                q_lift=grasp_qR,
                z_start=self.params.z_start,
                follow_joint_name="r_shoulder_lift_joint",
                gripper_joint_name="r_gripper_l_finger_joint",
                gripper_close_time=self.t_grab,
            )
        )

        self.builder.Connect(
            self.rgbd_sensor.depth_image_32F_output_port(),
            self.bar_rgbd_sensor.get_input_port(0),
        )
        self.builder.Connect(
            self.rgbd_sensor.body_pose_in_world_output_port(),
            self.bar_rgbd_sensor.get_input_port(1),
        )

        self.builder.Connect(
            self.plant.get_state_output_port(), left_ctrl.get_input_port(0)
        )
        if left_ctrl.num_input_ports() > 1:
            self.builder.Connect(bar_midpoint_port, left_ctrl.get_input_port(1))
        self.builder.Connect(
            left_ctrl.get_output_port(0),
            self.actuation_adder.get_input_port(1),
        )

        self.builder.Connect(
            self.plant.get_state_output_port(), right_ctrl.get_input_port(0)
        )
        if right_ctrl.num_input_ports() > 1:
            self.builder.Connect(bar_midpoint_port, right_ctrl.get_input_port(1))
        self.builder.Connect(
            right_ctrl.get_output_port(0),
            self.actuation_adder.get_input_port(2),
        )

        self.bar_pose_logger = LogVectorOutput(
            self.bar_rgbd_sensor.get_output_port(0), self.builder
        )
        self.bar_pose_logger.set_name("bar_pose_rgbd")

    # ------------------------------------------------------------------
    # Logging + viz + simulator
    # ------------------------------------------------------------------
    def _add_logging_and_visualization(self):
        self.bar_z_logger = LogVectorOutput(
            self.bar_sensor.get_output_port(0), self.builder
        )
        self.bar_z_logger.set_name("bar_z_logger")
        MeshcatVisualizer.AddToBuilder(
            self.builder, self.scene_graph, self.meshcat
        )

    def _finalize_simulator(self):
        self.diagram = self.builder.Build()
        self.simulator = Simulator(self.diagram)
        self.simulator.Initialize()
        self.simulator.set_target_realtime_rate(1.0)
        self.context = self.simulator.get_mutable_context()

    def run(self, T_final=10.0, plot=True):
        print(f"Running simulation to t = {T_final} s")
        self.simulator.AdvanceTo(T_final)

        if not plot:
            return self

        log = self.bar_z_logger.FindLog(self.context)
        t_samples = log.sample_times()
        z_samples = log.data()[0, :]

        plt.figure()
        plt.plot(t_samples, z_samples, label="bar z(t) – truth")
        plt.plot(
            t_samples,
            [self.z_ref_fn(t)[0] for t in t_samples],
            "--",
            label="z_ref(t)",
        )
        if hasattr(self, "bar_pose_logger") and self.bar_pose_logger is not None:
            rgbd_log = self.bar_pose_logger.FindLog(self.context)
            rgbd_times = rgbd_log.sample_times()
            rgbd_z = rgbd_log.data()[2, :]
            plt.plot(rgbd_times, rgbd_z, label="RGB-D estimated z")
        plt.axvline(self.t_grab, linestyle=":", label="t_grab")
        plt.xlabel("time [s]")
        plt.ylabel("bar height z [m]")
        plt.legend()
        plt.title("Bar height tracking")
        plt.show()

        return self



def build_bench_buddy_station(meshcat):
    """Quick standalone station build (plant + supports + PR2 + viz)."""
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)

    scene_graph.AddRenderer("renderer", MakeRenderEngineVtk(RenderEngineVtkParams()))

    parser = Parser(plant)
    ConfigureParser(parser)

    params = SimulationParameters(
        z_start=0.35,
        z_target=0.85,
        t_hold_bottom=1.0,
        t_move=4.0,
        t_hold_top=1.0,
        bar_sdf_path=str(BAR_SDF_PATH),
    )

    bar_body, left_joint, right_joint, support_handles = add_bar_and_supports(
        plant, parser, params
    )
    bar_start_pos = support_handles["bar_start_pos"]

    pr2_model, pr2_handles = add_pr2_model(plant, parser, bar_start_pos)

    plant.set_penetration_allowance(1e-3)
    plant.mutable_gravity_field().set_gravity_vector(GRAVITY)
    plant.Finalize()

    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    diagram = builder.Build()

    handles = {
        "bar_body": bar_body,
        "left_support_joint": left_joint,
        "right_support_joint": right_joint,
        **support_handles,
        "pr2_model": pr2_model,
        **pr2_handles,
    }

    return diagram, plant, scene_graph, handles


def run(T_final=10.0, plot=True, **kwargs):
    '''Convenience wrapper: spin up BarCatchSimulation and run it.'''
    sim = BarCatchSimulation(**kwargs)
    return sim.run(T_final=T_final, plot=plot)

def start_meshcat():
    return StartMeshcat()


# Main

In [None]:
meshcat = start_meshcat()
diagram, plant, scene_graph, handles = build_bench_buddy_station(meshcat)
print("Built bench buddy station diagram.")

INFO:drake:Meshcat listening for connections at http://localhost:7000


Built bench buddy station diagram.
