<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_hermite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root
from functools import partial

In [151]:
np.set_printoptions(linewidth=200)

In [272]:
def rhs(t,y,u):
    return jnp.array([y[1], -y[0]+u[0]*jnp.sin(u[1]*t)])

y0=jnp.array([1.,0.])
u0=jnp.array([1.,1.])
Ny=y0.size
Nu=u0.size

du=jnp.array([0.001,0.002])
u1=u0+du

rhs=jax.jit(rhs)
rhs_jac=jax.jit(jax.jacobian(rhs,1))

tend=5.

rhs_vec=jax.vmap(rhs, (0,1,None), 1)

In [273]:
res0=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=jax.jacobian(rhs,1), args=(u0,))
res1=solve_ivp(rhs, (0,tend), y0, method='Radau', dense_output=True, jac=jax.jacobian(rhs,1), args=(u1,))

In [274]:
def get_interpolate(rhs_vec, res):

    ts=res.t
    Nb=res.t.size
    hs=res.t[1:]-res.t[:-1]

    def interpolate(t,y,u):

        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(ts,t, side='right')-1   #this precludes JIT
        i=jnp.where(i>=Nb-1,i-1,i)
        x = jnp.where(t<ts[-1], (t-ts[i])/hs[i], 1.)
        p0 = y[:,i]
        p1 = y[:,i+1]
        m0 = rhs_vec(ts[i],p0,u)
        m1 = rhs_vec(ts[i+1],p1,u)
        h00 = 2*x**3-3*x**2+1
        h10 = x**3 - 2*x**2 + x
        h01 = -2*x**3 + 3*x**2
        h11 = x**3 - x**2

        h00d = 6*x**2-6*x
        h10d = 3*x**2-4*x + 1
        h01d = -6*x**2 + 6*x
        h11d = 3*x**2 - 2*x
        return jnp.squeeze(h00*p0 + h10*hs[i]*m0 + h01*p1 + h11*hs[i]*m1), jnp.squeeze((h00d*p0/hs[i] + h10d*m0 + h01d*p1/hs[i] + h11d*m1))

    return interpolate

In [404]:
def get_collocation(rhs_vec, res, jit=False):

    Nb=res.t.size
    hs=res.t[1:]-res.t[:-1]
    tmid = (res.t[1:]+res.t[:-1])/2

    x = 0.5

    h00 = 2*x**3-3*x**2+1
    h10 = x**3 - 2*x**2 + x
    h01 = -2*x**3 + 3*x**2
    h11 = x**3 - x**2

    h00d = 6*x**2-6*x
    h10d = 3*x**2-4*x + 1
    h01d = -6*x**2 + 6*x
    h11d = 3*x**2 - 2*x

    def collocation(y,u):
        y=jnp.concatenate([res.y[:,0].reshape(Ny,1), y.reshape(Ny,-1)],axis=1)
        m0 = rhs_vec(res.t[:-1],y[:,:-1],u)
        m1 = rhs_vec(res.t[1:],y[:,1:],u)

        ymid=jnp.squeeze(h00*y[:,:-1] + h10*hs*m0 + h01*y[:,1:] + h11*hs*m1)
        ypmid=jnp.squeeze((h00d*y[:,:-1]/hs + h10d*m0 + h01d*y[:,1:]/hs + h11d*m1))

        return (ypmid-rhs_vec(tmid, ymid, u)).ravel()

    def interpolate(t,y,u):
        y=jnp.concatenate([res.y[:,0].reshape(Ny,1), y.reshape(Ny,-1)],axis=1)
        t=jnp.atleast_1d(t)
        i=jnp.searchsorted(res.t,t, side='right')-1   #this precludes JIT
        i=jnp.where(i>=Nb-1,i-1,i)
        x = jnp.where(t<res.t[-1], (t-res.t[i])/hs[i], 1.)
        p0 = y[:,i]
        p1 = y[:,i+1]
        m0 = rhs_vec(res.t[i],p0,u)
        m1 = rhs_vec(res.t[i+1],p1,u)
        h00 = 2*x**3-3*x**2+1
        h10 = x**3 - 2*x**2 + x
        h01 = -2*x**3 + 3*x**2
        h11 = x**3 - x**2
        return jnp.squeeze(h00*p0 + h10*hs[i]*m0 + h01*p1 + h11*hs[i]*m1)

    if jit:
        return jax.jit(collocation), jax.jit(jax.jacobian(collocation,0)), jax.jit(jax.jacobian(collocation,1)), interpolate
    else:
        return collocation, jax.jacobian(collocation,0), jax.jacobian(collocation,1), interpolate

In [405]:
collocation, jacy, jacu, interpolate = get_collocation(rhs_vec, res0, jit=True)

In [406]:
collocation(res0.y[:,1:].ravel(),u0)

Array([-1.15080146e-13, -4.16390715e-11, -4.28181340e-07, -3.43446973e-05, -9.23492167e-05, -6.81894703e-05, -3.06915856e-05,  4.61364137e-05,  1.42483594e-05,  1.95632238e-05,  2.27625328e-05,
        2.34939622e-05,  2.16457660e-05,  4.52922165e-05,  4.40695237e-05, -5.59378561e-05, -1.58861880e-04, -1.66533454e-15, -2.04882777e-11, -1.80481962e-07, -7.21175151e-06,  9.24596984e-06,
        3.68097133e-05,  5.34625832e-05,  1.38016374e-04,  1.08019965e-05,  6.57834230e-06,  1.34721545e-06, -4.30503722e-06, -9.73712801e-06, -4.13338450e-05, -1.34906790e-04, -1.25633041e-04,
        8.89311012e-05], dtype=float64)

In [407]:
solve=root(collocation, res0.y[:,1:].ravel(),args=(u0,))

In [408]:
t=jnp.linspace(0,tend,10)
interpolate(tend, solve.x,u0)

Array([-0.90486299, -1.43818692], dtype=float64)

In [414]:
Hu=jacu(solve.x,u0)
Hy=jacy(solve.x,u0)
dy_du=-(jnp.linalg.inv(Hy) @ Hu)

In [415]:
jac=jax.jacobian(interpolate,1)
dyend_dy=jac(tend, solve.x,u0)

In [417]:
dyend_du = dyend_dy @ dy_du

In [418]:
res0.y[:,-1]

array([-0.9049825 , -1.43835127])

In [419]:
res1.y[:,-1]

array([-0.91694816, -1.439567  ])

In [420]:
res0.y[:,-1]+dyend_du @ du

Array([-0.91696745, -1.43960056], dtype=float64)