In [1]:
%autosave 10

import ctypes
import math
import os

import numpy as np
from scipy import LowLevelCallable
from scipy.integrate import quad, nquad

Autosaving every 10 seconds


In [2]:
quad(math.sin, 0, math.pi)

(2.0, 2.220446049250313e-14)

In [3]:
def get_quad_llc(dll_path, func_name, argtypes=(ctypes.c_double,)):
    cdll = ctypes.CDLL(dll_path)
    func = cdll[func_name]
    func.restype = ctypes.c_double
    func.argtypes = argtypes
    llc = LowLevelCallable(func)
    return llc

!gcc functions.c -O3 -ffast-math -fPIC -shared -o functions.so
dll_path = os.path.join(os.getcwd(), 'functions.so')

In [4]:
!cythonize -i cy.pyx
import cy
func_cy_llc = LowLevelCallable.from_cython(cy, 'cython_func')

running build_ext


In [5]:
func_llc = get_quad_llc(dll_path, 'func')

def func_numpy(x):
    return np.sin(x) / x

def func_math(x):
    return math.sin(x) / x

assert quad(func_llc, 0, 1) == quad(func_cy_llc, 0, 1) == quad(func_numpy, 0, 1) == quad(func_math, 0, 1)

%timeit quad(func_llc, 0, 1)
%timeit quad(func_cy_llc, 0, 1)
%timeit quad(func_numpy, 0, 1)
%timeit quad(func_math, 0, 1)

4.38 µs ± 6.06 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
4.46 µs ± 17.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
49.9 µs ± 278 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
11.3 µs ± 46 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [10]:
dphot_llc = get_quad_llc(dll_path, 'dphot_dz', (ctypes.c_int, ctypes.POINTER(ctypes.c_double)))

def dphot_math(z, Omega):
    return 1.0 / math.sqrt((1.0 - Omega) * (1.0 + z)**3 + Omega)

assert quad(dphot_llc, 0, 1, args=(0.7,)) == quad(dphot_math, 0, 1, args=(0.7,))

%timeit quad(dphot_llc, 0, 1, args=(0.7,))
%timeit quad(dphot_math, 0, 1, args=(0.7,))

4.36 µs ± 14 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
15.4 µs ± 68.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [6]:
force_llc = get_quad_llc(dll_path, 'force', (ctypes.c_int, ctypes.POINTER(ctypes.c_double)))

def force_math(x, y, z, R):
    r = math.sqrt(x*x + y*y + z*z)
    if r > R:
        return r**-2
    return r / R**3

print(nquad(force_llc, [[0., 10.], [0, 10.], [0., 10.]], args=(1.0,)))
print(nquad(force_math, [[0., 10.], [0, 10.], [0., 10.]], args=(1.0,)))

%timeit nquad(force_llc, [[0., 10.], [0, 10.], [0., 10.]], args=(1.0,))
%timeit nquad(force_math, [[0., 10.], [0, 10.], [0., 10.]], args=(1.0,))

(18.007213311678726, 1.9422110470122394e-07)
(18.007213311678722, 1.9422110459084792e-07)
1.07 s ± 5.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.35 s ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
