<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import numpy.polynomial.polynomial as P
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp, Radau
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root
from scipy.special import roots_jacobi
from functools import partial

In [2]:
def rhs(t,y,u):
    return jnp.array([y[1], -y[0]])

In [11]:
y0=jnp.array([0., 1.])
u=0.0
tend=10
res=solve_ivp(rhs,(0,tend),y0,method='Radau',jac=jax.jacobian(rhs,1),args=(u,),dense_output=True, atol=1e-7,rtol=1e-5)

In [12]:
@jax.jit
def _interpolate(t, ts, hs, y0s, Qs):
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(y0s,i,1),jnp.take(y0s,i,1))
    return jnp.squeeze(y)

def get_interp(res):

    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])

    return partial(_interpolate, ts=ts, hs=hs, y0s=y0s, Qs=Qs)

In [13]:
x = (np.r_[-1.,roots_jacobi(3,0,1)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]
rhs_vec=jnp.vectorize(rhs,signature='(),(2),()->(2)')

ts=jnp.array(res.t)
hs=ts[1:]-ts[:-1]
y0s=jnp.array(res.y)
Qs=jnp.stack([s.Q for s in res.sol.interpolants])
Nb = Qs.shape[0]

interpolate = get_interp(res)
jac=jax.jacobian(_interpolate,4)
jac(4, ts, hs, y0s, Qs).shape

(2, 58, 2, 3)

In [14]:
@jax.jit
def eqs(v):
    qflat,yb0=v[:Nb*2*3],v[Nb*2*3:]
    Qs=jnp.reshape(qflat, (Nb, 2, 3))
    yb0=jnp.reshape(yb0,(2,Nb)).T
    yb0=jnp.concatenate([y0.reshape(1,2),yb0],axis=0)
    qp=jnp.einsum('byi, ix -> byx', Qs, p)
    yb=qp[:,:,1:]+yb0[:-1,:,None]
    tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
    ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
    ybp=ybp[:,:,1:-1]
    end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
    collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
    return jnp.concatenate([end_equality, collocation])

jac=jax.jit(jax.jacobian(eqs))

In [15]:
solve=root(eqs, jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()]),jac=jac)
qflat,yb0=solve.x[:Nb*2*3],solve.x[Nb*2*3:]
Qs=jnp.reshape(qflat, (Nb, 2, 3))
yb0=jnp.concatenate([jnp.reshape(y0,(2,1)),jnp.reshape(yb0,(2,Nb))], axis=1)

In [16]:
tplot = jnp.linspace(0,tend,100)
y0,y1=_interpolate(tplot, ts, hs, yb0, Qs)

In [17]:
y1

Array([ 1.        ,  0.99490444,  0.97966431,  0.95443904,  0.91948234,
        0.87515286,  0.8219015 ,  0.7602711 ,  0.69089113,  0.61446632,
        0.53177851,  0.44366861,  0.35103606,  0.25482466,  0.15601544,
        0.05561561, -0.04535129, -0.14585625, -0.24487398, -0.34139594,
       -0.43443721, -0.52305019, -0.60633164, -0.68343065, -0.75356426,
       -0.81601409, -0.87014655, -0.91540744, -0.95133705, -0.97756908,
       -0.99383369, -0.99996898, -0.99590787, -0.98169581, -0.95747473,
       -0.92349343, -0.880098  , -0.82772938, -0.76692404, -0.69829856,
       -0.62255552, -0.54046501, -0.45286506, -0.36064835, -0.26475473,
       -0.16616225, -0.06587548,  0.03508298,  0.13568371,  0.23490161,
        0.33172475,  0.42516636,  0.51427447,  0.59813862,  0.67590667,
        0.74678297,  0.81004734,  0.86505302,  0.91124081,  0.94813998,
        0.97537133,  0.99266164,  0.99983031,  0.99680774,  0.98362233,
        0.96041014,  0.92740785,  0.8849494 ,  0.83347128,  0.77

In [18]:
jnp.cos(tplot)

Array([ 1.        ,  0.99490282,  0.97966323,  0.95443659,  0.91948007,
        0.87515004,  0.8218984 ,  0.76026803,  0.69088721,  0.61446323,
        0.53177518,  0.44366602,  0.35103397,  0.25482335,  0.15601496,
        0.0556161 , -0.04534973, -0.14585325, -0.24486989, -0.34139023,
       -0.43443032, -0.52304166, -0.60632092, -0.68341913, -0.75355031,
       -0.81599952, -0.87013012, -0.91539031, -0.95131866, -0.97754893,
       -0.9938137 , -0.99994717, -0.9958868 , -0.981674  , -0.95745366,
       -0.92347268, -0.88007748, -0.82771044, -0.76690542, -0.69828229,
       -0.6225406 , -0.54045251, -0.45285485, -0.36064061, -0.26474988,
       -0.16616018, -0.06587659,  0.03507857,  0.13567613,  0.23489055,
        0.33171042,  0.4251487 ,  0.51425287,  0.59811455,  0.67587883,
        0.74675295,  0.8100144 ,  0.86501827,  0.91120382,  0.94810022,
        0.97533134,  0.99261957,  0.99978867,  0.99676556,  0.98358105,
        0.96036956,  0.9273677 ,  0.88491192,  0.83343502,  0.77