In [1]:
from diffeqpy import de
from time import process_time
from julia import Main
import numba
import numpy as np


In [2]:
jul_f = Main.eval("""
function f(dy,y,p,t)
    x, y, z = y
    sigma, rho, beta = p
    dy[1] = sigma * (y - x)
    dy[2] = x * (rho - z) - y
    dy[3] = x * y - beta * z
end""")
u0 = [1.0,0.0,0.0]
tspan = (0., 100.)
p = [10.0,28.0,2.66]
prob = de.ODEProblem(jul_f, u0, tspan, p)
sol = de.solve(prob)


In [3]:
%timeit -n 1000 sol_new = de.solve(prob)

2.7 ms ± 275 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
from scipy.integrate import solve_ivp

du = np.zeros((3), dtype=np.float32)

def f2(t, u, du, sigma, rho, beta):
    x, y, z = u
    du[0] = sigma * (y - x)
    du[1] = x * (rho - z) - y
    du[2] = x * y - beta * z
    return du

numba_f = numba.jit(f2, nopython=True)


In [5]:
%%timeit -n 10

sol_new = solve_ivp(
    numba_f,
    tspan,
    u0,
    method='DOP853',
    t_eval=None,
    dense_output=False,
    events=None,
    vectorized=False,
    args=(du,10.0,28.0,2.66),
)


181 ms ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [67]:
@numba.njit(
    cache = True,
    fastmath = True,
)
def derivs(
    N,
    y,
    N_eqs,
):
    NEQS = N_eqs
    dy_dN = np.zeros(NEQS)
    
    if y[2] >= 1:
        dy_dN = np.zeros(NEQS)
    else:
        if y[2] > VERY_SMALL_NUM:
            dy_dN[0] = - np.sqrt(y[2] / (4 * np.pi))
        else:
            dy_dN[0] = 0.0
        
        dy_dN[1] = y[1] * y[2]
        dy_dN[2] = y[2] * (y[3] + 2 * y[2])
        dy_dN[3] = 2 * y[4] - 5 * y[2] * y[3] - 12 * y[2] * y[2]
        
        for i in range(4, NEQS-1):
            dy_dN[i] = (0.5 * (i-3) * y[3] + (i-4) * y[2]) * y[i] + y[i+1]
            
        dy_dN[NEQS-1] = (0.5 * (NEQS-4) * y[3] + (NEQS-5) * y[2]) * y[NEQS-1]

    return dy_dN

@numba.njit(
    cache = True,
    fastmath = True,
)
def inflation_ends(
    N,
    y,
    N_eqs,
):
    return y[2] - (1 - 1e-8)


In [49]:
deriv_fun = """
function f(dy,y,p,t)
    n = p
    if y[3] >= 1
        for i in 1:n
            dy[i] = 0
        end
    else
        if y[3] > 1e-8
            dy[1] = - sqrt(y[3] / (4 * pi))
        else
            dy[1] = 0
        end
        
        dy[2] = y[2] * y[3]
        dy[3] = y[3] * (y[4] + 2 * y[3])
        dy[4] = 2 * y[5] - 5 * y[3] * y[4] - 12 * y[3] * y[3]

        for i in 5:(n-1)
            dy[i] = (0.5 * (i-4) * y[4] + (i-5) * y[3]) * y[i] + y[i+1]
        end
        
        dy[n] = (0.5 * (n-4) * y[4] + (n-5) * y[3]) * y[n]
    end
end
"""

jul_deriv = Main.eval(deriv_fun)


In [41]:
VERY_SMALL_NUM = 1e-8

for _ in range(10000):
    NEQS = np.random.randint(5, 20)
    print(NEQS, end=' ')
    y1 = np.random.uniform(size=NEQS, low = 0.0, high = 1.0)
    # print(y1)

    dy1 = derivs(1, y1, NEQS)
    dy2 = jul_deriv(np.zeros(NEQS), y1, NEQS)
    
    assert np.all(np.isclose(dy1, dy2, atol=1e-16))


7 12 9 18 18 10 8 11 7 10 8 9 13 8 13 5 7 16 16 10 13 7 8 11 5 5 16 15 10 19 16 12 10 13 16 9 6 13 14 11 6 12 5 8 16 7 13 14 18 9 18 18 14 8 15 17 10 6 16 12 8 16 18 7 15 6 18 8 8 17 9 7 7 5 7 9 5 10 11 11 19 10 5 18 13 13 13 14 5 15 15 6 13 19 5 10 6 8 16 9 9 12 10 8 9 19 18 5 14 12 7 7 17 15 13 8 18 16 17 12 16 19 19 12 6 14 6 7 9 19 6 8 18 12 16 15 14 17 17 17 5 8 12 11 12 18 14 5 11 17 14 16 7 15 17 11 12 11 5 14 14 19 15 16 13 10 15 18 13 11 17 18 13 6 11 9 6 13 6 8 5 7 16 18 5 16 19 17 14 7 15 11 16 10 13 13 13 13 11 14 11 6 9 17 18 18 17 7 19 9 12 9 17 19 5 9 6 18 13 9 16 5 12 12 10 11 17 17 10 16 11 11 16 15 11 18 5 5 8 13 6 18 13 15 18 5 7 5 14 16 8 6 9 16 12 9 17 13 16 17 13 5 16 19 8 12 7 10 18 13 17 18 15 9 17 11 19 5 5 7 13 7 14 14 9 11 19 12 14 16 8 9 8 17 14 17 12 9 17 15 5 5 8 6 18 12 15 5 18 14 14 18 12 7 17 17 10 14 9 8 19 14 12 19 7 7 6 15 6 10 11 19 16 5 14 19 11 6 10 8 11 14 9 18 15 9 8 13 9 15 11 12 7 9 17 15 15 11 14 19 19 19 5 11 15 6 15 12 12 12 12 6 18 17 10 1

In [42]:
import numba as nb

NUM_EFOLDS_MIN = 46.
NUM_EFOLDS_MAX = 60.

@nb.njit(cache = True, fastmath = True)
def pick_init_vals(N_eqs):
    np.random.seed(0)
    init_vals = np.zeros(N_eqs)

    init_vals[0] = 0.0
    init_vals[1] = 1.0
    init_vals[2] = np.random.uniform(0, 0.8)
    init_vals[3] = np.random.uniform(-0.5, 0.5)
    init_vals[4] = np.random.uniform(-0.05, 0.05)

    width = 0.05
    
    for i in range(5, N_eqs):
        init_vals[i] = np.random.uniform(-0.5 * width, 0.5 * width)
        width *= 0.1
        
    init_N_efolds = np.random.uniform(NUM_EFOLDS_MIN, NUM_EFOLDS_MAX)
    
    return init_vals, init_N_efolds

N_start, N_end = 1000, 0


In [68]:
N_eqs = 8

y_init, N_efolds_init = pick_init_vals(N_eqs)


In [83]:
%%timeit -n 100

sol = solve_ivp(
    derivs,
    [N_start, N_end],
    y_init,
    args=(N_eqs,),
    events=inflation_ends,
    method='DOP853',
    first_step=1e-6,
)

149 ms ± 21.6 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [79]:
cond = Main.eval("""
condition(y, t, integrator) = y[3] - (1 - 1e-8)
""")

affect = Main.eval("""
affect!(integrator) = terminate!(integrator)
""")

callback_event = de.ContinuousCallback(cond, affect)

In [80]:
prob = de.ODEProblem(jul_deriv, y_init, (N_start, N_end), N_eqs)
sol = de.solve(prob, callback = callback_event)


In [81]:
%timeit -n 1000 sol = de.solve(prob)


734 µs ± 205 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [85]:
Main.eval('include("./f.jl")')


[[0.0,
  0.0022360242881008263,
  0.018143704830655496,
  0.052173852104473124,
  0.0947203530743621,
  0.14073793521832798,
  0.19325113632945576,
  0.2531223087194043,
  0.3142472815328738,
  0.39563654769937434,
  0.47767050695135505,
  0.5475817558169621,
  0.5963069168054481,
  0.6354511214987237,
  0.6531080245170758,
  0.653935797873861,
  0.653935797873861],
 [1.0,
  0.9980364233402051,
  0.9844655225892893,
  0.9577094037156277,
  0.9283312348404813,
  0.9010616657459889,
  0.8745852949036484,
  0.8484943458789029,
  0.823525174310516,
  0.7866440106201726,
  0.7357209040152348,
  0.6722737166741259,
  0.6122894156928359,
  0.552224874235176,
  0.5211521247775276,
  0.5196333593122198,
  0.5196333593122198],
 [0.061803071349734756,
  0.06117237300189653,
  0.05674963690822145,
  0.04779246091812154,
  0.03792821765656158,
  0.029365943943097456,
  0.022587919644139566,
  0.018896620972074018,
  0.020044717484678162,
  0.033933414518149296,
  0.08328278599305103,
  0.2065104910