In [None]:
# Suprress warnings
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    # Standard imports
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import numpy as np

    # Drake imports
    from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
    import pydrake
    from pydrake.all import LogOutput, DirectCollocation, DirectTranscription, MathematicalProgram, InputPortSelection

# Imports of other project files
from log_wrapper import LogWrapper
import constants
import finger
import pedestal
from paper import Paper

from pydrake.all import (MultibodyPlant, Parser, DiagramBuilder, Simulator,
                         PlanarSceneGraphVisualizer, SceneGraph, TrajectorySource,
                         SnoptSolver, MultibodyPositionToGeometryPose, PiecewisePolynomial,
                         MathematicalProgram, JacobianWrtVariable, eq)

# Other imports
import importlib

In [None]:
# Matplotlib configuring
plt.style.use(['science', 'no-latex'])
font = {'size'   : 14}
matplotlib.rc('font', **font)

In [None]:
# Meshcat init
proc, zmq_url, web_url = start_zmq_server_as_subprocess()

In [None]:
## Pre-finalize steps
builder = pydrake.systems.framework.DiagramBuilder()

# Add all elements
plant, scene_graph = pydrake.multibody.plant.AddMultibodyPlantSceneGraph(builder, time_step=constants.DT)
pedestal_instance = pedestal.AddPedestal(plant)

# These joint angles start the paper approximately the right spot
paper = Paper(plant, 20, default_joint_angle= [
    -np.pi/100,
    -np.pi/100,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    np.pi/10,
    np.pi/10,
    np.pi/10,
    np.pi/10,
    np.pi/10,
    0,
    0,
    0,
    0,
])
paper.weld_paper_edge(pedestal.PEDESTAL_WIDTH, pedestal.PEDESTAL_HEIGHT)

finger_instance = finger.AddFinger(plant, constants.INIT_X, constants.INIT_Z)

# Set up logger (needs to happen after all bodies are added)
log_wrapper = LogWrapper(plant.num_bodies())
builder.AddSystem(log_wrapper)

In [None]:
## CHOOSE CONTROL SYSTEM HERE BY UNCOMMENTING
# # PD control: hits too low
# finger_ctrlr = finger.PDFinger(
#     plant,
#     int(finger_instance),
#     [
#         [constants.INIT_X, constants.INIT_Z],
#         [constants.INIT_X*1.1, constants.INIT_Z],
#         [constants.INIT_X*1.1, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.05],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.05],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2],
#     ],
#     tspan_per_segment=1,
#     kx=10,
#     kz=10
# )

# # PD control: hits too high
# finger_ctrlr = finger.PDFinger(
#     plant,
#     int(finger_instance),
#     [
#         [constants.INIT_X, constants.INIT_Z],
#         [constants.INIT_X*1.1, constants.INIT_Z],
#         [constants.INIT_X*1.1, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.06],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2+0.06],
#         [0, pedestal.PEDESTAL_HEIGHT+constants.FINGER_RADIUS/2],
#     ],
# )


# # Edge feedback
# finger_ctrlr = finger.EdgeController(
#     plant,
#     int(finger_instance),
#     int(paper.get_free_edge_instance()),
#     K=15,
#     F_Nd=0.03,
#     d_d=0.04,
#     w_l=paper.link_width,
#     debug=True,
# )

# Optimization controller
finger_ctrlr = finger.OptimizationController(
    plant,
    paper,
    int(finger_instance)
)

In [None]:
plant.Finalize()

## Post finalize steps
paper.init_ctrlrs(builder)

# Conect finger controller
builder.AddSystem(finger_ctrlr)
builder.Connect(finger_ctrlr.get_output_port(), plant.get_actuation_input_port(finger_instance))
builder.Connect(plant.get_body_poses_output_port(), finger_ctrlr.get_input_port(0))
builder.Connect(plant.get_body_spatial_velocities_output_port(), finger_ctrlr.get_input_port(1))

# Add logger
builder.Connect(plant.get_body_poses_output_port(), log_wrapper.get_input_port(0))
builder.Connect(plant.get_body_spatial_velocities_output_port(), log_wrapper.get_input_port(1))

# Visualization and logging
logger = LogOutput(log_wrapper.get_output_port(), builder)
vis = pydrake.systems.meshcat_visualizer.ConnectMeshcatVisualizer(builder, scene_graph)

# Build diagram and do actions requiring 
diagram = builder.Build()
diagram_context = diagram.CreateDefaultContext()
paper.connect_ctrlrs(diagram, diagram_context)

In [None]:
finger_ctrlr.optimize()

In [None]:
# plant = plant.ToAutoDiffXd()
# nq = plant.num_positions()
# nf = 2
# def manipulator_equations(vars):
#     # configuration, velocity, acceleration, force
#     assert vars.size == 3 * nq + nf
#     split_at = [nq, 2 * nq, 3 * nq]
#     q, qd, qdd, f = np.split(vars, split_at)
    
#     # set state
#     context = plant.CreateDefaultContext()
#     plant.SetPositions(context, q)
#     plant.SetVelocities(context, qd)
    
#     # matrices for the manipulator equations
#     M = plant.CalcMassMatrixViaInverseDynamics(context)
#     Cv = plant.CalcBiasTerm(context)
#     tauG = plant.CalcGravityGeneralizedForces(context)
    
#     # Jacobian of the stance foot
#     J = get_foot_jacobian(plant, context)
    
#     # return violation of the manipulator equations
#     return M.dot(qdd) + Cv - tauG - J.T.dot(f)

# def get_foot_jacobian(plant, context):
    
#     # get reference frames for the given leg and the ground
#     finger_frame = plant.GetBodyByName("finger_body").body_frame()
#     wolrd_frame = plant.world_frame()

#     # compute Jacobian matrix
#     J = plant.CalcJacobianTranslationalVelocity(
#         context,
#         JacobianWrtVariable(0),
#         finger_frame,
#         [0, 0, 0],
#         wolrd_frame,
#         wolrd_frame
#     )
    
#     # discard y components since we are in 2D
#     return J[[0, 2]]

In [None]:
# # time steps in the trajectory optimization
# opt_dt = 100*constants.DT
# T = int(constants.TSPAN/(opt_dt))

# # minimum and maximum time interval is seconds
# h_min = opt_dt*(1 - 1e-1)
# h_max = opt_dt*(1 + 1e-1)

# # initialize program
# prog = MathematicalProgram()

# # vector of the time intervals
# # (distances between the T + 1 break points)
# h = prog.NewContinuousVariables(T, name='h')

# # system configuration, generalized velocities, and accelerations
# q = prog.NewContinuousVariables(rows=T+1, cols=nq, name='q')
# qd = prog.NewContinuousVariables(rows=T+1, cols=nq, name='qd')
# qdd = prog.NewContinuousVariables(rows=T, cols=nq, name='qdd')

# # stance-foot force
# f = prog.NewContinuousVariables(rows=T, cols=2, name='f')

# # lower and upper bound on the time steps for all t
# prog.AddBoundingBoxConstraint([h_min] * T, [h_max] * T, h)

# # link the configurations, velocities, and accelerations
# # uses implicit Euler method, https://en.wikipedia.org/wiki/Backward_Euler_method
# for t in range(T):
#     prog.AddConstraint(eq(q[t+1], q[t] + h[t] * qd[t+1]))
#     prog.AddConstraint(eq(qd[t+1], qd[t] + h[t] * qdd[t]))

# # manipulator equations for all t (implicit Euler)
# for t in range(T):
#     vars = np.concatenate((q[t+1], qd[t+1], qdd[t], f[t]))
#     prog.AddConstraint(manipulator_equations, lb=[0]*nq, ub=[0]*nq, vars=vars)

In [None]:
# solver = SnoptSolver()
# result = solver.Solve(prog)
# result.is_success()

In [None]:
# result.GetSolution(f)

In [None]:
# Finalize simulation and visualization
simulator = pydrake.systems.analysis.Simulator(diagram, diagram_context)
simulator.Initialize()
vis.start_recording()
simulator.AdvanceTo(constants.TSPAN)
vis.stop_recording()
vis.publish_recording()

In [None]:
# Plot manipulator position vs. trajectory
if type(finger_ctrlr) is finger.PDFinger:
    plt.figure(figsize=(2*3,2*2))
    plt.plot(logger.data()[12*int(finger_instance)],
             logger.data()[12*int(finger_instance)+2],
             label='Manipulator position')
    plt.plot(finger_ctrlr.xs, finger_ctrlr.zs, label='Trajectory')
    plt.xlabel("$x$ position")
    plt.ylabel("$z$ position")
    plt.legend()
    plt.show()

In [None]:
# Generate plots for paper
nb = plant.num_bodies()
paper_idxs = [int(i) for i in paper.link_instances]
x_traces = []
z_traces = []
for b in paper_idxs:
    x_traces.append(logger.data()[12*b])
    z_traces.append(logger.data()[12*b+2])
x_traces = np.array(x_traces)
z_traces = np.array(z_traces)

times = np.arange(0,constants.TSPAN,constants.TSPAN/15)
cmap = cm.get_cmap("viridis_r")
plt.figure(figsize=(2*3,2*2))
if type(finger_ctrlr) is finger.PDFinger:
    plt.plot(finger_ctrlr.xs, finger_ctrlr.zs, '--k', zorder=-1)
for t in times:
    c = cmap(t/constants.TSPAN)
    idx = int(t/constants.DT)
    plt.plot(x_traces[:-1,idx], z_traces[:-1,idx], color=c)
    plt.scatter(logger.data()[12*int(finger_instance),idx], 
                logger.data()[12*int(finger_instance)+2,idx],
                color=c, s=300, zorder=1)

xlim = plt.xlim()
ylim = plt.ylim()
plt.scatter([xlim[0]-50, xlim[0]-50], [ylim[0]-50, ylim[0]-50], c=[0, constants.TSPAN], cmap=cmap)
plt.xlim(xlim)
plt.ylim(ylim)
cb = plt.colorbar()
cb.set_label("Time")
plt.xlabel("$x$ position")
plt.ylabel("$z$ position")
plt.show()

In [None]:
plt.figure(figsize=(12,12))
if type(finger_ctrlr) is finger.EdgeController:
    plt.plot(finger_ctrlr.debug['FN'])
    plt.plot(finger_ctrlr.debug['d'])
    plt.axhline(finger_ctrlr.F_Nd)