# Session 8-1: Multiprocessing, Dask and Xarray

![ntl](./assets/parallel.png)

Recall back to the beginning of the course when we discussed that modern computers have multiple central processing units (CPUs), or cores. Opperating systems, like Windows, are optimized to efficently take advantage of the CPUs on a given system, with hidden subroutines that keep track of CPU loading and RAM allocation. As budding data scientists, we will not go into the details of how a modern computer functions. But we will cover the basics of parallel processing so that you can take advantage of all the CPUs on your system. 

Python has two relatively new packages - [`dask and xarray`](https://docs.xarray.dev/en/stable/user-guide/dask.html) - that enable to processing of gridded datasets in parallel with ease. But to understand how these packages work together, we first need to spend a moment looking at native [`python multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) module.  

A few things to remember. When you launch a **Earth Sciences JupyterLab Classroom** instance of Jupyter on Tempest, the system allocates 8 CPUS and 32GB of RAM for you to use (FYI - Tempest is one of the largest university clusters in the USA, with Tempest currently has 12,264 logical CPU cores, 55.9TB of ECC memory). Thus, if you want to maximize the power of your Jupyter instance, you can have up to 8 CPUs working with 4 GB of RAM (32/8 = 4). 

But when tell a CPU to preform a task and it gets allocated a pool of memory, the data being processed on that CPU cannot access the data being processed on another CPU due Python's [Global Interpreter Lock](https://realpython.com/python-gil/). You don't need to understand the GIL in detail for this class. What you do need to know, is that to leverage all the CPUs you need to chunk your data into logical blocks that each CPU can workon without having to talk to the other CPUs. 

This might make you a bit confused. It made me very confused when I started teaching this to myself back in graduate school. The easiest way to learn, as we know is to play some code. So let's do that!

<p style="height:1pt"> </p>

<div class="boxhead2">
    Session Topics
</div>

<div class="boxtext2">
<ul class="a">
    <li> 📌 Introduction to <span class="codeb">multiprocessing</span> </li>
    <ul class="b">
        <li> Census API </li>
        <li> Merging with shapefiles </li>
        <li> Plotting Data </li>
        <li> Area Aggregation </li>
    </ul>
</div>

<hr style="border-top: 0.2px solid gray; margin-top: 12pt; margin-bottom: 0pt"></hr>

### Instructions
We will work through this notebook together. To run a cell, click on the cell and press "Shift" + "Enter" or click the "Run" button in the toolbar at the top. 

<p style="color:#408000; font-weight: bold"> 🐍 &nbsp; &nbsp; This symbol designates an important note about Python structure, syntax, or another quirk.  </p>

<p style="color:#008C96; font-weight: bold"> ▶️ &nbsp; &nbsp; This symbol designates a cell with code to be run.  </p>

<p style="color:#008C96; font-weight: bold"> ✏️ &nbsp; &nbsp; This symbol designates a partially coded cell with an example.  </p>

<hr style="border-top: 1px solid gray; margin-top: 24px; margin-bottom: 1px"></hr>

# multiprocessing

The [`python multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) module is one of the things that make Python such an awesome tool and why Python is among the most popular languages for data science. This is because it make scaling code across CPUs quite easy. We are not going to go deep into the package in this class. But I want to introdroduce you to it, because understanting [`python multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) will make learning [`dask and xarray`](https://docs.xarray.dev/en/stable/user-guide/dask.html) much easier.

To get started, let's start with a simple example. Let's say you have global, daily maximum heat index data for 2016 that you want to convert from celcius to farenheit and then find the global average temperature for each day. In other words, you have 366 (2016 is a leap year) raster files on which you need to preform simple multiplication. These files are pretty large (76mb) and about 9 GB in total data. 

As you will see below, you can do this with a `for` loop. But it will be faster to do this in parallel. Let's take a look

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
import os
import glob  
import sys
from multiprocessing import Pool 
import time 
import numpy as np
import rasterio
import multiprocessing
import matplotlib.pyplot as plt

In [None]:
# Get the path to your files in a list files
path = os.path.join('/home/group/earthsciences/gphy591/github/GPHY-491-591/materials/Day8/data/2016/')
fns = sorted(glob.glob(path + '*.tif'))

# print first five files to make sure they are in order
fns[:5]

### First, let's take a look at the data
Open one raster and plot it, then look at the meta data to see what the NaN values are so we don't mess those up when we convert °C to °F. 

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
# Let's look at the data for July 1, 2016
arr = rasterio.open(fns[182]).read(1)
plt.imshow(arr, vmin = -32)
plt.colorbar(shrink = 0.4)
plt.title('Maximum Heat Index in °C for July 1, 2016');

In [None]:
# Let's look at the NAN values
rasterio.open(fns[182]).meta

### Now Let's try a `for` loop. 

Remember, our NaN values is `-9999` so we want to make sure we mask that out when we calculate our global average HI for each day.

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
# function to convert c to f
def C_to_F(Tmax_C):
    "Function converts temp in C to F"
    Tmax_F = (Tmax_C * (9/5)) + 32
    
    return Tmax_F

In [None]:
# clock it
start = time.time()

for fn in fns:
    
    # Get the date
    date = fn.split('data/2016/')[1].split('.tif')[0]
    
    # open the array in c
    arr_c = rasterio.open(fn).read(1)
    
    # covert c to f, using np where so we keep the -9999 values
    arr_f = np.where(arr_c != -9999, C_to_F(arr_c), -9999) # this says, where arr_c does not equal -9999, covert data from c to f, but everwhere else write -9999
    
    # get the daily global average maximum heat index 
    land = arr_f[arr_f != -9999] # drop all ocean -9999 values
    avg = land.mean()
    
    print('On', date, 'the global average heat index was', round(avg, 2), '°F')
    
end = time.time()
print('Time:', end-start)

### Now try it in parallel. 
Running the code sequentially takes about 33 seconds on one CPU. That's not too bad, but what if you need to do run this calculation 1,000 times? That will add up, and, if you make a mistake, and need to re-run the code, it will set you back even further. This is where parallel processing is super useful. 

Let's see how fast this goes when we feed our list of files to all 8 CPUS. To do this we have to:

1. Create a function to pass to multiprocessing.
2. Create a pool of works (e.g. cpus)
3. Pass our function and list for our works to work on.

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
def avg_fast(fn):
    
    "Function takes a global heat index raster, coverts the data from C to F, and caluclates the average of the value for that raster, assuming NaN = -9999" 
    
    # you can print your work id
    # print(multiprocessing.current_process())  # for now, we'll leave this commented out
    
    # Get the date
    date = fn.split('data/2016/')[1].split('.tif')[0]
    
    # open the array in c
    arr_c = rasterio.open(fn).read(1)

    # covert c to f, using np where so we keep the -9999 values
    arr_c = np.where(arr_c != -9999, C_to_F(arr_c), -9999) # this says, where arr_c does not equal -9999, covert data from c to f, but everwhere else write -9999
    
    # get the daily global average maximum heat index 
    land_c = arr_c[arr_c != -9999] # drop all ocean -9999 values
    avg_c = land_c.mean()
    
    # print('On', date, 'the global average heat index was', round(avg, 2), '°F \n')
    print(date, avg_c)

In [None]:
# Clock it
start = time.time()

# set up your pool of works, in this case we have 8 
n_cpus = 8 
pool = Pool(processes = n_cpus)

# map the function and the arguments to the pool of works, in this case avg_fast and the list of files
pool.map(avg_fast, fns)

# shut down your pool of workers
pool.close()

end = time.time()
print('Time:', end-start)

### Some things to think about. 

**Wow!** That was quite a bit faster than one CPU working on the data sequentially.

Notice that the `print` statements from each worker are not printed sequentially. This is because each CPU is working independently and the output is spit out when then the CPU is done. The GIL makes sure that the output from each worker doens't get mixed up, but because this is a relatively small computation for a CPU, there is not enough time delay between tasks for Python to print each output from a worker. 

But if you have each CPU cruch a lot more data, then each worker will will be slowed down enough that the outputs get printed sequentially. Below is a an example where we convert the data from c to f 40 times, to show you what a bigger job looks like. But we are only going to run this on 4 cpus and the first 30 files (e.g. `fns[:30`) as an  example.

In [None]:
def avg_slow(fn):
    
    "Function takes a global heat index raster, coverts the data from C to F, and caluclates the average of the value for that raster, assuming NaN = -9999" 
    
    # you can print your work id
    # print(multiprocessing.current_process())  # for now, we'll leave this commented out
    
    # Get the date
    date = fn.split('data/2016/')[1].split('.tif')[0]


    # Conver the data from c to f 10 times to slow everything down   
    for i in list(range(40)):
        
        # open the array in c
        arr_c = rasterio.open(fn).read(1)
        
        # covert c to f, using np where so we keep the -9999 values
        arr_f = np.where(arr_c != -9999, C_to_F(arr_c), -9999) # this says, where arr_c does not equal -9999, covert data from c to f, but everwhere else write -9999
    
    # get the daily global average maximum heat index 
    land_f = arr_f[arr_f != -9999] # drop all ocean -9999 values
    avg_f = land_f.mean()
    
    # print('On', date, 'the global average heat index was', round(avg, 2), '°F \n')
    print(multiprocessing.current_process(), date, avg_f)

In [None]:
# Clock it
start = time.time()

# set up your pool of works, in this case we have 4 
n_cpus = 4 
pool = Pool(processes = n_cpus)

# map the function and the arguments to the pool of works, in this case avg_fast and the list of files
pool.map(avg_slow, fns[:30])

# shut down your pool of workers
pool.close()

end = time.time()
print('Time:', end-start)

<hr style="border-top: 1px solid gray; margin-top: 24px; margin-bottom: 1px"></hr>

# Dask & Xarray 

[`python multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) is a great package to use if you want to crunch a bunch of files in parallel. You can even `chunk` datasets and feed lists of chunks to your `Pool` to process data in parallel. But things get more complex when you want the outputs of each CPU's tasks to talk with each other. 

For example, let's say we want to calculate the daily average heat index for each _pixel_ in the dataset. This would require a 3-d array (or cube), where our _x_ and _y_ axes are _longitude_ and _latitude_, respecitively, and our _z_ axis is _time_. Is is possible to open each raster and stack them into a 3-d data cube. But remember, that would require a huge amount of memory since the total data is about 9 GB of data. When we open those rasters, they will balloon because `GeoTiff` files are compressed to some degree too.

![ntl](./assets/data.png)

Welcome to [`Dask and Xarray`](https://docs.xarray.dev/en/stable/user-guide/dask.html)! They will make your life a lot easier. You complete an entire free online tutorial if you want to go deep on `Dask and Xarray`, [An Introduction to Earth and Environmental Data Science](https://earth-env-data-science.github.io/intro.html). Here, we are just going to touch on the basics of both packages, but I **highly** recommend you read about `Xarray` data structures [here](https://docs.xarray.dev/en/stable/user-guide/data-structures.html) because the terminology can be a bit confusing.

![ntl](./assets/xarray.png)

Basically, `Xarray` allows you to create labeled n-dimentional numpy arrays. So you can label your datasets (temperature, precipitation, etc.) and your dimentions (time, lat/long, etc.) to easily subset the data to run analysis. For example, you could say, what is the average heat index in Bozeman based on Lat/Long with just a few lines of code. **Note** We will use the package [`rioxarray`](https://corteva.github.io/rioxarray/html/rioxarray.html) too to load in the GeoTiff files. More on this later. Let's get started!

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
# Dependencies 
import xarray as xr
import dask
import rioxarray as rio
import pandas as pd

### A Toy Dataset
Let's start by making two toy data xarray data arrays and turn them into an xarray dataset with temperature and precipitation. They will be 100 by 100 by 30 arrays, representing 100 by 100 lat/long and 30 days.

In [None]:
# make an empty temperature array
temp = np.random.randint(20, high=100, size=(100,100,30), dtype=int)
temp.shape

In [None]:
# make an empty precip array
precip = np.random.randint(0, high= 10, size=(100,100,30), dtype=int)
precip.shape

In [None]:
# make a 30 day time stamp
time = pd.date_range("2016-01-01", periods=30)
time

In [None]:
len(time)

In [None]:
# make x and y index values
x = list(range(1,100+1))
y = list(range(1,100+1))

In [None]:
# Turn temp into xarray data array
temp_da = xr.DataArray(data = temp, # data
                       dims = ['x', 'y', 'time'], # dim labels as a list
                       coords = {'x' : x, 'y' : y, 'time' : time}, # coords data as a dict
                       name = 'temp' # name the da
                      )
temp_da

In [None]:
# Turn precip into xarray data array
precip_da = xr.DataArray(data = precip, # data
                       dims = ['x', 'y', 'time'], # dim labels as a list
                       coords = {'x' : x, 'y' : y, 'time' : time}, # coords data as a dict
                       name = 'precip' # name the da
                      )
precip_da

In [None]:
# Now combine temp and precip into a dataset
ds = xr.merge([temp_da, precip_da])
ds

#### A bit about Xarray Datasets

Xarray Datasets are nice because you can easily stack several data arrays and run analysis on the data. Like numpy arrays, they must be the same size. But unlike numpy arrays, you have labels, so it is easy to subset the data.

In [None]:
# access the temp data
ds.temp

In [None]:
# access the precip data
ds.precip 

In [None]:
# Estimate correlation between precip and temperature over time
corr = xr.corr(ds.temp, ds.precip, dim= 'time')
corr

## Let's look at some real data with Dask

Using `Xarray` and `Dask` together is really powerful. Be sure to read this [overview](https://docs.xarray.dev/en/stable/user-guide/dask.html) on your own time. But the short of it is that `Dask` allows you to create `chunks` of large datasets that are small enough to load into memory, without loading all the data into memory. When used with `xarray`, you can set up a instructions (e.g. computations) to run on `Xarray` object that Python will complete in parallel without actually loading the dataset into memory. Confused? I was too. The easiest way to learn this, is to actually do it. 

Let's look at an example with our heat index data from 2016 that is nearly 9 GB.

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
# Get the path to your files in a list files
path = os.path.join('/home/group/earthsciences/gphy591/github/GPHY-491-591/materials/Day8/data/2016/')
fns = sorted(glob.glob(path + '*.tif'))

# print first five files to make sure they are in order
fns[:5]

#### Load a 'view' of all 366 files into a Xarray DataSet

The code below will not load the files into memory, but it will create a DataSet on which we can look at what the data would be if it were loaded into memory, set up a set of instructions, and then tell `Dask` to run the instructions in parallel.

Like python `multiprocessing`, you first have to tell `Dask` to tee up 8 CPUS. If you run this on your on machine, `Dask` creates a URL where you can watch your workers work. More on this later.

<div class="run">
    ▶️ <b> Run the cells below. </b>
</div>

In [None]:
# Dependencies 
from dask.distributed import Client, LocalCluster

In [None]:
# create and connect to a dash cluster + get link to watch progress
client = Client(n_workers = 8)
client

In [None]:
# Now 'load' the GeoTif Files
da = xr.concat([rio.open_rasterio(f, chunks = 'auto') for f in fns], dim='band') 

In [None]:
# Take a look at your DataArray
da

In [None]:
# How big is the da object?
sys.getsizeof(da)

As you see, the da object is **REALLY** small. That is because the data is not actually in memory. `Dask` has chunked the data for us so it will run in parallel using `auto` argument, where `dask` figures out the optimal size of the chunk give our CPUs and available memory.

### Updating Data
We can update information about our data array, like renaming `band` to time and we can add a `Pandas` time series to the `time` dimention as coordinates, again without actually loading the data into memory.

In [None]:
# Make a pandas time daily time series for 2016
time = time = pd.date_range("2016-01-01", periods=366)
time

In [None]:
# rename time dim
da = da.rename({'band' : 'time'})
da

In [None]:
# Revalue the time' coord
da.coords['time'] = time
da

### Subsetting data

There are tons of ways to slice Xarray Data Arrays. Here are two examples.

In [None]:
# Select by month (e.g. January)
jan = da.where(da['time.month'] == 1, drop = True)
jan

In [None]:
# Select by set by col / row
subset = da[:, 1000:1500, 3000:4000] # subset for west africa
subset

In [None]:
# Slice by lat / long
subset = da.sel(x = slice(-30,19), y = slice(20, -5))
subset

### Average Heat Index
Now, let's calculate the daily and monthly maximum heat index for 2016 with dask! Again, nothing is actually being computed, until you add `.compute()` to the data array. Check out the code below to understand.

In [None]:
# Estimate the mean daily max heat index along the time dim
daily_avg = da.mean(dim = 'time')
daily_avg

In [None]:
# Now run the code again with .compute() added
daily_avg = da.mean(dim = 'time').compute()
daily_avg

In [None]:
# plot it 
arr = daily_avg.data
plt.imshow(arr, vmin = -32)
plt.colorbar(shrink = 0.3)
plt.title('Average Daily Maximum Heat Index for 2016');

### Save your the daily average heat index for 2016 as a tif file

In [None]:
# get the meta data
meta = rasterio.open(fns[0]).meta

In [None]:
# make a file name
fn_out = os.path.join('./himax_2016_avg.tif')

In [None]:
# save it 
with rasterio.open(fn_out, "w", **meta) as dest:
    dest.write(arr, 1)

In [None]:
# Check it
plt.imshow(rasterio.open(fn_out).read(1), vmin = -32);

### 