<a href="https://colab.research.google.com/github/profteachkids/CHE5136_Fall2023/blob/main/AutoStiff_NonStiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
jax.config.update('jax_enable_x64',True)
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots

In [102]:
np.set_printoptions(precision=10,linewidth=200)
jnp.set_printoptions(precision=10, linewidth=200)

In [96]:
def vdp(t,v):
    return jnp.array([v[1], 100*(1 - v[0]**2)*v[1] - v[0]])



In [140]:
#Lobatto IIIC
a_lobatto=np.array([[1/6, -1/3, 1/6],
                    [1/6, 5/12, -1/12],
                    [1/6, 2/3, 1/6]])

c_lobatto=np.array([0, 0.5, 1])
b1_lobatto = np.array([1/6, 2/3, 1/6])
b2_lobatto = np.array([-0.5, 2, -0.5])

a_DormandPrince = np.array([[0,     0,      0,      0,      0,      0,      0],
                            [1/5,   0,      0,      0,      0,      0,      0],
                            [3/40,  9/40,   0,      0,      0,      0,      0],
                            [44/45, -56/15, 32/9,   0,      0,      0,      0],
                            [19372/6561,    -25360/2187,    64448/6561,     -212/729,   0,      0, 0],
                            [9017/3168,     -355/33,        46732/5247,     49/176,     -5103/18656,    0, 0],
                            [35/384,        0,          500/1113,       125/192,        -2187/6784,     11/84, 0]])
c_DormandPrince = np.array([0, 1/5, 3/10, 4/5, 8/9, 1, 1.])
b1_DormandPrince = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0])
b2_DormandPrince = np.array([5179/57600, 0, 7571/16695, 393/640, -92097/339200, 187/2100, 1/40])

def get_explicit_step(rhs, a=a_DormandPrince, b1=b1_DormandPrince, b2=b2_DormandPrince, c=c_DormandPrince):

    def step(t0, y0, h):
        tvec = t0+c*h
        Ny=y0.size
        def f_scan(ik, act):
            i,k=ik
            a,c,t = act
            k=k.at[i].set(rhs(t,y0+h*a@k))
            return (i+1,k), None

        (_,k),_= jax.lax.scan(f_scan, (0,jnp.zeros((c.shape[0], Ny))), (a, c, tvec))
        return y0 + h*b1@k, y0 + h*b2@k

    return jax.jit(step)

In [141]:
def get_implicit_step(rhs, a=a_lobatto, b1=b1_lobatto, b2=b2_lobatto, c=c_lobatto):
    Nstages = c.size
    rhs_vec=jax.jit(jax.vmap(rhs,in_axes=(0,0),out_axes=0))

    def k_func(tvec, y,h,k):
        k=k.reshape(Nstages,y.size)
        yvec = y + h*a@k
        return (k-rhs_vec(tvec,yvec)).ravel()

    k_func_jac=jax.jacobian(k_func,3)

    def step(t0, y0, h):
        tvec = t0+c*h
        k = jnp.tile(rhs(t0,y0).reshape(1,-1),(Nstages,1)).ravel()
        f=k_func(tvec,y0,h,k)

        def body_func(kfi):
            k,f,i=kfi
            f=k_func(tvec,y0,h,k)
            dk = jnp.linalg.solve(k_func_jac(tvec,y0,h,k), -f)
            k=k+dk
            return k,f, i+1

        def cond_func(kfi):
            k,f,i=kfi
            return jnp.logical_and(jnp.linalg.norm(f)<1e-12,i<50)

        k,*_=jax.lax.while_loop(cond_func, body_func, (k, f, 0))
        k=k.reshape(3,y0.size)
        return y0+ h*(b1@k), y0 + h*(b2@k)


    return step

In [142]:
step_exp=get_explicit_step(vdp)
step_imp=get_implicit_step(vdp)
h=0.001
y0=jnp.array([0.51123413,0.5431234341])
print(step_imp(0,y0,h))
print(step_exp(0,y0,h))
print(sol(h))

(Array([0.5117772534, 0.5827294511], dtype=float64), Array([0.5117772534, 0.5827294511], dtype=float64))
(Array([0.5117975477, 0.5842119497], dtype=float64), Array([0.5117975477, 0.5842119506], dtype=float64))
[0.5117975477 0.5842119497]


In [137]:

tend=2.
res=solve_ivp(vdp,(0,tend),y0,method='Radau',dense_output=True,atol=1e-12,rtol=1e-10)
sol=res.sol

array([0.5115107033, 0.5632926404])

In [9]:
sol(h)

array([0.50137312, 0.60163823])

In [10]:
np.linalg.eig(jac(0,y0))

NameError: ignored

In [None]:
y, f_jvp=jax.linearize(lambda v: rhs_sin(0,v), y0)

In [None]:
bold=np.random.uniform(size=2)

In [None]:
for i in range(10):
    bnew= f_jvp(bold)
    u = jnp.sum(bold*bnew)/jnp.sum(bold*bold)
    bnew=bnew/jnp.linalg.norm(bnew)
    bold=bnew
    print(u)

In [None]:
np.set_printoptions(precision=15)
y=[y0]
yprev=y0.copy()

dt=0.01
tplot = np.arange(0,tend*1.01,dt)
for t in tplot:
    ls, vs= np.linalg.eig(jac(0,yprev))
    cs = np.linalg.solve(vs,yprev)
    yprev=np.real(vs@(cs*jnp.exp(ls*dt)))
    y.append(yprev)

In [None]:
linear_y=jnp.stack(y,axis=1)

In [None]:
fig=make_subplots()
fig.add_scatter(x=tplot,y=linear_y[0])
fig.add_scatter(x=tplot,y=sol(tplot)[0])