<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [28]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 2.*h20
tend=200.
Nu=25
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.05)
bounds = [(0.,0.1)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_jac=jax.jit(jax.jacobian(rhs,1))

rhs_vec=jnp.vectorize(rhs,signature=f'(),({Ny}),({Nu})->(2)')

In [29]:
x = (np.r_[leggauss(2)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]

def get_collocation_eqs(res, jit=False):
    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])
    Nb = Qs.shape[0]

    flatres0 = Qs.ravel()

    def collocation_eqs(v,u):
        Qs=jnp.reshape(v, (Nb, Ny, 3))
        qp=jnp.einsum('byi, ix -> byx', Qs, p)
        yb0=jnp.cumsum(qp[:,:,-1],axis=0)
        yb0=jnp.concatenate([jnp.zeros_like(y0).reshape(1,-1),yb0],axis=0)+y0
        yb=qp+yb0[:-1,:,None]
        tb=(hs[:,None]*x[None,:]+ts[:-1,None])
        ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
        collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb,1,2),u))
        return collocation

    def _interpolate(t,v):
        Qs=jnp.reshape(v, (Nb, Ny, 3))
        qp=jnp.einsum('byi, ix -> byx', Qs, p)
        yb0=jnp.cumsum(qp[:,:,-1],axis=0)
        yb0=jnp.concatenate([jnp.zeros_like(y0).reshape(1,-1),yb0],axis=0)+y0
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(ts,t, side='right')-1
        x = (t-ts[i])/hs[i]
        px = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
        y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), px)+jnp.take(yb0,i,0).T,jnp.take(yb0,i,0).T)
        return jnp.squeeze(y)

    jacv=jax.jacobian(collocation_eqs,0)
    jacu=jax.jacobian(collocation_eqs,1)

    if jit:
        return jax.jit(collocation_eqs), jax.jit(_interpolate), flatres0, jax.jit(jacv), jax.jit(jacu)
    else:
        return collocation_eqs, _interpolate, flatres0, jacv, jacu

In [30]:
res=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(u0,))
collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res, jit=True)
flatres0=flatres.copy()

In [31]:
def obj(u):

    global flatres0, collocation_eqs

    def integrated_error(v):
        return _interpolate(tend,v)[-1]
    derrdq = jax.jacobian(integrated_error)

    solve=root(collocation_eqs, flatres0, jac=jacv,args=(u,))
    if solve.success:
        flatres0=solve.x
    flatres=solve.x
    Hu=jacu(flatres,u)
    Hq=jacv(flatres,u)
    dqdu=-(jnp.linalg.inv(Hq) @ Hu)
    err=integrated_error(flatres)
    grad=derrdq(flatres)@dqdu
    print(err)
    return err,grad

In [None]:
sol=minimize(obj,u0, jac=True, bounds=bounds)

In [None]:
res = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res, jit=True)
flatres0=flatres.copy()
sol=minimize(obj,sol.x, jac=True, bounds=bounds)

In [34]:
resoptim = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
tplot=np.linspace(0,tend,200)
h1,h2,err=resoptim.sol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1, name='h1')
fig.add_scatter(x=tplot,y=h2, name='h2')
fig.add_scatter(x=ut,y=sol.x,secondary_y=True, name='qcontrol',mode='lines+markers')
fig.update_layout(width=800,height=600,template='plotly_dark')

In [23]:
hsp

10.193679918450561