In [None]:
import numpy as np
import pydot
from pydrake.all import (
    DiagramBuilder,
    MultibodyPlant,
    Parser,
    Propeller,
    PropellerInfo,
    RigidTransform,
    StartMeshcat,
    MeshcatVisualizer,
    SceneGraph,
    Simulator
)
from pydrake.examples import (
    QuadrotorGeometry
)
from IPython.display import display, SVG

from underactuated.scenarios import AddFloatingRpyJoint

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

In [None]:
def make_n_quadrotor_system(n):
    builder = DiagramBuilder()
    # The MultibodyPlant handles f=ma, but doesn't know about propellers.
    plant = builder.AddSystem(MultibodyPlant(0.0))
    parser = Parser(plant)
    parser.SetAutoRenaming(True)
    quadrotor_model_instances = []
    for i in range(n):
        (model_instance,) = parser.AddModelsFromUrl(
            "package://drake/examples/quadrotor/quadrotor.urdf"
        )
        quadrotor_model_instances.append(model_instance)
        # By default the multibody has a quaternion floating base.  To match
        # QuadrotorPlant, we can manually add a FloatingRollPitchYaw joint. We set
        # `use_ball_rpy` to false because the BallRpyJoint uses angular velocities
        # instead of ṙ, ṗ, ẏ.
        AddFloatingRpyJoint(
            plant,
            plant.GetFrameByName("base_link", model_instance),
            model_instance,
            use_ball_rpy=False,
        )
    
    plant.Finalize()

    # Default parameters from quadrotor_plant.cc:
    L = 0.15  # Length of the arms (m).
    kF = 1.0  # Force input constant.
    kM = 0.0245  # Moment input constant.
    
    # Now we can add in propellers as an external force on the MultibodyPlant.
    prop_info = []
    for model_instance in quadrotor_model_instances:
        body_index = plant.GetBodyByName("base_link", model_instance).index()
        # Note: Rotors 0 and 2 rotate one way and rotors 1 and 3 rotate the other.
        prop_info += [
            PropellerInfo(body_index, RigidTransform([L, 0, 0]), kF, kM),
            PropellerInfo(body_index, RigidTransform([0, L, 0]), kF, -kM),
            PropellerInfo(body_index, RigidTransform([-L, 0, 0]), kF, kM),
            PropellerInfo(body_index, RigidTransform([0, -L, 0]), kF, -kM),
        ]

    propellers = builder.AddNamedSystem("propeller", Propeller(prop_info))
    builder.Connect(
        propellers.get_output_port(),
        plant.get_applied_spatial_force_input_port(),
    )
    builder.Connect(
        plant.get_body_poses_output_port(),
        propellers.get_body_poses_input_port(),
    )
    builder.ExportInput(propellers.get_command_input_port(), "u")
    
    # Set up meshcat
    scene_graph = builder.AddSystem(SceneGraph())
    for model_instance in quadrotor_model_instances:
        QuadrotorGeometry.AddToBuilder(
            builder, plant.get_state_output_port(model_instance), scene_graph
        )
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
    
    return builder.Build(), plant

In [None]:
diagram, plant = make_n_quadrotor_system(5)

In [None]:
display(
    SVG(
        pydot.graph_from_dot_data(plant.GetTopologyGraphvizString())[0].create_svg()
    )
)

In [None]:
display(
    SVG(
        pydot.graph_from_dot_data(diagram.GetGraphvizString())[
            0
        ].create_svg()
    )
)

In [None]:
diagram.GetInputPort("u").size()

In [None]:
simulator = Simulator(diagram)
simulator.set_target_realtime_rate(1.0)
context = simulator.get_mutable_context()

u = diagram.GetInputPort("u")
u.FixValue(context, np.zeros(u.size()))

# Simulate
for i in range(5):
    context.SetTime(0.0)
    context.SetContinuousState(
        0.5
        * np.random.randn(
            context.num_continuous_states(),
        )
    )
    simulator.Initialize()
    simulator.AdvanceTo(4.0)
