In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
from pydrake.all import (
    DiagramBuilder,
    LinearQuadraticRegulator,
    MeshcatVisualizer,
    MultibodyPlant,
    Parser,
    Propeller,
    PropellerInfo,
    RigidTransform,
    RobotDiagramBuilder,
    SceneGraph,
    Simulator,
    Meshcat,
    StartMeshcat,
    namedview,
)
from pydrake.examples import QuadrotorGeometry, QuadrotorPlant, StabilizingLQRController
import numpy as np

running_as_notebook = True
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7000


# LQR

In [5]:
from pydrake.all import DiagramBuilder, FramePoseVector

from control.QuadrotorControllers import QuadrotorLQR
from sim.Quadrotor import MakeMultibodyQuadrotor, QuadrotorGeometry
from math.quaternions import SampleQuaternion
from IPython.display import SVG, display, clear_output, Markdown
import pydot
import time 

quadrotor, mbp = MakeMultibodyQuadrotor(show_diagram=False)

Q = np.diag(np.concatenate(([10] * 6, [1] * 6)))
R = np.eye(4)

builder = DiagramBuilder()

quadrotor_system = builder.AddSystem(quadrotor)

controller = builder.AddSystem(QuadrotorLQR(quadrotor_system, mbp, Q, R))

builder.Connect(
    quadrotor_system.get_output_port(), controller.get_input_port()
)
builder.Connect(
    controller.get_output_port(), quadrotor_system.get_input_port()
)

# Setup visualization
scene_graph = builder.AddSystem(SceneGraph())
quadrotor_geometry_system = QuadrotorGeometry.AddToBuilder(
    builder, quadrotor_system.get_output_port(0), scene_graph
)
meshcat.Delete()
meshcat.ResetRenderMode()
meshcat.SetProperty("/Background", "visible", False)
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

diagram = builder.Build()
# display(SVG(pydot.graph_from_dot_data(
#             diagram.GetGraphvizString(max_depth=1))[0].create_svg()))

# Set up a simulator to run this diagram
simulator = Simulator(diagram)
simulator.set_target_realtime_rate(1.0)
diagram_context = simulator.get_mutable_context()

# Get the subsystem contexts
quadrotor_context = quadrotor_system.GetMyContextFromRoot(diagram_context)
controller_context = controller.GetMyContextFromRoot(diagram_context)
geometry_context = diagram.GetSubsystemContext(quadrotor_geometry_system, diagram_context)

# Function to print the state vector
def dynamically_update_output(state, control):
    clear_output(wait=True)
    display(Markdown(f"**State:**\n```\n{state}\n```\n**Control:**\n```\n{control}\n```"))

def log_state():
    quadrotor_state_port = quadrotor_system.get_output_port(0)
    control_output_port = controller.get_output_port(0)

    quadrotor_state = quadrotor_state_port.Eval(quadrotor_context)
    control = control_output_port.Eval(controller_context)
    
    dynamically_update_output(quadrotor_state, control)
    
    if np.any(np.isnan(quadrotor_state)):
        raise ValueError("Quadrotor state contains NaN values")

    if np.any(np.isnan(control)):
        raise ValueError("Control input contains NaN values")

def inspect_base_link_state(context, plant):
    inspector = plant.GetMyContextFromRoot(context)
    base_link_frame = plant.GetFrameByName('base_link')
    pose = plant.CalcRelativeTransform(inspector, plant.world_frame(), base_link_frame)
    print(f"Base Link Pose:\n{pose}")

def simulate(diagram_context, initial_state, duration):
    diagram_context.SetTime(0.0)
    diagram_context.SetContinuousState(initial_state
    )
    simulator.Initialize()
    try:
        while diagram_context.get_time() < duration:
            # log_state()

            # Force evaluation of output port to see the state in QuadrotorGeometry
            # quadrotor_geometry_output: FramePoseVector = diagram.GetSubsystemByName('QuadrotorGeometry').get_output_port(0).Eval(geometry_context)
            # quadrotor_geometry_pose = quadrotor_geometry_output.value(quadrotor_geometry_output.ids()[0])

            # position_np = np.array(quadrotor_geometry_pose.translation())
            # rotation_np = np.array(quadrotor_geometry_pose.rotation().matrix())
            # print(f"Pose (Position):\n{position_np}")
            # print(f"Pose (Rotation):\n{rotation_np}")

            # inspect_base_link_state(diagram_context, mbp)

            simulator.AdvanceTo(diagram_context.get_time() + 0.1)
    except ValueError as e:
        print(f"Simulation error: {e}")

# initial_state = np.array([1., 0., 0., 0., 0., 0., 1.5, 0., 0., 0., 0., 0., 0.])
initial_state = np.hstack(
    (SampleQuaternion(near_identity=True),
    np.array([1., 0., 0., 0., 0., 0., 0., 0., 0.])
    )
)
simulate(diagram_context, initial_state, duration = 4.0)

/home/malkstik
[ 0.66184571 -0.16674321 -0.50128301  0.66184571  1.          0.
  0.          0.          0.          0.          0.          0.
  0.        ]
[]

[ 6.61845713e-01 -1.66743213e-01 -5.01283011e-01  6.61845713e-01
  1.00000000e+00  0.00000000e+00  0.00000000e+00 -2.34970842e-02
  3.01628312e-02 -1.67132814e-02 -6.87813542e-04 -3.44449057e-04
 -2.61359059e-05]
[]

[ 6.61847940e-01 -1.66743610e-01 -5.01279180e-01  6.61846288e-01
  9.99999862e-01 -6.88898114e-08 -5.22718117e-09 -4.68414257e-02
  6.08129713e-02 -3.27386871e-02 -1.37564626e-03 -6.88907720e-04
 -5.22588615e-05]
[]

[ 6.61846825e-01 -1.66743405e-01 -5.01281095e-01  6.61846004e-01
  9.99999931e-01 -3.44450658e-08 -2.61337475e-09 -4.69170181e-02
  6.05688627e-02 -3.30821663e-02 -1.37563612e-03 -6.88901423e-04
 -5.22634141e-05]
[]

[ 6.61851868e-01 -1.66744262e-01 -5.01272394e-01  6.61847335e-01
  9.99999619e-01 -1.90834508e-07 -1.44778388e-08 -1.53247702e-01
  1.98619003e-01 -1.07400088e-01 -4.49852547e-03 -2.2528

# iLQR

In [None]:
from pydrake.all import DiagramBuilder, FramePoseVector

from control.QuadrotorControllers import QuadrotoriLQR
from sim.Quadrotor import MakeMultibodyQuadrotor, QuadrotorGeometry
from math.quaternions import SampleQuaternion
from IPython.display import SVG, display, clear_output, Markdown

quadrotor, mbp = MakeMultibodyQuadrotor(show_diagram=False)

Q = 
R = np.eye(4)

builder = DiagramBuilder()

quadrotor_system = builder.AddSystem(quadrotor)


iLQRparams = {
    "quadrotor" : quadrotor_system,
    "multibody_plant" : mbp,
    "Q" : np.diag(np.concatenate(([10] * 6, [1] * 6))),
    "R" : np.eye(4),
    "Qf" : np.diag(np.concatenate(([10] * 6, [1] * 6))),
    "Tf" : 3.0,
    "dt" : 0.2
}



controller = builder.AddSystem(QuadrotoriLQR(**iLQRparams))

builder.Connect(
    quadrotor_system.get_output_port(), controller.get_input_port()
)
builder.Connect(
    controller.get_output_port(), quadrotor_system.get_input_port()
)

# Setup visualization
scene_graph = builder.AddSystem(SceneGraph())
quadrotor_geometry_system = QuadrotorGeometry.AddToBuilder(
    builder, quadrotor_system.get_output_port(0), scene_graph
)
meshcat.Delete()
meshcat.ResetRenderMode()
meshcat.SetProperty("/Background", "visible", False)
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

diagram = builder.Build()
# display(SVG(pydot.graph_from_dot_data(
#             diagram.GetGraphvizString(max_depth=1))[0].create_svg()))

# Set up a simulator to run this diagram
simulator = Simulator(diagram)
simulator.set_target_realtime_rate(1.0)
diagram_context = simulator.get_mutable_context()

# Get the subsystem contexts
quadrotor_context = quadrotor_system.GetMyContextFromRoot(diagram_context)
controller_context = controller.GetMyContextFromRoot(diagram_context)
geometry_context = diagram.GetSubsystemContext(quadrotor_geometry_system, diagram_context)

# Function to print the state vector
def dynamically_update_output(state, control):
    clear_output(wait=True)
    display(Markdown(f"**State:**\n```\n{state}\n```\n**Control:**\n```\n{control}\n```"))

def log_state():
    quadrotor_state_port = quadrotor_system.get_output_port(0)
    control_output_port = controller.get_output_port(0)

    quadrotor_state = quadrotor_state_port.Eval(quadrotor_context)
    control = control_output_port.Eval(controller_context)
    
    dynamically_update_output(quadrotor_state, control)
    
    if np.any(np.isnan(quadrotor_state)):
        raise ValueError("Quadrotor state contains NaN values")

    if np.any(np.isnan(control)):
        raise ValueError("Control input contains NaN values")

def inspect_base_link_state(context, plant):
    inspector = plant.GetMyContextFromRoot(context)
    base_link_frame = plant.GetFrameByName('base_link')
    pose = plant.CalcRelativeTransform(inspector, plant.world_frame(), base_link_frame)
    print(f"Base Link Pose:\n{pose}")

def simulate(diagram_context, initial_state, duration):
    diagram_context.SetTime(0.0)
    diagram_context.SetContinuousState(initial_state
    )
    simulator.Initialize()
    try:
        while diagram_context.get_time() < duration:
            log_state()
            simulator.AdvanceTo(diagram_context.get_time() + 0.1)
    except ValueError as e:
        print(f"Simulation error: {e}")

# initial_state = np.array([1., 0., 0., 0., 0., 0., 1.5, 0., 0., 0., 0., 0., 0.])
goal_state = np.hstack(np.array([1., 0., 0., 0.]), 3.0 * np.random.randn)
initial_state = np.hstack(
    (SampleQuaternion(near_identity=True),
    np.array([1., 0., 0., 0., 0., 0., 0., 0., 0.])
    )
)
simulate(diagram_context, initial_state, duration = 4.0)