<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/adjoint_ode_derivative.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import numpy.polynomial.polynomial as P
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp, Radau, quad_vec, fixed_quad
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root, minimize
from scipy.special import roots_jacobi
from functools import partial

In [2]:
@jax.jit
def _interpolate(t, ts, hs, y0s, Qs):
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(y0s,i,1),jnp.take(y0s,i,1))
    return jnp.squeeze(y)

def get_interp(res):

    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])

    return partial(_interpolate, ts=ts, hs=hs, y0s=y0s, Qs=Qs)

In [39]:
a=0.01
A = 0.1
qin_initial_ss = 0.1
sq2g = jnp.sqrt(2*9.81)
h10 = h20=qin_initial_ss**2/(a**2)/(2*9.81)
hsp = 2*h20
tend=200.
ut = jnp.linspace(0.,tend,25)
uy = jnp.full_like(ut,0.02)
gx, gw = leggauss(50)
gt=tend/2*gx + tend/2

@jax.jit
def dhdt(t, h, uy):
    h1,h2 = h
    qin=qin_initial_ss
    qcontrol= jnp.interp(t,ut, uy)
    q12 = a*sq2g*(h1**0.5)
    q2 = a*sq2g*(h2**0.5)
    return jnp.array( [(qin + qcontrol - q12)/A, (q12 - q2)/A])  #Nx, Nt

In [40]:
adj_f = jax.jit(lambda x, p: (x[1]-hsp)**2)   #Nt
adj_h = jax.jit(lambda x, xp, p, t: xp-dhdt(t,x,p))

In [41]:
adj_h_gradp = jax.jit(jax.jacobian(adj_h,2))   #Nx, Nt, Np
adj_h_gradx = jax.jit(jax.jacobian(adj_h, 0))
adj_h_gradxp = jax.jit(jax.jacobian(adj_h, 1))
adj_f_gradx = jax.jit(jax.jacobian(adj_f,0))
adj_f_gradp = jax.jit(jax.jacobian(adj_f,1))   #Nt, Np

In [42]:
def obj(p):
    res=solve_ivp(dhdt, (0,tend), [h10,h20], method='Radau', dense_output=True, jac=jax.jacobian(dhdt,1), args=(p,))
    hsol=res.sol

    interror=tend/2*jnp.sum(gw*((hsol(gt)[1]-hsp)**2))
    print(interror)

    def adj_ode(t, L, p):
        x=hsol(t)
        xp = dhdt(t,x, p)
        return (adj_f_gradx(x, p) + adj_h_gradx(x, xp, p, t).T @ L)

    L0=jnp.zeros(2)
    res=solve_ivp(adj_ode, (tend,0), L0, method='Radau', dense_output=True, args=(uy,))
    Lsol=res.sol

    def dpF_integrand(t, p):
        x=hsol(t)
        xp = dhdt(t,x, p)
        L = Lsol(t)
        return adj_f_gradp(x,uy)+jnp.einsum('xt, xtp -> tp', L, adj_h_gradp(x,xp,uy,t))

    dpF=tend/2*jnp.einsum('t, tp -> p', gw,dpF_integrand(gt, uy))


    return interror, dpF

In [None]:
res=minimize(obj,uy, jac=True, bounds=[(0,0.1)]*uy.size,
         options=dict(maxiter=100))
res

In [44]:

hsol=solve_ivp(dhdt, (0,tend), [h10,h20], method='Radau', dense_output=True, jac=jax.jacobian(dhdt,1), args=(res.x,)).sol
tplot=jnp.linspace(0,tend,500)
h1,h2=hsol(tplot)
fig=make_subplots(specs=[[dict(secondary_y=True)]])
fig.add_scatter(x=tplot,y=h1,name='h1')
fig.add_scatter(x=tplot,y=h2,name='h2')
fig.add_scatter(x=tplot,y=jnp.interp(tplot,ut,res.x),secondary_y=True, name='u')
fig.update_layout(width=800, height=600,template='plotly_dark')

In [None]:
hsol(tplot)[1]