In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../src")

In [None]:
import numpy as np
from pydrake.all import (
    StartMeshcat,
    Simulator
)

In [None]:
from world import make_n_quadrotor_system
from util import DisableCollisionChecking
from stabilization import find_fixed_point_snopt, lqr_stabilize_to_point, add_controller_to_system

In [None]:
meshcat = StartMeshcat()

In [None]:
np.random.seed(2)
n_quadrotors = 1

In [None]:
diagram, plant = make_n_quadrotor_system(meshcat, n_quadrotors)

In [None]:
fixed_point, fixed_control = find_fixed_point_snopt(diagram)
print(fixed_point)
print(fixed_control)
fixed_point.shape

In [None]:
# Check that it's a fixed point

simulator = Simulator(diagram)
simulator.set_target_realtime_rate(1)
context = simulator.get_mutable_context()

sg = diagram.GetSubsystemByName("scene_graph")
DisableCollisionChecking(sg, context)

u = diagram.GetInputPort("u")
u.FixValue(context, fixed_control)

eps = 0

# Simulate
context.SetTime(0.0)
context.SetContinuousState(
    fixed_point + np.random.normal(loc=0, scale=eps, size=fixed_point.shape)
)

simulator.set_publish_every_time_step(True)
simulator.Initialize()
simulator.AdvanceTo(5.0)
print(simulator.get_actual_realtime_rate())

In [None]:
# Make an LQR controller

Q_quadrotor_pos = [10.] * 6
Q_quadrotor_vel = [1.] * 6
Q_freebody_pos = [0.] * 4 + [10.] * 3
Q_freebody_vel = [0.] * 3 + [0.] * 3
Q_pos = Q_quadrotor_pos * n_quadrotors + Q_freebody_pos
Q_vel = Q_quadrotor_vel * n_quadrotors + Q_freebody_vel
Q = np.diag(Q_pos + Q_vel)
R = np.eye(4 * n_quadrotors)

lqr_controller = lqr_stabilize_to_point(diagram, fixed_point, fixed_control, Q, R)

controlled_diagram, controlled_plant = add_controller_to_system(diagram, lqr_controller)

In [None]:
# Simulate the LQR controller

simulator = Simulator(controlled_diagram)
simulator.set_publish_every_time_step(True)
simulator.set_target_realtime_rate(1.0)
context = simulator.get_mutable_context()

sg = diagram.GetSubsystemByName("scene_graph")
DisableCollisionChecking(sg, context)

# Simulate
while True:
    context.SetTime(0.0)
    context.SetContinuousState(
        fixed_point + np.random.normal(loc=0, scale=0.5, size=fixed_point.shape)
    )
    simulator.Initialize()
    simulator.AdvanceTo(5)
    print(simulator.get_actual_realtime_rate())

