# Parallelization with Dask

[Dask](https://docs.dask.org/en/latest/array.html) 

<img src="_static/dask.png" width="500">

## Exercise, Xarray and XGCM with Dask

Let's load a larger dataset (~2.5GB). Here is where Xarray really shines! The data can be easily `chunked` into Dask arrays. First we will redo the divergence calculation from the Xarray tutorial, without any Dask chunks.

In [1]:
import xarray as xr
import warnings
warnings.filterwarnings("ignore")

ds = xr.open_dataset('TPOSE6_Daily_2012.nc',decode_timedelta=True)

In [2]:
ds.THETA

In [3]:
ds.THETA.data

array([[[[25.939558 , 25.918556 , 25.90393  , ..., 24.047636 ,
          24.014856 , 23.986374 ],
         [25.945522 , 25.92448  , 25.90966  , ..., 23.998495 ,
          23.982985 , 23.966032 ],
         [25.954805 , 25.934542 , 25.920166 , ..., 23.906435 ,
          23.915352 , 23.919128 ],
         ...,
         [25.90544  , 25.898523 , 25.894108 , ..., 26.610434 ,
          26.627668 , 26.6449   ],
         [25.910702 , 25.900629 , 25.893917 , ..., 26.652458 ,
          26.670025 , 26.687704 ],
         [25.904188 , 25.89322  , 25.884508 , ..., 26.696756 ,
          26.713339 , 26.729265 ]],

        [[25.933687 , 25.912369 , 25.897388 , ..., 24.03935  ,
          24.006435 , 23.977962 ],
         [25.939726 , 25.918404 , 25.903294 , ..., 23.990353 ,
          23.974865 , 23.957865 ],
         [25.949148 , 25.928549 , 25.913889 , ..., 23.898745 ,
          23.907852 , 23.911688 ],
         ...,
         [25.900799 , 25.893978 , 25.889643 , ..., 26.60659  ,
          26.623999 , 26.

In [4]:
import xgcm 
import cmocean.cm as cmo

# create the grid object from our dataset
grid = xgcm.Grid(ds, periodic=['X','Y'])
grid

<xgcm.Grid>
X Axis (periodic, boundary=None):
  * center   XC --> outer
  * outer    XG --> center
Z Axis (not periodic, boundary=None):
  * center   Z
T Axis (not periodic, boundary=None):
  * center   time
Y Axis (periodic, boundary=None):
  * center   YC --> outer
  * outer    YG --> center

In [5]:
%%time
u_transport = ds.UVEL * ds.dyG * ds.hFacW * ds.drF
v_transport = ds.VVEL * ds.dxG * ds.hFacS * ds.drF
div_uv = (grid.diff(u_transport, 'X') + grid.diff(v_transport, 'Y')) / ds.rA  # calculate the divergence of the flow

div_uv.compute()

CPU times: user 349 ms, sys: 605 ms, total: 955 ms
Wall time: 954 ms


In [6]:
del div_uv, u_transport, v_transport

### Xarray automatically sped up with Dask

If we `chunk()` our data, then Xarray and Dask will change the underlying array to a Dask array. We can then do parallel computing on our data. We can see the size of the chunks through inspection of the DataSet.

To start we will try chunking the data by sets of 10 timesteps.

In [7]:
ds = ds.chunk({'time': 10})

In [8]:
ds.THETA.data

Unnamed: 0,Array,Chunk
Bytes,619.23 MiB,16.92 MiB
Shape,"(366, 22, 84, 240)","(10, 22, 84, 240)"
Dask graph,37 chunks in 2 graph layers,37 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 619.23 MiB 16.92 MiB Shape (366, 22, 84, 240) (10, 22, 84, 240) Dask graph 37 chunks in 2 graph layers Data type float32 numpy.ndarray",366  1  240  84  22,

Unnamed: 0,Array,Chunk
Bytes,619.23 MiB,16.92 MiB
Shape,"(366, 22, 84, 240)","(10, 22, 84, 240)"
Dask graph,37 chunks in 2 graph layers,37 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Fantastic. That is exactly what we would hope to see. Now, can we do our divergence calculation any faster?? 

In [9]:
# create the grid object from our dataset
grid = xgcm.Grid(ds, periodic=['X','Y'])
grid

<xgcm.Grid>
X Axis (periodic, boundary=None):
  * center   XC --> outer
  * outer    XG --> center
Z Axis (not periodic, boundary=None):
  * center   Z
T Axis (not periodic, boundary=None):
  * center   time
Y Axis (periodic, boundary=None):
  * center   YC --> outer
  * outer    YG --> center

In [10]:
%%time
u_transport = ds.UVEL * ds.dyG * ds.hFacW * ds.drF
v_transport = ds.VVEL * ds.dxG * ds.hFacS * ds.drF
div_uv = (grid.diff(u_transport, 'X') + grid.diff(v_transport, 'Y')) / ds.rA  # calculate the divergence of the flow

div_uv.compute()

CPU times: user 769 ms, sys: 612 ms, total: 1.38 s
Wall time: 440 ms


We can see that without Dask, this computation took a little over 1 second. With Dask it only took 450ms. That is more than a 2x speed up! When you take into account that this subset of the model output is less than 2% of the full model domain, that speed up starts to look pretty nice! If you are clever about when and how you chunk your data, you can get much more than a 2x speed up. 

**NOTE** Another reason we only see 2x speed up here is that these are very small chunks of data (17MB), much smaller than what we can hold in memory. There is some overhead to parallelization (your computer has to do some logistic in the background). Really small chunks are inefficient because the overhead and the computation itself may take similar amounts of time. Depending on your system, you may want to aim for ~100MB or even 1GB in a chunk. 


More resources for Xarray and Dask: [1](https://docs.xarray.dev/en/v2023.01.0/user-guide/dask.html), [2](https://examples.dask.org/xarray.html), [3](https://tutorial.xarray.dev/intermediate/xarray_and_dask.html)  
See [this page](https://docs.xarray.dev/en/stable/user-guide/dask.html#best-practices) for a more detailed discussion of best practices with Xarray and Dask.