In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
from pydrake.all import (
    DiagramBuilder,
    LinearQuadraticRegulator,
    MeshcatVisualizer,
    MultibodyPlant,
    Parser,
    Propeller,
    PropellerInfo,
    RigidTransform,
    RobotDiagramBuilder,
    SceneGraph,
    Simulator,
    Meshcat,
    StartMeshcat,
    namedview,
)
from pydrake.examples import QuadrotorGeometry, QuadrotorPlant, StabilizingLQRController
import numpy as np

running_as_notebook = True
meshcat = StartMeshcat()

# LQR

In [None]:
from pydrake.all import DiagramBuilder
from control.UtilLeafSystems import NonConstantVectorSource

from control.QuadrotorControllers import QuadrotorQuatLQRController
from sim.Quadrotor import MakeMultibodyQuadrotor, QuadrotorGeometry
from maths.quaternions import SampleQuaternion
from IPython.display import SVG, display, clear_output, Markdown
import pydot
import time 

dt = 0.001
Q = np.diag(np.concatenate(([10] * 6, [1] * 6)))
R = np.eye(4)

iLQRparams = {
    "Q" : Q,
    "R" : R,
    "Qf" : Q,
    "mode": "discrete"
}

builder = DiagramBuilder()

# Init Systems
quadrotor, mbp = MakeMultibodyQuadrotor(show_diagram=False)
quadrotor_system = builder.AddSystem(quadrotor)
controller = builder.AddSystem(QuadrotorQuatLQRController(**iLQRparams))
goal_state_source = builder.AddSystem(NonConstantVectorSource(13))

# Controller input connections
builder.Connect(
    quadrotor_system.get_output_port(), controller.get_input_port(0)
)
builder.Connect(
    goal_state_source.get_output_port(), controller.get_input_port(1)
)

# Controller output connections
builder.Connect(
    controller.get_output_port(), quadrotor_system.get_input_port()
)

# Setup visualization
scene_graph = builder.AddSystem(SceneGraph())
quadrotor_geometry_system = QuadrotorGeometry.AddToBuilder(
    builder, quadrotor_system.get_output_port(0), scene_graph
)
meshcat.Delete()
meshcat.ResetRenderMode()
meshcat.SetProperty("/Background", "visible", False)
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

diagram = builder.Build()
# display(SVG(pydot.graph_from_dot_data(
#             diagram.GetGraphvizString(max_depth=1))[0].create_svg()))

# Set up a simulator to run this diagram
simulator = Simulator(diagram)
simulator.set_target_realtime_rate(1.0)
diagram_context = simulator.get_mutable_context()

# Get the subsystem contexts
quadrotor_context = quadrotor_system.GetMyContextFromRoot(diagram_context)
controller_context = controller.GetMyContextFromRoot(diagram_context)
geometry_context = diagram.GetSubsystemContext(quadrotor_geometry_system, diagram_context)

# Function to print the state vector
def dynamically_update_output(state, control):
    clear_output(wait=True)
    display(Markdown(f"**State:**\n```\n{state}\n```\n**Control:**\n```\n{control}\n```"))

def log_state():
    quadrotor_state_port = quadrotor_system.get_output_port(0)
    control_output_port = controller.get_output_port(0)

    quadrotor_state = quadrotor_state_port.Eval(quadrotor_context)
    control = control_output_port.Eval(controller_context)
    
    dynamically_update_output(quadrotor_state, control)
    
    if np.any(np.isnan(quadrotor_state)):
        raise ValueError("Quadrotor state contains NaN values")

    if np.any(np.isnan(control)):
        raise ValueError("Control input contains NaN values")

def inspect_base_link_state(context, plant):
    inspector = plant.GetMyContextFromRoot(context)
    base_link_frame = plant.GetFrameByName('base_link')
    pose = plant.CalcRelativeTransform(inspector, plant.world_frame(), base_link_frame)
    print(f"Base Link Pose:\n{pose}")

def simulate(diagram_context, initial_state, ref_state, duration):
    diagram_context.SetTime(0.0)
    diagram_context.SetContinuousState(initial_state)
    goal_state_source.SetState(ref_state)
    simulator.Initialize()
    try:
        while diagram_context.get_time() < duration:
            # log_state()

            # Force evaluation of output port to see the state in QuadrotorGeometry
            # quadrotor_geometry_output: FramePoseVector = diagram.GetSubsystemByName('QuadrotorGeometry').get_output_port(0).Eval(geometry_context)
            # quadrotor_geometry_pose = quadrotor_geometry_output.value(quadrotor_geometry_output.ids()[0])

            # position_np = np.array(quadrotor_geometry_pose.translation())
            # rotation_np = np.array(quadrotor_geometry_pose.rotation().matrix())
            # print(f"Pose (Position):\n{position_np}")
            # print(f"Pose (Rotation):\n{rotation_np}")

            # inspect_base_link_state(diagram_context, mbp)

            simulator.AdvanceTo(diagram_context.get_time() + 0.1)
    except ValueError as e:
        print(f"Simulation error: {e}")

# initial_state = np.array([1., 0., 0., 0., 0., 0., 1.5, 0., 0., 0., 0., 0., 0.])
initial_state = np.hstack(
    (SampleQuaternion(near_identity=True),
    np.array([1., 2., 3., 2., 0., 0., 0., 0., 0.])
    )
)

ref_state = np.array([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])
simulate(diagram_context, initial_state, ref_state, duration = 4.0)

# iLQR

In [None]:
from pydrake.all import DiagramBuilder
from control.UtilLeafSystems import NonConstantVectorSource

from control.QuadrotorControllers import QuadrotorController_Quat
from sim.Quadrotor import MakeMultibodyQuadrotor, QuadrotorGeometry
from maths.quaternions import SampleQuaternion
from IPython.display import SVG, display, clear_output, Markdown
import pydot
import time 

import numpy as np

np.set_printoptions(linewidth=500)

Q = np.diag(np.concatenate(([10.] * 6, [1.] * 6)))
Qf = np.diag(np.concatenate(([500.]* 3, [100.] * 3, [20.] * 6)))

R = np.diag([0.02] * 4)
dt = 0.01

builder = DiagramBuilder()

# Init Systemds
quadrotor, mbp = MakeMultibodyQuadrotor(show_diagram = False)
quadrotor_system = builder.AddSystem(quadrotor)
iLQRparams = {
    "Q" : Q,
    "R" : R,
    "Qf" : Qf,
    "N" : 30,
    "dt" : dt,
    "max_iter" : 20,
    "max_linesearch_iters": 10,
    "d_tol" : 5e-4
}


# print(mbp.GetStateNames())

controller = builder.AddSystem(QuadrotorController_Quat(**iLQRparams))
goal_state_source = builder.AddSystem(NonConstantVectorSource(13))


# Controller input connections
builder.Connect(
    quadrotor_system.get_output_port(), controller.get_input_port(0)
)
builder.Connect(
    goal_state_source.get_output_port(), controller.get_input_port(1)
)

# Controller output connections
builder.Connect(
    controller.get_output_port(), quadrotor_system.get_input_port()
)

# Setup visualization
scene_graph = builder.AddSystem(SceneGraph())
quadrotor_geometry_system = QuadrotorGeometry.AddToBuilder(
    builder, quadrotor_system.get_output_port(0), scene_graph
)
meshcat.Delete()
meshcat.ResetRenderMode()
meshcat.SetProperty("/Background", "visible", False)
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

diagram = builder.Build()
# display(SVG(pydot.graph_from_dot_data(
#             diagram.GetGraphvizString(max_depth=1))[0].create_svg()))

# Set up a simulator to run this diagram
simulator = Simulator(diagram)
# simulator.set_target_realtime_rate(1.0)
diagram_context = simulator.get_mutable_context()

# Get the subsystem contexts
quadrotor_context = quadrotor_system.GetMyContextFromRoot(diagram_context)
controller_context = controller.GetMyContextFromRoot(diagram_context)
geometry_context = diagram.GetSubsystemContext(quadrotor_geometry_system, diagram_context)

# Function to print the state vector
def dynamically_update_output(state, control, time):
    clear_output(wait=True)
    display(Markdown(f"**t = {time:4.4f}s:**\n```\nState: {state}\n```\n**Control:**\n```\n{control}\n```\n"))

def log_state(time):
    quadrotor_state_port = quadrotor_system.get_output_port(0)
    control_output_port = controller.get_output_port(0)

    quadrotor_state = quadrotor_state_port.Eval(quadrotor_context)
    control = control_output_port.Eval(controller_context)
    
    dynamically_update_output(quadrotor_state, control, time)
    
    if np.any(np.isnan(quadrotor_state)):
        raise ValueError("Quadrotor state contains NaN values")

    if np.any(np.isnan(control)):
        raise ValueError("Control input contains NaN values")

def inspect_base_link_state(context, plant):
    inspector = plant.GetMyContextFromRoot(context)
    base_link_frame = plant.GetFrameByName('base_link')
    pose = plant.CalcRelativeTransform(inspector, plant.world_frame(), base_link_frame)
    print(f"Base Link Pose:\n{pose}")

def simulate(diagram_context, initial_state, ref_state, duration):
    diagram_context.SetTime(0.0)
    diagram_context.SetContinuousState(initial_state)
    goal_state_source.SetState(ref_state)
    simulator.Initialize()
    try:
        while diagram_context.get_time() < duration:
            # log_state(diagram_context.get_time())
            sim_time = diagram_context.get_time() + iLQRparams["dt"]
            simulator.AdvanceTo(sim_time)
            # log_state(sim_time)
    except ValueError as e:
        print(f"Simulation error: {e}")


from scipy.spatial.transform import Rotation as R
q = R.from_euler("zyx", [0, 90, 0], degrees = True).as_quat() #Set pitch angle to 90deg (singularity for rpy)
q = np.hstack((q[-1], q[1:])) #Rearrange to match our convention
# initial_state = np.hstack(
#     (SampleQuaternion(near_identity=True),
#     np.array([1., 2., 3., 2., 0., 0., 0., 0., 0.])
#     )
# )
initial_state = np.hstack(
    (q,
    np.array([1., 2., 1., 1.4, 0., 0., 10., 0., 0.])
    )
)
ref_state = np.array([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])

print("Begin simulation")
meshcat.StartRecording()

simulate(diagram_context, initial_state, ref_state, duration = 5.0)
meshcat.StopRecording()
meshcat.PublishRecording()
