# Fast computations in python

Author: Mathurin Massias first.last@gmail.com

Inspired from Jake Vanderplas blog post: https://jakevdp.github.io/blog/2013/06/15/numba-vs-cython-take-2/

One of the notorious drawbacks of Python is it low speed. Due to its dynamic type checking, Python performs all kinds of checks when performing basic operations such as `a + b`; this can harm computation time. The most well-known example of slow Python computation is long for loops involving small computations, as we will observe below. 

In this notebook, around the example of pairwise distances computations, we introduce different ways to make Python faster.

## Setup

Consider an observation matrix $X \in \mathbb{R}^{500 \times 10}$ (hence, 500 observations $x_i$ living in dimension 10). 

In many Machine Learning tasks, in particular clustering, the matrix of pairwise distances is needed: $D =(d_{ij})_{1 \leq i,j \leq n}$ with 
$$d_{ij} = ||x_i - x_j||$$ 

The basic cost should be $\mathcal{O}(n^2 p)$, which can quickly get high, and become the bottleneck in algorithms.
Let us try to compute this matrix in different ways.

In [1]:
# requirements:
# !conda install cython numpy scipy scikit-learn 
# !pip install numba

In [2]:
import numpy as np
np.random.seed(0)
X = np.random.randn(500, 10)

### Naive python

In [3]:
import math

def pairwise_python(X):
    n_samples, n_features = X.shape
    D = np.zeros((n_samples, n_samples))
    for i in range(n_samples):
        for j in range(n_samples):
            tmp = 0
            for k in range(n_features):
                 tmp += (X[i, k] - X[j, k]) ** 2
            D[i, j] = math.sqrt(tmp)
    return D

# we could be twice faster by using D[i, j] = D[j, i], but we don't.

For a precise timing (averaged over repetitions), we use the `%timeit` ipython magic as follows.

In [4]:
%timeit pairwise_python(X)

2.53 s ± 128 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


That is, as expected, quite slow, and we are far from the "Big Data" regime.

### Numba

Next, we will compile out function Just-In-Time (jit) with `numba` a package that has known a tremendous development in the python ecosystem in the last decade, for its ease of use and the impressive speed-ups it provides. 
**If there is only one way to fast python to remember, it is this one.**

In [5]:
from numba import jit

pairwise_numba = jit(pairwise_python) # jit acts as a decorator: it takes 
# a function as input, and return another function. 

# An alternative (more pythonic) syntax is to put it on top of the function 
# it affects, prefixed with @, as in:
@jit
def a_function():
    # your code here
    return 0

We call the function once so that it gets compiled, e.g. the type of each variable is inferred. Then, we will be able to call it with any arguments, provided they are of the same type as the one used when calling our function for the first time.

In [6]:
pairwise_numba(np.zeros([10, 3]));  # compilation happens

In [7]:
%timeit pairwise_numba(X)

2.75 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


We get a factor 1000 speedup using only a decorator! And we did not even have to to modify our original code. 
The ease of use is numba's main feature.

Going beyond our naive code, numba can compile more complicated functions. For example, it supports code which contains a lot of numpy functions, and more and more are added at each release (numba is still a "young" project). 

Nevertheless, some functions must be modified before calling `@jit` upon them. For example, numba does not (currently) support the `axis` keyword of numpy functions (hopefully, it will be supported in the future; think that `np.dot` was not supported by numba a few years ago!)

In [8]:
# this does not work. IRL it would not make sense to jit this function, as there is 
# no code to optimize (numpy does everything); it is only for pedagogical purposes
@jit(nopython=True)
def numba_issue(X):
    return np.mean(X, axis=1)  # compute the mean of each row X 

numba_issue(X)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function mean at 0x7f4b2c65f4c0>) found for signature:
 
 >>> mean(array(float64, 2d, C), axis=Literal[int](1))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'Numpy_method_redirection.generic': File: numba/core/typing/npydecl.py: Line 348.
    With argument(s): '(array(float64, 2d, C), axis=Literal[int](1))':
   Rejected as the implementation raised a specific error:
     TypingError: numba doesn't support kwarg for mean
  raised from /home/mathurin/miniconda3/lib/python3.8/site-packages/numba/core/typing/npydecl.py:370

During: resolving callee type: Function(<function mean at 0x7f4b2c65f4c0>)
During: typing of call at <ipython-input-8-092ded842e28> (5)


File "<ipython-input-8-092ded842e28>", line 5:
def numba_issue(X):
    return np.mean(X, axis=1)  # compute the mean of each row X 
    ^


Note that in the cell above, we use `jit(nopython=True)`, to force an error when jit does no know how to handle some code. An equivalent is to replace the `jit` decorator by `njit` (which is imported as: `from numba import njit`).

If we had used `jit`, we would have obtained a warning, but the code would have been executed using non compiled python.

Often times, one must rewrite a bit its original function to make it numba compatible. In the example above, a fix is as follows.

In [9]:
# one must resort to more 'naive' code (which will still be lightning fast)
@jit
def numba_fixed(X):
    results = np.zeros(X.shape[0])
    for i in range(X.shape[0]):
        results[i] = np.mean(X[i, :])  # avoid using the axis keyword
    return results

numba_fixed(np.random.randn(5, 6))

array([-0.55033805,  0.63196909, -0.6423911 , -0.47853447,  0.33903917])

The lab could stop here, but numba compilation process has a lot of dependencies (usually shipped with anaconda), which may make it not suited for all apllications. We also need to check that there do not exist faster alternatives.

As a first candidate, numpy is usually quite fast at performing vectorized operations on its arrays, and complicated operations can be implemented using broadcasting.

### numpy

In [10]:
def pairwise_numpy(X):
    return np.sqrt(np.sum((X[:, None, :] - X) ** 2, axis=-1))

In [11]:
# check that we did not mess up the computation
np.allclose(pairwise_numpy(X), pairwise_numba(X))

True

In [12]:
%timeit pairwise_numpy(X)

15.8 ms ± 514 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


So, the syntax used broadcasting which is a powerful but not "obvious" tool, and we are 10 times slower than numba, not to mention the large memory requirement.

### sklearn, scipy

Computing distances is a frequent task in scientific python, hence one can expect to find functions accomplishing this task in popular packages, which have done the optimization for us. 
We benchmark the implementations of `sklearn` and `scipy`, which rely on efficient Fortran linear algebra routines

In [None]:
from scipy.spatial.distance import cdist
%timeit cdist(X, X)

It is faster than numba (measurements may vary on different architectures and for different sizes of $X$), but numba is more flexible: we can implement our own functions. scipy has the benfit of being reliable it is tested, documented, contrary to code we may write ourselves.

sklearn also has a way to compute distances:

In [None]:
from sklearn.metrics import pairwise_distances
%timeit pairwise_distances(X)

The time taken is in the same range of values, altough on the higher side this time. Difference in implementations may be: kind of distance supported (Euclidean distance, Manhattan distance, etc); support of sparse data $X$, support centering and normalization, etc.

With sklearn and scipy, we get fast code, but we cannot customize code. sklearn relies on Cython, a hybrid of C and python, which allows to write code to be compiled in C or C++, retaining the syntax of python, and allowing eventual calls to python. We can even call the extremely fast BLAS and LAPACK Fortran routines (used to perform vector/vector, matrix/vector, matrix/matrix computations, as well as typical linear algebras such as SVD decomposition.

Personal comment: Cython is nice and powerful, but it has a **high entry cost**, and is hard to debug. 

In [None]:
# this is an ipython magic to compile cython code inside the notebook.
# Cython compilation is more complicated outside of notebook, and usually
# delegated to the setup.py of the python package using it.
%load_ext cython

In [None]:
%%cython
# note: cimport, not import
cimport cython
import numpy as np
from libc.math cimport sqrt

# note the particular syntax, in particular variable
# types are declared
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:, :] pairwise_cython(double[:, :] X):
    cdef int n_samples = X.shape[0]
    cdef int n_features = X.shape[1]
    cdef double[:, :] D = np.empty((n_samples, n_samples))
    cdef int i, j, k
    cdef double tmp
    with nogil:  # make compilation break if a "call to python" is made
        for i in range(n_samples):  # we can still use range
            for j in range(n_samples):
                tmp = 0.
                for k in range(n_features):
                    tmp += (X[i, k] - X[j, k]) ** 2
                D[i, j] = sqrt(tmp)
    return np.asarray(D)

In [None]:
%timeit pairwise_cython(X)

It's complicated to write, but it's fast!

To summarize:
    
- numba: flexible, supports lots of functions, fast. Needs a compiler (shipped with conda)
- numpy and its vectorized computations may be enough for simple computations
- look for optimized functions in existing packages (s: usually fast, but you may not find what you need
- cython: super powerful. Hard to write, hard to debug, hard to maintain. Easy to ship! And the price of compilation is payed offline, whereas numba, like Julia, has an overhead at the first call.