In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax.lax import scan
from jax import random


from turtle import color
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
def solve_euler(f, t, x0):
    ΔT = t[1:] - t[:-1]
    x_cur = x0
    X = [x_cur]

    for t, Δt in tqdm(zip(t[1:], ΔT)):
        dy = f(t, x_cur)
        x_new = x_cur + Δt * dy
        X.append(x_new)
        x_cur = x_new

    return jnp.stack(X, axis=1)


def f(t, x):
    x, y = x
    dx = α * x - β * x * y
    dy = δ * x * y - γ * y
    return jnp.array((dx, dy))



h = 0.001
t_start = 0.0
t_end = 5.0
t = jnp.arange(t_start, t_end + h, h)
x0 = 2.0
y0 = 1.0
α = 1.0
β = 1.0
γ = 1.0
δ = 1.0

xy0 = jnp.array((x0, y0))

# %timeit jit(f)

# f_jit = jit(f)

# %timeit f(0.0, xy0)
# %timeit f_jit(0.0,xy0)

solve_euler_jit = jit(solve_euler,static_argnames=["f"])


%time solve_euler(f,t,xy0)
%time solve_euler_jit(f,t,xy0)

5000it [00:03, 1425.68it/s]


CPU times: total: 4.23 s
Wall time: 4.28 s


5000it [01:00, 82.71it/s] 


In [None]:
def solve_euler_scan(f, t, x0):
    def f_scan((t_cur,x_cur), slice):
        dydt = f(
            t_cur,x_cur
        )
        return x_cur + dydt

    _, X = scan(f, x0)
    return X


def f(t, x):
    x, y = x
    dx = α * x - β * x * y
    dy = δ * x * y - γ * y
    return jnp.array((dx, dy))