<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau_optimal_control.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from functools import partial

In [164]:
a=0.01
A = 0.2
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
eint0 = 0.
y0=jnp.array([h10, h20, eint0])
hsp = 2.*h20
tend=100.
Nu=25
ut = jnp.linspace(0.,tend,Nu)
u0 = jnp.full_like(ut,0.05)
bounds = [(0.,2.)]*Nu
Ny = y0.size


@jax.jit
def rhs(t, v, u):
    h1,h2,eint = v
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, u)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A, (h2-hsp)**2])

rhs_jac=jax.jit(jax.jacobian(rhs,1))

rhs_vec=jnp.vectorize(rhs,signature=f'(),(2),({Nu})->(2)')

In [165]:
res=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=jax.jacobian(rhs,1), args=(u0,))


In [166]:
x = (np.r_[-1.,leggauss(3)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]

def get_collocation_eqs(res):
    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])
    Nb = Qs.shape[0]

    flatres0 = jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()])

    def collocation_eqs(v,u):
        qflat,yb0=v[:Nb*Ny*3],v[Nb*Ny*3:]
        Qs=jnp.reshape(qflat, (Nb, Ny, 3))
        yb0=jnp.reshape(yb0,(Ny,Nb)).T
        yb0=jnp.concatenate([y0.reshape(1,-1),yb0],axis=0)
        qp=jnp.einsum('byi, ix -> byx', Qs, p)
        yb=qp[:,:,1:]+yb0[:-1,:,None]
        tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
        ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
        ybp=ybp[:,:,1:-1]
        end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
        collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
        return jnp.concatenate([end_equality, collocation])

    def _interpolate(t,v):
        qflat,yb0=v[:Nb*Ny*3],v[Nb*Ny*3:]
        Qs=jnp.reshape(qflat, (Nb, Ny, 3))
        yb0=jnp.reshape(yb0,(Ny,Nb))
        yb0=jnp.concatenate([y0.reshape(Ny,1),yb0],axis=1)
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(ts,t, side='right')-1
        x = (t-ts[i])/hs[i]
        p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
        y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(yb0,i,1),jnp.take(yb0,i,1))
        return jnp.squeeze(y)

    jacv=jax.jacobian(collocation_eqs,0)
    jacu=jax.jacobian(collocation_eqs,1)

    return collocation_eqs, _interpolate, flatres0, jacv, jacu


In [167]:
def obj(u):

    global flatres
    res=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(u,))
    collocation_eqs, _interpolate, flatres, jacv, jacu = get_collocation_eqs(res)
    Nb=res.t.size-1

    def integrated_error(v):
        return _interpolate(tend,v)[-1]
    derrdq = jax.jacobian(integrated_error)

    solve=root(collocation_eqs, flatres, jac=jacv,args=(u,))
    flatres=solve.x
    Hu=jacu(flatres,u)
    Hq=jacv(flatres,u)
    dqdu=-(jnp.linalg.inv(Hq) @ Hu)
    err=integrated_error(flatres)
    grad=derrdq(flatres)@dqdu
    print(err)
    return err,grad

In [168]:
sol=minimize(obj,u0, jac=True, bounds=bounds)

677.4760689520996
974348.7951566718
599.7867598391985
516.6895569171318
351.22344298439407
310.7339969404937
284.33583932297205
263.6488905771273
249.34633724921105
230.3659384830863
221.27057171099278
208.40220246232
202.8314859813526
201.6768778672056
198.17064170059703
191.89457365028292
187.48770357788
184.90285934788085
182.50464511640067
181.49764951091822
180.17644492953076
179.14170896424844
178.67233072446322
178.54374061019098
178.12736019886458
176.63456505891588
170.98615876224196
156.4423011301458
150.99472296740066
146.3561965637912
144.50151397932194
144.26502342839444
144.0900694803978
142.97758285127267
141.91753203142306
141.763517105713
141.2563003277396
141.25161184613873
141.32884020719555
141.2884112582651
141.24736664131393
141.2448225104094
141.3284258255536
141.24607214468352
141.2448212568318
141.24482122456337
141.24579075501333
141.24482082081315
141.24482081109966
141.24537547624152
141.2448206358707
141.24482064524906
141.2448206358707
141.24085135972118
1

In [169]:
resoptim = solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=rhs_jac, args=(sol.x,))
tplot=np.linspace(0,tend,200)
h1,h2,err=resoptim.sol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1, name='h1')
fig.add_scatter(x=tplot,y=h2, name='h2')
fig.add_scatter(x=ut,y=sol.x,secondary_y=True, name='qcontrol',mode='lines+markers')
fig.update_layout(width=800,height=600,template='plotly_dark')