<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/RadauDDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import numpy as np
import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)
from plotly.subplots import make_subplots
from scipy.integrate import Radau
from functools import partial

In [132]:
def dde_Radau(rhs, trange, y0, tau, rtol=1e-6, atol=1e-8, hist_N=int(10000)):
    t0,tf = trange
    neq= len(y0)

    ts=jnp.full(hist_N,jnp.inf)
    ts=ts.at[0].set(t0)
    ys=jnp.full((hist_N,neq),jnp.inf)
    ys=ys.at[0].set(y0)
    Qs=jnp.zeros((hist_N,neq,3))

    @jax.jit
    def yold(t, ts, ys, Qs):

        def before0(t):
            return y0

        def after0(t):
            i=jnp.searchsorted(ts,t, side='right')-1
            h=ts[i+1]-ts[i]
            x = (t-ts[i])/h
            p = x**np.arange(1,4)
            y= jnp.where(t<tf,Qs[i+1]@p+ys[i],ys[i])

            return y

        return jax.lax.cond(t<0,before0,after0, t)

    @jax.jit
    def rhs2(t,y, ts, ys, Qs):
        return rhs(t,y, lambda t: yold(t,ts,ys,Qs))

    jac=jax.jit(jax.jacobian(rhs2,1))
    solver=Radau(lambda t,y: rhs2(t,y, ts, ys, Qs), t0, y0, tf,
                    jac=lambda t,y: jac(t,y,ts,ys, Qs), rtol=1e-6,atol=1e-8)
    t=t0
    nstep=0
    while t<tf:
        solver.max_step=tau
        solver.step()
        nstep+=1
        t=solver.t
        ts=ts.at[nstep].set(solver.t)
        ys=ys.at[nstep].set(solver.y)
        Qs=Qs.at[nstep].set(solver.sol.Q)
    return jnp.vectorize(lambda t: yold(t, ts, ys, Qs), signature='()->(n)')


In [170]:
Ac=1.
alpha=0.1
Kc=1.1
h0 = 1.
q0 = alpha*h0**0.5
hsp=1.5
tend=50.
delay=1.5
def rhs(t,y, yold):
    h1=y
    err = hsp-yold(t-delay)[0]
    qin = q0 + Kc*(err)
    return (qin-alpha*(h1**0.5))/Ac

In [171]:
sol=dde_Radau(rhs, (0,tend), jnp.array([h0]), tau=delay)

In [172]:
tplot=jnp.linspace(0,tend,500)
hplot=sol(tplot)[:,0]
fig=make_subplots()
fig.add_scatter(x=tplot,y=hplot)
fig.update_layout(width=800, height=600, template='plotly_dark')