In [None]:
# Suprress warnings
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    # Standard imports
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import numpy as np

    # Drake imports
    from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
    import pydrake.all
    from pydrake.all import (
        RigidTransform, RotationMatrix, RollPitchYaw, RevoluteSpring, LogOutput, SignalLogger, ForceElement, 
        DoorHinge, DoorHingeConfig, PidController, FixedInputPortValue, SpatialVelocity
    )

# Imports of other project files
import body_pose_wrapper
import constants
import finger
import pedestal
from paper import Paper

# Other imports
import importlib

In [None]:
# Meshcat init
proc, zmq_url, web_url = start_zmq_server_as_subprocess()

In [None]:
## Pre-finalize steps
builder = pydrake.systems.framework.DiagramBuilder()

# Add all elements
plant, scene_graph = pydrake.multibody.plant.AddMultibodyPlantSceneGraph(builder, time_step=constants.DT)
pedestal_instance = pedestal.AddPedestal(plant)

paper = Paper(plant, 20)
paper.weld_paper_edge(pedestal.PEDESTAL_WIDTH, pedestal.PEDESTAL_HEIGHT)
pose_wrapper = body_pose_wrapper.BodyPoseWrapper(24)
builder.AddSystem(pose_wrapper)

finger_instance = finger.AddFinger(plant, constants.INIT_X, constants.INIT_Z)

In [None]:
## CHOOSE CONTROL SYSTEM HERE BY UNCOMMENTING
# PD control: hits too low
finger_ctrlr = finger.PDFinger(
    plant,
    int(finger_instance),
    [
        [constants.INIT_X, constants.INIT_Z],
        [constants.INIT_X*1.1, constants.INIT_Z],
        [constants.INIT_X*1.1, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.05],
        [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.05],
        [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2],
    ],
)

# PD control: hits too high
# finger_ctrlr = finger.PDFinger(
#     plant,
#     int(finger_instance),
#     [
#         [constants.INIT_X, constants.INIT_Z],
#         [constants.INIT_X*1.1, constants.INIT_Z],
#         [constants.INIT_X*1.1, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.06],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.06],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2],
#     ],
# )

In [None]:
plant.Finalize()

## Post finalize steps
paper.init_ctrlrs(builder)

# Conect finger controller
builder.AddSystem(finger_ctrlr)
builder.Connect(finger_ctrlr.get_output_port(), plant.get_actuation_input_port(finger_instance))
builder.Connect(plant.get_body_poses_output_port(), finger_ctrlr.get_input_port(0))
builder.Connect(plant.get_body_spatial_velocities_output_port(), finger_ctrlr.get_input_port(1))
builder.Connect(plant.get_body_poses_output_port(), pose_wrapper.get_input_port())

# Visualization and logging
poses_logger = LogOutput(pose_wrapper.get_output_port(), builder)
vis = pydrake.systems.meshcat_visualizer.ConnectMeshcatVisualizer(builder, scene_graph)

# Build diagram and do actions requiring 
diagram = builder.Build()
diagram_context = diagram.CreateDefaultContext()
paper.connect_ctrlrs(diagram, diagram_context)

In [None]:
# Finalize simulation and visualization
simulator = pydrake.systems.analysis.Simulator(diagram, diagram_context)
simulator.Initialize()
vis.start_recording()
simulator.AdvanceTo(constants.TSPAN)
vis.stop_recording()
vis.publish_recording()

In [None]:
# Plot manipulator position vs. trajectory
plt.figure()
plt.plot(poses_logger.data()[6*int(finger_instance)],
         poses_logger.data()[6*int(finger_instance)+2],
         label='Manipulator position')
plt.plot(finger_ctrlr.xs, finger_ctrlr.zs, label='Trajectory')
plt.xlabel("$x$ position")
plt.ylabel("$z$ position")
plt.legend()
plt.show()

In [None]:
# Generate plots for paper
nb = plant.num_bodies()
paper_idxs = [int(i) for i in paper.paper_instances]
x_traces = []
z_traces = []
for b in paper_idxs:
    x_traces.append(poses_logger.data()[6*b])
    z_traces.append(poses_logger.data()[6*b+2])
x_traces = np.array(x_traces)
z_traces = np.array(z_traces)

times = np.arange(2,constants.TSPAN+1,2)
cmap = cm.get_cmap("viridis_r")
plt.figure(figsize=(2*3,2*2))
plt.plot(finger_ctrlr.xs, finger_ctrlr.zs, '--k', zorder=-1)
for t in times:
    c = cmap(t/constants.TSPAN)
    idx = int(t/constants.DT)
    plt.plot(x_traces[:-1,idx], z_traces[:-1,idx], color=c)
    plt.scatter(poses_logger.data()[6*int(finger_instance),idx], 
                poses_logger.data()[6*int(finger_instance)+2,idx],
                color=c, s=300, zorder=1)

xlim = plt.xlim()
ylim = plt.ylim()
plt.scatter([xlim[0]-50, xlim[0]-50], [ylim[0]-50, ylim[0]-50], c=[0, constants.TSPAN], cmap=cmap)
plt.xlim(xlim)
plt.ylim(ylim)
cb = plt.colorbar()
cb.set_label("Time")
plt.xlabel("$x$ position")
plt.ylabel("$z$ position")
plt.show()