In [11]:
import numpy as np
import jax.numpy as jnp
import jax
import numba

## Numpy

In [14]:
def func(x: np.ndarray, a: np.ndarray) -> np.ndarray:
    return a[0] * np.sin(a[1] * x) + a[2]


x = np.linspace(0, 10, 1_000_000).astype(np.float64)
x0 = np.array([5, 2, 6]).astype(np.float64)

func(x, x0)
%timeit func(x, x0)

5.4 ms ± 457 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Numpy + Numba

In [13]:
@numba.njit
def func(x: np.ndarray, a: np.ndarray) -> np.ndarray:
    return a[0] * np.sin(a[1] * x) + a[2]


x = np.linspace(0, 10, 1_000_000).astype(np.float64)
x0 = np.array([5, 2, 6]).astype(np.float64)

func(x, x0)
%timeit func(x, x0)

4.35 ms ± 388 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX

In [19]:
def func(x, a):
    return a[0] * jnp.sin(a[1] * x) + a[2]


x = jnp.linspace(0, 10, 1_000_000)
x0 = jnp.array([5, 2, 6])

func(x, x0)

%timeit func(x, x0)

2.09 ms ± 201 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## JAX + Compile

In [20]:
@jax.jit
def func(x, a):
    return a[0] * jnp.sin(a[1] * x) + a[2]


x = jnp.linspace(0, 10, 1_000_000)
x0 = jnp.array([5, 2, 6])

func(x, x0)


%timeit func(x, x0)

1.12 ms ± 47.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Cython

In [2]:
%load_ext cython

In [3]:
%%cython -a
import time
import numpy as np
cimport numpy as np
cimport cython
np.import_array()


ctypedef np.float_t DTYPE_t

cdef extern from "math.h":
    cpdef double sin(double x)

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef np.ndarray[double, ndim=1] sin_func(np.ndarray[double, ndim=1] arr, np.ndarray[double, ndim=1] params):
    cdef int shape = len(arr)
    cdef np.ndarray[double, ndim=1] y = np.zeros(shape)
    for k in range(shape):
        y[k] = params[0]*sin(params[1]*arr[k]+params[2]) + params[3]
    return y

In [8]:
%%cython -a
from libc.math cimport sin  # Importing sin from the C library
import numpy as np
import cython
cimport numpy as np
cimport cython
np.import_array()


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef np.ndarray[double, ndim=1] sin_func(np.ndarray[double, ndim=1] arr, np.ndarray[double, ndim=1] params):
    cdef int k, shape = len(arr)
    cdef double a, b, c, d
    # Assign parameters to local variables for faster access
    a, b, c, d = params[0], params[1], params[2], params[3]
    
    # Using typed memory views for faster array access
    cdef double[:] x = arr
    cdef double[:] y = np.empty(shape, dtype=np.double)
    
    for k in range(shape):
        y[k] = a * sin(b * x[k] + c) + d
    
    return np.asarray(y)  # Convert back to NumPy array for compatibility



In file included from /Users/4cd87a/.cache/ipython/cython/_cython_magic_4b49f573179c0a6aa7c6d8a991e8f366.c:778:
In file included from /Users/4cd87a/anaconda3/envs/phd-main/lib/python3.10/site-packages/numpy/core/include/numpy/arrayobject.h:5:
In file included from /Users/4cd87a/anaconda3/envs/phd-main/lib/python3.10/site-packages/numpy/core/include/numpy/ndarrayobject.h:12:
In file included from /Users/4cd87a/anaconda3/envs/phd-main/lib/python3.10/site-packages/numpy/core/include/numpy/ndarraytypes.h:1940:
 ^
                    CYTHON_FALLTHROUGH;
                    ^
/Users/4cd87a/.cache/ipython/cython/_cython_magic_4b49f573179c0a6aa7c6d8a991e8f366.c:370:34: note: expanded from macro 'CYTHON_FALLTHROUGH'
      #define CYTHON_FALLTHROUGH __attribute__((fallthrough))
                                 ^
                    CYTHON_FALLTHROUGH;
                    ^
/Users/4cd87a/.cache/ipython/cython/_cython_magic_4b49f573179c0a6aa7c6d8a991e8f366.c:370:34: note: expanded from macro 'CYTH

In [9]:
x = np.linspace(0, 10, 1_000_000).astype(np.float64)
x0 = np.array([5, 2, 6]).astype(np.float64)

sin_func(x, x0)

%timeit sin_func(x, x0)

3.81 ms ± 196 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
