# Parallelization with Dask

[Dask](https://docs.dask.org/en/latest/array.html) 

<img src="_static/dask.png" width="500">

## Exercise, Xarray and XGCM with Dask

Let's load a larger dataset (~2.5GB). Here is where Xarray really shines! The data can be easily `chunked` into Dask arrays. First we will redo the divergence calculation from the Xarray tutorial, without any Dask chunks.

In [1]:
import xarray as xr
import warnings
warnings.filterwarnings("ignore")

ds = xr.open_dataset('TPOSE6_Daily_2012.nc',decode_timedelta=True)

In [2]:
ds.THETA

In [3]:
ds.THETA

In [4]:
import xgcm 
import cmocean.cm as cmo

# create the grid object from our dataset
grid = xgcm.Grid(ds, periodic=['X','Y'])
grid

<xgcm.Grid>
Z Axis (not periodic, boundary=None):
  * center   Z
T Axis (not periodic, boundary=None):
  * center   time
X Axis (periodic, boundary=None):
  * center   XC --> outer
  * outer    XG --> center
Y Axis (periodic, boundary=None):
  * center   YC --> outer
  * outer    YG --> center

In [5]:
%%time
u_transport = ds.UVEL * ds.dyG * ds.hFacW * ds.drF
v_transport = ds.VVEL * ds.dxG * ds.hFacS * ds.drF
div_uv = (grid.diff(u_transport, 'X') + grid.diff(v_transport, 'Y')) / ds.rA  # calculate the divergence of the flow

div_uv.compute()

CPU times: user 358 ms, sys: 696 ms, total: 1.05 s
Wall time: 1.11 s


In [6]:
del div_uv, u_transport, v_transport

### Xarray automatically sped up with Dask

If we `chunk()` our data, then Xarray and Dask will change the underlying array to a Dask array. We can then do parallel computing on our data. We can see the size of the chunks through inspection of the DataSet.

To start we will try chunking the data by sets of 10 timesteps.

In [7]:
ds = ds.chunk({'time': 10})

In [11]:
ds.THETA

Unnamed: 0,Array,Chunk
Bytes,619.23 MiB,16.92 MiB
Shape,"(366, 22, 84, 240)","(10, 22, 84, 240)"
Dask graph,37 chunks in 2 graph layers,37 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 619.23 MiB 16.92 MiB Shape (366, 22, 84, 240) (10, 22, 84, 240) Dask graph 37 chunks in 2 graph layers Data type float32 numpy.ndarray",366  1  240  84  22,

Unnamed: 0,Array,Chunk
Bytes,619.23 MiB,16.92 MiB
Shape,"(366, 22, 84, 240)","(10, 22, 84, 240)"
Dask graph,37 chunks in 2 graph layers,37 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.86 kiB,80 B
Shape,"(366,)","(10,)"
Dask graph,37 chunks in 2 graph layers,37 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 2.86 kiB 80 B Shape (366,) (10,) Dask graph 37 chunks in 2 graph layers Data type int64 numpy.ndarray",366  1,

Unnamed: 0,Array,Chunk
Bytes,2.86 kiB,80 B
Shape,"(366,)","(10,)"
Dask graph,37 chunks in 2 graph layers,37 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,88 B,88 B
Shape,"(22,)","(22,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 88 B 88 B Shape (22,) (22,) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",22  1,

Unnamed: 0,Array,Chunk
Bytes,88 B,88 B
Shape,"(22,)","(22,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,88 B,88 B
Shape,"(22,)","(22,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 88 B 88 B Shape (22,) (22,) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",22  1,

Unnamed: 0,Array,Chunk
Bytes,88 B,88 B
Shape,"(22,)","(22,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,88 B,88 B
Shape,"(22,)","(22,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 88 B 88 B Shape (22,) (22,) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",22  1,

Unnamed: 0,Array,Chunk
Bytes,88 B,88 B
Shape,"(22,)","(22,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 78.75 kiB 78.75 kiB Shape (84, 240) (84, 240) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",240  84,

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 78.75 kiB 78.75 kiB Shape (84, 240) (84, 240) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",240  84,

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.69 MiB,1.69 MiB
Shape,"(22, 84, 240)","(22, 84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.69 MiB 1.69 MiB Shape (22, 84, 240) (22, 84, 240) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",240  84  22,

Unnamed: 0,Array,Chunk
Bytes,1.69 MiB,1.69 MiB
Shape,"(22, 84, 240)","(22, 84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,433.12 kiB,433.12 kiB
Shape,"(22, 84, 240)","(22, 84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 433.12 kiB 433.12 kiB Shape (22, 84, 240) (22, 84, 240) Dask graph 1 chunks in 2 graph layers Data type bool numpy.ndarray",240  84  22,

Unnamed: 0,Array,Chunk
Bytes,433.12 kiB,433.12 kiB
Shape,"(22, 84, 240)","(22, 84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 78.75 kiB 78.75 kiB Shape (84, 240) (84, 240) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",240  84,

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 78.75 kiB 78.75 kiB Shape (84, 240) (84, 240) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",240  84,

Unnamed: 0,Array,Chunk
Bytes,78.75 kiB,78.75 kiB
Shape,"(84, 240)","(84, 240)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Fantastic. That is exactly what we would hope to see. Now, can we do our divergence calculation any faster?? 

In [9]:
# create the grid object from our dataset
grid = xgcm.Grid(ds, periodic=['X','Y'])
grid

<xgcm.Grid>
Z Axis (not periodic, boundary=None):
  * center   Z
T Axis (not periodic, boundary=None):
  * center   time
X Axis (periodic, boundary=None):
  * center   XC --> outer
  * outer    XG --> center
Y Axis (periodic, boundary=None):
  * center   YC --> outer
  * outer    YG --> center

In [10]:
%%time
u_transport = ds.UVEL * ds.dyG * ds.hFacW * ds.drF
v_transport = ds.VVEL * ds.dxG * ds.hFacS * ds.drF
div_uv = (grid.diff(u_transport, 'X') + grid.diff(v_transport, 'Y')) / ds.rA  # calculate the divergence of the flow

div_uv.compute()

CPU times: user 785 ms, sys: 673 ms, total: 1.46 s
Wall time: 482 ms


We can see that without Dask, this computation took a little over 1 second. With Dask it only took 450ms. That is more than a 2x speed up! When you take into account that this subset of the model output is less than 2% of the full model domain, that speed up starts to look pretty nice! If you are clever about when and how you chunk your data, you can get much more than a 2x speed up. 

**NOTE** Another reason we only see 2x speed up here is that these are very small chunks of data (17MB), much smaller than what we can hold in memory. There is some overhead to parallelization (your computer has to do some logistics in the background). Really small chunks are inefficient because the overhead and the computation itself may take similar amounts of time. Depending on your system, you may want to aim for ~100MB or even 1GB in a chunk. 


More resources for Xarray and Dask: [1](https://docs.xarray.dev/en/v2023.01.0/user-guide/dask.html), [2](https://examples.dask.org/xarray.html), [3](https://tutorial.xarray.dev/intermediate/xarray_and_dask.html)  
See [this page](https://docs.xarray.dev/en/stable/user-guide/dask.html#best-practices) for a more detailed discussion of best practices with Xarray and Dask.