<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/AutoStiff_NonStiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [180]:
import jax
import jax.numpy as jnp
import numpy as np
jax.config.update('jax_enable_x64',True)
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots
from functools import partial

In [2]:
np.set_printoptions(precision=10,linewidth=200)
jnp.set_printoptions(precision=10, linewidth=200)

In [385]:
def vdp(t,v,u):
    return jnp.array([v[1], 100*(1 - v[0]**2)*v[1] - v[0]])



In [386]:
#Lobatto IIIC
a_lobatto=np.array([[1/6, -1/3, 1/6],
                    [1/6, 5/12, -1/12],
                    [1/6, 2/3, 1/6]])

c_lobatto=np.array([0, 0.5, 1])
b1_lobatto = np.array([1/6, 2/3, 1/6])
b2_lobatto = np.array([-0.5, 2, -0.5])

#https://arxiv.org/pdf/1306.2392.pdf, Hairer's embedded RADAU5

s6=np.sqrt(6)
a_radau= np.array([[(88-7*s6)/360,   (296-169*s6)/1800,      (-2+3*s6)/225],
                   [(296+169*s6)/1800,  (88+7*s6)/360,      (-2-3*s6)/225],
                   [(16-s6)/36,         (16+s6)/36,         1/9]])

c_radau=np.array([(4-s6)/10, (4+s6)/10, 1])
b1_radau = np.array([(16-s6)/36,         (16+s6)/36,         1/9])
radau_g = 1/8
b2_radau = np.r_[radau_g, b1_radau-radau_g*np.array([(2+3*s6)/6, (2-3*s6)/6, 1/3])]





a_DormandPrince = np.array([[0,     0,      0,      0,      0,      0,      0],
                            [1/5,   0,      0,      0,      0,      0,      0],
                            [3/40,  9/40,   0,      0,      0,      0,      0],
                            [44/45, -56/15, 32/9,   0,      0,      0,      0],
                            [19372/6561,    -25360/2187,    64448/6561,     -212/729,   0,      0, 0],
                            [9017/3168,     -355/33,        46732/5247,     49/176,     -5103/18656,    0, 0],
                            [35/384,        0,          500/1113,       125/192,        -2187/6784,     11/84, 0]])
c_DormandPrince = np.array([0, 1/5, 3/10, 4/5, 8/9, 1, 1.])
b1_DormandPrince = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0])
b2_DormandPrince = np.array([5179/57600, 0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40])

def get_explicit_step(rhs, u=None, a=a_DormandPrince, b1=b1_DormandPrince, b2=b2_DormandPrince, c=c_DormandPrince):

    def step(t0, y0, h):
        tvec = t0+c*h
        Ny=y0.size
        def f_scan(ik, act):
            i,k=ik
            a,c,t = act
            k=k.at[i].set(rhs(t,y0+h*a@k, u))
            return (i+1,k), None

        (_,k),_= jax.lax.scan(f_scan, (0,jnp.zeros((c.shape[0], Ny))), (a, c, tvec))
        return y0 + h*b1@k, y0 + h*b2@k

    return jax.jit(step)

In [387]:
def get_implicit_step(rhs, u=None, a=a_radau, b1=b1_radau, b2=b2_radau, c=c_radau):
    Nstages = c.size
    rhs_vec=jax.jit(jax.vmap(rhs,in_axes=(0,0,None),out_axes=0))
    rhs_jac=jax.jacobian(rhs,1)

    def k_func(tvec, y,h,k):
        k=k.reshape(Nstages,y.size)
        yvec = y + h*a@k
        return (k-rhs_vec(tvec,yvec, u)).ravel()

    k_func_jac=jax.jacobian(k_func,3)

    def step(t0, y0, h):
        Ny=y0.size
        tvec = t0+c*h
        rhs0=rhs(t0,y0,u)
        k = jnp.tile(rhs0.reshape(1,-1),(Nstages,1)).ravel()
        f0=k_func(tvec,y0,h,k)

        def body_func_for(i, kf):
            k,f=kf
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f

        def body_func_while(kfi):
            k,f,i=kfi
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f,i+1

        def cond_func(kfi):
            k,f,i=kfi
            return jnp.logical_and(jnp.linalg.norm(f)>1e-12,i<50)

        # k,f,i=jax.lax.while_loop(cond_func, body_func_while, (k, f0, 0))  #does not work with reverse mode
        k,f=jax.lax.fori_loop(0,5, body_func_for, (k, f0))  #5 iterations should be enough
        k=k.reshape(Nstages,Ny)
        return y0+ h*(b1@k), y0 + h*(b2@ jnp.concatenate([rhs0.reshape(1,y0.size), k]))

    def step2(t0, y0, h):
        #Simplified Newton Method (less accurate)
        tvec = t0+c*h
        Ny=y0.size
        mat=jnp.kron(jnp.eye(Nstages),jnp.eye(Ny)) - h*jnp.kron(a_radau,rhs_jac(t0,y0,u))
        Y0=jnp.tile(y0,(Nstages,1))
        f0=Y0.reshape(-1,1)-jnp.tile(y0,(1,Nstages)).T-h*(a_radau @ rhs_vec(tvec,Y0,u)).reshape(-1,1)

        def body_func(Yfi):
            Y,f,i=Yfi
            f=Y.reshape(-1,1)-jnp.tile(y0,(1,Nstages)).T-h*(a_radau @ rhs_vec(tvec,Y,u)).reshape(-1,1)
            dY = jnp.linalg.solve(mat, -f)
            Y=Y+dY.reshape(Nstages,Ny)
            return Y,f,i+1

        def cond_func(Yfi):
            _,f,i=Yfi
            return jnp.logical_and(jnp.linalg.norm(f)>1e-12,i<50)

        Y,f,i=jax.lax.while_loop(cond_func, body_func, (Y0, f0, 0))

        return y0+ h*(b1@Y),y0 + h*(b2@ jnp.concatenate([rhs(t0,y0,u).reshape(1,Ny), Y]))

    return jax.jit(step)

In [388]:
def test(u):
    step=get_implicit_step(vdp,u)
    h=0.01
    y0=jnp.array([1.1, 0.1])
    return step(0,y0,h)

In [448]:
@partial(jax.jit, static_argnums=(0,))
def is_stiff(rhs, t0, y0):
    Ny=y0.size
    y, f_jvp = jax.linearize(lambda y: rhs(t0,y), y0)
    rand_v0 = jnp.array(np.random.uniform(size=Ny))

    def f_scan(v0, _):
        v = f_jvp(v0)
        u = jnp.sum(v0*v)/jnp.sum(v0*v0)
        v = v/jnp.linalg.norm(v)
        return v, u

    _, us = jax.lax.scan(f_scan, rand_v0, xs=None, length=10)

    return jnp.any(us<-25)

def auto_ivp(rhs, trange, y0, u=None, h=1, atol=1e-6, rtol=1e-3):
    rhsty=jax.jit(lambda t,y: rhs(t,y,u))

    Ny=y0.size
    t0,tend= trange
    step_exp=get_explicit_step(rhs,u)
    step_imp=get_implicit_step(rhs,u)
    t=t0
    y=y0
    ts=[t0]
    ys=[y0]
    stiffs=[]
    counts=[]
    hs=[]
    def h_ok(y1,y2,h):
        ratio = jnp.max(jnp.abs(y1-y2)/(atol+rtol*jnp.abs(y1)))
        if ratio>1:
            h=h/2
            return False, h
        if ratio<0.5:
            h=h*1.5
        return True,h

    while t<tend:
        stiff=0
        count=0
        if is_stiff(rhsty, t, y):
            stiff=1
            while True:
                count+=1
                yb1, yb2 = step_imp(t,y,h)
                hok, h = h_ok(yb1, yb2, h)
                if hok:
                    break
        else:
            while True:
                count+=1
                yb1, yb2 = step_exp(t,y,h)
                hok, h = h_ok(yb1, yb2, h)
                if hok:
                    break
        t=t+h
        y=yb1
        ts.append(t)
        ys.append(y)
        stiffs.append(stiff)
        counts.append(count)
        hs.append(h)

    return jnp.array(ts), jnp.stack(ys,axis=1),jnp.array(stiffs),jnp.array(counts),jnp.array(hs)


In [450]:
y0=jnp.array([1.,0.])
ts, ys, stiffs, counts, hs=auto_ivp(vdp,(0,10), y0, u=u, h=0.0001, atol=1e-12,rtol=1e-10)

In [451]:
counts[:5]

Array([1, 1, 1, 1, 1], dtype=int64)

In [452]:
ts.size

1413

In [453]:
np.histogram(hs)

(array([1398,    2,    1,    1,    1,    2,    1,    3,    2,    1]),
 array([4.1881212292e-05, 1.1926071003e-01, 2.3847953885e-01, 3.5769836767e-01, 4.7691719649e-01, 5.9613602531e-01, 7.1535485413e-01, 8.3457368295e-01, 9.5379251177e-01, 1.0730113406e+00,
        1.1922301694e+00]))

In [454]:

tend=12
res=solve_ivp(jax.jit(vdp),(0,tend),y0,method='LSODA',args=(u,), t_eval=ts)
sol=res.sol

In [455]:
res.t.size

1413

In [456]:
ts.size

1413

In [457]:
res.y[:,:5]

array([[ 1.0000000000e+00,  9.9999998775e-01,  9.9999992869e-01,  9.9999974517e-01,  9.9999925632e-01],
       [ 0.0000000000e+00, -1.4999999846e-04, -3.7499998406e-04, -7.1249993245e-04, -1.2187497325e-03]])

In [458]:
ys[:,:5]

Array([[ 1.0000000000e+00,  9.9999999500e-01,  9.9999996875e-01,  9.9999988719e-01,  9.9999966992e-01],
       [ 0.0000000000e+00, -9.9999999836e-05, -2.4999999749e-04, -4.7499998341e-04, -8.1249992150e-04]], dtype=float64)

In [426]:
ts

Array([0.0000000000e+00, 2.3437500000e-03, 4.1015625000e-03, ..., 8.3604719483e+00, 9.4530392138e+00, 1.0272464663e+01], dtype=float64)

In [427]:
res.t

array([0.0000000000e+00, 2.3437500000e-03, 4.1015625000e-03, ..., 8.3604719483e+00, 9.4530392138e+00, 1.0272464663e+01])