In [None]:
import numpy as np
import pydot
from pydrake.all import (
    DiagramBuilder,
    MultibodyPlant,
    Parser,
    Propeller,
    PropellerInfo,
    RigidTransform,
    StartMeshcat,
    MeshcatVisualizer,
    SceneGraph,
    Simulator,
    AddMultibodyPlantSceneGraph,
    LeafSystem,
    LeafSystem_,
    ExternallyAppliedSpatialForce,
    ExternallyAppliedSpatialForce_,
    TemplateSystem,
    AbstractValue,
    SpatialForce,
    SpatialForce_,
    CollisionFilterDeclaration,
    GeometrySet,
    LinearQuadraticRegulator
)

from pydrake.examples import (
    QuadrotorGeometry
)
from IPython.display import display, SVG, Image

from underactuated.scenarios import AddFloatingRpyJoint
from underactuated.utils import MakeNamedViewPositions

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

In [None]:
@TemplateSystem.define("TensileForce_")
def TensileForce_(T):
    class Impl(LeafSystem_[T]):
        def _construct(self, length, hooke_K, anchor_point, body_index, converter=None):
            LeafSystem_[T].__init__(self, converter)
            self.length = length  # In meters, > 0
            self.hooke_K = hooke_K  # Hooke's law spring constant, > 0
            self.anchor_point = anchor_point  # 3d point, given in the world frame
            self.body_index = body_index  # Index of the affected bod from the plant

            #             self.state_input = self.DeclareAbstractInputPort("State_Input",
            #                                                             )
            self.state_input = self.DeclareVectorInputPort("state_input", size=12)
            self.force_output = self.DeclareAbstractOutputPort("force_output",
                                                               alloc=lambda: AbstractValue.Make(
                                                                   [ExternallyAppliedSpatialForce_[T]()]),
                                                               calc=self.OutputForce)

        def OutputForce(self, context, output):
            state = self.state_input.Eval(context)
            pos = state[0:3]
            dist = np.linalg.norm(pos - self.anchor_point)

            f_mag = self.hooke_K * (dist - self.length)
            f_dir = (self.anchor_point - pos) / (
                    dist + 1e-6)  # add a machine zero so that autodiff for linearization can work
            f = f_mag * f_dir

            F_Bq_W = SpatialForce_[T](np.zeros((3, 1)), f.reshape(-1, 1))
            p_BoBq_B = np.zeros(3)  # Assume the force is applied at the body origin

            o = ExternallyAppliedSpatialForce_[T]()
            o.body_index = self.body_index
            o.F_Bq_W = F_Bq_W
            o.p_BoBq_B = p_BoBq_B
            output.set_value([o])

        def _construct_copy(self, other, converter=None, ):
            Impl._construct(self, other.length, other.hooke_K, other.anchor_point,
                            other.body_index, converter=converter)

    return Impl


# Thanks David!
# https://stackoverflow.com/a/72121171/9796174
@TemplateSystem.define("SpatialForceConcatinator_")
def SpatialForceConcatinator_(T):
    class Impl(LeafSystem_[T]):
        def _construct(self, N_inputs, converter=None):
            LeafSystem_[T].__init__(self, converter)
            self.N_inputs = N_inputs
            self.Input_ports = [self.DeclareAbstractInputPort(f"Spatial_Force_{i}",
                                                              AbstractValue.Make([ExternallyAppliedSpatialForce_[T]()]))
                                for i in range(N_inputs)]

            self.Output_port = self.DeclareAbstractOutputPort("Spatial_Forces",
                                                              lambda: AbstractValue.Make(
                                                                  [ExternallyAppliedSpatialForce_[T]()
                                                                   for i in range(N_inputs)]),
                                                              self.Concatenate)

        def Concatenate(self, context, output):
            out = []
            for port in self.Input_ports:
                out += port.Eval(context)
#             for o in out:
#                 print(o.F_Bq_W)
            output.set_value(out)

        def _construct_copy(self, other, converter=None, ):
            Impl._construct(self, other.N_inputs, converter=converter)

    return Impl


# Default instantations
TensileForce = TensileForce_[None]
SpatialForceConcatinator = SpatialForceConcatinator_[None]

In [None]:
def make_n_quadrotor_system(n):
    builder = DiagramBuilder()
    # The MultibodyPlant handles f=ma, but doesn't know about propellers.
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
    parser = Parser(plant)
    parser.SetAutoRenaming(True)
    quadrotor_model_instances = []
    for i in range(n):
        (model_instance,) = parser.AddModelsFromUrl(
            "package://drake/examples/quadrotor/quadrotor.urdf"
        )
        quadrotor_model_instances.append(model_instance)
        # By default the multibody has a quaternion floating base.  To match
        # QuadrotorPlant, we can manually add a FloatingRollPitchYaw joint. We set
        # `use_ball_rpy` to false because the BallRpyJoint uses angular velocities
        # instead of ṙ, ṗ, ẏ.
        AddFloatingRpyJoint(
            plant,
            plant.GetFrameByName("base_link", model_instance),
            model_instance,
            use_ball_rpy=False,
        )

    plant.Finalize()

    # Default parameters from quadrotor_plant.cc:
    L = 0.15  # Length of the arms (m).
    kF = 1.0  # Force input constant.
    kM = 0.0245  # Moment input constant.

    # Now we can add in propellers as an external force on the MultibodyPlant.
    prop_info = []
    tensile_forces = []
    for model_instance in quadrotor_model_instances:
        body_index = plant.GetBodyByName("base_link", model_instance).index()
        # Note: Rotors 0 and 2 rotate one way and rotors 1 and 3 rotate the other.
        prop_info += [
            PropellerInfo(body_index, RigidTransform([L, 0, 0]), kF, kM),
            PropellerInfo(body_index, RigidTransform([0, L, 0]), kF, -kM),
            PropellerInfo(body_index, RigidTransform([-L, 0, 0]), kF, kM),
            PropellerInfo(body_index, RigidTransform([0, -L, 0]), kF, -kM),
        ]

        tensile_force = builder.AddSystem(TensileForce(1, 10, np.zeros(3), body_index))
        builder.Connect(
            plant.get_state_output_port(model_instance),
            tensile_force.state_input
        )
        tensile_forces.append(tensile_force)

    propellers = builder.AddNamedSystem("propeller", Propeller(prop_info))

    # stacks forces resulting from tensile + controllers
    combiner = builder.AddNamedSystem("combiner", SpatialForceConcatinator(2))
    builder.Connect(
        propellers.get_output_port(),
        combiner.Input_ports[0]
    )

    builder.Connect(
        combiner.Output_port,
        plant.get_applied_spatial_force_input_port()
    )

    tensile_combiner = builder.AddNamedSystem("tensile_combiner", SpatialForceConcatinator(n))
    for tensile_force, combiner_port in zip(tensile_forces, tensile_combiner.Input_ports):
        builder.Connect(
            tensile_force.force_output,
            combiner_port
        )
    builder.Connect(
        tensile_combiner.Output_port,
        combiner.Input_ports[1]
    )

    builder.Connect(
        plant.get_body_poses_output_port(),
        propellers.get_body_poses_input_port(),
    )

    builder.ExportInput(propellers.get_command_input_port(), "u")
    builder.ExportOutput(plant.get_state_output_port(), "q")

    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    return builder.Build(), plant

In [None]:
def default_fixed_point(n):
    roots_of_unity = np.linspace(0.0, 2 * np.pi, num=n + 1)[:-1]
    xs = np.cos(roots_of_unity)
    ys = np.sin(roots_of_unity)
    fixed_points_pos = np.hstack([xs.reshape(-1, 1), ys.reshape(-1, 1), np.zeros((n, 1)), np.zeros((n, 3))]).flatten()
    fixed_points = np.concatenate([fixed_points_pos, np.zeros(n * 6)])
    return fixed_points

In [None]:
def full_system_lqr(n, mass, gravity, system_diagram):
    # all drones are spaces out in a circle of radius 1 meter
    context = system_diagram.CreateDefaultContext()

    fixed_points = default_fixed_point(n)

    context.SetContinuousState(fixed_points)
    diagram.get_input_port(0).FixValue(
        context, mass * gravity / 4.0 * np.ones(4 * n)
    )
    Q = np.diag(([10.0] * 6 + [1.0] * 6) * n)
    R = np.eye(4 * n)

#     print(
#         plant.EvalTimeDerivatives(
#             plant.GetMyContextFromRoot(context)
#         ).CopyToVector()
#     )

#     print(diagram.get_input_port(0).size())
#     print(diagram.get_output_port(0).size())

#     from pydrake.systems.primitives import Linearize, ControllabilityMatrix
#     from pydrake.math import ContinuousAlgebraicRiccatiEquation
    
#     ls = Linearize(system_diagram, context)
#     c_mat = ControllabilityMatrix(ls)
#     print(c_mat.shape)
#     print(np.linalg.matrix_rank(c_mat))
#     A, B = ls.A(), ls.B()
#     print(Q)
#     print(R)
#     S = ContinuousAlgebraicRiccatiEquation(A, B, Q, R)
    
#     for mat in [A, B]:
#         print()
#         for i in range(mat.shape[0]):
#             for j in range(mat.shape[1]):
#                 print(int(mat[i,j]), end=" ")
#             print()

    return LinearQuadraticRegulator(system_diagram, context, Q, R)


def add_controller(n, system_diagram):
    builder = DiagramBuilder()
    system_plant = builder.AddNamedSystem('inner_diagram', system_diagram)

#     full_system_lqr(n, 0.775, 9.81, system_plant)
    controller = builder.AddSystem(full_system_lqr(n, 0.775, 9.81, system_plant))
#     print(controller.D())
    builder.Connect(controller.get_output_port(0), system_plant.get_input_port(0))
    builder.Connect(system_plant.get_output_port(), controller.get_input_port(0))

    return builder.Build(), system_plant


In [None]:
n = 5

diagram, plant = make_n_quadrotor_system(n)
# # add in controller (magic numbers are pulled directly from quadrotor_plant.cc because I couldn't work out how to access them programmatically)
bigger_diagram, bigger_plant = add_controller(n, diagram)
#



In [None]:
display(
    Image(
        pydot.graph_from_dot_data(plant.GetTopologyGraphvizString())[0].create_png()
    )
)

In [None]:
display(
    Image(
        pydot.graph_from_dot_data(bigger_diagram.GetGraphvizString())[
            0
        ].create_png()
    )
)

In [None]:
def CreateNullExternalForce(plant):
    f = ExternallyAppliedSpatialForce()
    f.body_index = plant.world_body().index()
    return f

In [None]:
def DisableCollisionChecking(sg):
    sg_context = sg.GetMyContextFromRoot(context)
    cfm = sg.collision_filter_manager(sg_context)

    query_object = sg.get_query_output_port().Eval(sg_context)
    inspector = query_object.inspector()

    quads = GeometrySet()
    gids = inspector.GetAllGeometryIds()
    for gid in gids:
        # Might want to handle the case where not all geometries are collision geometries?
        quads.Add(gid)
    cfd = CollisionFilterDeclaration()
    cfd.ExcludeWithin(quads)
    cfm.Apply(cfd)

In [None]:
simulator = Simulator(bigger_diagram)
simulator.set_target_realtime_rate(1)
context = simulator.get_mutable_context()

# u = diagram.GetInputPort("u")
# u.FixValue(context, np.zeros(u.size()))

sg = diagram.GetSubsystemByName("scene_graph")
DisableCollisionChecking(sg)

# CreateNullExternalForce(diagram.GetSubsystemByName("plant"))

# combiner_system = diagram.GetSubsystemByName("combiner")
# combiner_empty_port = combiner_system.Input_ports[1]
# combiner_empty_port.FixValue(combiner_system.GetMyContextFromRoot(context), [CreateNullExternalForce(plant)])

# Simulate
while True:
    context.SetTime(0.0)
#     context.SetContinuousState(
#         0.25
#         * np.random.randn(
#             context.num_continuous_states(),
#         )
#     )
    noise = np.random.normal(loc=0, scale=0.05, size=context.num_continuous_states())
    context.SetContinuousState(default_fixed_point(n) + noise)

    simulator.Initialize()
    simulator.AdvanceTo(5)