<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_hermite_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [35]:
np.set_printoptions(linewidth=200)

In [36]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 2.*h20
tend=200.
Nu=25
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.05)
bounds = [(0.,0.1)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_jac=jax.jit(jax.jacobian(rhs,1))

rhs_vec=jax.vmap(rhs, (0,1,None), 1)

In [37]:
res0=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(u0,))

In [38]:
def get_collocation(rhs_vec, res, jit=False):

    Nb=res.t.size
    hs=res.t[1:]-res.t[:-1]
    tmid = (res.t[1:]+res.t[:-1])/2

    x = 0.5

    h00 = 2*x**3-3*x**2+1
    h10 = x**3 - 2*x**2 + x
    h01 = -2*x**3 + 3*x**2
    h11 = x**3 - x**2

    h00d = 6*x**2-6*x
    h10d = 3*x**2-4*x + 1
    h01d = -6*x**2 + 6*x
    h11d = 3*x**2 - 2*x

    def collocation(y,u):
        y=jnp.concatenate([res.y[:,0].reshape(Ny,1), y.reshape(Ny,-1)],axis=1)
        m0 = rhs_vec(res.t[:-1],y[:,:-1],u)
        m1 = rhs_vec(res.t[1:],y[:,1:],u)

        ymid=jnp.squeeze(h00*y[:,:-1] + h10*hs*m0 + h01*y[:,1:] + h11*hs*m1)
        ypmid=jnp.squeeze((h00d*y[:,:-1]/hs + h10d*m0 + h01d*y[:,1:]/hs + h11d*m1))

        return (ypmid-rhs_vec(tmid, ymid, u)).ravel()

    def interpolate(t,y,u):
        y=jnp.concatenate([res.y[:,0].reshape(Ny,1), y.reshape(Ny,-1)],axis=1)
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(res.t,t, side='right')-1   #this precludes JIT
        i=jnp.where(i>=Nb-1,i-1,i)
        x = jnp.where(t<res.t[-1], (t-res.t[i])/hs[i], 1.)
        p0 = y[:,i]
        p1 = y[:,i+1]
        m0 = rhs_vec(res.t[i],p0,u)
        m1 = rhs_vec(res.t[i+1],p1,u)
        h00 = 2*x**3-3*x**2+1
        h10 = x**3 - 2*x**2 + x
        h01 = -2*x**3 + 3*x**2
        h11 = x**3 - x**2
        return jnp.squeeze(h00*p0 + h10*hs[i]*m0 + h01*p1 + h11*hs[i]*m1)

    if jit:
        return jax.jit(collocation), jax.jit(jax.jacobian(collocation,0)), jax.jit(jax.jacobian(collocation,1)), interpolate
    else:
        return collocation, jax.jacobian(collocation,0), jax.jacobian(collocation,1), interpolate

In [66]:
collocation, jacy, jacu, interpolate = get_collocation(rhs_vec, res0, jit=True)

In [67]:
solve=root(collocation, res0.y[:,1:].ravel(),args=(u0,))
solvex0=solve.x

In [68]:
def obj(u, collocation):

    global solvex0

    def integrated_error(y):
        return interpolate(tend,y,u)[-1]
    derrdy = jax.jacobian(integrated_error)

    solve=root(collocation, solvex0, jac=jacy, args=(u,))
    if solve.success:
        solvex0=solve.x
    Hu=jacu(solve.x,u)
    Hy=jacy(solve.x,u)
    dydu=-(jnp.linalg.inv(Hy) @ Hu)
    err=integrated_error(solve.x)
    grad=derrdy(solve.x)@dydu
    print(err)
    return err,grad

In [None]:
sol=minimize(obj,u0, jac=True, bounds=bounds, args=(collocation,))

In [72]:
res0=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))

collocation2, jacy, jacu, interpolate = get_collocation(rhs_vec, res0, jit=True)
solve=root(collocation2, res0.y[:,1:].ravel(),args=(sol.x,))
solvex0=solve.x
sol2=minimize(obj,sol.x, jac=True, bounds=bounds, args=(collocation2,))

405.05233045631735
6021.851150572598
404.98520817683664
404.8844822979087
404.7378100673953
404.7266515576665
404.7171915881242
404.71605708034093
404.7159645030293
404.71591715249974
404.7158732569733
404.71565457180776
404.7153754295816
404.71510535134286
404.71495676117587
404.71489021198687
404.714874617553
404.71485305266606
404.7147858508441
404.71467570291634
404.71470263988203
404.7146157153837
404.714506482531
404.7144590036756
404.71444471042184
404.7144261269812
404.7143775273167
404.7143387985571
404.71426647913637
404.71422746567293
404.7141647486138
404.71405827648914
404.7139767163848
404.7139316715563
404.71390417673405
404.7138828338328
404.71383412476195
404.71380951422486
404.71378500500805
404.71375324120004
404.71372027634794
404.7136832790019
404.71365739670404
404.7136389335067
404.71363083684906
404.71362326765086
404.7136134201561
404.71361076459806
404.71360479784283
404.7135984728033
404.71358797482003
404.71356896448924
404.71355928877534
404.71354208568334


In [73]:
resoptim = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol2.x,))
tplot=np.linspace(0,tend,200)
h1,h2,err=resoptim.sol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1, name='h1')
fig.add_scatter(x=tplot,y=h2, name='h2')
fig.add_scatter(x=ut,y=sol.x,secondary_y=True, name='qcontrol',mode='lines+markers')
fig.update_layout(width=800,height=600,template='plotly_dark')

In [47]:
hsp

10.193679918450561