<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/AutoStiff_NonStiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
jax.config.update('jax_enable_x64',True)
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots

In [2]:
np.set_printoptions(precision=10,linewidth=200)
jnp.set_printoptions(precision=10, linewidth=200)

In [95]:
def vdp(t,v,u):
    return jnp.array([u[0]*v[1], u[1]*100*(1 - v[0]**2)*v[1] - v[0]])



In [96]:
#Lobatto IIIC
a_lobatto=np.array([[1/6, -1/3, 1/6],
                    [1/6, 5/12, -1/12],
                    [1/6, 2/3, 1/6]])

c_lobatto=np.array([0, 0.5, 1])
b1_lobatto = np.array([1/6, 2/3, 1/6])
b2_lobatto = np.array([-0.5, 2, -0.5])

#https://arxiv.org/pdf/1306.2392.pdf, Hairer's embedded RADAU5

s6=np.sqrt(6)
a_radau= np.array([[(88-7*s6)/360,   (296-169*s6)/1800,      (-2+3*s6)/225],
                   [(296+169*s6)/1800,  (88+7*s6)/360,      (-2-3*s6)/225],
                   [(16-s6)/36,         (16+s6)/36,         1/9]])

c_radau=np.array([(4-s6)/10, (4+s6)/10, 1])
b1_radau = np.array([(16-s6)/36,         (16+s6)/36,         1/9])
radau_g = 1/8
b2_radau = np.r_[radau_g, b1_radau-radau_g*np.array([(2+3*s6)/6, (2-3*s6)/6, 1/3])]





a_DormandPrince = np.array([[0,     0,      0,      0,      0,      0,      0],
                            [1/5,   0,      0,      0,      0,      0,      0],
                            [3/40,  9/40,   0,      0,      0,      0,      0],
                            [44/45, -56/15, 32/9,   0,      0,      0,      0],
                            [19372/6561,    -25360/2187,    64448/6561,     -212/729,   0,      0, 0],
                            [9017/3168,     -355/33,        46732/5247,     49/176,     -5103/18656,    0, 0],
                            [35/384,        0,          500/1113,       125/192,        -2187/6784,     11/84, 0]])
c_DormandPrince = np.array([0, 1/5, 3/10, 4/5, 8/9, 1, 1.])
b1_DormandPrince = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0])
b2_DormandPrince = np.array([5179/57600, 0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40])

def get_explicit_step(rhs, u=None, a=a_DormandPrince, b1=b1_DormandPrince, b2=b2_DormandPrince, c=c_DormandPrince):

    def step(t0, y0, h):
        tvec = t0+c*h
        Ny=y0.size
        def f_scan(ik, act):
            i,k=ik
            a,c,t = act
            k=k.at[i].set(rhs(t,y0+h*a@k, u))
            return (i+1,k), None

        (_,k),_= jax.lax.scan(f_scan, (0,jnp.zeros((c.shape[0], Ny))), (a, c, tvec))
        return y0 + h*b1@k, y0 + h*b2@k

    return jax.jit(step)

In [147]:
def get_implicit_step(rhs, u=None, a=a_radau, b1=b1_radau, b2=b2_radau, c=c_radau):
    Nstages = c.size
    rhs_vec=jax.jit(jax.vmap(rhs,in_axes=(0,0,None),out_axes=0))
    rhs_jac=jax.jacobian(rhs,1)

    def k_func(tvec, y,h,k):
        k=k.reshape(Nstages,y.size)
        yvec = y + h*a@k
        return (k-rhs_vec(tvec,yvec, u)).ravel()

    k_func_jac=jax.jacobian(k_func,3)

    def step(t0, y0, h):
        Ny=y0.size
        tvec = t0+c*h
        rhs0=rhs(t0,y0,u)
        k = jnp.tile(rhs0.reshape(1,-1),(Nstages,1)).ravel()
        f0=k_func(tvec,y0,h,k)

        def body_func_for(i, kf):
            k,f=kf
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f

        def body_func_while(kfi):
            k,f,i=kfi
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f,i+1

        def cond_func(kfi):
            k,f,i=kfi
            return jnp.logical_and(jnp.linalg.norm(f)>1e-12,i<50)

        # k,f,i=jax.lax.while_loop(cond_func, body_func_while, (k, f0, 0))  #does not work with reverse mode
        k,f=jax.lax.fori_loop(0,5, body_func_for, (k, f0))  #5 iterations should be enough
        k=k.reshape(Nstages,Ny)
        return y0+ h*(b1@k), y0 + h*(b2@ jnp.concatenate([rhs0.reshape(1,y0.size), k]))

    def step2(t0, y0, h):
        #Simplified Newton Method (less accurate)
        tvec = t0+c*h
        Ny=y0.size
        mat=jnp.kron(jnp.eye(Nstages),jnp.eye(Ny)) - h*jnp.kron(a_radau,rhs_jac(t0,y0,u))
        Y0=jnp.tile(y0,(Nstages,1))
        f0=Y0.reshape(-1,1)-jnp.tile(y0,(1,Nstages)).T-h*(a_radau @ rhs_vec(tvec,Y0,u)).reshape(-1,1)

        def body_func(Yfi):
            Y,f,i=Yfi
            f=Y.reshape(-1,1)-jnp.tile(y0,(1,Nstages)).T-h*(a_radau @ rhs_vec(tvec,Y,u)).reshape(-1,1)
            dY = jnp.linalg.solve(mat, -f)
            Y=Y+dY.reshape(Nstages,Ny)
            return Y,f,i+1

        def cond_func(Yfi):
            _,f,i=Yfi
            return jnp.logical_and(jnp.linalg.norm(f)>1e-12,i<50)

        Y,f,i=jax.lax.while_loop(cond_func, body_func, (Y0, f0, 0))

        return y0+ h*(b1@Y),y0 + h*(b2@ jnp.concatenate([rhs(t0,y0,u).reshape(1,Ny), Y]))

    return jax.jit(step)

In [148]:
def test(u):
    step=get_implicit_step(vdp,u)
    h=0.01
    y0=jnp.array([1.1, 0.1])
    return step(0,y0,h)

In [149]:
u=jnp.array([1.,1.])
test(u)

(Array([1.1008503481, 0.0710586368], dtype=float64),
 Array([1.1008503705, 0.0710583774], dtype=float64))

In [151]:
test_jac=jax.jit(jax.jacrev(lambda u: test(u)[0]))
test_jac(u)

Array([[ 8.5006043573e-04, -8.8140420406e-05],
       [-7.7978785122e-05, -1.6069119335e-02]], dtype=float64)

In [152]:
test_jac(u*2)

Array([[ 0.0007666805, -0.0001537651],
       [-0.0001144488, -0.0129675765]], dtype=float64)

In [123]:
step_exp=get_explicit_step(vdp,u)
step_imp=get_implicit_step(vdp,u)
h=0.01
y0=jnp.array([1.1, 0.1])
print(step_imp(0,y0,h))
print(step_exp(0,y0,h))
print(sol(h))

(Array([1.1008503481, 0.0710586368], dtype=float64), Array([1.1008503705, 0.0710583774], dtype=float64))
(Array([1.100850348 , 0.0710586379], dtype=float64), Array([1.1008503502, 0.0710585885], dtype=float64))
[1.1008503481 0.0710586357]


In [7]:

tend=2.
res=solve_ivp(vdp,(0,tend),y0,method='Radau',dense_output=True,atol=1e-12,rtol=1e-10)
sol=res.sol

In [None]:
res.t

In [None]:
sol(h)

In [None]:
np.linalg.eig(jac(0,y0))

In [None]:
y, f_jvp=jax.linearize(lambda v: rhs_sin(0,v), y0)

In [None]:
bold=np.random.uniform(size=2)

In [None]:
for i in range(10):
    bnew= f_jvp(bold)
    u = jnp.sum(bold*bnew)/jnp.sum(bold*bold)
    bnew=bnew/jnp.linalg.norm(bnew)
    bold=bnew
    print(u)

In [None]:
np.set_printoptions(precision=15)
y=[y0]
yprev=y0.copy()

dt=0.01
tplot = np.arange(0,tend*1.01,dt)
for t in tplot:
    ls, vs= np.linalg.eig(jac(0,yprev))
    cs = np.linalg.solve(vs,yprev)
    yprev=np.real(vs@(cs*jnp.exp(ls*dt)))
    y.append(yprev)

In [None]:
linear_y=jnp.stack(y,axis=1)

In [None]:
fig=make_subplots()
fig.add_scatter(x=tplot,y=linear_y[0])
fig.add_scatter(x=tplot,y=sol(tplot)[0])