<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [2]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 2.*h20
tend=200.
Nu=25
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.05)
bounds = [(0.,0.1)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_jac=jax.jit(jax.jacobian(rhs,1))

rhs_vec=jnp.vectorize(rhs,signature=f'(),({Ny}),({Nu})->(2)')



In [4]:
x = (np.r_[leggauss(2)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]

def get_collocation_eqs(res, jit=False):
    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])
    Nb = Qs.shape[0]

    flatres0 = Qs.ravel()

    def collocation_eqs(v,u):
        Qs=jnp.reshape(v, (Nb, Ny, 3))
        qp=jnp.einsum('byi, ix -> byx', Qs, p)
        yb0=jnp.cumsum(qp[:,:,-1],axis=0)
        yb0=jnp.concatenate([jnp.zeros_like(y0).reshape(1,-1),yb0],axis=0)+y0
        yb=qp+yb0[:-1,:,None]
        tb=(hs[:,None]*x[None,:]+ts[:-1,None])
        ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
        collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb,1,2),u))
        return collocation

    def _interpolate(t,v):
        Qs=jnp.reshape(v, (Nb, Ny, 3))
        qp=jnp.einsum('byi, ix -> byx', Qs, p)
        yb0=jnp.cumsum(qp[:,:,-1],axis=0)
        yb0=jnp.concatenate([jnp.zeros_like(y0).reshape(1,-1),yb0],axis=0)+y0
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(ts,t, side='right')-1
        x = (t-ts[i])/hs[i]
        px = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
        y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), px)+jnp.take(yb0,i,0).T,jnp.take(yb0,i,0).T)
        return jnp.squeeze(y)

    jacv=jax.jacobian(collocation_eqs,0)
    jacu=jax.jacobian(collocation_eqs,1)

    if jit:
        return jax.jit(collocation_eqs), jax.jit(_interpolate), flatres0, jax.jit(jacv), jax.jit(jacu)
    else:
        return collocation_eqs, _interpolate, flatres0, jacv, jacu

In [5]:
res=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(u0,))
collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res, jit=True)
flatres0=flatres.copy()

In [6]:
def obj(u):

    global flatres0, collocation_eqs

    def integrated_error(v):
        return _interpolate(tend,v)[-1]
    derrdq = jax.jacobian(integrated_error)

    solve=root(collocation_eqs, flatres0, jac=jacv,args=(u,))
    if solve.success:
        flatres0=solve.x
    flatres=solve.x
    Hu=jacu(flatres,u)
    Hq=jacv(flatres,u)
    dqdu=-(jnp.linalg.inv(Hq) @ Hu)
    err=integrated_error(flatres)
    grad=derrdq(flatres)@dqdu
    print(err)
    return err,grad

In [7]:
sol=minimize(obj,u0, jac=True, bounds=bounds)

776.6532429542498
2474.7241424332324
590.8392225201786
514.8693369700031
432.8321957905433
421.435700619241
413.3466682831723
410.17024579583983
408.84308128709426
407.79367673873764
406.2783443031737
407.149708602948
406.08301551119365
405.4853230507892
405.0351760685983
404.87642492459855
404.8072391190385
404.6650035643278
404.5834685722916
404.54237459808087
404.5287911573993
404.52074564144715
404.5055478922206
404.4830899060207
404.5059067048823
404.4744251719496
404.45833952996213
404.4528856276723
404.4413086625257
404.4384832509532
404.4199587652767
404.4091630241865
404.3900008923138
404.3470492203678
404.9356523666527
404.345998805733
404.36568416261804
404.3408291232734
404.3631055552519
404.34075387192894
404.3448998375142
404.33894795677014
404.31123273287204
404.29802935843804
404.2923615778957
404.3016344507999
404.2907594010414
404.28848738372204
404.2843351680604
404.27786888318553
404.27484233976185
404.2756758850017
404.273703232849
404.27278131320094
404.2711085889

In [8]:
res = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res, jit=True)
flatres0=flatres.copy()
sol=minimize(obj,sol.x, jac=True, bounds=bounds)

406.38362790603554
2877.8283294937455
405.04994516132064
404.9287827544847
404.7778251837257
404.7517514596252
404.7332513136924
404.7276577714302
404.71883940457474
404.71381530766155
404.7015325509392
404.69512496193573
404.6870241386143
404.7038158198371
404.68637833818394
404.68385757749326
404.683004149082
404.68242074013426
404.68135878969736
404.68113545161447
404.68007151080405
404.67993930741204
404.67966725691485
404.6792284496897
404.67882191737846
404.6783362656821
404.67955989235264
404.6781839407144
404.677704605066
404.67737179464484
404.67707019151845
404.67684618310733
404.6764847063518
404.67622911528736
404.6757251219198
404.675305718637
404.67505399062975
404.6748571145881
404.6746542127733
404.67485255184687
404.67455177402627
404.6744583873641
404.67429406537735
404.6741526636269
404.67389997084297
404.673558586701
404.6733579768428
404.6729642170697
404.67289445335314
404.67284884380115
404.6727565104064
404.67256162484824
404.67675535012
404.67253179990064
404.6

In [9]:
resoptim = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
tplot=np.linspace(0,tend,200)
h1,h2,err=resoptim.sol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1, name='h1')
fig.add_scatter(x=tplot,y=h2, name='h2')
fig.add_scatter(x=ut,y=sol.x,secondary_y=True, name='qcontrol',mode='lines+markers')
fig.update_layout(width=800,height=600,template='plotly_dark')