# Python Dask Demonstration on HPC Orion

Created: 7 Feb 2024

By: Kerrie Geil, Associate Research Professor, Geosystems Research Institute, Mississippi State University

**There are a few different ways to implement parallelization with python Dask on an HPC.** Dask includes a [few different schedulers](https://docs.dask.org/en/stable/scheduling.html) (schedulers send calculations to your cpu cores or threads). The code you write to implement dask with one vs the other scheduler is slightly different. Dask's single-machine scheduler is the default but is used mostly on pc's, it doesn't scale to an HPC environment. I prefer this scheduler on my laptop and workstation. For an HPC environment, Dask offers their distributed scheduler. And, the distributed scheduler has an option that will also run on a pc so that your code may be more transferable.

Using the **single-machine scheduler** on a pc is generally more "seamless" i.e. less code to write, more things working auto-magically behind the scenes. The limitation to the single-machine scheduler is that you can only parallelize across cpu cores on a single machine. The single machine scheduler isn't meant to work great in an HPC environment.

With the **distributed scheduler** in an HPC environment, you have a couple of options for building a cluster. When you set up a cluster you are defining how much computing power (cores, memory, time) you want to access. Dask's **SLURMCluster** function allows you to access multiple compute nodes simultaneously. Theoretically, you could access 10 entire nodes (400 cpu cores) per compute job. The biggest limitation to using SLURMCluster on HPC Orion is node availability and the overhead time it takes to start the cluster which is usually about 45-60 seconds. For SLURMCluster to work seamlessly, the cpu cores you request essentially need to be available immediately. This can't always be gauranteed on HPC Orion because there are a ton of other users on the system, so performance of SLURMCluster on Orion can be inconsistent or unpredictable. The other option for building a cluster is to use Dask's **LocalCluster** function. LocalCluster is kind of the HPC equivalent to using Dask's single-machine scheduler. LocalCluster allows you to access the cores and memory on a single node. On the orion partition this limits you to 40 cores and 192GB RAM. On the bigmem partition you'd be limited to 40 cores and 384GB RAM. Reminder, that info about partitions and memory, etc on Orion can be found in the [User Guide](https://intranet.hpc.msstate.edu/helpdesk/resource-docs/orion_guide.php) (username/password required). I find that even if you don't count the time it take SLURMCluster to start up, computing with LocalCluster is still a bit faster when running on 1 node's worth of cores. Another benefit to using the LocalCluster function is that it makes your code more transferable. Your code should also be able run in different HPC environments or on your own pc, it will just take longer because there are less cores.   

**My personal approach to parallelization** is to first attempt using the single-machine scheduler on my pc. If I need more compute power (because a code is taking too long to run or I need more memory) then I have two choices. I could use the Dask distributed scheduler with LocalCluster or SLURMCluster from a Jupyter notebook or a .py file. Or I could not use Dask or Jupyter notebooks at all, instead opting to run .py files in batch mode. This would require programming more of the parallelization manually in shell scripts that call the .py files. This is the "traditional" way of using an HPC. I try to always opt for Jupyter notebooks and LocalCluster if possible. If I still need more computational power I'll try SLURMCluster. SLURMCluster is good if my computation only needs a few nodes and doesn't need to be run frequently since SLURMCluster can be inconsistent due to system usage. If you don't want to babysit your code or you have code that is more operational (runs every day) and needs to succeed every time, then I'd abandon Jupyter notebooks in favor of batched .py scripts. 

**Now, we'll cover a couple different Dask parallelization techniques using Dask's distributed schedulers. These examples work best with gridded data in netcdf, zarr, or npy stacks.** Dask can also handle dataframes (tabular/txt data), but that won't be covered here.


## Set up

Generally it's good practice to import all your packages up front here.

However, we won't do that here in order to make it more clear which packages we are using at each stage of the notebook.

In [None]:
# what I would normally import here
# import glob
# from dask.distributed import Client,LocalCluster
# from dask_jobqueue import SLURMCluster
# import xarray as xr
# import dask.array as da
# import dask
# import matplotlib.pyplot as plt
# import numpy as np
# import getpass

In [None]:
# directory and file paths

datadir='/path/to/dir/where/Tmax/netcdf/lives/'
outdir='/path/to/dir/in/your/personal/workspace/'

# Dask distributed scheduler using LocalCluster

Computing on a single node. You will have access to the node you launched jupyer with. When you are going to use LocalCluster, my recommendation for launching Jupyter is to use number of nodes 1, number of tasks 40, additional slurm parameters --exclusive on the launch page.

We launched Jupyter on 1 node with 40 cores. On Orion each core has 2 threads. So the maximum threads we can have using Dask LocalCluster is 80. We won't actually need all those cores and threads for this notebook though. We'll set up our LocalCluster with 1/4 of the compute power available on our node: 20 Dask workers where each worker is 1 thread. Dask calls these parameters n_workers and threads_per_worker.  

In [None]:
print('importing...')
from dask.distributed import Client,LocalCluster

print('starting client...')
nworkers=20
cluster=LocalCluster(n_workers=nworkers,threads_per_worker=1)#, memory_limit="4.5GiB") # a cluster where each thread is a separate process or "worker"
client=Client(cluster)  # connect to your compute cluster
client.wait_for_workers(n_workers=nworkers,timeout=10) # wait up to 10s for the cluster to be fully ready, error if not ready in 10s
client # print info

### dask arrays

Dask arrays are great if your data is in netcdf or zarr format or npy stacks (usually gridded data).

Data in a Dask array is not held in memory. Also, Dask arrays are chunked (like divided into smaller subsets of the array). This allows for easy parallelization by spreading the chunks across many cpu cores when it comes time to execute a calculation.

Problems may arise for complicated calculations e.g. when a calculation requires calling multiple custom functions or subroutines and there are many variables involved. The problem usually arises because these type of calculations require holding too much data in memory. Behind the scenes, Dask attempts to estimate memory needs for every task, but when calculations get complicated sometimes Dask can't estimate memory needs well enough. In this situation, it's better to use dask delayed (demonstrated later).   

In [None]:
import glob

# glob is an easy way to get a list of multiple files
# we'll only use 1 of the files so this is just for demonstration purposes
f=glob.glob(datadir+'*.nc')
len(f),f[0:5]

#### let's talk chunking

Chunking means dividing up your array of data into separate parts in one or more dimensions.

For xarray, when you provide the chunk parameter your xarray object is backed by a dask array which is not held in memory. This makes computing with data bigger than memory easy. In the examples here our data is not bigger than memory, but the code scales.

My approach, when it's possible, is to choose a total number of chunks that is equal to the number of cores/threads that I have. This is when data is not small but can still fit into memory. If data is bigger than memory, I choose a total number of chunks that is a multiple of the cores/threads that I have. What dask does is cycle through the chunks, sending 1 or more chunks at a time to each core/thread.

It's not always best to make nchunks = ncores or nthreads though. The size (bytes) of a chunk also makes a difference. You don't want to make the chunks too small because the cycling through chunks adds overhead. For small data it will be faster to compute without chunking at all. You also don't want to make chunks too big or you could run out of memory as the cores compute. So how do you choose a good size chunk? Really it's trial and error. On HPC Orion, a compute node on the orion paritition has about 190GB of RAM shared across the 40 cores on each node. If you have 40 or more chunks and are operating on 1 node with 40 cores, you want to make sure your calculation on a single chunk won't need more RAM than is available on a single core (which is 190/40=4.75GB if you're using all cores). It is often not easy to estimate how much RAM will be used by a calculation, especially if the calculation is complex. Some good best practices for RAM management in your code are to avoid double precision data types when possible and to delete old variables that are hanging in memory after certain steps of a calculation that won't be used again. Personally, I usually try a couple different chunk sizes between 100MB and 1GB to see what ends up being fastest.

Here, I know that my data dimensions are 'time':365, 'lat':1800, 'lon':4320. Xarray's chunks parameter sets the chunk dimensions. So if I want to chunk over the longitude dimension and I want 18 chunks the xarray parameter is chunks={'time':-1,'lat':-1,'lon':240} where 240 comes from 4320/18. You could also set chunks as chunks={'time':365,'lat':1800,'lon':240} or chunks={'lon':240} without explicit mention of the other dimensions. All of these are interpreted identically.



In [None]:
# open chunked file
# ds is an xarray data structure filled with dask arrays

import xarray as xr

chunks={'time':-1,'lat':-1,'lon':240}

ds=xr.open_mfdataset(f[0],chunks=chunks,lock=False)
ds

### using built-in functions on dask arrays

In [None]:
%%time 
# send information about our data chunks and computation tasks
# here there were 2 tasks (layers) in the graph which are load and chunk
# persist starts move the data chunks to the workers in the background
# var=ds.Tmax.persist()
var=ds['Tmax-2m'].persist()

In [None]:
%%time
# most calculations will be lazy if you don't include .compute()
# lazy means the compute task is recorded on the dask graph but not executed
# .compute() is what executes the calculation
var_mean=var.mean('time')
var_mean

In [None]:
%%time
# execute the calculation
var_mean=var_mean.compute()
var_mean

In [None]:
# the above steps can also be combined into a single line

var_mean=var.mean('time').compute()

In [None]:
var_mean.plot()

what about calculations that are more than one line?

In [None]:
%%time
# you can string multiple calculations together and call compute only on the last variable to execute everything
var_monthly=var\
            .groupby('time.month')\
            .mean('time') # monthly means
month_minval=var_monthly\
            .min('month') # minimum of monthly means
month_minval.compute()

In [None]:
month_minval.plot()

#### custom compute functions on dask arrays

when we have written our own functions we can apply them to dask arrays with .map_blocks

In [None]:
# this function could contain anything, but we'll keep it simple here
def my_function(x):
    newval=x.mean('time')
    return newval

In [None]:
%%time
# xarray map_blocks
# var is technically an xarray object even though it's backed with dask arrays
# so map_blocks here is from the xarray library, see more at https://docs.xarray.dev/en/stable/generated/xarray.map_blocks.html
varmean=var.map_blocks(my_function).compute()
varmean.plot()

#### custom compute functions on overlapping chunks

where we want to apply convolutions we can use dask.array.map_overlap

In [None]:
import numpy as np

# imagine this function is applying a convolution filter or something more complicated
# we keep it simple here
def my_mult_function(x):
    newval=x*3.
    return newval

In [None]:
### this works if my_function doesn't change the data shape ###
import dask.array as da

# the dask function .map_overlap takes dask/numpy arrays as input, not xarray data structures
# our xarray data structure is already backed by dask/numpy arrays (because we accessed the nc file with xr.open_mfdataset with chunks) so
# all we have to do to access the data in dask/numpy format is to use the .data method on our xarray object

# we want our calculation to happen on overlapping chunks where 
#  - 3 pixels of each chunk overlaps (depth=3),
#  - no calculation happens at the chunk boundaries (boundary='none'),
#  - and the overlap pixels are trimmed off of each chunk after the calculation (trim=True)

# remember this just adds to the dask graph, does not actually compute
varmult=da.map_overlap(my_mult_function, var.data, dtype=np.float32, depth=3, boundary='none', trim=True)

In [None]:
%%time

# do the computation
# compute brings the result into memory in the form of a numpy array
# we are only multiplying, so the computation should return
# a numpy array with 3 dimensions
varmult=varmult.compute()
varmult.shape

In [None]:
import matplotlib.pyplot as plt

# plot the first time of the varmult numpy array

plt.imshow(varmult[0,:,:],interpolation='none',cmap='coolwarm')
plt.colorbar(shrink=0.5)

#### try with a function that changes the data shape

In [None]:
# this function accepts data of shape (time, lat, lon)
# applies the .mean function
# and returns data of shape (lat,lon)

# imagine though that we could be applying a convolution filter here
# then doing more calculations that end up reducing the dimensions of our input data
def my_mean_function(x):
    newval=np.mean(x,axis=0)
    return newval

In [None]:
import dask.array as da

# dask.map_overlap is using dask.map_blocks under the hood meaning that
# .map_overlap accepts parameters for .map_blocks as **kwargs
# we saw above with varmult that .map_overlap works without any kwargs if my_function doesn't change the shape of the data
# since .mean does change the shape of the data from (time, lat, lon) to (lat,lon) we have to use a kwarg "drop_axis"
# see the dask pages for .map_overlaps and .map_blocks for more info
# all the possible kwargs you can use will be listed on the .map_blocks page

kwargs={'drop_axis':0} # kwargs for .map_blocks, put them in a python dictionary {'key':value} here

# same thing as before except call our mean function and include the kwargs
varmean=da.map_overlap(my_mean_function, var.data, dtype=np.float32, depth=3, boundary='none', trim=True, **kwargs)
varmean

In [None]:
%%time
varmean=varmean.compute()

In [None]:
import matplotlib.pyplot as plt
plt.imshow(varmean,interpolation='none',cmap='coolwarm')
plt.colorbar(shrink=0.5)

### dask delayed

Dask has a few other features for executing computations. Sometimes, if we string too many calculations together on a dask array, dask can get confused. In these cases it's better to use Dask delayed. Dask delayed also works well for custom written functions.  

Dask delayed operates on numpy arrays not xarray. Here we'll just convert from xarray backed with dask arrays to dask array backed with numpy arrays (still not in memory). 

In [None]:
var_np=var.data # convert to dask/numpy
var_np

In [None]:
# converting to numpy means all our labels get deleted
# so we can't use .mean('time') from the xarray library anymore https://docs.xarray.dev/en/stable/generated/xarray.DataArray.mean.html
# we have to use .mean(axis=0) from the numpy library https://numpy.org/doc/stable/reference/generated/numpy.mean.html

# this function could contain anything, but we'll keep it simple here
def my_np_function(x):
    newval=np.mean(x,axis=0)
    return newval

In [None]:
# for large arrays we delay them first
# this reduces moving big data across the workers by just moving chunks where they are needed instead of moving the whole array
 
var_delay=var_np.to_delayed().ravel() # make each chunk a delayed object
var_delay  # the output shows all the chunks as dask delayed objects

In [None]:
import dask
task_list=[dask.delayed(my_np_function)(var_chunk) for var_chunk in var_delay] # create a list of delayed compute tasks
# what we've done above is called a list comprehension in python, see more at https://docs.python.org/2/tutorial/datastructures.html#list-comprehensions
task_list

In [None]:
%%time
# do the computation
result_chunks=dask.compute(*task_list)

In [None]:
# results is a list of 18 arrays, 1 array for each chunk
len(result_chunks),result_chunks[0],result_chunks[0].shape

In [None]:
# to reassemble into a single array we concatenate
import numpy as np
result=np.concatenate(result_chunks,axis=1) # put the chunks together along the longitude dimension axis 1
result.shape

In [None]:
# how to plot this numpy array
import matplotlib.pyplot as plt
plt.imshow(result,interpolation='none',cmap='coolwarm')
plt.colorbar(shrink=0.5)

In [None]:
# how to convert back to xarray
new_xr_var = xr.DataArray(result,dims=var_mean.dims,coords=var_mean.coords)
new_xr_var

In [None]:
new_xr_var.plot()

There are plenty of ways to make xarray plots look nicer too

### write results to a new netcdf file

In [None]:
# convert xarray data array to xarray dataset
varname='tmax_mean'
ds_out=new_xr_var.to_dataset(name=varname) # must give the variable in the dataset a name
ds_out

In [None]:
# assign some more metadata: variable attributes and spatial reference copied from the .nc file we read in originally

ds_out[varname].attrs=ds['Tmax-2m'].attrs
ds_out=ds_out.assign_coords({'spatial_ref':ds.spatial_ref[0]})
ds_out

In [None]:
# a few more things to make writing the netcdf work

time_encoding={'calendar':'standard','units':'days since 1900-01-01 00:00:00','_FillValue':None}
lat_encoding={'_FillValue':None}
lon_encoding={'_FillValue':None}
var_encoding = {'zlib':True,'dtype':'float32'}    


import getpass
user=getpass.getuser()
outfile=outdir+'dask_demo_output_'+user+'.nc'

ds_out.to_netcdf(outfile,
                encoding={'lat':lat_encoding,
                      'lon':lon_encoding,
                      'time':time_encoding,
                      varname:var_encoding})

In [None]:
# you don't ever want to connect to multiple clients at once, always shutdown your client before starting a new one
# restarting the jupyter kernel (from the kernel menu) also works to shutdown a client, but will clear out your whole notebook too

client.shutdown()  

# Dask distributed scheduler using SLURMCluster

setting up the cluster is different

where the code runs is different

the rest of the code is the same as above

In [None]:
# this deletes all the variables in your notebook
# I never really use this except in demo notebooks
%reset 

In [None]:
# USER NEEDS TO MODIFY QUEUE AND ACCOUNT PARAMETERS OF SLURMCluster BEFORE RUNNING

# some of these things we've already imported
# we don't need to import things again
# but I'm doing it anyway as a reminder of what packages we're using
from dask_jobqueue import SLURMCluster  # this import is new
from dask.distributed import Client
from time import time, sleep
import os  # this is new too

# make a folder in your home directory for all the logs that SLURM spits out
logpath='~/dask-worker-space-can-be-deleted'
if not os.path.exists(logpath):
    os.makedirs(logpath)

# this can be thought of as 1 worker
cluster = SLURMCluster(
    queue='xxxxx',
    account="xxxxxx-xxxxx",
    processes=1,
    cores=2,
    memory='9GB',
    walltime="00:20:00",
    log_directory=logpath)

client=Client(cluster) # connect to cluster

# I choose 18 workers here because that's how many chunks I have in my data 
# If I chose 80 workers the compute wouldn't go any faster because I only have 18 chunks to compute
# This data isn't very big (~10GB) so making smaller chunks on more workers doesn't speed things up (I tried)
nworkers=18  
cluster.scale(nworkers) # increase the size of the cluster to 18 workers
client.wait_for_workers(n_workers=nworkers,timeout=120) # wait up to 2 min for the cluster to be fully ready, error if not ready in 2min
client

### using built-in functions on dask arrays

In [None]:
%%time
# open chunked file
# ds is an xarray data structure filled with dask arrays

import glob
import xarray as xr

datadir='/path/to/dir/where/Tmax/netcdf/lives/'
outdir='/path/to/dir/in/your/personal/workspace/'

f=glob.glob(datadir+'*.nc')

chunks={'time':-1,'lat':-1,'lon':240}

ds=xr.open_mfdataset(f[0],chunks=chunks)
ds

In [None]:
%%time
# send information about our data chunks and computation tasks
# here there were 2 tasks (layers) in the graph which are load and chunk
# persist starts move the data chunks to the workers in the background
# most calculations will be lazy if you don't include .compute()
# lazy means the compute task is recorded on the dask graph but not executed
# .compute() is what executes the calculation
var=ds['Tmax-2m'].persist()
var_mean=var.mean('time').compute()

In [None]:
%%time
# you can string multiple calculations together and call compute only on the last variable to execute everything
var_monthly=var.groupby('time.month').mean('time') # monthly means
month_minval=var_monthly.min() # minimum of monthly means
month_minval.compute()
month_minval

### using custom functions on dask arrays

In [None]:
%%time
# this function could contain anything, but we'll keep it simple here
def my_function(x):
    newval=x.mean('time')
    return newval

varmean=var.map_blocks(my_function).compute()

In [None]:
varmean.plot()

### dask delayed

In [None]:
%%time
# converting to numpy means all our labels get deleted
# so we can't use .mean('time') from the xarray library anymore https://docs.xarray.dev/en/stable/generated/xarray.DataArray.mean.html
# we have to use .mean(axis=0) from the numpy library https://numpy.org/doc/stable/reference/generated/numpy.mean.html

import dask

# this function could contain anything, but we'll keep it simple here
def my_np_function(x):
    newval=x.mean(axis=0) # 0 is the index of the time dimension
    return newval

var_np=var.data # convert to dask/numpy
var_delay=var_np.to_delayed().ravel() # make each chunk a delayed object
task_list=[dask.delayed(my_np_function)(var_chunk) for var_chunk in var_delay] # create a list of delayed compute tasks
result_chunks=dask.compute(*task_list)
result=np.concatenate(result_chunks,axis=1) # put the chunks together along the longitude dimension axis 1

In [None]:
import matplotlib.pyplot as plt

plt.imshow(result,interpolation='none',cmap='coolwarm')
plt.colorbar(shrink=0.5)

In [None]:
client.shutdown()