<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root
from functools import partial

In [196]:
def rhs1(t,y,u):
    return jnp.array([y[1], -y[0]+u[0]*jnp.sin(u[1]*t)])

y0=jnp.array([1.,0.])
u0=jnp.array([1.,1.])
du=jnp.array([0.001,0.002])
u1=u0+du

rhs1=jax.jit(rhs1)
rhs1_jac=jax.jit(jax.jacobian(rhs1,1))

tend=5.

In [197]:
res=solve_ivp(rhs1, (0,tend), y0, method='Radau', dense_output=True, jac=jax.jacobian(rhs1,1), args=(u0,))
res2=solve_ivp(rhs1, (0,tend), y0, method='Radau', dense_output=True, jac=jax.jacobian(rhs1,1), args=(u1,))

In [198]:
@jax.jit
def _interpolate(t, ts, hs, y0s, Qs):
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(y0s,i,1),jnp.take(y0s,i,1))
    return jnp.squeeze(y)

def get_interp(res):

    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])

    return partial(_interpolate, ts=ts, hs=hs, y0s=y0s, Qs=Qs)

In [199]:
x = (np.r_[leggauss(2)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]
rhs_vec=jnp.vectorize(rhs1,signature='(),(2),(2)->(2)')

ts=jnp.array(res.t)
hs=ts[1:]-ts[:-1]
y0s=jnp.array(res.y)
Qs=jnp.stack([s.Q for s in res.sol.interpolants])
Nb = Qs.shape[0]



In [200]:
@jax.jit
def eqs(v,u):
    Qs=jnp.reshape(v, (Nb, 2, 3))
    qp=jnp.einsum('byi, ix -> byx', Qs, p)
    yb0=jnp.cumsum(qp[:,:,-1],axis=0)
    yb0=jnp.concatenate([jnp.zeros_like(y0).reshape(1,-1),yb0],axis=0)+y0
    yb=qp+yb0[:-1,:,None]
    tb=(hs[:,None]*x[None,:]+ts[:-1,None])
    ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
    collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb,1,2),u))
    return collocation

jacv=jax.jit(jax.jacobian(eqs,0))
jacu=jax.jit(jax.jacobian(eqs,1))

In [201]:
solve=root(eqs, Qs.ravel(),jac=jacv,args=(u0,))
Qs=jnp.reshape(solve.x, (Nb, 2, 3))
flat_res=Qs.ravel()

In [202]:
Hu=jacu(flat_res,u0)
Hq=jacv(flat_res,u0)
dqdu=-(jnp.linalg.inv(Hq) @ Hu)

In [203]:
def _interpolate(t,v):
    x = (np.r_[leggauss(2)[0],1])/2 + 0.5
    p = jnp.tile(x, (3,1))
    p = jnp.cumprod(p,axis=0)
    Qs=jnp.reshape(v, (Nb, 2, 3))
    qp=jnp.einsum('byi, ix -> byx', Qs, p)

    yb0=jnp.cumsum(qp[:,:,-1],axis=0)
    yb0=jnp.concatenate([jnp.zeros_like(y0).reshape(1,-1),yb0],axis=0)+y0
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(yb0,i,0).T,jnp.take(yb0,i,0).T)
    return jnp.squeeze(y)

jac=jax.jacobian(_interpolate,1)
dydq=jac(tend, flat_res)

In [204]:
dydu = dydq @ dqdu

In [205]:
dydu

Array([[-1.18879274, -5.39939367],
       [-2.39738523,  0.57452449]], dtype=float64)

In [206]:
res.y[:,-1]

array([-0.9049825 , -1.43835127])

In [207]:
res2.y[:,-1]

array([-0.91694816, -1.439567  ])

In [209]:
res.y[:,-1]+dydu @ du

Array([-0.91697008, -1.43959961], dtype=float64)