<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [210]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 2.*h20
tend=200.
Nu=25
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.05)
bounds = [(0.,0.1)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_jac=jax.jit(jax.jacobian(rhs,1))

rhs_vec=jnp.vectorize(rhs,signature=f'(),(2),({Nu})->(2)')

In [211]:
x = (np.r_[-1.,leggauss(3)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]

def get_collocation_eqs(res, jit=False):
    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])
    Nb = Qs.shape[0]

    flatres0 = jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()])

    def collocation_eqs(v,u):
        qflat,yb0=v[:Nb*Ny*3],v[Nb*Ny*3:]
        Qs=jnp.reshape(qflat, (Nb, Ny, 3))
        yb0=jnp.reshape(yb0,(Ny,Nb)).T
        yb0=jnp.concatenate([y0.reshape(1,-1),yb0],axis=0)
        qp=jnp.einsum('byi, ix -> byx', Qs, p)
        yb=qp[:,:,1:]+yb0[:-1,:,None]
        tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
        ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
        ybp=ybp[:,:,1:-1]
        end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
        collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
        return jnp.concatenate([end_equality, collocation])

    def _interpolate(t,v):
        qflat,yb0=v[:Nb*Ny*3],v[Nb*Ny*3:]
        Qs=jnp.reshape(qflat, (Nb, Ny, 3))
        yb0=jnp.reshape(yb0,(Ny,Nb))
        yb0=jnp.concatenate([y0.reshape(Ny,1),yb0],axis=1)
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(ts,t, side='right')-1
        x = (t-ts[i])/hs[i]
        p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
        y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(yb0,i,1),jnp.take(yb0,i,1))
        return jnp.squeeze(y)

    jacv=jax.jacobian(collocation_eqs,0)
    jacu=jax.jacobian(collocation_eqs,1)

    if jit:
        return jax.jit(collocation_eqs), jax.jit(_interpolate), flatres0, jax.jit(jacv), jax.jit(jacu)
    else:
        return collocation_eqs, _interpolate, flatres0, jacv, jacu

In [212]:
res=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(u0,))
collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res, jit=True)
flatres0=flatres.copy()

In [213]:
def obj(u):

    global flatres0, collocation_eqs

    def integrated_error(v):
        return _interpolate(tend,v)[-1]
    derrdq = jax.jacobian(integrated_error)

    solve=root(collocation_eqs, flatres0, jac=jacv,args=(u,))
    if solve.success:
        flatres0=solve.x
    flatres=solve.x
    Hu=jacu(flatres,u)
    Hq=jacv(flatres,u)
    dqdu=-(jnp.linalg.inv(Hq) @ Hu)
    err=integrated_error(flatres)
    grad=derrdq(flatres)@dqdu
    print(err)
    return err,grad

In [214]:
sol=minimize(obj,u0, jac=True, bounds=bounds)

776.7043887612344
2528.9552642382214
593.9878195369629
501.4775536352571
435.8960051829062
418.11711720858386
412.6464409987631
410.7392541862199
409.0792694614904
408.12872504680183
406.96756042892326
406.1083153832253
405.5469381162547
405.12303215919894
405.0645567172953
404.9726282849158
404.92565551060045
404.8875912632971
404.84616088827624
404.8256396570772
404.8195748002216
404.8154649940564
404.80832204718683
404.79771342149525
404.7896577920208
404.7882729112284
404.78674308063376
404.78618257466024
404.78555756689144
404.7848300382125
404.7837928265094
404.78255124694584
404.7803503123615
404.77957662238356
404.77921127556465
404.7790444945964
404.7787604275206
404.77809362435886
404.77661775604656
404.7739965765329
404.80546749102297
404.773673099305
404.77109967020056
404.76982494845066
404.7694874487374
404.7693347390322
404.7693206232438
404.76919404039745
404.7689526057921
404.76865622675297
404.76837036338134
404.7679512658097
404.7671979453115
404.76638803476476
404.7

In [215]:
res = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res, jit=True)
flatres0=flatres.copy()
sol=minimize(obj,sol.x, jac=True, bounds=bounds)

405.5962449993125
3650.913067077604
405.00684525037866
404.8264068116468
404.79586082182817
404.7842836134466
404.7573122692553
404.74592017391717
404.7386294244325
404.7344045496817
404.72889800169315
404.71878115814377
404.70833686333884
404.7047401073377
404.7096727323581
404.7042979339435
404.70339107608385
404.7020111859316
404.7007282741746
404.69963734220056
404.69913645887107
404.69882112659286
404.6979703283375
404.698361704958
404.69754778546354
404.6971606379727
404.69665973332087
404.6963595546218
404.6962886950148
404.6957326951749
404.6956490507648
404.6954979156819
404.69527723933254
404.69503138141334
404.69470003581
404.69450034740396
404.69423867915464
404.6936708891473
404.69314895574723
404.6925077788649
404.69256110880684
404.692301983116
404.6920262148421
404.6918360157864
404.69159064702546
404.69123430694776
404.69093246013824
404.69066160445874
404.69027785815155
404.6900289768969
404.68926815888557
404.688907646436
404.6887886732553
404.6887891241931
404.68874

In [216]:
resoptim = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
tplot=np.linspace(0,tend,200)
h1,h2,err=resoptim.sol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1, name='h1')
fig.add_scatter(x=tplot,y=h2, name='h2')
fig.add_scatter(x=ut,y=sol.x,secondary_y=True, name='qcontrol',mode='lines+markers')
fig.update_layout(width=800,height=600,template='plotly_dark')