# Parallel loops through numpy arrays
We look at speeding up loops through numpy arrays. In this example we have to call a third-party library in each iteration and this third-party library will only accept a subset of our total array. As we are calling a third-party library we can't apply tricks like JIT compilation.

The scenario here is that we have a 3-dimensional array with dimensions (x,y,time). We will imagine that this is a time series of 2-dimensional maps of ocean salinity. Our third-party library is the seawater library. We assume that the seawater library only accepts 2-dimensional inputs so we need to loop through the time dimension and call this library on each iteration. 

# Libraries

#### In this example we will use the built-in [concurrent.futures](https://docs.python.org/3/library/concurrent.futures.html) module, [Joblib](https://joblib.readthedocs.io/en/latest/) and [dask](https://docs.dask.org/en/stable/) libraries.  In the case of Dask we are using the `dask delayed` API for parallelising the loop.

In [25]:
# May need to install gsw seawater library here
# !conda install --yes -c conda-forge gsw

In [2]:
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor,ThreadPoolExecutor

import numpy as np
from joblib import Parallel,delayed
import dask

import gsw
# Import the gswImport function from a script so the parallel processing will run it
import gswImport


from IPython.display import Markdown, display

In [3]:
def printmd(string):
    display(Markdown(string))

# Generate data
We generate the inputs we need for the seawater library. We add arguments to allow us to specify the size of the array.

In [16]:
def generateData(xyLength:int,timesteps:int):
    SalinityTimeseries = 35 + np.random.standard_normal(size=(xyLength,xyLength,timesteps))
    assert SalinityTimeseries.shape == (xyLength,xyLength,timesteps)
    p = 0
    lon = np.tile(np.linspace(0,100,xyLength)[:,np.newaxis],xyLength)
    assert lon.shape == (xyLength,xyLength)
    lat = np.tile(np.linspace(-30,30,xyLength)[:,np.newaxis],xyLength)
    assert lat.shape == (xyLength,xyLength)
    return SalinityTimeseries,p,lon,lat

SalinityTimeseries,p,lon,lat = generateData(xyLength=3,timesteps=5)    

#### We define the function that we are going to call in each iteration: `getAbsoluteSalinity`

In [17]:
def getAbsoluteSalinity(SalinitySnapshot:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,index=None):
    return (gsw.SA_from_SP(SalinitySnapshot,p,lon,lat),index)

#### Serial processing

In [18]:
def sequentialProcessing(SalinityTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,):
    return np.stack(
        [getAbsoluteSalinity(SalinityTimeseries[:,:,timestep], p, lon, lat)[0] for timestep in range(SPTimeseries.shape[2])],
        axis=2)

outputSeq = sequentialProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat)

# Concurrent.futures

In [20]:
def sortResultsToArray(results:list):
    # Make sure the results are sorted correctly by timestamp
    results = sorted(results,key=lambda x:x[1])
    # Drop the timestamp index from each element
    results = [el[0] for el in results]
    # Turn the list of arrays back into a single array
    results = np.stack(results,axis=2)
    return results

In [21]:
def concurrentProcessing(SalinityTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,backend = "multiprocessing"):
    results = []
    if backend == "multiprocessing":
        executor = ProcessPoolExecutor()
    else:
        executor = ThreadPoolExecutor()
    with executor as pool:
        futr_results = [
            pool.submit(
                gswImport.getAbsoluteSalinity,SalinityTimeseries[:,:,timestep],p,lon,lat,timestep) for timestep in range(SalinityTimeseries.shape[2]
                                                                                               )
        ]
        for future in futr_results: 
            results.append(future.result())
    results = sortResultsToArray(results=results)
    return results
# Check the function runs
outputConcurrent = concurrentProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat,backend = "multiprocessing")
# Check the output matches the serial run
np.testing.assert_array_equal(outputSeq,outputConcurrent)

# Joblib

In [23]:
def joblibProcessing(SalinityTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray,backend = "threading",n_jobs:int=2):
    results = Parallel(n_jobs=n_jobs, backend=backend)(delayed(getAbsoluteSalinity)(SalinityTimeseries[:,:,timestep], p, lon, lat,timestep) for timestep in range(SalinityTimeseries.shape[2]))
    results = sortResultsToArray(results=results)
    return results

outputParallel = joblibProcessing(SPTimeseries=SPTimeseries,p=p,lon=lon,lat=lat,n_jobs=1)
np.testing.assert_array_equal(outputSeq,outputParallel)

# Dask Delayed

In [24]:
def daskDelayedProcessing(SalinityTimeseries:np.ndarray,p:int,lon:np.ndarray,lat:np.ndarray):
    results = []
    for timestep in range(SalinityTimeseries.shape[2]):
        results.append(dask.delayed(
            getAbsoluteSalinity)(SalinityTimeseries[:,:,timestep], p, lon, lat,timestep)
                      )
    results = dask.compute(*results)
    results = sortResultsToArray(results=results)
    return results

outputDask = daskDelayedProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat)
np.testing.assert_array_equal(outputParallel,outputDask)

#### Generate some larger data for timings

In [10]:
xyLength = 1000
timesteps = 100
SPTimeseries,p,lon,lat = generateData(xyLength=xyLength,timesteps=timesteps)    

#### Do timings

In [12]:
printmd("**Serial processing**")
%timeit -n 1 -r 5 sequentialProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat)
printmd("**Concurent.futures multiprocessing**")
%timeit -n 1 -r 5 concurrentProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat,backend = "multiprocessing")
printmd("**Concurent.futures threading**")
%timeit -n 1 -r 5 concurrentProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat,backend = "threading")
printmd("**Joblib processing 2 jobs**")
%timeit -n 1 -r 5 joblibProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat,n_jobs=2)
printmd("**Joblib processing 4 jobs**")
%timeit -n 1 -r 5 joblibProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat,n_jobs=4)
printmd("**Joblib processing 8 jobs**")
%timeit -n 1 -r 5 joblibProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat,n_jobs=8)
printmd("**Dask delayed processing**")
%timeit  -n 1 -r 5 daskDelayedProcessing(SalinityTimeseries=SalinityTimeseries,p=p,lon=lon,lat=lat)

**Serial processing**

12.6 s ± 374 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


**Concurent.futures multiprocessing**

12.5 s ± 800 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


**Concurent.futures threading**

3.63 s ± 193 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


**Joblib processing 2 jobs**

6.46 s ± 300 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


**Joblib processing 4 jobs**

3.77 s ± 19.6 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


**Joblib processing 8 jobs**

3.43 s ± 35 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


**Dask delayed processing**

3.57 s ± 66.1 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
