In [None]:
# Suprress warnings
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    # Standard imports
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import numpy as np

    # Drake imports
    from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
    import pydrake
    from pydrake.all import LogOutput, DirectCollocation, DirectTranscription, MathematicalProgram, InputPortSelection

# Imports of other project files
from log_wrapper import LogWrapper
import constants
import finger
import pedestal
from paper import Paper

from pydrake.all import (MultibodyPlant, Parser, DiagramBuilder, Simulator,
                         PlanarSceneGraphVisualizer, SceneGraph, TrajectorySource,
                         SnoptSolver, MultibodyPositionToGeometryPose, PiecewisePolynomial,
                         MathematicalProgram, JacobianWrtVariable, eq, RollPitchYaw, AutoDiffXd)

# Other imports
import importlib

In [None]:
# Matplotlib configuring
plt.style.use(['science', 'no-latex'])
font = {'size'   : 14}
matplotlib.rc('font', **font)

In [None]:
# Meshcat init
proc, zmq_url, web_url = start_zmq_server_as_subprocess()

In [None]:
## Pre-finalize steps
builder = pydrake.systems.framework.DiagramBuilder()

# Add all elements
plant, scene_graph = pydrake.multibody.plant.AddMultibodyPlantSceneGraph(builder, time_step=constants.DT)
pedestal_instance = pedestal.AddPedestal(plant)

# These joint angles start the paper approximately the right spot
paper = Paper(plant, 20, default_joint_angle=[
    np.pi/100,
    np.pi/100,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    -np.pi/10,
    -np.pi/10,
    -np.pi/10,
    -np.pi/10,
    -np.pi/10,
    0,
    0,
    0,
    0,
])
paper.weld_paper_edge(pedestal.PEDESTAL_WIDTH, pedestal.PEDESTAL_HEIGHT)

finger_instance = finger.AddFinger(plant, constants.INIT_Y, constants.INIT_Z)

# Set up logger (needs to happen after all bodies are added)
log_wrapper = LogWrapper(plant.num_bodies())
builder.AddSystem(log_wrapper)

In [None]:
## CHOOSE CONTROL SYSTEM HERE BY UNCOMMENTING
# # PD control: hits too low
# finger_ctrlr = finger.PDFinger(
#     plant,
#     int(finger_instance),
#     [
#         [constants.INIT_Y, constants.INIT_Z],
#         [constants.INIT_Y*1.1, constants.INIT_Z],
#         [constants.INIT_Y*1.1, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.05],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.05],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2],
#     ],
#     tspan_per_segment=1,
#     ky=10,
#     kz=10
# )

# # PD control: hits too high
# finger_ctrlr = finger.PDFinger(
#     plant,
#     int(finger_instance),
#     [
#         [constants.INIT_Y, constants.INIT_Z],
#         [constants.INIT_Y*1.1, constants.INIT_Z],
#         [constants.INIT_Y*1.1, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.08],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.08],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2],
#     ],
#     tspan_per_segment=1,
#     ky=10,
#     kz=10
# )


# Edge feedback
finger_ctrlr = finger.EdgeController(
    plant,
    int(finger_instance),
    int(paper.get_free_edge_instance()),
    K=15,
    F_Nd=0.03,
    d_d=0.04,
    w_l=paper.link_width,
    debug=True,
)

# # Optimization controller
# finger_ctrlr = finger.OptimizationController(
#     plant,
#     paper,
#     int(finger_instance),
#     paper.get_free_edge_instance()
# )

In [None]:
plant.Finalize()

## Post finalize steps
# paper.init_ctrlrs(builder)

# Conect finger controller
builder.AddSystem(finger_ctrlr)
builder.Connect(finger_ctrlr.get_output_port(), plant.get_actuation_input_port(finger_instance))
builder.Connect(plant.get_body_poses_output_port(), finger_ctrlr.get_input_port(0))
builder.Connect(plant.get_body_spatial_velocities_output_port(), finger_ctrlr.get_input_port(1))

# Add logger
builder.Connect(plant.get_body_poses_output_port(), log_wrapper.get_input_port(0))
builder.Connect(plant.get_body_spatial_velocities_output_port(), log_wrapper.get_input_port(1))

# Visualization and logging
logger = LogOutput(log_wrapper.get_output_port(), builder)
vis = pydrake.systems.meshcat_visualizer.ConnectMeshcatVisualizer(builder, scene_graph)

# Build diagram and do actions requiring 
diagram = builder.Build()
diagram_context = diagram.CreateDefaultContext()
# paper.connect_ctrlrs(diagram, diagram_context)

In [None]:
if type(finger_ctrlr) is finger.OptimizationController:
    finger_ctrlr.optimize(plant.GetPositions(diagram_context))

In [None]:
# Finalize simulation and visualization
simulator = pydrake.systems.analysis.Simulator(diagram, diagram_context)
simulator.Initialize()
vis.start_recording()
simulator.AdvanceTo(constants.TSPAN)
vis.stop_recording()
vis.publish_recording()

In [None]:
# Plot manipulator position vs. trajectory
if type(finger_ctrlr) is finger.PDFinger:
    plt.figure(figsize=(2*3,2*2))
    plt.plot(logger.data()[12*int(finger_instance)+1],
             logger.data()[12*int(finger_instance)+2],
             label='Manipulator position')
    plt.plot(finger_ctrlr.ys, finger_ctrlr.zs, label='Trajectory')
    plt.xlabel("$y$ position")
    plt.ylabel("$z$ position")
    plt.legend()
    plt.show()

In [None]:
# Generate plots for paper
nb = plant.num_bodies()
paper_idxs = [int(i) for i in paper.link_instances]
y_traces = []
z_traces = []
for b in paper_idxs:
    y_traces.append(logger.data()[12*b+1])
    z_traces.append(logger.data()[12*b+2])
y_traces = np.array(y_traces)
z_traces = np.array(z_traces)

times = np.arange(0,constants.TSPAN,constants.TSPAN/10)
cmap = cm.get_cmap("viridis_r")
plt.figure(figsize=(2*3,2*2))
if type(finger_ctrlr) is finger.PDFinger:
    plt.plot(finger_ctrlr.ys, finger_ctrlr.zs, '--k', zorder=-1)
for t in times:
    c = cmap(t/constants.TSPAN)
    idx = int(t/constants.DT)
    plt.plot(y_traces[:-1,idx], z_traces[:-1,idx], color=c)
    plt.scatter(logger.data()[12*int(finger_instance)+1,idx], 
                logger.data()[12*int(finger_instance)+2,idx],
                color=c, s=300, zorder=1)

xlim = plt.xlim()
ylim = plt.ylim()
plt.scatter([xlim[0]-50, xlim[0]-50], [ylim[0]-50, ylim[0]-50], c=[0, constants.TSPAN], cmap=cmap)
plt.xlim(xlim)
plt.ylim(ylim)
cb = plt.colorbar()
cb.set_label("Time")
plt.xlabel("$x$ position")
plt.ylabel("$z$ position")
plt.show()