<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/AutoStiff_NonStiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
jax.config.update('jax_enable_x64',True)
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots

In [98]:
np.set_printoptions(linewidth=200)

In [99]:
def vdp(t,v):
    return jnp.array([v[1], 100*(1 - v[0]**2)*v[1] - v[0]])

y0=jnp.array([0.5,0.5])
tend=2.

In [208]:
#Lobatto IIIC
f16=1/6
f13=1/3
f23=2/3
f512=5/12
f112=1/12

a_lobatto=np.array([[f16, -f13, f16],
                    [f16, f512, -f112],
                    [f16, f23, f16]])

c_lobatto=np.array([0, 0.5, 1])
b1_lobatto = np.array([f16, f23, f16])
b2_lobatto = np.array([-0.5, 2, -0.5])

a_DormandPrince = np.array([[0,     0,      0,      0,      0,      0,      0],
                            [1/5,   0,      0,      0,      0,      0,      0],
                            [3/40,  9/40,   0,      0,      0,      0,      0],
                            [44/45, -56/15, 32/9,   0,      0,      0,      0],
                            [19372/6561,    -25360/2187,    64448/6561,     -212/729,   0,      0, 0],
                            [9017/3168,     -355/33,        46732/5247,     49/176,     -5103/18656,    0, 0],
                            [35/384,        0,          500/1113,       125/192,        -2187/6784,     11/84, 0]])
c_DormandPrince = np.array([0, 1/5, 3/10, 4/5, 8/9, 1, 1.])
b1_DormandPrince = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0])
b2_DormandPrince = np.array([5179/57600, 0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40])

def get_ks(rhs, Ny, t0, y0, h, a, c):
    tvec = t0+c*h

    def f_scan(ik, act):
        i,k=ik
        a,c,t = act
        k=k.at[i].set(rhs(t,y0+h*a@k))
        return (i+1,k), None

    return jax.lax.scan(f_scan, (0,jnp.zeros((c.shape[0], Ny))), (a, c, tvec))

h=0.0025
(_,k),_=get_ks(vdp, y0.size, 0., y0, h, a_DormandPrince, c_DormandPrince)
y0 + h*b1_DormandPrince@k, y0 + h*b2_DormandPrince@k

(Array([0.50137312, 0.60163824], dtype=float64),
 Array([0.50137312, 0.60163832], dtype=float64))

In [174]:
def get_k_func(rhs, Ny, a, c):
    Nstages = c.size
    rhs_vec=jax.vmap(rhs,in_axes=(0,0),out_axes=0)

    def k_func(t,y,h,k):
        k=k.reshape(Nstages,Ny)
        tvec = t+c*h
        yvec = y + h*a@k
        return (k-rhs_vec(tvec,yvec)).ravel()

    return jax.jit(k_func)

k_func=get_k_func(vdp, y0.size, a_lobatto, c_lobatto)
t0=0
kguess = jnp.tile( vdp(0,y0).reshape(1,-1),(3,1)).ravel()
k_func(0,y0,h,kguess)  #3 x Ny
k_func_jac=jax.jit(jax.jacobian(k_func,3))
k=kguess
yold=y0
for i in range(10):
    f=k_func(0,y0,h,k)
    if jnp.linalg.norm(f)<1e-12:
        break
    dk = jnp.linalg.solve(k_func_jac(0,y0,h,kguess), -f)
    k=k+dk
k=k.reshape(3,y0.size)
ynew1 = yold + h*(b1_lobatto@k)
ynew2 = yold + h*(b2_lobatto@k)
ynew1, ynew2

(Array([0.50020425, 0.51511551], dtype=float64),
 Array([0.50020422, 0.51511329], dtype=float64))

In [160]:

res=solve_ivp(vdp,(0,tend),y0,method='Radau',dense_output=True,atol=1e-10,rtol=1e-8)
sol=res.sol

In [169]:
sol(res.t[1])

array([0.50020425, 0.51511551])

In [153]:
sol(h)

array([0.50005019, 0.50371378])

In [47]:
np.linalg.eig(jac(0,y0))

(array([ 0.68627973, 74.31372027]),
 array([[-0.82451134, -0.01345525],
        [-0.56584542, -0.99990947]]))

In [48]:
y, f_jvp=jax.linearize(lambda v: rhs_sin(0,v), y0)

In [52]:
bold=np.random.uniform(size=2)

In [53]:
for i in range(10):
    bnew= f_jvp(bold)
    u = jnp.sum(bold*bnew)/jnp.sum(bold*bold)
    bnew=bnew/jnp.linalg.norm(bnew)
    bold=bnew
    print(u)

-0.5105499729765429
-0.44862685393586843
-0.38985779374212254
-0.37186965013939727
-0.36555942435890987
-0.3632254153489002
-0.36234473167396125
-0.36200990370770625
-0.36188223796252467
-0.3618335070777532


In [None]:
np.set_printoptions(precision=15)
y=[y0]
yprev=y0.copy()

dt=0.01
tplot = np.arange(0,tend*1.01,dt)
for t in tplot:
    ls, vs= np.linalg.eig(jac(0,yprev))
    cs = np.linalg.solve(vs,yprev)
    yprev=np.real(vs@(cs*jnp.exp(ls*dt)))
    y.append(yprev)

In [None]:
linear_y=jnp.stack(y,axis=1)

In [None]:
fig=make_subplots()
fig.add_scatter(x=tplot,y=linear_y[0])
fig.add_scatter(x=tplot,y=sol(tplot)[0])