This notebook provides examples to go along with the [textbook](https://underactuated.csail.mit.edu/pend.html).  I recommend having both windows open, side-by-side!


In [None]:
from copy import copy

import matplotlib.pyplot as plt
import mpld3
import numpy as np
from IPython.display import display
from pydrake.all import (
    DiagramBuilder,
    Linearize,
    LinearQuadraticRegulator,
    MeshcatVisualizer,
    Saturation,
    SceneGraph,
    Simulator,
    StartMeshcat,
    VectorLogSink,
    VectorSystem,
    wrap_to,
)
from pydrake.examples import PendulumGeometry, PendulumParams, PendulumPlant

from underactuated import running_as_notebook

if running_as_notebook:
    mpld3.enable_notebook()

In [None]:
# Start the visualizer (run this cell only once, each instance consumes a port)
meshcat = StartMeshcat()

# Energy Shaping Control

In [None]:
class EnergyShapingController(VectorSystem):
    def __init__(self, pendulum):
        VectorSystem.__init__(self, 2, 1)
        self.pendulum = pendulum
        self.pendulum_context = pendulum.CreateDefaultContext()
        self.SetPendulumParams(PendulumParams())

    def SetPendulumParams(self, params):
        self.pendulum_context.get_mutable_numeric_parameter(0).SetFromVector(
            params.CopyToVector()
        )
        self.pendulum_context.SetContinuousState([np.pi, 0])
        self.desired_energy = self.pendulum.EvalPotentialEnergy(self.pendulum_context)

    def DoCalcVectorOutput(self, context, pendulum_state, unused, output):
        self.pendulum_context.SetContinuousState(pendulum_state)
        params = self.pendulum_context.get_numeric_parameter(0)
        thetadot = pendulum_state[1]
        total_energy = self.pendulum.EvalPotentialEnergy(
            self.pendulum_context
        ) + self.pendulum.EvalKineticEnergy(self.pendulum_context)
        output[:] = params.damping() * thetadot - 0.1 * thetadot * (
            total_energy - self.desired_energy
        )


def PhasePlot(pendulum):
    phase_plot = plt.figure()
    ax = phase_plot.gca()
    theta_lim = [-np.pi, 3.0 * np.pi]
    ax.set_xlim(theta_lim)
    ax.set_ylim(-10.0, 10.0)

    theta = np.linspace(theta_lim[0], theta_lim[1], 601)  # 4*k + 1
    thetadot = np.zeros(theta.shape)
    context = pendulum.CreateDefaultContext()
    params = context.get_numeric_parameter(0)
    context.SetContinuousState([np.pi, 0])
    E_upright = pendulum.EvalPotentialEnergy(context)
    E = [E_upright, 0.1 * E_upright, 1.5 * E_upright]
    for e in E:
        for i in range(theta.size):
            v = (
                e
                + params.mass() * params.gravity() * params.length() * np.cos(theta[i])
            ) / (0.5 * params.mass() * params.length() * params.length())
            if v >= 0:
                thetadot[i] = np.sqrt(v)
            else:
                thetadot[i] = float("nan")
        ax.plot(theta, thetadot, color=[0.6, 0.6, 0.6])
        ax.plot(theta, -thetadot, color=[0.6, 0.6, 0.6])

    return ax


def energy_shaping_demo():
    builder = DiagramBuilder()

    pendulum = builder.AddSystem(PendulumPlant())
    ax = PhasePlot(pendulum)
    saturation = builder.AddSystem(Saturation(min_value=[-3], max_value=[3]))
    builder.Connect(saturation.get_output_port(0), pendulum.get_input_port(0))
    controller = builder.AddSystem(EnergyShapingController(pendulum))
    builder.Connect(pendulum.get_output_port(0), controller.get_input_port(0))
    builder.Connect(controller.get_output_port(0), saturation.get_input_port(0))

    logger = builder.AddSystem(VectorLogSink(2))
    builder.Connect(pendulum.get_output_port(0), logger.get_input_port(0))

    diagram = builder.Build()
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    for i in range(5):
        context.SetTime(0.0)
        context.SetContinuousState(
            np.random.randn(
                2,
            )
        )
        simulator.Initialize()
        simulator.AdvanceTo(4)
        log = logger.FindLog(context)
        ax.plot(log.data()[0, :], log.data()[1, :])
        log.Clear()

    display(mpld3.display())


energy_shaping_demo()

## Swing-up and balance

Now we will combine our simple energy shaping controller with a linear controller that stabilizes the upright fixed point once we get close enough.  We'll read more about this approach in the Acrobot and Cart-Pole notes.

In [None]:
def BalancingLQR(pendulum):
    context = pendulum.CreateDefaultContext()

    pendulum.get_input_port(0).FixValue(context, [0])
    context.SetContinuousState([np.pi, 0])

    Q = np.diag((10.0, 1.0))
    R = [1]

    linearized_pendulum = Linearize(pendulum, context)
    (K, S) = LinearQuadraticRegulator(
        linearized_pendulum.A(), linearized_pendulum.B(), Q, R
    )
    return (K, S)


class SwingUpAndBalanceController(VectorSystem):
    def __init__(self, pendulum):
        VectorSystem.__init__(self, 2, 1)
        (self.K, self.S) = BalancingLQR(pendulum)
        self.energy_shaping = EnergyShapingController(pendulum)
        self.energy_shaping_context = self.energy_shaping.CreateDefaultContext()

        # TODO(russt): Add a witness function to tell the simulator about the
        # discontinuity when switching to LQR.

    def DoCalcVectorOutput(self, context, pendulum_state, unused, output):
        xbar = copy(pendulum_state)
        xbar[0] = wrap_to(xbar[0], 0, 2.0 * np.pi) - np.pi

        # If x'Sx <= 2, then use the LQR controller
        if xbar.dot(self.S.dot(xbar)) < 2.0:
            output[:] = -self.K.dot(xbar)
        else:
            self.energy_shaping.get_input_port(0).FixValue(
                self.energy_shaping_context, pendulum_state
            )
            output[:] = self.energy_shaping.get_output_port(0).Eval(
                self.energy_shaping_context
            )


def swing_up_and_balance_demo(show=False):
    builder = DiagramBuilder()

    pendulum = builder.AddSystem(PendulumPlant())
    ax = PhasePlot(pendulum)
    saturation = builder.AddSystem(Saturation(min_value=[-3], max_value=[3]))
    builder.Connect(saturation.get_output_port(0), pendulum.get_input_port(0))
    controller = builder.AddSystem(SwingUpAndBalanceController(pendulum))
    builder.Connect(pendulum.get_output_port(0), controller.get_input_port(0))
    builder.Connect(controller.get_output_port(0), saturation.get_input_port(0))

    # Setup visualization
    scene_graph = builder.AddSystem(SceneGraph())
    PendulumGeometry.AddToBuilder(
        builder, pendulum.get_state_output_port(), scene_graph
    )
    MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)

    logger = builder.AddSystem(VectorLogSink(2))
    builder.Connect(pendulum.get_output_port(0), logger.get_input_port(0))

    diagram = builder.Build()
    simulator = Simulator(diagram)
    context = simulator.get_mutable_context()

    if show:
        simulator.set_target_realtime_rate(1.0)

    for i in range(5):
        context.SetTime(0.0)
        context.SetContinuousState(
            np.random.randn(
                2,
            )
        )
        simulator.Initialize()
        simulator.AdvanceTo(4)
        log = logger.FindLog(context)
        ax.plot(log.data()[0, :], log.data()[1, :])
        log.Clear()

    ax.set_xlim(np.pi - 3.0, np.pi + 3.0)
    ax.set_ylim(-5.0, 5.0)
    display(mpld3.display())


swing_up_and_balance_demo(show=running_as_notebook)