In [1]:
import numpy as np
import pydot
from IPython.display import HTML, SVG, display
from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    DiagramBuilder,
    GenerateHtml,
    InverseDynamicsController,
    PidController,
    MeshcatVisualizer,
    MeshcatVisualizerParams,
    MultibodyPlant,
    Parser,
    Simulator,
    StartMeshcat,
    SceneGraph,
    PassThrough,
    Demultiplexer,
    LeafSystem,
)

from catbot.utils.meshcat_util import MeshcatCatBotSliders 

In [2]:
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7001


In [8]:
# -- Add Reward (for testing RL Reward) -- #
class RewardSystem(LeafSystem):
    def __init__(self):
        LeafSystem.__init__(self)
        self.DeclareVectorInputPort("catbot_state", 10)
        self.DeclareVectorOutputPort("reward", 1, self.CalcReward)

    def CalcReward(self, context, output):
        # state is Center, a_hinge, a_rot, b_hinge, b_rot
        catbot_state = self.get_input_port(0).Eval(context)
        print('center: ', catbot_state[0])
        print('a_rot: ', catbot_state[2])
        print('b_rot: ', catbot_state[4])

        # So we clamp the angle between 0 and 2pi
        a_hinge_world = (catbot_state[0] + catbot_state[2]) % (np.pi * 2) - np.pi
        b_hinge_world = (catbot_state[0] + catbot_state[4]) % (np.pi * 2) - np.pi

        print('a_world: ', a_hinge_world)
        print('b_world: ', b_hinge_world)


        # Add position cost
        cost = a_hinge_world **2 + b_hinge_world **2

        state_to_control_projection = np.array([
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
        ])

        # Note that we don't include effort cost here

        print('cost: ', cost)
        output[0] = 20 - cost

In [10]:
time_step = 1e-3

builder = DiagramBuilder()

# -- Add original plant -- #
# Adds both MultibodyPlant and the SceneGraph, and wires them together.
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=time_step)
# Note that we parse into both the plant and the scene_graph here.
model = Parser(plant, scene_graph).AddModelFromFile(
    "../models/singleAxisCatBot.urdf")

gravity_field = plant.mutable_gravity_field()
gravity_field.set_gravity_vector(np.array([0.0, 0.0, 0.0]))

plant.Finalize()

# print('plant names: ', plant.GetPositionNames(model))
# print('plant state: ', plant.GetState(model))

# -- Add controller plant -- #
controller_plant = MultibodyPlant(time_step=time_step)
model = Parser(controller_plant).AddModelFromFile(
    "../models/singleAxisCatBot.urdf")
controller_gravity_field = controller_plant.mutable_gravity_field()
controller_gravity_field.set_gravity_vector(np.array([0.0, 0.0, 0.0]))
controller_plant.Finalize()

# -- Add visualizer -- #
visualizer = MeshcatVisualizer.AddToBuilder(builder, 
                                            scene_graph.get_query_output_port(),
                                            meshcat)

meshcat.ResetRenderMode()
meshcat.DeleteAddedControls()

# -- Add controller -- #
# Create a PID controller for each joint.
num_model_actuators = controller_plant.num_actuators()

# state_projection_matrix = np.array([
#     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
#     [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
#     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
#     [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],

#     [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
#     [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
#     [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
#     [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
# ])

# actuator pos is a_rev, b_rev, a_hinge, b_hinge
# state is Center, a_hinge?, a_rot, b_hinge, b_rot 
# idk why they're different
state_projection_matrix = np.array([
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],

    [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
])

# kp = [0.0000, 0.000, 0.02, 0.02]
# ki = [0.00, 0.00, 0.0, 0.0]
# kd = [0.00, 0.00, 0.02, 0.02]

kp = [0.03, 0.03, 0.03, 0.03]
ki = [0.00, 0.00, 0.0, 0.0]
kd = [0.025, 0.025, 0.04, 0.04]
catbot_controller = builder.AddSystem(
    PidController(state_projection_matrix, kp, ki, kd)
)

catbot_controller.set_name("catbot_controller")
builder.Connect(
    plant.get_state_output_port(model),
    catbot_controller.get_input_port_estimated_state(),
)
builder.Connect(
    catbot_controller.get_output_port_control(), 
    plant.get_actuation_input_port()
)

# -- Set up teleop wigits -- #
teleop = builder.AddSystem(
    MeshcatCatBotSliders(
        meshcat,
    ))

builder.Connect(teleop.get_output_port(0), catbot_controller.get_input_port_desired_state())
builder.Connect(plant.get_state_output_port(), teleop.get_input_port(0))

# -- Add reward system -- #
reward = builder.AddSystem(RewardSystem())
builder.Connect(plant.get_state_output_port(), reward.get_input_port(0))
builder.Connect(reward.get_output_port(0), teleop.get_input_port(1))


diagram = builder.Build()
simulator = Simulator(diagram)
context = simulator.get_mutable_context()

simulator.set_target_realtime_rate(1.0)
meshcat.AddButton("Stop Simulation", "Escape")
print("Press Escape to stop the simulation")

# -- Set init pose -- #
# state is Center, a_hinge?, a_rot, b_hinge, b_rot 
q0 = [0.0, 0.00, 0.00, 0.0, 0.0]
plant_context = plant.GetMyMutableContextFromRoot(context)
plant.SetPositions(plant_context, q0)
simulator.AdvanceTo(0.01)

# -- Uncomment to simulate -- #
cnt = 0
while meshcat.GetButtonClicks("Stop Simulation") < 1:
    simulator.AdvanceTo(simulator.get_context().get_time() + 1.0)
    # if cnt % 10 == 0:
    #     print_status()
    #     input('Continue')
    # cnt += 1

meshcat.DeleteButton("Stop Simulation")

Keyboard Controls:
a_rev : KeyQ / KeyE
b_rev : KeyW / KeyS
a_hinge : KeyA / KeyD
b_hinge : KeyJ / KeyL
Press Escape to stop the simulation
Pos:  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
center:  0.0
a_rot:  0.0
b_rot:  0.0
a_world:  -3.141592653589793
b_world:  -3.141592653589793
cost:  19.739208802178716
Reward: [0.2607912]
center:  -0.3346490419657037
a_rot:  0.6911502825247408
b_rot:  0.0004577375377117317
a_world:  -2.7850914130307562
b_world:  2.8074013491618013
cost:  15.638236514213157
Reward: [4.36176349]
center:  -0.03910353488668511
a_rot:  1.1495120777652486
b_rot:  -1.0686981056661178
a_world:  -2.0311841107112296
b_world:  2.03379101303699
cost:  8.262014776315795
Reward: [11.73798522]
center:  1.0322404512453422
a_rot:  -1.0636333327161578
b_rot:  -1.0696635998575303
a_world:  3.1101997721189774
b_world:  3.1041695049776052
cost:  19.30921093812185
Reward: [0.69078906]
center:  1.0353024118907264
a_rot:  -1.0696551180456129
b_rot:  -1.0699709637327877
a_world:  3.1072399474349064
