In [None]:
import numpy as np

import pydrake
from pydrake.all import (
    PiecewisePolynomial, RigidTransform, RotationMatrix
)

from src.drake_helpers import (
    BuildAndSimulateTrajectory, 
    setup_manipulation_station,
    visualize_transform, 
)
from src.ik import create_q_knots
from src.throw import plan_throw

In [None]:
# start a single meshcat server instance to use for remainder of this notebook.
from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
proc, zmq_url, web_url = start_zmq_server_as_subprocess(server_args=[])
print(web_url)

In [None]:
# Get initial pose of the gripper by using default context of manip station.
P_WORLD_TARGET = np.array([-4, 0, 2])
GRIPPER_TO_OBJECT_DIST = 0.13 # meters
T_world_objectInitial = RigidTransform(
    p=[-.1, -.69, 1.04998503e-01],
    R=RotationMatrix.MakeZRotation(np.pi/2.0)
)
T_world_gripperObject = RigidTransform(
    p=T_world_objectInitial.translation() + np.array([0, 0, GRIPPER_TO_OBJECT_DIST]),
    R=RotationMatrix.MakeXRotation(-np.pi/2.0)
)
T_world_robotInitial, meshcat = setup_manipulation_station(T_world_objectInitial, zmq_url)

#object frame viz
visualize_transform(meshcat, "T_world_obj0", T_world_objectInitial)

# plan the throw
t_lst, pose_lst, g_traj = plan_throw(
    T_world_robotInitial=T_world_robotInitial,
    T_world_gripperObject=T_world_gripperObject,
    p_world_target=P_WORLD_TARGET,
    gripper_to_object_dist=GRIPPER_TO_OBJECT_DIST,
    meshcat=meshcat
)

# viz the full trajectory if desired
# for t, pose in enumerate(pose_lst):
#    visualize_transform(meshcat, str(t), pose)

In [None]:
# turn trajectory into joint space
q_knots = np.array(create_q_knots(pose_lst))
q_traj = PiecewisePolynomial.CubicShapePreserving(t_lst, q_knots[:, 0:7].T)

In [None]:
# do the thing
simulator, station_plant, meshcat = BuildAndSimulateTrajectory(q_traj, g_traj, T_world_objectInitial, None, zmq_url)
visualize_transform(
    meshcat,
    "TARGET",
    RigidTransform(RotationMatrix.MakeZRotation(0), P_WORLD_TARGET),
    prefix='',
    length=0.3,
    radius=0.02
)

meshcat.start_recording()
print(f"Running for {q_traj.end_time()} seconds")
simulator.AdvanceTo(q_traj.end_time())
meshcat.stop_recording()
meshcat.publish_recording()