<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/RadauDDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp, Radau
from functools import partial

In [203]:
def rhs(t,y, ts,tf,ys,Qs):
    yold=interpolate(t-0.01, ts, tf, ys, Qs)
    print(ys[:10])
    return jnp.array([y[1], -y[0]+yold[0]])

tau=20
n=10
beta=0.25
gamma=0.1

@jax.jit
def rhs2(t,y, ts,tf,ys,Qs):
    yold=interpolate(t-tau, ts, tf, ys, Qs)
    return beta*yold/(1+yold**n-gamma*y)

@jax.jit
def rhs3(t,y, ts,tf,ys,Qs):
    yold=interpolate(t-tau, ts, tf, ys, Qs)
    return jnp.sin(yold)


jac=jax.jit(jax.jacobian(rhs3,1))
t0=0.
y0=jnp.array([0.00011])
tf=500

size=10000
neq= y0.size
ts=jnp.full(size,jnp.inf)
ts=ts.at[0].set(t0)
ys=jnp.full((size,neq),jnp.inf)
ys=ys.at[0].set(y0)
Qs=jnp.zeros((size,neq,3))

In [204]:

def interpolate(t, ts, tf, ys, Qs):

    def before0(t,ts,tf,ys,Qs):
        return y0

    def after0(t,ts,tf,ys,Qs):
        i=jnp.searchsorted(ts,t, side='right')-1
        h=ts[i+1]-ts[i]
        x = (t-ts[i])/h
        p = x**np.arange(1,4)
        y= jnp.where(t<tf,Qs[i+1]@p+ys[i],ys[i])

        return y

    return jax.lax.cond(t<0,before0,after0, t,ts,tf,ys,Qs)

In [None]:
t=t0
nstep=0
solver=Radau(lambda t,y: rhs3(t,y,ts,tf,ys,Qs), t0, y0, tf, jac=lambda t,y: jac(t,y,ts,tf,ys,Qs),rtol=1e-9,atol=1e-12)
while t<tf:
    solver.max_step=tau
    solver.step()
    nstep+=1
    t=solver.t
    print(nstep, solver.t, solver.y)
    ts=ts.at[nstep].set(solver.t)
    ys=ys.at[nstep].set(solver.y)
    Qs=Qs.at[nstep].set(solver.sol.Q)

In [184]:
solver.step?

In [149]:
interpolate(.2,ts,tf,ys,Qs)

Array([0.19866933, 0.98006656], dtype=float64)

In [142]:
np.sin(0.2)

0.19866933079506122

In [67]:
ts[:5]

Array([0.        , 0.000999  , 0.01098901, 0.11088911, 0.68699441],      dtype=float64, weak_type=True)