In [1]:
import jax
from jax import jit
from jax import lax
from jax import vmap
import jax.numpy as jnp

from functools import partial

jax.config.update('jax_enable_x64', True)

In [2]:
import math
import numpy as np
import plotly.express as px
import IPython
import matplotlib.pyplot as plt 
import ipywidgets as widgets

%config InlineBackend.figure_formats = ['svg']

In [3]:
from jax_control_algorithms.trajectory_optimization import Solver, constraint_geq, constraint_leq
from jax_control_algorithms.ui import manual_investigate, solve_and_plot
from jax_control_algorithms.common import rk4

In [4]:
test_results = []

# Pendulum model

In [5]:
def problem_def_pendulum(n_steps, dt):
        
    def model(x, u, k, theta):
        del k
        phi, phi_dot = x

        a, b  = theta['a'], theta['b']
        force = jnp.squeeze(u)

        x_dot = jnp.array([
            phi_dot,
            force - a * jnp.sin(phi) - b * phi_dot
        ])

        # cost
        J = theta['wu'] * u**2
        
        return x_dot, phi, J

    def f(x, u, k, theta):
        x_dot, _, _ = model(x, u, k, theta)
        return x_dot

    def g(x, u, k, theta):
        
        x_dot, phi, _ = model(x, u, k, theta)
        return 

    def running_cost(x, u, k, theta):
        
        x_dot, _, J  = model(x, u, k, theta)
        return J
        
    def terminal_constraints(x_final, theta):
        phi, phi_dot = x_final

        phi_t     = theta['phi_final']
        phi_dot_t = theta['phi_dot_final']
        
        return jnp.array([
            phi - phi_t,           # != 0
            phi_dot - phi_dot_t,   # != 0  
        ])
    
    def inequality_constraints(x, u, k, theta):
        
        force = u[:,0]
        
        # constraints
        c_ineq = jnp.array([
            constraint_geq( force, theta['force_min'] ),
            constraint_leq( force, theta['force_max'] ),
            
        ])
        
        # x_dot, P_transmission, P_yield, J, c_ineq = model(x, u, k, theta)
        return c_ineq

    def initial_guess(x0, theta): # TODO: add theta to the solver
                
        U_guess = jnp.zeros( (n_steps, 1) )
        X_guess = jnp.vstack((
            jnp.linspace(x0[0], theta['phi_final'], n_steps),
            jnp.linspace(x0[1], theta['phi_dot_final'], n_steps),
        )).T
        
        return { 'X_guess' : X_guess, 'U_guess' : U_guess }
    
    theta = { 
        'a'             : 5.0, 
        'b'             : 0.1,
        'wu'            : 0.1,
        'force_min'     : -2.5,
        'force_max'     :  2.5,
        'phi_final'     : jnp.pi,
        'phi_dot_final' : 0.0
    }
    
    c_1_init, c_2_init = 0.5, 0.5
    
    x0 = jnp.array([ jnp.deg2rad(0.0), jnp.deg2rad(0.0), ])
    
    #
    f_dscr = rk4(f, dt)
    
    problem_definition = {
        'f' : f_dscr,
        'g' : g,
        'running_cost' : running_cost,
        'terminal_constraints': terminal_constraints,
        'inequality_constraints' : inequality_constraints,
        'initial_guess' : initial_guess,
        'parameters' : theta,
        'x0' : x0,
    }
    
    return problem_definition   

f_dscr, g, running_cost, terminal_constraints, inequality_constraints, theta, x0, initial_guess = problem_def_pendulum(50, 0.1)

In [6]:
def plot_pendulum(X_opt, U_opt, system_outputs, theta):

    # prepare data
    force = U_opt[:,0]
    phi, phi_dot = X_opt[:,0], X_opt[:,1]

    _ = system_outputs # unpack output variable (return of function g)

    _, _, _, _ = theta['a'], theta['b'], theta['force_min'], theta['force_max']

    # make time vectors
    time1 = jnp.linspace(0, phi.shape[0]-1,   phi.shape[0])
    time2 = jnp.linspace(0, force.shape[0]-1, force.shape[0])

    # Create a figure and two subplots
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(4, 6))

    ax1.plot( time1, jnp.rad2deg(phi), 'k', label='phi')
    ax1.plot(
        [ time1[0], time1[-1] ], 
        jnp.rad2deg( jnp.array([theta['phi_final'], theta['phi_final']]) ), 'k:',
        label='phi final'
    )
    
    ax1.legend()
    ax1.set_ylabel('phi [degrees]')

    ax2.plot( time1, jnp.rad2deg(phi_dot), 'k', label='phi_dot' )
    ax2.plot(
        [ time1[0], time1[-1] ], 
        jnp.rad2deg( jnp.array([theta['phi_dot_final'], theta['phi_dot_final']]) ), 'k:',
        label='phi dot final'
    )
    ax2.set_ylabel('phi dot [degrees/s]')
    ax2.legend()

    # 
    ax3.plot( time2, force, 'k', label='force' )    
    ax3.plot(
        [ time2[0], time2[-1] ], 
        jnp.array([theta['force_min'], theta['force_min']]), ':k',
        label='force_min'
    )
    ax3.plot(
        [ time2[0], time2[-1] ], 
        jnp.array([theta['force_max'], theta['force_max']]), ':k',
        label='force_max'
    )
    
    ax3.legend()
    ax3.set_ylabel('force []')

    # Show the plot
    plt.show()

In [7]:
sliders = {
        'a'         : widgets.FloatSlider(min=1.0,  max=20,  step=0.1,   value=5.0,  description='a'),
        'b'         : widgets.FloatSlider(min=0.01, max=5.0, step=0.01,  value=0.1,  description='b'),
        'force_min' : widgets.FloatSlider(min=-10,  max=0,   step=0.01,  value=-2,   description='force_min'),
        'force_max' : widgets.FloatSlider(min=0,    max=10,  step=0.01,  value=2,    description='force_max'),
}

solver = Solver( partial(problem_def_pendulum, n_steps = 50, dt=0.1) )

def set_theta_fn(solver, a, b, force_min, force_max):
        solver.problem_definition['parameters']['a'] = a
        solver.problem_definition['parameters']['b'] = b
        solver.problem_definition['parameters']['force_min'] = force_min
        solver.problem_definition['parameters']['force_max'] = force_max
        
#solver.theta

In [8]:
def test_pendulum():
    # test
    set_theta_fn(solver, a=5.0, b=0.1, wu=0.1, force_min=-10, force_max=10)
    X_opt, U_opt, system_outputs, res = solver.run()    

In [9]:
ui, output_box, print_output, plot_output = manual_investigate(solver, sliders, set_theta_fn, plot_pendulum )
display(ui, plot_output, print_output)

GridBox(children=(FloatSlider(value=5.0, description='a', max=20.0, min=1.0), FloatSlider(value=0.1, descripti…

Output()

Output()

In [10]:
test_results.append( ['pendulum without cart', solver.success ] )

# Verify

In [11]:
def verify_test_results(test_results):
    for r in test_results:
        if not r[1]:
            raise BaseException('Test ' + r[0] + ' failed')
            
verify_test_results(test_results)