In [None]:
# Suprress warnings
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    # Standard imports
    import matplotlib.pyplot as plt
    import numpy as np

    # Drake imports
    from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
    import pydrake.all
    from pydrake.all import (
        RigidTransform, RotationMatrix, RollPitchYaw, RevoluteSpring, LogOutput, SignalLogger, ForceElement, 
        DoorHinge, DoorHingeConfig, PidController, FixedInputPortValue
    )
from paper import Paper
import constants

In [None]:
# Meshcat init
proc, zmq_url, web_url = start_zmq_server_as_subprocess()

In [None]:
# Constants
PEDESTAL_WIDTH = 0.225
PEDESTAL_HEIGHT = 0.1
PEDESTAL_DEPTH = 0.3

DT = 1e-4
INIT_X = PEDESTAL_WIDTH*0.6
INIT_Z = 0

# Fingers normally are around 2cm wide
FINGER_RADIUS = 0.01
FINGER_VOLUME = (4/3)*FINGER_RADIUS**3*np.pi
FINGER_MASS = FINGER_VOLUME*1e3 # Assume finger is made of water

In [None]:
def AddPedestal(plant):
    # Parse pedestal model
    parser = pydrake.multibody.parsing.Parser(plant)
    pedestal_instance = parser.AddModelFromFile("pedestal.sdf")
    pedestal_body = plant.GetBodyByName("pedestal_base", pedestal_instance)

    # Weld pedestal to world
    plant.WeldFrames(
        plant.world_frame(),
        plant.GetFrameByName("pedestal_base", pedestal_instance),
        RigidTransform(RotationMatrix(), [0, 0, PEDESTAL_HEIGHT/2])
    )
    
    return pedestal_instance

In [None]:
def AddPointFinger(plant, init_x=INIT_X, init_z=INIT_Z):
    radius = FINGER_RADIUS
    finger = plant.AddModelInstance("finger")

    # Add false body at the origin that the finger 
    false_body = plant.AddRigidBody("false_body", finger,
        pydrake.multibody.tree.SpatialInertia(0, [0,0,0], pydrake.multibody.tree.UnitInertia(0,0,0)))

    # Initialize finger body
    finger_body = plant.AddRigidBody("body", finger, 
        pydrake.multibody.tree.SpatialInertia(
            mass=FINGER_MASS,
            p_PScm_E=np.array([0., 0., 0.]),
            G_SP_E=pydrake.multibody.tree.UnitInertia(1.0, 1.0, 1.0)))

    # Register geometry
    shape = pydrake.geometry.Sphere(radius)
    if plant.geometry_source_is_registered():
        plant.RegisterCollisionGeometry(
            finger_body, RigidTransform(), shape, "body", pydrake.multibody.plant.CoulombFriction(
                constants.FRICTION, constants.FRICTION))
        plant.RegisterVisualGeometry(finger_body, RigidTransform(), shape, "body", [.9, .5, .5, 1.0])
    
    # Add control joins for x and z movement
    finger_x = plant.AddJoint(pydrake.multibody.tree.PrismaticJoint(
        "finger_x",
        plant.world_frame(),
        plant.GetFrameByName("false_body"), [1, 0, 0], -1, 1))
    plant.AddJointActuator("finger_x", finger_x)
    finger_x.set_default_translation(init_x)
    finger_z = plant.AddJoint(pydrake.multibody.tree.PrismaticJoint(
        "finger_z",
        plant.GetFrameByName("false_body"),
        plant.GetFrameByName("body"), [0, 0, 1], -1, 1))
    finger_z.set_default_translation(init_z)
    plant.AddJointActuator("finger_z", finger_z)

    return finger

In [None]:
class FingerControl(pydrake.systems.framework.LeafSystem):

    def __init__(self, plant, traj):
        pydrake.systems.framework.LeafSystem.__init__(self)
        self._plant = plant

        self.DeclareVectorInputPort(
            "finger_state", pydrake.systems.framework.BasicVector(4))
        self.DeclareVectorOutputPort(
            "finger_actuation", pydrake.systems.framework.BasicVector(2), 
                                    self.CalcOutput)
        
        self.xs, self.zs, self.xdots, self.zdots = traj
        plt.figure()
        plt.plot(self.xs)
        plt.plot(self.zs)
        plt.plot(self.xdots)
        plt.plot(self.zdots)
        plt.figure()
        plt.plot(self.xs, self.zs)
        self.kx = 5
        self.kz = 5
        self.dx = 0.01
        self.dz = 0.01
        self.idx = 0

    def CalcOutput(self, context, output):
        g = self._plant.gravity_field().gravity_vector()[[0,2]]
#         print(self.idx, len(self.xs))
        x, z, xdot, zdot = self.get_input_port(0).Eval(context)
        if self.idx < len(self.xs):
            fx = self.kx*(self.xs[self.idx] - x) + self.dx*(self.xdots[self.idx] - xdot)
            fz = self.kz*(self.zs[self.idx] - z) + self.dz*(self.zdots[self.idx] - zdot)
        else:
            fx = self.kx*(self.xs[-1] - x) + self.dx*(-xdot)
            fz = self.kz*(self.zs[-1] - z) + self.dz*(- zdot)
        output.SetFromVector(-FINGER_MASS*g + [fx, fz])
        self.idx += 1

In [None]:
pts = [
        [INIT_X, INIT_Z],
        [INIT_X*1.1, INIT_Z],
        [INIT_X*1.1, PEDESTAL_HEIGHT+FINGER_RADIUS/2+0.06],
        [0, PEDESTAL_HEIGHT+FINGER_RADIUS/2+0.06],
        [0, PEDESTAL_HEIGHT+FINGER_RADIUS/2],
    ]

def pts_to_traj(pts, tspan_per_segment, dt):
    xs = []
    zs = []
    xdots = [0]
    zdots = [0]
    
    for segment_i in range(len(pts)-1):
        start_pt = pts[segment_i]
        end_pt = pts[segment_i+1]
        for prog_frac in np.arange(0, tspan_per_segment, dt)/tspan_per_segment:
            new_x = start_pt[0] + (end_pt[0] - start_pt[0])*prog_frac
            xs.append(new_x)
            new_z = start_pt[1] + (end_pt[1] - start_pt[1])*prog_frac
            zs.append(new_z)
    
    for i in range(len(xs)-1):
        xdots.append((xs[i+1]-xs[i])/dt)
        zdots.append((zs[i+1]-zs[i])/dt)
        
    return xs, zs, xdots, zdots

traj = pts_to_traj(pts, 5, DT)

In [None]:
## Pre-finalize steps
builder = pydrake.systems.framework.DiagramBuilder()

# Add all elements
plant, scene_graph = pydrake.multibody.plant.AddMultibodyPlantSceneGraph(builder, time_step=DT)
pedestal_instance = AddPedestal(plant)

paper = Paper(plant, 20)
paper.weld_paper_edge(PEDESTAL_WIDTH, PEDESTAL_HEIGHT)

finger_instance = AddPointFinger(plant)
finger_ctrlr = FingerControl(plant, traj)

In [None]:
plant.Finalize()

## Post finalize steps
paper.init_ctrlrs(builder)

# Conect finger controller
builder.AddSystem(finger_ctrlr)
builder.Connect(finger_ctrlr.get_output_port(), plant.get_actuation_input_port(finger_instance))
builder.Connect(plant.get_state_output_port(finger_instance), finger_ctrlr.get_input_port())

# Visualization and logging
poses_logger = LogOutput(plant.get_state_output_port(finger_instance), builder)
vis = pydrake.systems.meshcat_visualizer.ConnectMeshcatVisualizer(builder, scene_graph)

# Build diagram and do actions requiring 
diagram = builder.Build()
diagram_context = diagram.CreateDefaultContext()
paper.connect_ctrlrs(diagram, diagram_context)

In [None]:
# Finalize simulation and visualization
simulator = pydrake.systems.analysis.Simulator(diagram, diagram_context)
simulator.Initialize()
vis.start_recording()
simulator.AdvanceTo(20)
vis.stop_recording()
vis.publish_recording()

In [None]:
plt.figure()
plt.plot(poses_logger.sample_times(), poses_logger.data()[0,:], 'b', label="x position")
plt.plot(poses_logger.sample_times(), poses_logger.data()[2,:], 'b--', label="x velocity")
plt.plot(poses_logger.sample_times(), poses_logger.data()[1,:], 'm', label="y position")
plt.plot(poses_logger.sample_times(), poses_logger.data()[3,:], 'm--', label="y velocity")
plt.legend()
plt.xlabel("Time (s)")
plt.show()

In [None]:
plt.figure()
plt.plot(poses_logger.data()[0,:], poses_logger.data()[1,:])
plt.plot(traj[0], traj[1])

In [None]:
len(traj)