In [1]:
import numpy as np
from pydrake.all import (
    RigidTransform,
    RotationMatrix,
    StartMeshcat,
    RandomGenerator,
)
import numpy as np

import sponana.utils
from sponana.planner.rrt import SpotProblem, rrt_planning
import sponana.sim

In [2]:
# Start the visualizer.
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7001


In [None]:
# Clean up the Meshcat instance.
meshcat.Delete()
meshcat.DeleteAddedControls()
rng = np.random.default_rng(145)  # this is for python
generator = RandomGenerator(rng.integers(0, 1000))  # this is for c++

add_spot = True
# simulation_time = -1  # run indefinitely until ESC is pressed
simulation_time = -1
debug = True
add_fixed_cameras = False
enable_arm_ik = True  # turn this off if you find the arm too annoying
use_teleop = False

simulator, diagram = sponana.sim.clutter_gen(
    meshcat,
    rng,
    debug=debug,
    simulation_time=simulation_time,
    add_spot=add_spot,
    add_fixed_cameras=add_fixed_cameras,
    enable_arm_ik=enable_arm_ik,
    use_teleop=use_teleop,
)

In [4]:
context = simulator.get_mutable_context()
station = diagram.GetSubsystemByName("station")
# context = station.GetMyContextFromRoot(context)
scene_graph = station.GetSubsystemByName("scene_graph")
plant = station.GetSubsystemByName("plant")

In [6]:
def check_collision_pair_name(pair_name0, pair_name1):
    pair0_is_spot = pair_name0.startswith("spot")
    pair1_is_spot = pair_name1.startswith("spot")
    return pair0_is_spot != pair1_is_spot

In [7]:
def in_collision(plant, scene_graph, context, print_collisions=True):
    plant_context = plant.GetMyContextFromRoot(context)
    sg_context = scene_graph.GetMyContextFromRoot(context)
    query_object = plant.get_geometry_query_input_port().Eval(plant_context)
    inspector = scene_graph.get_query_output_port().Eval(sg_context).inspector()
    pairs = query_object.ComputePointPairPenetration()

    for pair in pairs:
        pair_name0 = inspector.GetName(pair.id_A)
        pair_name1 = inspector.GetName(pair.id_B)
        if check_collision_pair_name(pair_name0, pair_name1):
            if print_collisions:
                print(pair_name0, pair_name1)
            return True
    return False

In [8]:
collision_check = in_collision(plant, scene_graph, context, print_collisions=True)
print("Check Collision:", collision_check)

Check Collision: False


In [9]:
from sponana.controller.inverse_kinematics import q_nominal_arm


def move_spot(q_spot):
    plant_context = plant.GetMyContextFromRoot(context)
    spot_model_instance = plant.GetModelInstanceByName("spot")
    spot_init_positions = plant.GetPositions(plant_context, spot_model_instance)

    if len(q_spot) == 3:
        q_desired = np.concatenate((q_spot, q_nominal_arm))
    else:
        q_desired = q_spot

    plant.SetPositions(plant_context, plant.GetModelInstanceByName("spot"), q_desired)
    return spot_init_positions

In [31]:
def ExistsCollision(q_spot):
    spot_init_positions = move_spot(q_spot)
    collision_check = in_collision(plant, scene_graph, context, print_collisions=False)

    # print("Check Collision:", collision_check)
    if collision_check == True:
        previous_pos = move_spot(spot_init_positions)
    return collision_check

In [14]:
base_pose = np.array([1.00000000e00, 1.50392176e-12, 3.15001955e00])
q_start = base_pose
# q_goal = np.array([1.00000000e+00, 1.50392176e-12 -1, 3.15001955e+00])
# q_goal = np.array([1.00000000e+00, -0.5, 3.15001955e+00])
q_goal = np.array([-2, -2, 3.15001955e00])
# q_goal = np.array([0.20894849, -0.47792893, 0.2475])

In [32]:
spot_problem = SpotProblem(
    q_start=q_start, q_goal=q_goal, collision_checker=ExistsCollision
)
path = rrt_planning(spot_problem, 1000, 0.05)

In [35]:
print(len(path))

326


In [37]:
# visualize RRT output
from manipulation.meshcat_utils import AddMeshcatTriad


def visualize_path(path):
    for i, pose in enumerate(path):
        pose = RigidTransform(RotationMatrix.MakeZRotation(pose[2]), [*pose[:2], 0.0])
        opacity = 0.2
        AddMeshcatTriad(meshcat, f"trajectory_{i}", X_PT=pose, opacity=opacity)


visualize_path(path)