This notebook provides examples to go along with the [textbook](https://underactuated.csail.mit.edu/lqr.html).  I recommend having both windows open, side-by-side!


In [None]:
import matplotlib.pyplot as plt
import mpld3
import numpy as np
from IPython.display import Markdown, display
from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    ConstantVectorSource,
    DiagramBuilder,
    DiscreteTimeLinearQuadraticRegulator,
    IsControllable,
    Linearize,
    LogVectorOutput,
    MatrixGain,
    MeshcatVisualizer,
    OutputPortSelection,
    Parser,
    Simulator,
    StartMeshcat,
)
from pydrake.common.containers import namedview

from underactuated import running_as_notebook

if running_as_notebook:
    mpld3.enable_notebook()

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

## Balancing a 2D Segway (aka "Ballbot")

Q: Can LQR work even better if I used absolute coordinates for the bot, instead of relative? (The wheel rotates a lot, but the bot angle should not change much).

In [None]:
ballbot_sliding_base_urdf = """
<?xml version="1.0"?>

<robot xmlns="http://drake.mit.edu"
 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
 name="Ballbot">

  <link name="ground">
    <visual>
      <origin xyz="0 0 -5" rpy="0 0 0" />
      <geometry>
        <box size="1000 1000 10" />
      </geometry>
      <material>
        <color rgba="0.93 .74 .4 1" />
      </material>
    </visual>
    <collision>
      <origin xyz="0 0 -5" rpy="0 0 0" />
      <geometry>
        <box size="1000 1000 10" />
      </geometry>
      <drake:proximity_properties>
        <drake:mu_dynamic value="100" />
        <drake:mu_static value="100" />
      </drake:proximity_properties>
    </collision>
  </link>

  <joint name="ground_weld" type="fixed">
    <parent link="world" />
    <child link="ground" />
  </joint>

  <link name="ball_x" />

  <link name="ball">
    <inertial>
      <origin xyz="0 0 0" rpy="0 0 0" />
      <mass value="5" />
      <inertia ixx=".02" ixy="0" ixz="0" iyy="0.02" iyz="0" izz="0.02" />
    </inertial>
    <visual>
      <origin xyz="0 0 0" rpy="0 0 0" />
      <geometry>
        <sphere radius=".1" />
      </geometry>
      <material>
        <color rgba="0.25 0.52 0.96 1" />
      </material>
    </visual>
    <visual> 
      <!-- add a visual cue to see when the ball is rolling -->
      <origin xyz="0 0 0" rpy="0 0 0" />
      <geometry>
        <box size="0.04 .201 0.002" />
      </geometry>
      <material>
        <color rgba=".1 .1 .1 1" />
      </material>
    </visual>
    <collision>
      <origin xyz="0 0 0" rpy="0 0 0" />
      <geometry>
        <sphere radius=".1" />
      </geometry>
      <drake:proximity_properties>
        <drake:mu_dynamic value="100" />
        <drake:mu_static value="100" />
      </drake:proximity_properties>
    </collision>
  </link>

  <joint name="x" type="prismatic">
    <parent link="world" />
    <child link="ball_x" />
    <!-- height is set to (radius - 0.001) to have sufficient penetration to have a normal force that supports the required frictional force -->
    <origin xyz="0 0 .099" />
    <axis xyz="1 0 0" />
  </joint>

  <link name="bot">
    <inertial>
      <origin xyz="0 0 .05" rpy="0 0 0" />
      <mass value="4" />
      <inertia ixx="0.018" ixy="0" ixz="0" iyy="0.018" iyz="0" izz="0.0288" />
    </inertial>
    <!-- no collision geometry since Drake AutoDiffXd doesn't support cylinder on box collisions yet. -->
    <visual>
      <origin xyz="0 0 .05" rpy="0 0 0" />
      <geometry>
         <cylinder length=".1" radius=".12" />
      </geometry>
      <material>
        <color rgba=".61 .63 .67 1" />
      </material>
    </visual>
  </link>  

  <joint name="theta_ball" type="continuous">
    <parent link="ball_x" />
    <child link="ball" />
    <axis xyz="0 -1 0" />   
    <!-- use -1 to be consistent with the planar joint having z up below -->
  </joint>

  <joint name="theta_bot" type="continuous">
    <parent link="ball" />
    <child link="bot" />
    <axis xyz="0 -1 0" />
  </joint>

  <transmission type="SimpleTransmission" name="ball_torque">
    <actuator name="torque" />
    <joint name="theta_bot" />
  </transmission>
  
</robot>
"""

r = 0.0995  # Effective r given the contact point (see the example above)
n = np.sqrt(1 + r**2)

F = np.array([[1, r, 0, 0, 0, 0], [0, 0, 0, 1, r, 0]])
P = np.array(
    [
        [r / n, -1 / n, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 0, 0, r / n, -1 / n, 0],
        [0, 0, 0, 0, 0, 1],
    ]
)

np.testing.assert_almost_equal(P @ P.T, np.eye(4, 4))
np.testing.assert_almost_equal(P @ F.T, np.zeros((4, 2)))


def BallbotUprightState():
    state = (0, 0, 0, 0, 0, 0)
    return state


def MakeSlidingBaseBallbot(time_step=0.001):
    builder = DiagramBuilder()

    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step)
    Parser(plant).AddModelsFromString(ballbot_sliding_base_urdf, "urdf")
    plant.Finalize()

    builder.ExportInput(plant.get_actuation_input_port(), "torque")
    builder.ExportOutput(plant.get_state_output_port(), "state")
    builder.ExportOutput(scene_graph.get_query_output_port(), "query")

    return builder.Build()


def BallbotControllability(plant):
    context = plant.CreateDefaultContext()
    plant.get_input_port().FixValue(context, [0])

    context.SetDiscreteState(BallbotUprightState())

    linear = Linearize(plant, context, output_port_index=OutputPortSelection.kNoOutput)
    # display(Markdown("$A = " + ToLatex(linear.A()) + "$"))
    # display(Markdown("$B = " + ToLatex(linear.B()) + "$"))
    # display(
    #    Markdown("$ctrb(A,B) = " + ToLatex(ControllabilityMatrix(linear), 6)
    #             + "$"))
    display(Markdown(f"Is controllable? {IsControllable(linear)}"))


def BallbotManifoldLQR(plant):
    context = plant.CreateDefaultContext()
    plant.get_input_port().FixValue(context, [0])

    Q = np.diag((10, 10, 1, 1))
    R = np.array([1])

    context.SetDiscreteState(BallbotUprightState())

    linear = Linearize(plant, context, output_port_index=OutputPortSelection.kNoOutput)

    A = P @ linear.A() @ P.T
    B = P @ linear.B()

    K, S = DiscreteTimeLinearQuadraticRegulator(A, B, Q, R)

    return MatrixGain(-K @ P)


def BallbotSlidingBaseExample():
    builder = DiagramBuilder()
    time_step = 0.001

    plant = builder.AddSystem(MakeSlidingBaseBallbot(time_step))
    logger = LogVectorOutput(plant.GetOutputPort("state"), builder)

    # For completeness, we can check that the plant by itself is not controllable:
    BallbotControllability(plant)

    if False:  # Useful for debugging.
        kd_wheel = 1
        pid = builder.AddSystem(
            PidController(
                state_projection=np.array([[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0]]),
                kp=[0],
                ki=[0],
                kd=[kd_wheel],
            )
        )
        builder.Connect(pid.get_output_port(), plant.get_input_port())
        builder.Connect(
            plant.GetOutputPort("state"), pid.get_input_port_estimated_state()
        )
        pid_command = builder.AddSystem(ConstantVectorSource([0, 1]))
        builder.Connect(
            pid_command.get_output_port(), pid.get_input_port_desired_state()
        )
    else:
        controller = builder.AddSystem(BallbotManifoldLQR(plant))
        controller.set_name("LQR Controller")
        builder.Connect(plant.GetOutputPort("state"), controller.get_input_port())
        builder.Connect(controller.get_output_port(), plant.get_input_port())

    # Setup visualization
    meshcat.Delete()
    meshcat.Set2dRenderMode(xmin=-0.5, xmax=0.5, ymin=-0.2, ymax=0.5)
    visualizer = MeshcatVisualizer.AddToBuilder(
        builder, plant.GetOutputPort("query"), meshcat
    )

    diagram = builder.Build()

    # For reference, let's draw the diagram we've assembled:
    # display(SVG(pydot.graph_from_dot_data(diagram.GetGraphvizString())[0].#create_svg()))

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    # simulator.set_target_realtime_rate(1.0 if running_as_notebook else 0.0)
    context = simulator.get_mutable_context()
    x0 = BallbotUprightState() + P.T @ (
        0.0
        * np.random.randn(
            4,
        )
        + np.array([1, 0, 0, 0])
    )
    assert np.allclose(F @ x0, np.zeros(2))
    context.SetDiscreteState(x0)

    # Simulate

    visualizer.StartRecording(False)
    simulator.AdvanceTo(4 if running_as_notebook else 0.1)
    visualizer.PublishRecording()
    log = logger.FindLog(context)
    fig, ax = plt.subplots(1, 2, figsize=((10, 6)))
    ax[0].plot(log.sample_times(), log.data()[:3, :].T)
    ax[0].legend(["x", "theta_ball", "theta_bot"])
    ax[1].plot(log.sample_times(), (F @ log.data()).T)
    ax[1].legend(["x - r*theta_ball", "xdot - r*thetadot_ball"])


BallbotSlidingBaseExample()

But there is subtle point to be made here.  I've modeled the wheels as a prismatic joint at the base.  Let's instead model the robot with a floating joint, and using collisions elements to simulate the interaction with the robot and the ground...

In [None]:
ballbot_floating_base_urdf = """
<?xml version="1.0"?>

<robot xmlns="http://drake.mit.edu"
 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
 name="Ballbot">

  <link name="ground">
    <visual>
      <origin xyz="0 0 -5" rpy="0 0 0" />
      <geometry>
        <box size="1000 1000 10" />
      </geometry>
      <material>
        <color rgba="0.93 .74 .4 1" />
      </material>
    </visual>
    <collision>
      <origin xyz="0 0 -5" rpy="0 0 0" />
      <geometry>
        <box size="1000 1000 10" />
      </geometry>
    </collision>
  </link>

  <joint name="ground_weld" type="fixed">
    <parent link="world" />
    <child link="ground" />
  </joint>

  <link name="ball">
    <inertial>
      <origin xyz="0 0 0" rpy="0 0 0" />
      <mass value="5" />
      <inertia ixx=".02" ixy="0" ixz="0" iyy="0.02" iyz="0" izz="0.02" />
    </inertial>
    <visual>
      <origin xyz="0 0 0" rpy="0 0 0" />
      <geometry>
        <sphere radius=".1" />
      </geometry>
      <material>
        <color rgba="0.25 0.52 0.96 1" />
      </material>
    </visual>
    <visual> 
      <!-- add a visual cue to see when the ball is rolling -->
      <origin xyz="0 0 0" rpy="0 0 0" />
      <geometry>
        <box size="0.04  0.002 .201" />
      </geometry>
      <material>
        <color rgba=".1 .1 .1 1" />
      </material>
    </visual>
    <collision>
      <origin xyz="0 0 0" rpy="0 0 0" />
      <geometry>
        <sphere radius=".1" />
      </geometry>
    </collision>
  </link>

  <link name="bot">
    <inertial>
      <origin xyz="0 0 .05" rpy="0 0 0" />
      <mass value="4" />
      <inertia ixx="0.018" ixy="0" ixz="0" iyy="0.018" iyz="0" izz="0.0288" />
    </inertial>
    <collision>
      <origin xyz="0 0 .05" rpy="0 0 0" />
      <geometry>
         <cylinder length=".1" radius=".12" />
      </geometry>
    </collision>
    <visual>
      <origin xyz="0 0 .05" rpy="0 0 0" />
      <geometry>
         <cylinder length=".1" radius=".12" />
      </geometry>
      <material>
        <color rgba=".61 .63 .67 1" />
      </material>
    </visual>
  </link>
  
  <joint name="floating_base" type="planar">
    <parent link="world" />
    <child link="ball" />
    <origin rpy="1.57 0 0" xyz="0 0 .1" />
  </joint>

  <joint name="theta2" type="continuous">
    <parent link="ball" />
    <child link="bot" />
    <origin rpy="-1.57 0 0" xyz="0 0 0" />
    <axis xyz="0 -1 0" />
  </joint>

  <transmission type="SimpleTransmission" name="ball_torque">
    <actuator name="torque" />
    <joint name="theta2" />
    <mechanicalReduction>1</mechanicalReduction>
  </transmission>

</robot>
"""


def BallbotFloatingBaseExample():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    Parser(plant).AddModelsFromString(ballbot_floating_base_urdf, "urdf")
    plant.Finalize()

    # Just use zero instead of a controller to start
    command = builder.AddSystem(ConstantVectorSource([0.0]))
    builder.Connect(command.get_output_port(), plant.get_actuation_input_port())

    # Setup visualization
    meshcat.Delete()
    meshcat.Set2dRenderMode(xmin=-0.5, xmax=0.5, ymin=-0.2, ymax=0.5)
    visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    simulator.set_target_realtime_rate(1.0 if running_as_notebook else 0.0)
    context = simulator.get_mutable_context()

    State = namedview(
        "State",
        [
            "x",
            "z",
            "theta1",
            "theta2",
            "xdot",
            "zdot",
            "theta1dot",
            "theta2dot",
        ],
    )
    x0 = State(np.zeros(8))
    x0.z = 0.15
    x0.theta1 = 0.05
    plant_context = plant.GetMyContextFromRoot(context)
    plant.SetPositionsAndVelocities(plant_context, x0[:])

    # Simulate
    visualizer.StartRecording(False)
    simulator.AdvanceTo(3 if running_as_notebook else 0.1)
    visualizer.PublishRecording()
    xf = State(plant.GetPositionsAndVelocities(plant_context))
    print(f"z at final time = {xf.z}")


BallbotFloatingBaseExample()

Here's the real test... you should be able to run the controller designed for the simple model on the floating base version of the ballbot.  Give it a shot!

In [None]:
def BallbotFloatingBaseLqrExample():
    builder = DiagramBuilder()
    time_step = 0.001
    ballbot_minimal_coordinates = builder.AddSystem(MakeSlidingBaseBallbot(time_step))

    controller_minimal_coordinates = builder.AddSystem(
        BallbotManifoldLQR(ballbot_minimal_coordinates)
    )
    controller_minimal_coordinates.set_name("LQR Controller for Minimal")
    builder.Connect(
        controller_minimal_coordinates.get_output_port(),
        ballbot_minimal_coordinates.get_input_port(),
    )
    builder.Connect(
        ballbot_minimal_coordinates.GetOutputPort("state"),
        controller_minimal_coordinates.get_input_port(),
    )

    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step)
    Parser(plant).AddModelsFromString(ballbot_floating_base_urdf, "urdf")
    plant.Finalize()

    controller = builder.AddSystem(BallbotManifoldLQR(ballbot_minimal_coordinates))
    controller.set_name("LQR Controller for Floating")
    builder.Connect(controller.get_output_port(), plant.get_actuation_input_port())
    # builder.ExportInput(plant.get_actuation_input_port(), "u")

    MinimalState = namedview(
        "MinimalState",
        [
            "x",
            "theta_ball",
            "theta_bot_m_ball",
            "xdot",
            "thetadot_bot",
            "thetadot_bot_m_ball",
        ],
    )
    FloatingState = namedview(
        "FloatingState",
        [
            "x",
            "z",
            "theta_ball",
            "theta_bot_m_ball",
            "xdot",
            "zdot",
            "thetadot_ball",
            "thetadot_bot_m_ball",
        ],
    )

    m_ind = MinimalState(range(6))
    f_ind = FloatingState(range(8))
    P_floating_to_minimal = np.delete(np.eye(8), [f_ind.z, f_ind.zdot], 0)
    floating_state_to_minimal_state = builder.AddSystem(
        MatrixGain(P_floating_to_minimal)
    )
    builder.Connect(
        plant.get_state_output_port(),
        floating_state_to_minimal_state.get_input_port(),
    )
    builder.Connect(
        floating_state_to_minimal_state.get_output_port(),
        controller.get_input_port(),
    )

    # Setup visualization
    meshcat.Delete()
    meshcat.Set2dRenderMode(xmin=-0.5, xmax=0.5, ymin=-0.2, ymax=0.5)
    visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    minimal_state_logger = LogVectorOutput(
        ballbot_minimal_coordinates.GetOutputPort("state"), builder
    )
    floating_state_logger = LogVectorOutput(plant.get_state_output_port(), builder)
    minimal_control_logger = LogVectorOutput(
        controller_minimal_coordinates.get_output_port(), builder
    )
    floating_control_logger = LogVectorOutput(controller.get_output_port(), builder)

    diagram = builder.Build()

    # Set up a simulator to run this diagram
    simulator = Simulator(diagram)
    # simulator.set_target_realtime_rate(1.0 if running_as_notebook else 0.0)
    context = simulator.get_mutable_context()
    # diagram.GetInputPort("u").FixValue(context, [3])

    x0 = FloatingState(np.zeros(8))
    x0.z = -0.0018  # the approximate resting depth due to soft contact
    x0.theta_bot_m_ball = 0.05
    plant_context = plant.GetMyContextFromRoot(context)
    plant.SetPositionsAndVelocities(plant_context, x0[:])

    ballbot_minimal_state_context = (
        ballbot_minimal_coordinates.GetMyMutableContextFromRoot(context)
    )
    minimal_x0 = MinimalState(P_floating_to_minimal @ x0[:])
    ballbot_minimal_state_context.SetDiscreteState(minimal_x0[:])

    values = plant.AllocateDiscreteVariables()
    plant.CalcForcedDiscreteVariableUpdate(plant_context, values)
    print(f"floating_xn = {FloatingState(values.get_value())}")
    print(f"projected_xn = {MinimalState( P_floating_to_minimal @ values.get_value())}")

    values = ballbot_minimal_coordinates.AllocateDiscreteVariables()
    ballbot_minimal_coordinates.CalcForcedDiscreteVariableUpdate(
        ballbot_minimal_state_context, values
    )
    print(f"minimal_xn = {MinimalState(values.get_value())}")

    # Simulate
    visualizer.StartRecording(False)
    simulator.AdvanceTo(5 if running_as_notebook else 0.1)
    visualizer.PublishRecording()

    minimal_state_log = minimal_state_logger.FindLog(context)
    floating_state_log = floating_state_logger.FindLog(context)
    fig, ax = plt.subplots(1, 2, figsize=(15, 8))
    ax[0].plot(minimal_state_log.sample_times(), minimal_state_log.data().T)
    ax[0].set_title("minimal state")
    ax[0].legend(
        [
            "x",
            "theta_ball",
            "theta_bot_m_ball",
            "xdot",
            "thetadot_ball",
            "thetadot_bot_m_ball",
        ]
    )
    if True:  # plot projection of floating state
        ax[1].plot(
            floating_state_log.sample_times(),
            (P_floating_to_minimal @ floating_state_log.data()).T,
        )
    else:
        ax[1].plot(floating_state_log.sample_times(), floating_state_log.data().T)
        ax[1].legend(
            [
                "x",
                "z",
                "theta1",
                "theta2",
                "xdot",
                "zdot",
                "theta1dot",
                "theta2dot",
            ]
        )
    ax[1].set_title("floating state")
    ax[0].set_ylim(ax[1].get_ylim())


BallbotFloatingBaseLqrExample()