# Tutorial: ODE Integration

This notebook looks at how one can do ODE integration in JAX from scratch.

In [1]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".home"])

# append to path
sys.path.append(str(root))


Let's take an initial value problem (IVP) of the form:

$$
\dot{x}(t) = \boldsymbol{f}(x(t),t), \hspace{4mm} x(0)=x_0
$$

where the solution to this differential equation is:

$$
x(t) = x_0 + \int_{t_0}^{t_1}\boldsymbol{f}(x(\tau))\tau
$$

Almost all problems involve some sort of dicretization

$$
u_{t+1} = u_{t} + \boldsymbol{g}(u_{t}, c)
$$

## Euler Integration Method

$$
u_{t+1} = u_{t} + \alpha \boldsymbol{f}(u_t, \delta_t)
$$

In [None]:
def odeint_euler(f, y0, t, *args):
  def step(state, t):
    y_prev, t_prev = state
    dt = t - t_prev
    y = y_prev + dt * f(y_prev, t_prev, *args)
    return (y, t), y
  _, ys = lax.scan(step, (y0, t[0]), t[1:])
  return ys

# Runga-Kutta, 4th Order

In [None]:
def odeint_rk4(f, y0, t, *args):
  def step(state, t):
    y_prev, t_prev = state
    h = t - t_prev
    k1 = h * f(y_prev, t_prev, *args)
    k2 = h * f(y_prev + k1/2., t_prev + h/2., *args)
    k3 = h * f(y_prev + k2/2., t_prev + h/2., *args)
    k4 = h * f(y_prev + k3, t + h, *args)
    y = y_prev + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return (y, t), y
  _, ys = lax.scan(step, (y0, t[0]), t[1:])
  return ys

# Diffrax

In [2]:
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

