<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/RadauDDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import solve_ivp, Radau
from functools import partial

In [None]:
tau=20
n=10
beta=0.25
gamma=0.1

@jax.jit
def rhs2(t,y, ts,tf,ys,Qs):
    yold=interpolate(t-tau, ts, tf, ys, Qs)
    return beta*yold/(1+yold**n-gamma*y)

@jax.jit
def rhs3(t,y, ts,tf,ys,Qs):
    yold=interpolate(t-tau, ts, tf, ys, Qs)
    return jnp.sin(yold)



t0=0.
y0=jnp.array([0.00010001])
tf=300

size=10000
neq= y0.size
ts=jnp.full(size,jnp.inf)
ts=ts.at[0].set(t0)
ys=jnp.full((size,neq),jnp.inf)
ys=ys.at[0].set(y0)
Qs=jnp.zeros((size,neq,3))

In [None]:
@jax.jit
def interpolate(t, ts, tf, ys, Qs):

    def before0(t,ts,tf,ys,Qs):
        return y0

    def after0(t,ts,tf,ys,Qs):
        i=jnp.searchsorted(ts,t, side='right')-1
        h=ts[i+1]-ts[i]
        x = (t-ts[i])/h
        p = x**np.arange(1,4)
        y= jnp.where(t<tf,Qs[i+1]@p+ys[i],ys[i])

        return y

    return jax.lax.cond(t<0,before0,after0, t,ts,tf,ys,Qs)

In [None]:
def dde_Radau(rhs, trange, y0, rtol=1e-6, atol=1e-8, hist_N=int(10000)):
    t0,tf = trange
    neq= len(y0)



    def solve():

        ts=jnp.full(hist_N,jnp.inf)
        ts=ts.at[0].set(t0)
        ys=jnp.full((hist_N,neq),jnp.inf)
        ys=ys.at[0].set(y0)
        Qs=jnp.zeros((hist_N,neq,3))

        def yold(t):

            def before0(t):
                return y0

            def after0(t):
                i=jnp.searchsorted(ts,t, side='right')-1
                h=ts[i+1]-ts[i]
                x = (t-ts[i])/h
                p = x**np.arange(1,4)
                y= jnp.where(t<tf,Qs[i+1]@p+ys[i],ys[i])

                return y

            return jax.lax.cond(t<0,before0,after0, t)

        jac=jax.jit(jax.jacobian(rhs,1), static_argnums=2)
        solver=Radau(lambda t,y: rhs(t,y, yold), t0, y0, tf,
                     jac=lambda t,y: jac(t,y, yold), rtol=1e-6,atol=1e-8)
        t=t0
        nstep=0
        while t<tf:
            solver.max_step=tau
            solver.step()
            nstep+=1
            t=solver.t
            ts=ts.at[nstep].set(solver.t)
            ys=ys.at[nstep].set(solver.y)
            Qs=Qs.at[nstep].set(solver.sol.Q)
        return yold

    return solve

In [None]:
@partial(jax.jit, static_argnums=2)
def rhs(t,y, yold):
    return jnp.array([y[1], -y[0]+yold(t-0.01)[0]])

In [None]:
solver=dde_Radau(rhs, (0,10), jnp.array([1., 1.]))

In [None]:
sol=solver()

In [None]:
tplot=jnp.linspace(0,10,11)
sol_vec = jax.vmap(sol, 0, 0)

In [None]:
sol_vec(tplot)

Array([[ 1.        ,  1.        ],
       [ 1.84147106,  0.54030235],
       [ 1.90929756, -0.41614691],
       [ 1.14112001, -0.98999254],
       [ 0.24319754, -0.65364359],
       [ 0.04107574,  0.28366218],
       [ 0.72058449,  0.96017035],
       [ 1.65698658,  0.75390223],
       [ 1.98935812, -0.14550002],
       [ 1.41211843, -0.91113014],
       [ 0.4559789 , -0.83907151]], dtype=float64)

In [None]:
t=t0
nstep=0
jac=jax.jit(jax.jacobian(rhs3,1))
solver=Radau(lambda t,y: rhs3(t,y,ts,tf,ys,Qs), t0, y0, tf, jac=lambda t,y: jac(t,y,ts,tf,ys,Qs),rtol=1e-6,atol=1e-8)
while t<tf:
    solver.max_step=tau
    solver.step()
    nstep+=1
    t=solver.t
    ts=ts.at[nstep].set(solver.t)
    ys=ys.at[nstep].set(solver.y)
    Qs=Qs.at[nstep].set(solver.sol.Q)
print(nstep, solver.t, solver.y)