In [None]:
import sys
sys.path.append("../underactuated")

In [None]:
import numpy as np

from polynomial_fvi import *
from pydrake.examples.pendulum import (PendulumPlant)
from pydrake.all import (DiagramBuilder, Simulator, WrapToSystem, LeafSystem,
                         BasicVector)
from underactuated.pendulum import PendulumVisualizer
from underactuated.jupyter import AdvanceToAndVisualize

In [None]:
class Controller(LeafSystem):
    def __init__(self, J, plant, params_dict):
        LeafSystem.__init__(self)
        self.plant = plant
        self.context = plant.CreateDefaultContext()
        self.x_dim = 2
        self.u_dim = 1
        self.x2z = params_dict["x2z"]
        self.J = J
        self.poly_func = lambda t, i, n: monomial(t, i, n)

        self.state_input_port = self.DeclareVectorInputPort(
            "state", BasicVector(self.x_dim))

        self.policy_output_port = self.DeclareVectorOutputPort(
            "policy", BasicVector(self.u_dim), self.CalculateController)

    def CalculateController(self, context, output):
        state = self.state_input_port.Eval(context)
        # In polynomial_fvi, the axis is pointing upwards; the pendulum plant
        # has the axis pointing downwards.
        state[0] = state[0] + np.pi
        z = self.x2z(state)
        y = output.get_mutable_value()
        dJdz_expr, z_var = calc_dJdz(self.J, self.poly_func, params_dict)
        y[:]  = calc_u_opt(dJdz_expr, z_var, z, params_dict)
        print("v: ", state[1])

In [None]:
def simulate(J, params_dict):
    # Animate the resulting policy.
    builder = DiagramBuilder()
    pendulum = builder.AddSystem(PendulumPlant())

    wrap = builder.AddSystem(WrapToSystem(2))
    wrap.set_interval(0, 0, 2*np.pi)
    builder.Connect(pendulum.get_output_port(0), wrap.get_input_port(0))
    vi_policy = Controller(J, pendulum, params_dict)
    builder.AddSystem(vi_policy)
    builder.Connect(wrap.get_output_port(0), vi_policy.get_input_port(0))
    builder.Connect(vi_policy.get_output_port(0),
                    pendulum.get_input_port(0))

    visualizer = builder.AddSystem(
        PendulumVisualizer(show=False))
    builder.Connect(pendulum.get_output_port(0),
                    visualizer.get_input_port(0))

    diagram = builder.Build()
    simulator = Simulator(diagram)
    simulator.get_mutable_context().SetContinuousState([0.1, 0.0])

    AdvanceToAndVisualize(simulator, visualizer, 12)

In [None]:
method = LSTSQ
poly_type = MONOMIAL
deg = 2
dt = 0.01
J = np.load("pendulum_swingup/{}/{}/J_{}_{}.npy".format(method, poly_type, deg, dt))
params_dict = pendulum_setup(poly_type)
simulate(J, params_dict)