# MPC for pendulum
model predictive control for simple pendulum        
inspired by:        
https://github.com/simorxb/MPC-Pendulum-Python-2/blob/main/MPC.py

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from IPython.display import HTML, display

from pydrake.all import (
    AddMultibodyPlantSceneGraph,
    ControllabilityMatrix,
    DiagramBuilder,
    LeafSystem,
    Linearize,
    LinearQuadraticRegulator,
    MeshcatVisualizer,
    Parser,
    Simulator,
    StartMeshcat,
    DirectCollocation,
    FiniteHorizonLinearQuadraticRegulator,
    FiniteHorizonLinearQuadraticRegulatorOptions,
    LogVectorOutput,
    MakeFiniteHorizonLinearQuadraticRegulator,
    MultibodyPlant,
    MultibodyPositionToGeometryPose,
    PiecewisePolynomial,
    PlanarSceneGraphVisualizer,
    SceneGraph,
    SnoptSolver,
    Solve,
    TrajectorySource,
    MathematicalProgram,
    MakeVectorVariable,
    Variable,
    ConstantVectorSource,
    SnoptSolver,
    SolverOptions,
    SolverType
)

from pydrake.examples import AcrobotGeometry, AcrobotPlant, PendulumPlant, PendulumState
from underactuated import ConfigureParser, running_as_notebook
from underactuated.meshcat_utils import MeshcatSliders
from IPython.core.display import Image, display
from underactuated.pendulum import PendulumVisualizer

In [None]:
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at http://localhost:7000


In [None]:
def pend_dircol(N):
    plant = PendulumPlant()
    context = plant.CreateDefaultContext()

    max_dt = 0.5
    N * max_dt
    dircol = DirectCollocation(
        plant, # we do not have to explicitly set the dynamics as constraints. It does it by itself here
        context,
        num_time_samples=N,
        minimum_time_step=0.05,
        maximum_time_step=max_dt,
    )
    prog = dircol.prog()

    dircol.AddEqualTimeIntervalsConstraints()

    torque_limit = 3.0  # N*m.
    u = dircol.input()
    dircol.AddConstraintToAllKnotPoints(-torque_limit <= u[0])
    dircol.AddConstraintToAllKnotPoints(u[0] <= torque_limit)

    initial_state = PendulumState()
    print(dircol.initial_state())
    print(initial_state.get_value())
    initial_state.set_theta(0.0)
    initial_state.set_thetadot(0.0)
    prog.AddBoundingBoxConstraint(
        initial_state.get_value(),
        initial_state.get_value(),
        dircol.initial_state(),
    )
    # More elegant version is blocked on drake #8315:
    # dircol.AddLinearConstraint(
    #     dircol.initial_state() == initial_state.get_value())

    final_state = PendulumState()
    final_state.set_theta(np.pi)
    final_state.set_thetadot(0.0)
    prog.AddBoundingBoxConstraint(
        final_state.get_value(), final_state.get_value(), dircol.final_state()
    )
    # dircol.AddLinearConstraint(dircol.final_state() == final_state.get_value())

    R = 10  # Cost on input "effort".
    dircol.AddRunningCost(R * u[0] ** 2)

    initial_x_trajectory = PiecewisePolynomial.FirstOrderHold(
        [0.0, 4.0], [initial_state.get_value(), final_state.get_value()]
    )
    dircol.SetInitialTrajectory(PiecewisePolynomial(), initial_x_trajectory)

    print(prog.decision_variables())

    result = Solve(prog)
    print('optimiztion success: ', result.is_success() )
    # return input values
    out = dircol.GetInputSamples(result)
    return out

    '''
    x_trajectory = dircol.ReconstructStateTrajectory(result)
    print(dircol.GetSampleTimes(result))
    fig, ax = plt.subplots()

    x_knots = np.hstack(
        [
            x_trajectory.value(t)
            for t in np.linspace(
                x_trajectory.start_time(), x_trajectory.end_time(), 100
            )
        ]
    )

    ax.set_xlabel("$q$")
    ax.set_ylabel("$\dot{q}$")
    ax.plot(x_knots[0, :], x_knots[1, :])
    display(plt.show())

    # Animate the result.
    vis = PendulumVisualizer(show=False)
    ani = vis.animate(x_trajectory)
    display(HTML(ani.to_jshtml()))
    '''


N = 21 # model predictive horizon
time_range = 10.0
dt = 0.1
L = round(time_range/dt)
tau = np.zeros(L)
step = 0
for i in range(L):

    # ---- generate reference point. 
    if step < 5:
        q_ref[i] = np.pi
    else:
        q_ref[i] = np.pi / 2

    step += dt

    # ----- solve trajopt for next N steps
    tau_traj = pend_dircol()

    # get the first element
    tau[i] = tau_traj[0]

    # initial solution for the next step
    tau_init = tau_traj

    # ----- run simulation. extract the state of the next step
