In [None]:
import numpy as np
import pydot
from pydrake.all import (
    DiagramBuilder,
    MultibodyPlant,
    Parser,
    Propeller,
    PropellerInfo,
    RigidTransform,
    StartMeshcat,
    MeshcatVisualizer,
    SceneGraph,
    Simulator,
    AddMultibodyPlantSceneGraph,
    LeafSystem,
    LeafSystem_,
    ExternallyAppliedSpatialForce,
    ExternallyAppliedSpatialForce_,
    TemplateSystem,
    AbstractValue
)
from pydrake.examples import (
    QuadrotorGeometry
)
from IPython.display import display, SVG, Image

from underactuated.scenarios import AddFloatingRpyJoint

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

In [None]:
class TensileForces(LeafSystem):
    pass # TODO

# Thanks David!
# https://stackoverflow.com/a/72121171/9796174
@TemplateSystem.define("SpatialForceConcatinator_")
def SpatialForceConcatinator_(T):
    class Impl(LeafSystem_[T]):
        def _construct(self, N_inputs, converter = None):
            LeafSystem_[T].__init__(self, converter)
            self.N_inputs = N_inputs
            self.Input_ports = [self.DeclareAbstractInputPort(f"Spatial_Force_{i}",
                                AbstractValue.Make([ExternallyAppliedSpatialForce_[T]()]))
                                for i in range(N_inputs)]
        
            self.Output_port = self.DeclareAbstractOutputPort("Spatial_Forces",
                                           lambda: AbstractValue.Make(                                             
                                           [ExternallyAppliedSpatialForce_[T]()
                                              for i in range(N_inputs)]),
                                           self.Concatenate)

        def Concatenate(self, context, output):
            out = []
            for port in self.Input_ports:
                out += port.Eval(context)
            output.set_value(out)
        
        def _construct_copy(self, other, converter=None,):
            Impl._construct(self, other.N_inputs, converter=converter)
    
    return Impl

# Default instantations
SpatialForceConcatinator = SpatialForceConcatinator_[None]

In [None]:
def make_n_quadrotor_system(n):
    builder = DiagramBuilder()
    # The MultibodyPlant handles f=ma, but doesn't know about propellers.
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
    parser = Parser(plant)
    parser.SetAutoRenaming(True)
    quadrotor_model_instances = []
    for i in range(n):
        (model_instance,) = parser.AddModelsFromUrl(
            "package://drake/examples/quadrotor/quadrotor.urdf"
        )
        quadrotor_model_instances.append(model_instance)
        # By default the multibody has a quaternion floating base.  To match
        # QuadrotorPlant, we can manually add a FloatingRollPitchYaw joint. We set
        # `use_ball_rpy` to false because the BallRpyJoint uses angular velocities
        # instead of ṙ, ṗ, ẏ.
        AddFloatingRpyJoint(
            plant,
            plant.GetFrameByName("base_link", model_instance),
            model_instance,
            use_ball_rpy=False,
        )
    
    plant.Finalize()

    # Default parameters from quadrotor_plant.cc:
    L = 0.15  # Length of the arms (m).
    kF = 1.0  # Force input constant.
    kM = 0.0245  # Moment input constant.
    
    # Now we can add in propellers as an external force on the MultibodyPlant.
    prop_info = []
    for model_instance in quadrotor_model_instances:
        body_index = plant.GetBodyByName("base_link", model_instance).index()
        # Note: Rotors 0 and 2 rotate one way and rotors 1 and 3 rotate the other.
        prop_info += [
            PropellerInfo(body_index, RigidTransform([L, 0, 0]), kF, kM),
            PropellerInfo(body_index, RigidTransform([0, L, 0]), kF, -kM),
            PropellerInfo(body_index, RigidTransform([-L, 0, 0]), kF, kM),
            PropellerInfo(body_index, RigidTransform([0, -L, 0]), kF, -kM),
        ]

    propellers = builder.AddNamedSystem("propeller", Propeller(prop_info))
    combiner = builder.AddNamedSystem("combiner", SpatialForceConcatinator(2))
    builder.Connect(
        propellers.get_output_port(),
        combiner.Input_ports[0]
    )
    builder.Connect(
        combiner.Output_port,
        plant.get_applied_spatial_force_input_port()
    )
    builder.Connect(
        plant.get_body_poses_output_port(),
        propellers.get_body_poses_input_port(),
    )
    builder.ExportInput(propellers.get_command_input_port(), "u")

    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
    
    return builder.Build(), plant

In [None]:
diagram, plant = make_n_quadrotor_system(5)

In [None]:
display(
    Image(
        pydot.graph_from_dot_data(plant.GetTopologyGraphvizString())[0].create_png()
    )
)

In [None]:
display(
    Image(
        pydot.graph_from_dot_data(diagram.GetGraphvizString())[
            0
        ].create_png()
    )
)

In [None]:
def CreateNullExternalForce(plant):
    f = ExternallyAppliedSpatialForce()
    f.body_index = plant.world_body().index()
    return f

In [None]:
simulator = Simulator(diagram)
simulator.set_target_realtime_rate(0.5)
context = simulator.get_mutable_context()

u = diagram.GetInputPort("u")
u.FixValue(context, np.zeros(u.size()))

CreateNullExternalForce(diagram.GetSubsystemByName("plant"))

combiner_system = diagram.GetSubsystemByName("combiner")
combiner_empty_port = combiner_system.Input_ports[1]
combiner_empty_port.FixValue(combiner_system.GetMyContextFromRoot(context), [CreateNullExternalForce(plant)])

# Simulate
while True:
    context.SetTime(0.0)
    context.SetContinuousState(
        0.5
        * np.random.randn(
            context.num_continuous_states(),
        )
    )
    simulator.Initialize()
    simulator.AdvanceTo(1.5)