In [None]:
import numba
import gsw
import numpy as np
import cffi

Start with an example from older numba docs:
http://numba.pydata.org/numba-doc/0.12.1/interface_c.html
I haven't found anything similar in the current docs:
https://numba.readthedocs.io/en/stable/reference/pysupported.html?highlight=cffi

In [None]:
# Example from older numba docs.
from numba import jit
from cffi import FFI

ffi = FFI()
ffi.cdef('double sin(double x);')

# loads the entire C namespace
C = ffi.dlopen(None)
c_sin = C.sin

@jit(nopython=True)
def cffi_sin_example(x):
    return c_sin(x)

cffi_sin_example(2.2)

Modify the example for a single function from gsw.

In [None]:
ffi = cffi.FFI()
ffi.cdef("""
     double gsw_rho(double s, double t, double p);
 """)
C = ffi.dlopen("libgswteos-10.so")
s, t, p = 35, 25, 1600
sf, tf, pf = 35.0, 25.0, 1600.0
rho = C.gsw_rho(s, t, p)
print(type(C))
print(rho)

In [None]:
%timeit C.gsw_rho(s, t, p)
%timeit C.gsw_rho(sf, tf, pf)

In [None]:
%timeit gsw.rho(s, t, p)

For scalar arguments, this cffi ABI access is much faster than the gsw module.

Use a loop to work with matching vector arguments:

In [None]:
def manyrho(s, t, p):
    out = np.empty(s.shape, float)
    for i in range(len(s)):
        out[i] = C.gsw_rho(s[i], t[i], p[i])
    return out

In [None]:
s = np.ones((1000,)) * 35
t = np.ones_like(s) * 25
p = np.linspace(0, 2500, len(s))

%timeit manyrho(s, t, p)
%timeit gsw.rho(s, t, p)
# 531 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 49.6 µs ± 790 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


```
# 531 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
# 49.6 µs ± 790 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
As expected, the gsw module is now much faster.  It has considerable overhead,
though, since going from 1 to 1000 calculations goes from 21 to 50 µs, so even
with 1000 calculations 40% of the time is overhead.

Now try using the JIT.  It seems the memory allocation has to be done outside,
and returning the array is not possible, at least with this simple form.
The argument signature supplied to the decorator is not necessary, but
perhaps it speeds up the compilation?

In [None]:
rho = C.gsw_rho  # This is necessary.
dtype = np.dtype(np.float64)
out = np.empty(s.shape, dtype)

@numba.njit((numba.float64[:], numba.float64[:], numba.float64[:], numba.float64[:]))
def manyrhojit(s, t, p, out):
    for i in range(np.shape(s)[0]):
        out[i] = rho(s[i], t[i], p[i])

In [None]:
dir(C)

In [None]:
manyrhojit(s, t, p, out)

In [None]:
(out == gsw.rho(s, t, p)).all()

In [None]:
%timeit manyrhojit(s, t, p, out)


It's almost twice as fast as using the gsw module, presumably because
there is less overhead; it is not a ufunc.  Try making a ufunc:


In [None]:
@numba.vectorize
def vecrhojit(s, t, p):
    return rho(s, t, p)

In [None]:
r = vecrhojit(s, t, p)

In [None]:
(r == out).all()

In [None]:
%timeit vecrhojit(s, t, p)
%timeit vecrhojit(35, 25, 1600)

Wow! It's still faster than the gsw version, both for vectors and for scalars!
Even with the extra ufunc machinery, it's as fast as the version specialized to
1-D arrays.  Check that broadcasting works:

In [None]:
rr = vecrhojit(s.reshape((2, 500)), t[:500], 10)
print(rr.shape)

## Can we import from the gsw-Python dll?

In [None]:
dllname = gsw._gsw_ufuncs.__file__

ffip = cffi.FFI()
ffip.cdef("""
     double gsw_rho(double s, double t, double p);
 """)
Cp = ffip.dlopen(dllname)
s, t, p = 35.0, 25.0, 1600.0
rho = Cp.gsw_rho(s, t, p)
print(type(Cp))
print(rho)

In [None]:
%timeit Cp.gsw_rho(s, t, p)

In [None]:
@numba.njit
def pass_in(func, args):
    return func(*args)

@numba.njit
def pass_in3(func, s, t, p):
    return func(s, t, p)

pass_in(Cp.gsw_rho, (s, t, p))

In [None]:
func = Cp.gsw_rho
%timeit pass_in(Cp.gsw_rho, (s, t, p))
%timeit pass_in(func, (s, t, p))
%timeit pass_in3(func, s, t, p)

In [None]:
func = Cp.gsw_rho  #this name resolution has to be outside the jitted function
@numba.njit
def no_pass_in(args):
    return func(*args)

%timeit no_pass_in((s, t, p))

In [None]:
@numba.njit
def no_pass_in3(s, t, p):
    return func(s, t, p)
%timeit no_pass_in3(s, t, p)

There is significant overhead in unpacking the tuple instead of putting
the three scalar arguments on the command line.

## Go through ctypes instead of cffi?

In [None]:
import ctypes
gswlib = ctypes.cdll.LoadLibrary(dllname)

In [None]:
crho = gswlib.gsw_rho
crho.restype = ctypes.c_double
crho.argtypes = (ctypes.c_double, ctypes.c_double, ctypes.c_double)

In [None]:
crho(s, t, p)

In [None]:
%timeit crho(s, t, p)

In [None]:
@numba.njit
def no_pass_ctypes(s, t, p):
    return crho(s, t, p)

%timeit no_pass_ctypes(s, t, p)

In [None]:
%timeit pass_in3(crho, s, t, p)

The "no_pass_in" mode is slightly faster with cffi; but the pass_in mode
is **much** faster with ctypes!  Pass-in is still slower by about a factor of 7, 
though.

## Try WAP (Wrapper Address Protocol)

https://docs.python.org/3/library/ctypes.html#loading-dynamic-link-libraries


In [None]:
class Gswlib_rho(numba.types.WrapperAddressProtocol):
    def __wrapper_address__(self):
        return ctypes.cast(gswlib.gsw_rho, ctypes.c_voidp).value
    def signature(self):
        return numba.float64(numba.float64, numba.float64, numba.float64)
    
gswlib_rho = Gswlib_rho()

@numba.njit
def pass_in_WAP(func, s, t, p):
    return func(s, t, p)

@numba.njit
def no_pass_in_WAP(s, t, p):
    return gswlib_rho(s, t, p)

print(pass_in_WAP(gswlib_rho, s, t, p))
%timeit pass_in_WAP(gswlib_rho, s, t, p)
%timeit no_pass_in_WAP(s, t, p)

That is dissappointing; it is faster than cffi pass-in but slower 
than the unadorned ctypes pass-in (1.55 µs).  (The WAP also makes it
slightly slower in no-pass-in mode.)

## Ctypes and njit

In [None]:
from neutral_surfaces._densjmd95 import rho as ndrho
%timeit ndrho(s, t, p)
%timeit crho(s, t, p)
%timeit no_pass_ctypes(s, t, p)

Review above: plain ctypes is slow; wrapping it in a jit function makes it as fast as anything; and the full-jit jmd95 rho is similar in speed to jit(ctypes(gsw))

In [None]:
cspecvol = gswlib.gsw_specvol
cspecvol.restype = ctypes.c_double
cspecvol.argtypes = (ctypes.c_double, ctypes.c_double, ctypes.c_double) 

In [None]:
print(crho(s, t, p), cspecvol(s, t, p))

### Test: in no-pass mode, can we change the function?

In [None]:
@numba.njit
def wrap_somefunc(s, t, p):
    return somefunc(s, t, p)

somefunc = cspecvol
print(wrap_somefunc(s, t, p))
%timeit wrap_somefunc(s, t, p)

somefunc = crho
print(wrap_somefunc(s, t, p))
%timeit wrap_somefunc(s, t, p)

**DANGER:** we see that the actual function is compiled in the first time the wrapper is called.  Changing what the name "somefunc" points to after that first encounter has no
effect.

Try putting the switching logic inside the wrapper.

In [None]:
funcs = [crho, cspecvol]  # doesn't work; "reflected list" as global
funcs = (crho, cspecvol)

@numba.njit
def wrap_funclist(ind, s, t, p):
    return funcs[ind](s, t, p)

print(wrap_funclist(0, s, t, p), wrap_funclist(1, s, t, p))
%timeit wrap_funclist(0, s, t, p)
%timeit wrap_funclist(1, s, t, p)

**Bingo!** That looks very promising; we have nearly the same speed as with pure no-pass mode, but we can select an entry from a global list.  Again, though, the limitation is that whatever that list is when the jitting occurs will determine its contents for the rest of the run.

https://stackoverflow.com/questions/44131691/how-to-clear-cache-or-force-recompilation-in-numba

https://numba.pydata.org/numba-doc/dev/user/faq.html

Try using `recompile()` method:

In [None]:
@numba.njit
def wrap_somefunc(s, t, p):
    return somefunc(s, t, p)

somefunc = cspecvol
wrap_somefunc.recompile()
print(wrap_somefunc(s, t, p))
%timeit wrap_somefunc(s, t, p)

somefunc = crho
wrap_somefunc.recompile()
print(wrap_somefunc(s, t, p))
%timeit wrap_somefunc(s, t, p)
print("recompilation time:")
%timeit wrap_somefunc.recompile()

That looks like a reasonable alternative that could be used either in no-pass mode, as above, or with the method of passing in a tuple of functions, if we wanted to be able to register new functions in a new tuple of options.  It requires keeping track of all functions that need to be recompiled based on an input argument, all handled in the top-level function.  It looks like it adds at least 22 ms to that top-level function for each such recompilation; not a big deal.  Probably consolidating the low-level functions into a minimum number will speed up both the recompilations and the execution.