# Parallel loops through numpy arrays with Dask and Joblib
We look at speeding up loops through numpy arrays. In this example we have to call a third-party library in each iteration and this third-party library will only accept a subset of our total array. As we are calling a third-party library we can't apply tricks like JIT compilation.

The scenario here is that we have a 3-dimensional array with dimensions (x,y,time). We will imagine that this is a time series of 2-dimensional maps of ocean salinity. Our third-party library is the seawater library. This seawater library only accepts 2-dimensional inputs so we need to loop through the time dimension and call this library on each iteration. 

# Libraries
In this example we will use the [Joblib](https://joblib.readthedocs.io/en/latest/) and [dask](https://docs.dask.org/en/stable/) libraries.  In the case of Dask we are using the dask delayed module for parallelising the loop.

In [10]:
!pip install numpy --upgrade
!pip install gsw --upgrade

You should consider upgrading via the '/usr/local/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/usr/local/bin/python -m pip install --upgrade pip' command.[0m


In [2]:
import numpy as np
from joblib import Parallel,delayed
import dask

# import gsw


# Generate data
We generate the inputs we need for the seawater library. We add arguments to allow us to specify the size of the array.

In [3]:
def generateData(xyLength:int,timesteps:int):
    SPTimeseries = 35 + np.random.standard_normal(size=(xyLength,xyLength,timesteps))
    assert SPTimeseries.shape == (xyLength,xyLength,timesteps)
    p = 0
    lon = np.tile(np.linspace(0,100,xyLength)[:,np.newaxis],xyLength)
    assert lon.shape == (xyLength,xyLength)
    lat = np.tile(np.linspace(-30,30,xyLength)[:,np.newaxis],xyLength)
    assert lat.shape == (xyLength,xyLength)
    return SPTimeseries,p,lon,lat

SPTimeseries,p,lon,lat = generateData(xyLength=3,timesteps=3)    

We define the function that we are going to call in each iteration `getAbsoluteSalinity`

In [4]:
def getAbsoluteSalinity(SPSnapshot:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray):
    return SPSnapshot

First we create a baseline non-parallelised function to do sequential processing

In [5]:
def sequentialProcessing(SPTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,):
    return np.stack(
        [getAbsoluteSalinity(SPTimeseries[:,:,timestep], p, lon, lat) for timestep in range(SPTimeseries.shape[2])],
        axis=2)

outputSeq = sequentialProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)
assert outputSeq.shape == SPTimeseries.shape

Now we create a parallel processing function using Joblib and test to make sure the outputs are the same as for the sequential processing

In [6]:
def joblibProcessing(SPTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,backend = "threading",n_jobs:int=2):
    return np.stack(
        Parallel(n_jobs=n_jobs, backend=backend)(delayed(getAbsoluteSalinity)(
            SPTimeseries[:,:,timestep], p, lon, lat) for timestep in range(SPTimeseries.shape[2])),
    axis=2)

outputParallel = joblibProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=1)
np.testing.assert_array_equal(outputSeq,outputParallel)

Finally we create a function for processing using dask delayed. We test each time to make sure that the outputs are the same in each case.

In [None]:
def daskDelayedProcessing(SPTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray):
    outputs = []
    for timestep in range(SPTimeseries.shape[2]):
        y = dask.delayed(getAbsoluteSalinity)(SPTimeseries[:,:,timestep], p, lon, lat)
        outputs.append(y)
    return np.stack(dask.compute(*outputs),axis=2)
outputDask = daskDelayedProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)
assert outputDask.shape == SPTimeseries.shape
np.testing.assert_array_equal(outputParallel,outputDask)

In [None]:
xyLength = 1000
timesteps = 100
SPTimeseries,p,lon,lat = generateData(xyLength=xyLength,timesteps=timesteps)    

In [None]:
%timeit -n 1 -r 1 sequentialProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)

In [None]:
%timeit -n 1 -r 1 parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=2)

In [None]:
%timeit -n 1 -r 1 parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=4)

In [None]:
%timeit -n 1 -r 1 parallelProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=8)

In [None]:
SPTimeseries32 = SPTimeseries.astype(np.float32)
%timeit parallelProcessing(SPTimeseries=SPTimeseries32,p=p,lon=lon,lat=lat,n_jobs=8)

In [None]:
%timeit daskDelayedProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat)