<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [31]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 1.1*h20
tend=1000.
Nu=20
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.1)
bounds = [(0.,5.)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_vec=jnp.vectorize(rhs,signature=f'(),(2),({Nu})->(2)')

In [32]:
res=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=jax.jacobian(rhs,1), args=(u0,))
ts=jnp.array(res.t)
hs=ts[1:]-ts[:-1]
y0s=jnp.array(res.y)
Qs=jnp.stack([s.Q for s in res.sol.interpolants])
Nb = Qs.shape[0]

flatres0 = jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()])

In [33]:
x = (np.r_[-1.,leggauss(3)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]

@jax.jit
def eqs(v,u):
    qflat,yb0=v[:Nb*Ny*3],v[Nb*Ny*3:]
    Qs=jnp.reshape(qflat, (Nb, Ny, 3))
    yb0=jnp.reshape(yb0,(Ny,Nb)).T
    yb0=jnp.concatenate([y0.reshape(1,-1),yb0],axis=0)
    qp=jnp.einsum('byi, ix -> byx', Qs, p)
    yb=qp[:,:,1:]+yb0[:-1,:,None]
    tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
    ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
    ybp=ybp[:,:,1:-1]
    end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
    collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
    return jnp.concatenate([end_equality, collocation])

jacv=jax.jit(jax.jacobian(eqs,0))
jacu=jax.jit(jax.jacobian(eqs,1))

In [34]:
solve=root(eqs, flatres0,jac=jacv,args=(u0,))

In [35]:

def _interpolate(t,v):
    qflat,yb0=v[:Nb*Ny*3],v[Nb*Ny*3:]
    Qs=jnp.reshape(qflat, (Nb, Ny, 3))
    yb0=jnp.reshape(yb0,(Ny,Nb))
    yb0=jnp.concatenate([y0.reshape(Ny,1),yb0],axis=1)
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(yb0,i,1),jnp.take(yb0,i,1))
    return jnp.squeeze(y)

def errint(v):
    return _interpolate(tend,v)[-1]
derrdq = jax.jacobian(errint)

In [36]:

flatres=flatres0

def obj(u):

    global flatres
    solve=root(eqs, flatres, jac=jacv,args=(u,))
    flatres=(solve.x)[:]
    Hu=jacu(flatres,u)
    Hq=jacv(flatres,u)
    dqdu=-(jnp.linalg.inv(Hq) @ Hu)
    err=errint(flatres)
    grad=derrdq(flatres)@dqdu
    print(err)
    return err,grad

In [43]:
sol=minimize(obj,u0, jac=True, bounds=bounds)

23022.588288980867
259.7777756998617
48.039437885355504
5.808548040049536
5.369799017485421
5.002082449313983
4.472910619782905
4.369658662578919
3.9640564782106957
3.854524351769607
3.7490072836549744
3.641685775297632
3.578449310362496
3.5745574895603776
3.564347263078288
3.563438560847784
3.5620383679328183
3.5624014408186806
3.561808427876903
3.56163242212031
3.5614417341122144
3.561371426258537
3.5613255963647887
3.561317222359199
3.56131106825183
3.5612890436392592
3.5612527082743077
3.561349030957153
3.561242185686726
3.5612222096906416
3.5612133228416494
3.5612063413213435
3.5611952647968614
3.561178060975475
3.5611550701321177
3.5611199185153497
3.561081207319555
3.56105853487736
3.5609948640964726
3.5609481105388303
3.5609102421070453
3.560860609190602
3.5609301382683065
3.5608475428878585
3.560837169244957
3.5608346832707967
3.5608334895449834
3.5608336274625354
3.560833418598026
3.5608333596745196
3.560833343315087
3.5608333165262596
3.5608332891992593
3.5608333759162982
3.

In [44]:
solve=root(eqs, flatres, jac=jacv,args=(sol.x,))

In [45]:
tplot=np.linspace(0,tend,200)
h1,h2,err=_interpolate(tplot,solve.x)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1)
fig.add_scatter(x=tplot,y=h2)
fig.add_scatter(x=tplot,y=jnp.interp(tplot,ut,sol.x),secondary_y=True)
fig.update_layout(width=800,height=600,template='plotly_dark')