In [1]:
import jax
from jax import jit
from jax import lax
from jax import vmap
import jax.numpy as jnp

from functools import partial

jax.config.update('jax_enable_x64', True)

In [2]:
import math
import numpy as np
import plotly.express as px
import IPython
import matplotlib.pyplot as plt 
import ipywidgets as widgets

%config InlineBackend.figure_formats = ['svg']

In [3]:
from jax_control_algorithms.trajectory_optimization import Solver, constraint_geq, constraint_leq
from jax_control_algorithms.ui import manual_investigate, solve_and_plot
from jax_control_algorithms.common import rk4

# Integrator model

In [4]:
def problem_def_integrator(n_steps, dt):
        
    def model(x, u, k, theta):
        del k
        x_1,  = x

        a,   = theta['a'], 
        x_1_dot = jnp.squeeze(u) - a * x_1

        x_dot = jnp.array([
            x_1_dot,
        ])

        # cost
        J = theta['wu'] * u**2
        
        return x_dot, J

    def f(x, u, k, theta):
        x_dot, _ = model(x, u, k, theta)
        return x_dot

    def g(x, u, k, theta):
        
        x_dot, _ = model(x, u, k, theta)
        return 

    def running_cost(x, u, k, theta):
        
        x_dot, J  = model(x, u, k, theta)
        return J
        
    def terminal_state_eq_constraints(x_final, theta):
        x_1, = x_final
        
        return jnp.array([
            x_1 - theta['x_1_final'],           # != 0
        ])
    
    def inequ_constraints(x, u, k, theta):
        
        u = u[:,0]
        
        # constraints
        c_ineq = jnp.array([
            constraint_geq( u, theta['u_min'] ),
            constraint_leq( u, theta['u_max'] ),
            
        ])
        
        # x_dot, P_transmission, P_yield, J, c_ineq = model(x, u, k, theta)
        return c_ineq

    def make_guess(x0, theta): # TODO: add theta to the solver
                
        U_guess = jnp.zeros( (n_steps, 1) )
        X_guess = jnp.vstack((
            jnp.linspace( x0[0], theta['x_1_final'], n_steps),
        )).T
        
        return X_guess, U_guess
    
    theta = { 
        'a'             : 5.0, 
        'wu'            :  1.00,
        'u_min'         : -2.5,
        'u_max'         : 2.5,
        'x_1_final'     : 1.0,
    }
    
    
    x0 = jnp.array([ jnp.deg2rad(0.0), ])
    
    #
    f_dscr = rk4(f, dt)

    return f_dscr, g, running_cost, terminal_state_eq_constraints, inequ_constraints, theta, x0, make_guess

f_dscr, g, running_cost, terminal_state_eq_constraints, inequ_constraints, theta, x0, make_guess = problem_def_integrator(50, 0.1)



In [5]:
def plot_integrator(X_opt, U_opt, system_outputs, theta):

    # prepare data
    u = U_opt[:,0]
    x_1, =  X_opt[:,0], 

    _ = system_outputs # unpack output variable (return of function g)

    # make time vectors
    time1 = jnp.linspace(0, x_1.shape[0]-1,   x_1.shape[0])
    time2 = jnp.linspace(0, u.shape[0]-1,     u.shape[0])

    # Create a figure and two subplots
    fig, (ax1, ax2, ) = plt.subplots(2, 1, sharex=True, figsize=(4, 6))

    ax1.plot( time1, x_1, 'k', label='x_1')
    ax1.plot(
        [ time1[0], time1[-1] ], 
        jnp.array([theta['x_1_final'], theta['x_1_final']]), 'k:',
        label='x_1_final'
    )
    
    ax1.legend()
    ax1.set_ylabel('x_1 []')


    # 
    ax2.plot( time2, u, 'k', label='u' )    
    ax2.plot(
        [ time2[0], time2[-1] ], 
        jnp.array([theta['u_min'], theta['u_min']]), ':k',
        label='u_min'
    )
    ax2.plot(
        [ time2[0], time2[-1] ], 
        jnp.array([theta['u_max'], theta['u_max']]), ':k',
        label='u_max'
    )
    
    ax2.legend()
    ax2.set_ylabel('u []')

    # Show the plot
    plt.show()

In [6]:
sliders = {
        'a'     : widgets.FloatSlider(min=0.0,  max=1.0,  step=0.01,   value=0.0,  description='a'),
        'u_min' : widgets.FloatSlider(min=-10,  max=0,   step=0.01,  value=-2,   description='u_min'),
        'u_max' : widgets.FloatSlider(min=0,    max=10,  step=0.01,  value=2,    description='u_max'),
}

solver = Solver( partial(problem_def_integrator, n_steps = 50, dt=0.1) )
solver.max_float32_iterations = 0


def set_theta_fn(solver, a, u_min, u_max):
        solver.theta['a'] = a
        solver.theta['u_min'] = u_min
        solver.theta['u_max'] = u_max
        
#solver.theta

In [7]:
def test_pendulum():
    # test
    set_theta_fn(solver, a=0.0,  u_min=-10, u_max=10)
    X_opt, U_opt, system_outputs, res = solver.run()    

In [8]:
ui, output_box, print_output, plot_output = manual_investigate(solver, sliders, set_theta_fn, plot_integrator )
display(ui, plot_output, print_output)

GridBox(children=(FloatSlider(value=0.0, description='a', max=1.0, step=0.01), FloatSlider(value=-2.0, descrip…

Output()

Output()