# Parallel loops through numpy arrays with Dask and Joblib
We look at speeding up loops through numpy arrays. In this example we have to call a third-party library in each iteration and this third-party library will only accept a subset of our total array. As we are calling a third-party library we can't apply tricks like JIT compilation.

The scenario here is that we have a 3-dimensional array with dimensions (x,y,time). We will imagine that this is a time series of 2-dimensional maps of ocean salinity. Our third-party library is the seawater library. This seawater library only accepts 2-dimensional inputs so we need to loop through the time dimension and call this library on each iteration. 

# Libraries
In this example we will use the [Joblib](https://joblib.readthedocs.io/en/latest/) and [dask](https://docs.dask.org/en/stable/) libraries.  In the case of Dask we are using the dask delayed module for parallelising the loop.

In [17]:
import numpy as np
from joblib import Parallel,delayed
import dask

import gsw
%load_ext memory_profiler

The memory_profiler extension is already loaded. To reload it, use:
  %reload_ext memory_profiler


# Generate data
We generate the inputs we need for the seawater library. We add arguments to allow us to specify the size of the array.

In [2]:
def generateData(xyLength:int,timesteps:int):
    SPTimeseries = 35 + np.random.standard_normal(size=(xyLength,xyLength,timesteps))
    assert SPTimeseries.shape == (xyLength,xyLength,timesteps)
    p = 0
    lon = np.tile(np.linspace(0,100,xyLength)[:,np.newaxis],xyLength)
    assert lon.shape == (xyLength,xyLength)
    lat = np.tile(np.linspace(-30,30,xyLength)[:,np.newaxis],xyLength)
    assert lat.shape == (xyLength,xyLength)
    return SPTimeseries,p,lon,lat

SPTimeseries,p,lon,lat = generateData(xyLength=3,timesteps=3)    

We define the function that we are going to call in each iteration `getAbsoluteSalinity`

In [3]:
def getAbsoluteSalinity(SPSnapshot:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray):
    return gsw.SA_from_SP(SPSnapshot,p,lon,lat)

First we create a baseline non-parallelised function to do sequential processing

In [1]:
def sequentialProcessing(SPTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,):
    return np.stack(
        [getAbsoluteSalinity(SPTimeseries[:,:,timestep], p, lon, lat) for timestep in range(SPTimeseries.shape[2])],
        axis=2)

outputSeq = sequentialProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)
assert outputSeq.shape == SPTimeseries.shape

NameError: name 'np' is not defined

Now we create a parallel processing function using Joblib and test to make sure the outputs are the same as for the sequential processing

In [5]:
def parallelProcessing(SPTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,backend = "threading",n_jobs:int=2):
    return np.stack(
        Parallel(n_jobs=n_jobs, backend=backend)(delayed(getAbsoluteSalinity)(SPTimeseries[:,:,timestep], p, lon, lat) for timestep in range(SPTimeseries.shape[2])),
    axis=2)

outputParallel = parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=1)
assert outputParallel.shape == SPTimeseries.shape
np.testing.assert_array_equal(outputSeq,outputParallel)

How about doing the calculation with 32-bit floats instead of 64-bit floats?  We can re-use the 64-bit function but we will cast the input to to 32-bit floats. 

Note that we have to test the output with `np.testing.assert_array_almost_equal` to 5-decimal places to check the output is nearly equivalent.

In [6]:
SPTimeseries32 = SPTimeseries.astype(np.float32)
outputParallel32 = parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=1)
assert outputParallel32.shape == SPTimeseries.shape
np.testing.assert_array_almost_equal(outputSeq,outputParallel32,decimal=5)

Finally we create a function for processing using dask delayed. We test each time to make sure that the outputs are the same in each case.

In [7]:
def daskDelayedProcessing(SPTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray):
    outputs = []
    for timestep in range(SPTimeseries.shape[2]):
        y = dask.delayed(getAbsoluteSalinity)(SPTimeseries[:,:,timestep], p, lon, lat)
        outputs.append(y)
    return np.stack(dask.compute(*outputs),axis=2)
outputDask = daskDelayedProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)
assert outputDask.shape == SPTimeseries.shape
np.testing.assert_array_equal(outputParallel,outputDask)

In [63]:
xyLength = 1000
timesteps = 100
SPTimeseries,p,lon,lat = generateData(xyLength=xyLength,timesteps=timesteps)    

In [64]:
%timeit -n 1 -r 1 sequentialProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)

12.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [65]:
%timeit -n 1 -r 1 parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=2)

6.31 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [68]:
%timeit -n 1 -r 1 parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=4)

3.75 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [69]:
%timeit -n 1 -r 1 parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=8)

3.96 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [73]:
SPTimeseries32 = SPTimeseries.astype(np.float32)
%timeit parallelProcessing(SPTimeseries=SPTimeseries32,p=p,lon=lon,lat=lat,n_jobs=8)

2.9 s ± 34.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
%memit sequentialProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)

peak memory: 3301.66 MiB, increment: 1464.84 MiB


In [43]:
%memit parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=8)

peak memory: 3300.92 MiB, increment: 1525.00 MiB


In [74]:
%memit parallelProcessing32(SPTimeseries=SPTimeseries32,p=p,lon=lon,lat=lat,n_jobs=8)

peak memory: 3816.03 MiB, increment: 1585.90 MiB


In [47]:
%memit daskDelayedProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)

peak memory: 2436.57 MiB, increment: 1356.46 MiB


In [45]:
%timeit daskDelayedProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)

4.78 s ± 101 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [49]:
SPTimeseries.dtype

dtype('float64')