# Numba Tests

In [14]:
import numpy as np

## 1. Apply jit only if numba is available.

See http://stackoverflow.com/questions/29587317/python-numba-jit-conditional-and-recursive-stack-use

### Scenario 1:
Use quite a lot of decorators

In [111]:
def rmse(y, yhat):
    """ Calculate and return Root Mean Squared Error (RMSE)

    Args:
        y (np.ndarray): known values
        yhat (np.ndarray): predicted values

    Returns:
        float: Root Mean Squared Error
    """
    return ((y - yhat) ** 2).mean() ** 0.5

In [127]:
has_numba = True
try:
    import numba as nb
except ImportError:
    has_numba = False
    
from functools import wraps


def doublewrap(f):
    """ Allows decorators to be called with/without args/kwargs
    
    Allows:
        @decorator
        @decorator()
        @decorator(args, kwargs=values)
        
    Modification of answer from user "bj0" on StackOverflow:
    http://stackoverflow.com/a/14412901
    """
    @wraps(f)
    def new_dec(*args, **kwargs):
        if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
            # called as @decorator
            return f(args[0])
        elif len(args) == 0:
            # called as @decorator()
            return f
        else:
            # called as @decorator(*args, **kwargs)
            return lambda realf: f(realf, *args, **kwargs)

    return new_dec


@doublewrap
def try_jit(f, *args, **kwargs):
    """ Apply numba.jit to function ``f`` if Numba is available
    """
    if has_numba:
        @wraps(f)
        def wrap(*args, **kwargs):
            return nb.jit(*args, **kwargs)
        return wrap(f)
    else:
        return f

In [106]:
@try_jit
def rmse_jit(y, yhat):
    """ Calculate and return Root Mean Squared Error (RMSE)

    Args:
        y (np.ndarray): known values
        yhat (np.ndarray): predicted values

    Returns:
        float: Root Mean Squared Error
    """
    return ((y - yhat) ** 2).mean() ** 0.5

rmse_jit(y, yhat)

0.9756407207105383

In [107]:
@try_jit()
def rmse_jit(y, yhat):
    """ Calculate and return Root Mean Squared Error (RMSE)

    Args:
        y (np.ndarray): known values
        yhat (np.ndarray): predicted values

    Returns:
        float: Root Mean Squared Error
    """
    return ((y - yhat) ** 2).mean() ** 0.5

rmse_jit(y, yhat)

0.9756407207105383

In [108]:
@try_jit("float32[:](float32[:], float32[:])", nopython=True)
def rmse_jit(y, yhat):
    """ Calculate and return Root Mean Squared Error (RMSE)

    Args:
        y (np.ndarray): known values
        yhat (np.ndarray): predicted values

    Returns:
        float: Root Mean Squared Error
    """
    return ((y - yhat) ** 2).mean() ** 0.5

rmse_jit(y, yhat)

0.9756407207105383

In [109]:
has_numba = False
@doublewrap
def try_jit(f, *args, **kwargs):
    """ Apply numba.jit to function ``f`` if Numba is available
    """
    if has_numba:
        @wraps(f)
        def wrap(*args, **kwargs):
            return nb.jit(*args, **kwargs)
        return wrap(f)
    else:
        return f

@try_jit("float32[:](float32[:], float32[:])", nopython=True)
def rmse_jit_nojit(y, yhat):
    """ Calculate and return Root Mean Squared Error (RMSE)

    Args:
        y (np.ndarray): known values
        yhat (np.ndarray): predicted values

    Returns:
        float: Root Mean Squared Error
    """
    return ((y - yhat) ** 2).mean() ** 0.5

rmse_jit(y, yhat)

0.9756407207105383

In [None]:
@nb.jit()
def rmse_jit_basic(y, yhat):
    """ Calculate and return Root Mean Squared Error (RMSE)

    Args:
        y (np.ndarray): known values
        yhat (np.ndarray): predicted values

    Returns:
        float: Root Mean Squared Error
    """
    return ((y - yhat) ** 2).mean() ** 0.5

rmse_jit_basic(y, yhat)

Test timings!

In [15]:
y = np.random.rand(100)
yhat = y + np.random.standard_normal(100)

In [153]:
%timeit rmse(y, yhat)
%timeit rmse_jit_basic(y, yhat)
%timeit rmse_jit(y, yhat)
%timeit rmse_jit_nojit(y, yhat)

The slowest run took 8.50 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 5.42 µs per loop
The slowest run took 9.56 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 623 ns per loop
The slowest run took 9.53 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 625 ns per loop
The slowest run took 8.72 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 5.42 µs per loop


**Conclusions**: I'm really surprised it can handle all of the various decorator declarations:
``` python
@try_jit
@try_jit()
@try_jit("signature", kwargs=values)
```

Since the many tiers of decorators are only evaluated on the initial function ingest, there doesn't seem to be a performance difference.

### Scenario 2:
Use a much more simple approach...

In [151]:
has_numba = True
try:
    import numba as nb
except ImportError:
    has_numba = False
    

def try_jit2(f, *args, **kwargs):
    if has_numba:
        return nb.jit(*args, **kwargs)(f)
    else:
        return f
rmse_jit2 = try_jit2(rmse, nopython=True)
    
has_numba = False
def try_jit2_nojit(f, *args, **kwargs):
    if has_numba:
        return nb.jit(*args, **kwargs)(f)
    else:
        return f
rmse_jit2_nojit = try_jit2_nojit(rmse, nopython=True)

In [152]:
%timeit rmse(y, yhat)
%timeit rmse_jit2(y, yhat)
%timeit rmse_jit2_nojit(y, yhat)

The slowest run took 9.59 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 5.42 µs per loop
The slowest run took 92549.24 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 624 ns per loop
The slowest run took 9.01 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 5.45 µs per loop


### Test conclusions:
Just use many decorators!

In [155]:
?nb.njit