# Speed-up Python using ctypes

**Brief summary:**
- Here, not as quick as *numba* and *jax*
    - This could be due to implicit type conversion,
    - Or lower C code optimization by GCC.
- To consider: including of c files in a library might complicate installation.

In [40]:
import ctypes
import os
import numpy as np
from numba import njit
from jax import jit as jjit

## Conditional sum

### ctypes

File: mysum.c 


```c
long long mysum(int n, int* array) {
    if (n <= 1) {
        return 0;
    }
    long long res = 0;
    for (int i = 0; i < n-1; ++i) {
        if (array[i] > array[i+1]) {
            res += array[i];
        }
    }
    return res;
}
```


To compile a C shared library:
```
gcc -fPIC -Wall -Werror -shared -o mysum.so mysum.c
```

In [32]:
mysum = ctypes.cdll.LoadLibrary(os.path.abspath('mysum.so'))

mysum.mysum.restype = ctypes.c_longlong
mysum.mysum.argtypes = [ctypes.c_int, 
                        np.ctypeslib.ndpointer(dtype=np.int32)]

arr = np.array([1,2,3,2,3], dtype=np.int32)
print(mysum.mysum(len(arr), arr))

3


In [38]:
arr = np.random.choice(5, size=1000000).astype(np.int32)
print(mysum.mysum(len(arr), arr))

1199522


In [46]:
%timeit _ = mysum.mysum(len(arr), arr)

4.88 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### python

In [44]:
def mysumfoo(array):
    n = len(array)
    if n <= 1:
        return 0
    res = 0
    for i in range(n-1):
        if array[i] > array[i+1]:
            res += array[i]
    return res

print(mysumfoo(arr))
%timeit _ = mysumfoo(arr)

1199522
295 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### numba

In [45]:
@njit
def mysumfoo(array):
    n = len(array)
    if n <= 1:
        return 0
    res = 0
    for i in range(n-1):
        if array[i] > array[i+1]:
            res += array[i]
    return res

mysumfoo(arr)
%timeit _ = mysumfoo(arr)

151 µs ± 1.86 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### jax

In [52]:
def mysumfoo(array):
    n = len(array)
    if n <= 1:
        return 0
    res = 0
    for i in range(n-1):
        if array[i] > array[i+1]:
            res += array[i]
    return res

mysumfoo = jjit(mysumfoo, static_argnums=(0,))

print(mysumfoo(arr))
%timeit _ = mysumfoo(arr)

1199522
35.7 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
