<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/implicit_deriv_radau.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import numpy.polynomial.polynomial as P
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp, Radau
from numpy.polynomial.legendre import leggauss
from scipy.optimize import root
from scipy.special import roots_jacobi
from functools import partial

In [91]:
def rhs(t,y,u):
    return jnp.array([y[1], -y[0]+u[0]*jnp.sin(u[1]*t)])

In [151]:
y0=jnp.array([0., 1.])
u=jnp.array([0.05,2.])
du=jnp.array([0.002,0.003])
tend=10
res=solve_ivp(rhs,(0,tend),y0,method='Radau',jac=jax.jacobian(rhs,1),args=(u,),dense_output=True, atol=1e-7,rtol=1e-5)
resdu=solve_ivp(rhs,(0,tend),y0,method='Radau',jac=jax.jacobian(rhs,1),args=(u+du,),dense_output=True, atol=1e-7,rtol=1e-5)

In [93]:
@jax.jit
def _interpolate(t, ts, hs, y0s, Qs):
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(y0s,i,1),jnp.take(y0s,i,1))
    return jnp.squeeze(y)

def get_interp(res):

    ts=jnp.array(res.t)
    hs=ts[1:]-ts[:-1]
    y0s=jnp.array(res.y)
    Qs=jnp.stack([s.Q for s in res.sol.interpolants])

    return partial(_interpolate, ts=ts, hs=hs, y0s=y0s, Qs=Qs)

In [98]:
x = (np.r_[-1.,leggauss(3)[0],1])/2 + 0.5
p = jnp.tile(x, (3,1))
p = jnp.cumprod(p,axis=0)
dp = p[:-1,:]*jnp.arange(2,4)[:,None]
rhs_vec=jnp.vectorize(rhs,signature='(),(2),(2)->(2)')

ts=jnp.array(res.t)
hs=ts[1:]-ts[:-1]
y0s=jnp.array(res.y)
Qs=jnp.stack([s.Q for s in res.sol.interpolants])
Nb = Qs.shape[0]



(2, 59, 2, 3)

In [99]:
@jax.jit
def eqs(v,u):
    qflat,yb0=v[:Nb*2*3],v[Nb*2*3:]
    Qs=jnp.reshape(qflat, (Nb, 2, 3))
    yb0=jnp.reshape(yb0,(2,Nb)).T
    yb0=jnp.concatenate([y0.reshape(1,2),yb0],axis=0)
    qp=jnp.einsum('byi, ix -> byx', Qs, p)
    yb=qp[:,:,1:]+yb0[:-1,:,None]
    tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
    ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
    ybp=ybp[:,:,1:-1]
    end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
    collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
    return jnp.concatenate([end_equality, collocation])

jacv=jax.jit(jax.jacobian(eqs,0))
jacu=jax.jit(jax.jacobian(eqs,1))

In [100]:
flat_res=jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()])

In [101]:
solve=root(eqs, jnp.concatenate([Qs.ravel(),res.y[:,1:].ravel()]),jac=jacv,args=(u,))
qflat,yb0=solve.x[:Nb*2*3],solve.x[Nb*2*3:]
Qs=jnp.reshape(qflat, (Nb, 2, 3))
yb0=jnp.concatenate([jnp.reshape(y0,(2,1)),jnp.reshape(yb0,(2,Nb))], axis=1)

In [102]:
tplot = jnp.linspace(0,tend,100)
ycol0,ycol1=_interpolate(tplot, ts, hs, yb0, Qs)

In [111]:
Hu=jacu(flat_res,u)
Hq=jacv(flat_res,u)

In [132]:
dqdu=(jnp.linalg.inv(Hq) @ Hu)

In [128]:

def _interpolate(t,v):
    qflat,yb0=v[:Nb*2*3],v[Nb*2*3:]
    Qs=jnp.reshape(qflat, (Nb, 2, 3))
    yb0=jnp.reshape(yb0,(2,Nb))
    yb0=jnp.concatenate([y0.reshape(2,1),yb0],axis=1)
    t=jnp.atleast_1d(t)
    i=jnp.searchsorted(ts,t, side='right')-1
    x = (t-ts[i])/hs[i]
    p = jnp.cumprod(jnp.tile(x, (3,1)),axis=0)
    y= jnp.where(t<ts[-1],jnp.einsum('tyi, it -> yt', jnp.take(Qs, i, 0), p)+jnp.take(yb0,i,1),jnp.take(yb0,i,1))
    return jnp.squeeze(y)

def eqs(v,u):
    qflat,yb0=v[:Nb*2*3],v[Nb*2*3:]
    Qs=jnp.reshape(qflat, (Nb, 2, 3))
    yb0=jnp.reshape(yb0,(2,Nb)).T
    yb0=jnp.concatenate([y0.reshape(1,2),yb0],axis=0)
    qp=jnp.einsum('byi, ix -> byx', Qs, p)
    yb=qp[:,:,1:]+yb0[:-1,:,None]
    tb=(hs[:,None]*x[None,1:-1]+ts[:-1,None])
    ybp=(jnp.einsum('byi, ix -> byx', Qs[:,:,1:] , dp) + Qs[:,:,0][:,:,None])/hs[:,None,None]
    ybp=ybp[:,:,1:-1]
    end_equality=jnp.ravel(yb[:,:,-1]-yb0[1:,:])
    collocation=jnp.ravel(jnp.swapaxes(ybp,1,2)-rhs_vec(tb,jnp.swapaxes(yb[:,:,:-1],1,2),u))
    return jnp.concatenate([end_equality, collocation])

jac=jax.jacobian(_interpolate,1)
dydq=jac(tend, flat_res)

In [134]:
dydu = dydq @ dqdu

In [152]:
dydu @ du

Array([0.00143183, 0.000646  ], dtype=float64)

In [153]:
resdu.y[:,-1]-res.y[:,-1]

array([-0.00142811, -0.00060177])