<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import numpy.polynomial.polynomial as P
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp, Radau
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root
from scipy.special import roots_jacobi
from functools import partial

In [89]:
def rhs(t,y,u):
    return jnp.array([y[1], -y[0]+u*jnp.sin(t)])

In [80]:
y0=jnp.array([0., 1.])
u=0.05
tend=10
res=solve_ivp(rhs,(0,tend),y0,method='Radau',jac=jax.jacobian(rhs,1),args=(u,),dense_output=True, atol=1e-7,rtol=1e-5)

In [81]:
@jax.jit
def _interpolate(t, ts, hs, y0s, Qs):
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(y0s,i,1),jnp.take(y0s,i,1))
    return jnp.squeeze(y)

def get_interp(res):

    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])

    return partial(_interpolate, ts=ts, hs=hs, y0s=y0s, Qs=Qs)

In [82]:
x = (np.r_[-1.,leggauss(3)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]
rhs_vec=jnp.vectorize(rhs,signature='(),(2),()->(2)')

ts=jnp.array(res.t)
hs=ts[1:]-ts[:-1]
y0s=jnp.array(res.y)
Qs=jnp.stack([s.Q for s in res.sol.interpolants])
Nb = Qs.shape[0]

interpolate = get_interp(res)
jac=jax.jacobian(_interpolate,4)
jac(4, ts, hs, y0s, Qs).shape

(2, 55, 2, 3)

In [83]:
@jax.jit
def eqs(v,u):
    qflat,yb0=v[:Nb*2*3],v[Nb*2*3:]
    Qs=jnp.reshape(qflat, (Nb, 2, 3))
    yb0=jnp.reshape(yb0,(2,Nb)).T
    yb0=jnp.concatenate([y0.reshape(1,2),yb0],axis=0)
    qp=jnp.einsum('byi, ix -> byx', Qs, p)
    yb=qp[:,:,1:]+yb0[:-1,:,None]
    tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
    ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
    ybp=ybp[:,:,1:-1]
    end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
    collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
    return jnp.concatenate([end_equality, collocation])

jacv=jax.jit(jax.jacobian(eqs,0))
jacu=jax.jit(jax.jacobian(eqs,1))

In [84]:
flat_res=jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()])

In [85]:
solve=root(eqs, jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()]),jac=jacv,args=(u,))
qflat,yb0=solve.x[:Nb*2*3],solve.x[Nb*2*3:]
Qs=jnp.reshape(qflat, (Nb, 2, 3))
yb0=jnp.concatenate([jnp.reshape(y0,(2,1)),jnp.reshape(yb0,(2,Nb))], axis=1)

In [86]:
tplot = jnp.linspace(0,tend,100)
ycol0,ycol1=_interpolate(tplot, ts, hs, yb0, Qs)

In [87]:
ycol1

Array([ 1.        ,  0.99565905,  0.98166623,  0.95815277,  0.92533996,
        0.88353937,  0.83315143,  0.77465633,  0.7086177 ,  0.63566712,
        0.55650875,  0.47190269,  0.38266593,  0.28965774,  0.1937771 ,
        0.09594958, -0.00287859, -0.10175114, -0.19971071, -0.29580769,
       -0.38910976, -0.4787114 , -0.5637406 , -0.6433714 , -0.71682573,
       -0.78338886, -0.84240774, -0.89330617, -0.93558059, -0.96881681,
       -0.992681  , -1.00693599, -1.01143227, -1.006118  , -0.99103499,
       -0.96631873, -0.93220113, -0.8890006 , -0.83712984, -0.77707944,
       -0.70942597, -0.63481515, -0.55396486, -0.46765039, -0.37670463,
       -0.28200317, -0.18446122, -0.08502137,  0.01535386,  0.11569162,
        0.21501841,  0.31236946,  0.40679714,  0.49738259,  0.58324056,
        0.66353362,  0.7374737 ,  0.80433726,  0.86346343,  0.91427135,
        0.95625408,  0.98899489,  1.01216212,  1.02551868,  1.02892229,
        1.02232483,  1.00577941,  0.9794298 ,  0.9435221 ,  0.89

In [88]:
jacu(flat_res,u)

Array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.  