In [1]:
# details of the processor
!cat /proc/cpuinfo | grep -m 1 "model name"

model name	: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
cat: write error: Broken pipe


In [2]:
import numpy as np

np.__version__

'1.18.1'

# Writing Fast Python



In [3]:
def single_gaussian(x, A, x0, w):
    return A * np.exp(-(x - x0)**2. / (2. * w**2.))

In [4]:
x = np.linspace(-10., 10., 300000)

In [5]:
%timeit single_gaussian(x, 1., 0., 1.)

3.4 ms ± 36.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
import numba

numba.__version__

'0.48.0'

In [7]:
jit_single_gaussian = numba.jit(single_gaussian, nopython=True, fastmath=True)

In [8]:
%timeit jit_single_gaussian(x, 1., 0., 1.)

355 µs ± 5.58 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Calculate multiple Gaussians now

In [9]:
def first_multiple_gaussians(x: np.ndarray, A: np.ndarray, x0: np.ndarray, w: np.ndarray):
    Y = np.zeros((x.size, len(A)))
    ngaussians = len(A)
    # loop over each Gaussian
    for i in range(ngaussians):
        # Use the previously jit'd function to compute gaussians
        Y[:,i] = jit_single_gaussian(x, A[i], x0[i], w[i])
    return Y

In [10]:
ngaussians = 100

A = np.random.normal(5., 2., ngaussians)
x0 = np.random.rand(ngaussians)
w = abs(np.random.rand(ngaussians))

In [11]:
%timeit first_multiple_gaussians(x, A, x0, w)

392 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
jit_first_multiple_gaussians = numba.jit(first_multiple_gaussians, nopython=True, fastmath=True)

In [13]:
%timeit jit_first_multiple_gaussians(x, A, x0, w)

412 ms ± 5.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Parallelize for loop

A trivial way of improving the performance is by making the `for` loop parallelizable: we can do different parts of the loop asynchronously, instead of in a serial fashion. The `jit` functionality in Numba gives us the option to provide a kwarg, `parallel=True`, which will then automatically parallelize eligible parts of the code. In particular, we replace the Python `range` generator with `numba.prange`, which will parallelize the `for` loop.

The rest of the code is identical!

In [14]:
@numba.njit(fastmath=True, parallel=True, nogil=True)
def second_multiple_gaussians(x: np.ndarray, A: np.ndarray, x0: np.ndarray, w: np.ndarray):
    Y = np.zeros((x.size, len(A)))
    ngaussians = len(A)
    # Replace range generator with numba.prange to parallelize for loop
    for i in numba.prange(ngaussians):
        # Use the previously jit'd function to compute gaussians
        Y[:,i] = jit_single_gaussian(x, A[i], x0[i], w[i])
    return Y

In [15]:
%timeit second_multiple_gaussians(x, A, x0, w)

261 ms ± 6.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Being more clever about vectorization

In the previous implementations, we used a naive `for` loop to iterate over each of our Gaussians. By taking advantage of vectorization, however, we can potentially obtain a significant improvement in performance: in a `for` loop, NumPy will complete a computation and return back to Python between every iteration, whereas vectorization means handing off one big task to NumPy to complete without needing to go back to Python until it's done. The back and forth communication involved overhead, and definitely adds the bigger the loop.

In the implementation below, we're going to use pure NumPy, without any JIT compilation. Instead of looping and computing the amplitude of each Gaussian per iteration, we're going to compute them all at once! To obtain our 2D array of values, we take the `x` array, reshape and repeat it so that every row corresponds to a Gaussian. We will then reshape the Gaussian parameters into 2D arrays with a single column; each row also corresponds to the parameters for a Gaussian. Taking advantage of element-wise broadcasting, we are effectively sweeping across `x`, and transforming each element into the corresponding Gaussian amplitude gradually.

A big part of this is the ability to perform all the operations inplace: as you'll see in the code below, we take extra care to make sure there are no re-allocations of arrays, which will induce some overhead where NumPy will internally make copies of arrays if you're not careful. For example:

```Y = np.square(Y)```

will cause NumPy to make a copy of the array, even though logically (to us) we are taking the square of the array and setting it as itself. Core NumPy math functions have an `out` keyword which will dump the result into a pre-allocated array, so `Y = np.square(Y, out=Y)` will tell NumPy there is no need to make a copy of an array, just send the outputs back to `Y`!

In [16]:
def vectorized_multiple_gaussians(x: np.ndarray, A: np.ndarray, x0: np.ndarray, w: np.ndarray):
    # get the expected number of Gaussians
    ngaussians = len(A)
    # first reshape the 1D x array into a one row 2D array, then
    # repeat the row `ngaussian` number of times along rows
    Y = np.repeat(x.reshape(1, -1), ngaussians, axis=0)
    # reshape the parameters into a column vector, so that the number
    # of rows match the number of Gaussians
    A, x0, w = A.reshape(-1, 1), x0.reshape(-1, 1), w.reshape(-1, 1)
    # start performing in-place operations, transforming the x values
    # into the Gaussian amplitude corresponding to each row
    Y -= x0
    Y = np.square(Y, out=Y)
    Y /= (2 * w**2)
    Y = np.negative(Y, out=Y)
    Y = np.exp(Y, out=Y)
    Y *= A
    return Y

In [17]:
%timeit vectorized_multiple_gaussians(x, A, x0, w)

317 ms ± 9.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Our vectorized multiple Gaussian implementation gets us much closer to the parallelized Numba JIT'd version, even without compilation or optimizations! That means we could potentially beat the previous implementation if we add a touch of JIT! However, unfortunately it seems that the Numba analogs of NumPy functions do not support the `out` keywords, and so there may be some overhead. Another problem is that the array `repeat` function also does not supply an `axis` argument, and so we have to split the function into two, and only JIT the actual computation.

In [18]:
@numba.njit(fastmath=True, nogil=True, parallel=True)
def jit_vec_multi_gaussian(Y: np.ndarray, A: np.ndarray, x0: np.ndarray, w: np.ndarray):
    A, x0, w = A.reshape(-1, 1), x0.reshape(-1, 1), w.reshape(-1, 1)
    Y -= x0
    Y *= Y
    Y /= (2 * w**2)
    Y = np.negative(Y)
    Y = np.exp(Y)
    Y *= A
    return Y

def jit_vec_wrapper(x: np.ndarray, A: np.ndarray, x0: np.ndarray, w: np.ndarray):
    ngaussians = len(A)
    Y = np.repeat(x.reshape(1, -1), ngaussians, axis=0)
    Y = jit_vec_multi_gaussian(Y, A, x0, w)
    return Y

This halved our computation time! If you take the JIT flags out, you'll find that the biggest difference is provided by `parallel=True`: our arrays are big enough that parallelizing the computation on a single array provides more benefit than the overhead associated with parallelization.

In [19]:
%timeit jit_vec_wrapper(x, A, x0, w)

164 ms ± 9.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Using Jax JIT compilation

Jax is a relatively new library that provides a different kind of JIT compilation to Numba: instead of relying generating LLVM code, Jax uses XLA (a linear algebra optimization library). Since this type of optimization is more or less entirely orthogonal, we will get different kinds of speed ups here. One of the more notable ones is that `jp.repeat` supports an `axis` argument, meaning we can enclose the entire operation in a single function. The rest of the code is more or less the same.

In [20]:
import jax
import jax.numpy as jp

jax.__version__

'0.1.62'

In [21]:
def jax_vectorized_multiple_gaussians(x: np.ndarray, A: np.ndarray, x0: np.ndarray, w: np.ndarray):
    ngaussians = len(A)
    Y = jp.repeat(x.reshape(1, -1), ngaussians, axis=0)
    A, x0, w = A.reshape(-1, 1), x0.reshape(-1, 1), w.reshape(-1, 1)
    Y -= x0
    Y = jp.square(Y)
    Y /= (2 * w**2)
    Y = jp.exp(-Y)
    Y *= A
    return Y

In [22]:
# generate a JIT'd version of the function
jax_jit_vec = jp.jit(jax_vectorized_multiple_gaussians)

In [23]:
%timeit jax_jit_vec(x, A, x0, w)



15.2 ms ± 596 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


This is an order of magnitude faster! I'm not 100% sure what the XLA compilation is doing better than the Numba LLVM in this case, and so I would recommend trying both methods when needing to write code for speed and experimenting.