<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/AutoStiff_NonStiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
jax.config.update('jax_enable_x64',True)
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots
from functools import partial

In [2]:
np.set_printoptions(precision=10,linewidth=200)
jnp.set_printoptions(precision=10, linewidth=200)

In [3]:
def vdp(t,v,u):
    return jnp.array([v[1], 100*(1 - v[0]**2)*v[1] - v[0]])



In [4]:
#Lobatto IIIC
a_lobatto=np.array([[1/6, -1/3, 1/6],
                    [1/6, 5/12, -1/12],
                    [1/6, 2/3, 1/6]])

c_lobatto=np.array([0, 0.5, 1])
b1_lobatto = np.array([1/6, 2/3, 1/6])
b2_lobatto = np.array([-0.5, 2, -0.5])

#https://arxiv.org/pdf/1306.2392.pdf, Hairer's embedded RADAU5

s6=np.sqrt(6)
a_radau= np.array([[(88-7*s6)/360,   (296-169*s6)/1800,      (-2+3*s6)/225],
                   [(296+169*s6)/1800,  (88+7*s6)/360,      (-2-3*s6)/225],
                   [(16-s6)/36,         (16+s6)/36,         1/9]])

c_radau=np.array([(4-s6)/10, (4+s6)/10, 1])
b1_radau = np.array([(16-s6)/36,         (16+s6)/36,         1/9])
radau_g = 1/8
b2_radau = np.r_[radau_g, b1_radau-radau_g*np.array([(2+3*s6)/6, (2-3*s6)/6, 1/3])]





a_DormandPrince = np.array([[0,     0,      0,      0,      0,      0,      0],
                            [1/5,   0,      0,      0,      0,      0,      0],
                            [3/40,  9/40,   0,      0,      0,      0,      0],
                            [44/45, -56/15, 32/9,   0,      0,      0,      0],
                            [19372/6561,    -25360/2187,    64448/6561,     -212/729,   0,      0, 0],
                            [9017/3168,     -355/33,        46732/5247,     49/176,     -5103/18656,    0, 0],
                            [35/384,        0,          500/1113,       125/192,        -2187/6784,     11/84, 0]])
c_DormandPrince = np.array([0, 1/5, 3/10, 4/5, 8/9, 1, 1.])
b1_DormandPrince = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0])
b2_DormandPrince = np.array([5179/57600, 0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40])

def get_explicit_step(rhs, u=None, a=a_DormandPrince, b1=b1_DormandPrince, b2=b2_DormandPrince, c=c_DormandPrince):

    def step(t0, y, h):
        tvec = t0+c*h
        Ny=y.size
        def f_scan(ik, act):
            i,k=ik
            a,c,t = act
            k=k.at[i].set(rhs(t,y+h*a@k, u))
            return (i+1,k), None

        (_,k),_= jax.lax.scan(f_scan, (0,jnp.zeros((c.shape[0], Ny))), (a, c, tvec))
        return y + h*b1@k, y + h*b2@k

    return jax.jit(step)

In [5]:
def get_implicit_step(rhs, u=None, a=a_radau, b1=b1_radau, b2=b2_radau, c=c_radau):
    Nstages = c.size
    rhs_vec=jax.jit(jax.vmap(rhs,in_axes=(0,0,None),out_axes=0))
    rhs_jac=jax.jacobian(rhs,1)

    def k_func(tvec, y,h,k):
        k=k.reshape(Nstages,y.size)
        yvec = y + h*a@k
        return (k-rhs_vec(tvec,yvec, u)).ravel()

    k_func_jac=jax.jacobian(k_func,3)

    def step(t0, y0, h):
        Ny=y0.size
        tvec = t0+c*h
        rhs0=rhs(t0,y0,u)
        k = jnp.tile(rhs0.reshape(1,-1),(Nstages,1)).ravel()
        f0=k_func(tvec,y0,h,k)

        def body_func_for(i, kf):
            k,f=kf
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f

        def body_func_while(kfi):
            k,f,i=kfi
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f,i+1

        def cond_func(kfi):
            k,f,i=kfi
            return jnp.logical_and(jnp.linalg.norm(f)>1e-12,i<50)

        # k,f,i=jax.lax.while_loop(cond_func, body_func_while, (k, f0, 0))  #does not work with reverse mode
        k,f=jax.lax.fori_loop(0,5, body_func_for, (k, f0))  #5 iterations should be enough
        k=k.reshape(Nstages,Ny)
        return y0+ h*(b1@k), y0 + h*(b2@ jnp.concatenate([rhs0.reshape(1,y0.size), k]))

    def step2(t0, y0, h):
        #Simplified Newton Method (less accurate)
        tvec = t0+c*h
        Ny=y0.size
        mat=jnp.kron(jnp.eye(Nstages),jnp.eye(Ny)) - h*jnp.kron(a_radau,rhs_jac(t0,y0,u))
        Y0=jnp.tile(y0,(Nstages,1))
        f0=Y0.reshape(-1,1)-jnp.tile(y0,(1,Nstages)).T-h*(a_radau @ rhs_vec(tvec,Y0,u)).reshape(-1,1)

        def body_func(Yfi):
            Y,f,i=Yfi
            f=Y.reshape(-1,1)-jnp.tile(y0,(1,Nstages)).T-h*(a_radau @ rhs_vec(tvec,Y,u)).reshape(-1,1)
            dY = jnp.linalg.solve(mat, -f)
            Y=Y+dY.reshape(Nstages,Ny)
            return Y,f,i+1

        def cond_func(Yfi):
            _,f,i=Yfi
            return jnp.logical_and(jnp.linalg.norm(f)>1e-12,i<50)

        Y,f,i=jax.lax.while_loop(cond_func, body_func, (Y0, f0, 0))

        return y0+ h*(b1@Y),y0 + h*(b2@ jnp.concatenate([rhs(t0,y0,u).reshape(1,Ny), Y]))

    return jax.jit(step)

In [6]:
def test(u):
    step=get_implicit_step(vdp,u)
    h=0.01
    y0=jnp.array([1.1, 0.1])
    return step(0,y0,h)

In [130]:
@partial(jax.jit, static_argnums=(0,))
def is_stiff(rhs, t0, y):
    Ny=y.size
    y, f_jvp = jax.linearize(lambda y: rhs(t0,y), y)
    rand_v0 = jnp.array(np.random.uniform(size=Ny))

    def f_scan(v0, _):
        v = f_jvp(v0)
        u = jnp.sum(v0*v)/jnp.sum(v0*v0)
        v = v/jnp.linalg.norm(v)
        return v, u

    _, us = jax.lax.scan(f_scan, rand_v0, xs=None, length=10)

    return jnp.any(us<-25)

def auto_ivp(rhs, trange, y0, u=None, h=1, atol=1e-6, rtol=1e-3):
    rhsty=jax.jit(lambda t,y: rhs(t,y,u))

    Ny=y0.size
    t0,tend= trange
    step_exp=get_explicit_step(rhs,u)
    step_imp=get_implicit_step(rhs,u)
    t=t0
    y=y0
    ts=[t0]
    ys=[y]
    stiffs=[]
    counts=[]
    hs=[]
    def h_ok(y1,y2,h):

        ratio = jnp.max(jnp.abs(y1-y2)/(atol+rtol*jnp.abs(y1)))
        if ratio>1:
            hnew = h/2
            return False, h/2, h/2
        if ratio<0.5:
            return True, h*1.5, h
        return True, h, h

    while t<tend:
        stiff=0
        count=0
        hnew= h
        if is_stiff(rhsty, t, y):
            stiff=1
            while True:
                count+=1
                yb1, yb2 = step_imp(t,y,hnew)
                hok, hnew, h = h_ok(yb1, yb2, hnew)
                if hok:
                    break
        else:
            while True:
                count+=1
                yb1, yb2 = step_exp(t,y,hnew)
                hok, hnew, h = h_ok(yb1, yb2, hnew)
                if hok:
                    break
        t=t+h
        y=yb1
        ts.append(t)
        ys.append(y)
        stiffs.append(stiff)
        counts.append(count)
        hs.append(h)
        h=hnew

    return jnp.array(ts), jnp.stack(ys,axis=1),jnp.array(stiffs),jnp.array(counts),jnp.array(hs)


In [140]:
y0=jnp.array([1.,0.])
u=jnp.array([1.,1.])
ts, ys, stiffs, counts, hs=auto_ivp(vdp,(0,3000.), y0, u=u, h=0.01, atol=1e-12,rtol=1e-10)

In [141]:
stiffs.sum()/stiffs.size

Array(0.8070432179, dtype=float64)

In [142]:
np.histogram(counts,bins=[0,1,2,3,4,5])

(array([    0, 65034,  9412,    23,    14]), array([0, 1, 2, 3, 4, 5]))

In [143]:
np.histogram(hs)

(array([66706,  3051,  1585,  1102,   864,   542,   333,   174,   100,    26]),
 array([2.4827023059e-05, 1.0787707276e-01, 2.1572931850e-01, 3.2358156424e-01, 4.3143380998e-01, 5.3928605572e-01, 6.4713830146e-01, 7.5499054720e-01, 8.6284279293e-01, 9.7069503867e-01,
        1.0785472844e+00]))

In [144]:

tend=3000
res=solve_ivp(jax.jit(vdp),(0,tend),y0,method='LSODA',args=(u,), dense_output=True,atol=1e-12,rtol=1e-10)

In [145]:
res.sol(ts)[:,-10:]

array([[-1.3852584286, -1.3831675457, -1.3815927517, -1.3792197423, -1.3780283067, -1.3762349068, -1.374884884 , -1.3728517526, -1.3713204478, -1.3690129585],
       [ 0.0150688518,  0.015141387 ,  0.0151965307,  0.0152804696,  0.0153230014,  0.0153875181,  0.0154364822,  0.0155108755,  0.0155674321,  0.0156535219]])

In [146]:
ys[:,-10:]

Array([[-1.3852584816, -1.383167599 , -1.3815928052, -1.3792197961, -1.3780283607, -1.376234961 , -1.3748849383, -1.3728518072, -1.3713205026, -1.3690130136],
       [ 0.0150688499,  0.0151413851,  0.0151965288,  0.0152804677,  0.0153229995,  0.0153875162,  0.0154364803,  0.0155108735,  0.01556743  ,  0.0156535198]], dtype=float64)

In [147]:
res.t.size

42206

In [148]:
ts.size

74484