<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_hermite_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [35]:
np.set_printoptions(linewidth=200)

In [202]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 2.*h20
tend=200.
Nu=30
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.1)
dq = jax.grad(jnp.interp,0)
bounds = [(0.,0.5)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    dqcontrol = dq(t,ut,u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_jac=jax.jit(jax.jacobian(rhs,1))
rhs_vec=jax.vmap(rhs, (0,1,None), 1)

In [203]:
res0=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(u0,))

In [204]:
def get_collocation(rhs_vec, res, jit=False):

    Nb=res.t.size
    hs=res.t[1:]-res.t[:-1]
    tmid = (res.t[1:]+res.t[:-1])/2

    x = 0.5

    h00 = 2*x**3-3*x**2+1
    h10 = x**3 - 2*x**2 + x
    h01 = -2*x**3 + 3*x**2
    h11 = x**3 - x**2

    h00d = 6*x**2-6*x
    h10d = 3*x**2-4*x + 1
    h01d = -6*x**2 + 6*x
    h11d = 3*x**2 - 2*x

    def collocation(y,u):
        y=jnp.concatenate([res.y[:,0].reshape(Ny,1), y.reshape(Ny,-1)],axis=1)
        m0 = rhs_vec(res.t[:-1],y[:,:-1],u)
        m1 = rhs_vec(res.t[1:],y[:,1:],u)

        ymid=jnp.squeeze(h00*y[:,:-1] + h10*hs*m0 + h01*y[:,1:] + h11*hs*m1)
        ypmid=jnp.squeeze((h00d*y[:,:-1]/hs + h10d*m0 + h01d*y[:,1:]/hs + h11d*m1))

        return (ypmid-rhs_vec(tmid, ymid, u)).ravel()

    def interpolate(t,y,u):
        y=jnp.concatenate([res.y[:,0].reshape(Ny,1), y.reshape(Ny,-1)],axis=1)
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(res.t,t, side='right')-1   #this precludes JIT
        i=jnp.where(i>=Nb-1,i-1,i)
        x = jnp.where(t<res.t[-1], (t-res.t[i])/hs[i], 1.)
        p0 = y[:,i]
        p1 = y[:,i+1]
        m0 = rhs_vec(res.t[i],p0,u)
        m1 = rhs_vec(res.t[i+1],p1,u)
        h00 = 2*x**3-3*x**2+1
        h10 = x**3 - 2*x**2 + x
        h01 = -2*x**3 + 3*x**2
        h11 = x**3 - x**2
        return jnp.squeeze(h00*p0 + h10*hs[i]*m0 + h01*p1 + h11*hs[i]*m1)

    if jit:
        return jax.jit(collocation), jax.jit(jax.jacobian(collocation,0)), jax.jit(jax.jacobian(collocation,1)), interpolate
    else:
        return collocation, jax.jacobian(collocation,0), jax.jacobian(collocation,1), interpolate

In [205]:
collocation, jacy, jacu, interpolate = get_collocation(rhs_vec, res0, jit=True)

In [206]:
solve=root(collocation, res0.y[:,1:].ravel(),args=(u0,))
solvex0=solve.x

In [207]:
def obj(u, collocation):

    global solvex0

    def integrated_error(y):
        return interpolate(tend,y,u)[-1]
    derrdy = jax.jacobian(integrated_error)

    solve=root(collocation, solvex0, jac=jacy, args=(u,))
    if solve.success:
        solvex0=solve.x
    Hu=jacu(solve.x,u)
    Hy=jacy(solve.x,u)
    dydu=-(jnp.linalg.inv(Hy) @ Hu)
    err=integrated_error(solve.x)
    grad=derrdy(solve.x)@dydu
    print(err)
    return err,grad

In [208]:
sol=minimize(obj,u0, jac=True, bounds=bounds, args=(collocation,))

9012.553180972642
9012.553180972642
3846.929612224789
754.214564968631
647.2711092605255
585.7005393482145
478.1323280235726
395.43584385460406
363.2234807581092
340.9906716065142
316.71147142141984
283.9006766021473
262.4181625328994
251.39294338743386
249.97830980732147
249.01300902327276
247.33821339142446
242.93043362190625
234.84271046334604
241.7712937565599
228.415602184974
254.27135060493555
228.41423436187375
255.46589868403817
228.41341033277098
257.49285234900026
228.41310639825122
222.48016677402032
220.43279936401254
216.72288593174275
213.88776596709897
211.9026275356514
210.63210991799556
208.6377567821403
203.75193324511946
201.11688568563287
200.23908106662284
200.17228419429415
200.10752459600545
200.04409759700442
199.98109887447828
199.93602708414124
199.9132689222725
199.8613275910931
199.78949365660688
199.68148741807494
199.65131758368236
199.60506865757407
199.60114575184264
199.60051529476948
199.60052995229555
199.6001208018028
199.59967187994548
199.598089008

In [209]:
res0=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))

collocation2, jacy, jacu, interpolate = get_collocation(rhs_vec, res0, jit=True)
solve=root(collocation2, res0.y[:,1:].ravel(),args=(sol.x,))
solvex0=solve.x
sol2=minimize(obj,sol.x, jac=True, bounds=bounds, args=(collocation2,))

669.9961994287046
5195.555513944623
415.0602554741443
285.23442620604806
228.09211255313264
221.55549088477628
203.2344189931571
201.95479307767332
200.496552148605
200.15336407930678
199.9927897878108
199.6746636323827
199.56931249380548
199.49093948027735
199.4267732774044
199.32428660218665
199.18593970975508
199.0550910362583
199.04215259356732
199.0336262950685
199.02555881601077
199.0121955186314
199.01707351377425
199.0044448403822
198.9910101417167
198.98207621296814
198.97956934168792
198.97567485524968
198.96880552270792
198.96213741409773
198.9554963815446
198.95290538552737
198.95087971229233
198.94861310191627
198.9424975934205
198.93959969179627
198.9335305593361
198.93177958463488
198.9310446694962
198.9298804100113
198.92605289413046
198.92219121490024
198.91843689837106
198.91539653145986
198.912942780999
198.9115899133555
198.90734584330545
198.90322132718705
198.89831510586384
198.8948042232
198.89059773429483
198.8834757483499
198.87649133554822
198.86987687094515
1

In [210]:
resoptim = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol2.x,))
tplot=np.linspace(0,tend,200)
h1,h2,err=resoptim.sol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1, name='h1')
fig.add_scatter(x=tplot,y=h2, name='h2')
fig.add_scatter(x=ut,y=sol.x,secondary_y=True, name='qcontrol',mode='lines+markers')
fig.update_layout(width=800,height=600,template='plotly_dark')

In [167]:
hsp

10.193679918450561