# Scale-up analysis in the cloud -- Option 1: Dask within a single EC2

### Revision history

1. 3/20/2023, first draft, Jinbo Wang

Produced by the PO.DAAC coding club. <br>
Supported by NASA ESDIS PO.DAAC.


In [1]:
# Import necessary libraries
import xarray as xr
import dask.array as da
from dask.distributed import Client, LocalCluster, progress
from dask import delayed
import h5py
import numpy as np

#import earthaccess

In [4]:
# Set up a Dask cluster with 50 workers

# Automatically detect available vCPUs
import multiprocessing
n_workers = multiprocessing.cpu_count()
print(n_workers)

64


In [5]:
# Create a LocalCluster with as many workers as vCPUs

#cluster = LocalCluster(n_workers=16, threads_per_worker=8)
#client = Client(cluster)

#print(cluster)
#cluster = LocalCluster(n_workers=50, threads_per_worker=1)

#Using dask-labextension makes it easyy to use the Dask dashboard. https://github.com/dask/dask-labextension
#Start a dask LocalCluster and change the port number below accordingly

client = Client('tcp://127.0.0.1:43247')

In [6]:
def init_S3FileSystem():
    """
    This routine automatically pull your EDL crediential from .netrc file and use it to obtain an AWS S3 credential through a podaac service accessable at https://archive.podaac.earthdata.nasa.gov/s3credentials
    
    Return:
    =======
    
    s3: an AWS S3 filesystem
    """
    import requests,s3fs
    creds = requests.get('https://archive.podaac.earthdata.nasa.gov/s3credentials').json()
    s3 = s3fs.S3FileSystem(anon=False,
                           key=creds['accessKeyId'],
                           secret=creds['secretAccessKey'], 
                           token=creds['sessionToken'])
    return s3
# Define a function to load a single granule given year, month, and day
def c_granule(fn):
    ds = h5py.File(s3sys.open(fn),'r')
    
    # Define the California Current bounding box
    min_lat, max_lat = 30, 50
    min_lon, max_lon = -140, -110
    lat,lon=ds['lat'][:],ds['lon'][:]
    i0=np.where(np.abs(lat-min_lat)==np.abs(lat-min_lat).min())[0][0]
    i1=np.where(np.abs(lat-max_lat)==np.abs(lat-max_lat).min())[0][0]
    j0=np.where(np.abs(lon-min_lon)==np.abs(lon-min_lon).min())[0][0]
    j1=np.where(np.abs(lon-max_lon)==np.abs(lon-max_lon).min())[0][0]
    
    # Define a function to process a single granule and return the regional mean SST
    #msk=ds['mask'][0,10000:11000,10000:11000]==1
    #sst=ds["analysed_sst"][0,j0:j1,i0:i1]*0.001
    sst=ds["analysed_sst"][0,:10,:10]*0.001
    #msk=ds['mask'][0,j0:j1,i0:i1]!=1
    #sst = np.ma.masked_array(sst,mask=msk).mean()
    del ds
    return sst.mean()

def downscaling_mursst(fn):
    """
    This function reads a netCDF file from MUR SST 1km data,
    processes the data by masking invalid values, reshaping, and computing the mean
    temperature for each 1x1 degree box, and then returns the processed SST data.

    Parameters
    ----------
    fn : str, s3 path

    Returns
    -------
    sst : numpy.ma.core.MaskedArray (180x360)
        A masked array containing the processed SST data in a 1x1 degree resolution, with invalid values masked
        and the mean temperature computed for each granule.

    Dependencies
    ------------
    This function depends on the following libraries:
    - xarray (imported as xr)
    - numpy (imported as np)
    - s3fs (imported as s3sys)

    Notes
    -----
    The input file is expected to be in netCDF format and should contain a variable
    named 'analysed_sst' representing the sea surface temperature data, common in GHRSST products.

    The function assumes a specific data structure, with dimensions of 1x17900x36000 in the input and
    180x360 in the output. The 17900x36000 array was first expanded to 18000x36000 using np.r_[sst[0,:],sst].
    """
    d = xr.open_dataset(s3sys.open(fn), engine='h5netcdf')
    sst = d['analysed_sst'][0, ...]; del d
    sst0 = np.ma.masked_invalid(np.r_[sst[0:1, :], sst]); del sst
    sst = sst0.reshape(180, 100, 360, 100).mean(axis=-1).mean(axis=1); del sst0

    return

Find all granule names using s3fs.glob.

In [7]:
%%time
s3sys=init_S3FileSystem()
s3path="s3://podaac-ops-cumulus-protected/MUR-JPL-L4-GLOB-v4.1/"
fns=s3sys.glob(s3path+"*.nc")
print("total granules = ",len(fns))
print("Example filename: ", fns[0])

total granules =  7606
Example filename:  podaac-ops-cumulus-protected/MUR-JPL-L4-GLOB-v4.1/20020601090000-JPL-L4_GHRSST-SSTfnd-MUR-GLOB-v02.0-fv04.1.nc
CPU times: user 2.08 s, sys: 59.4 ms, total: 2.14 s
Wall time: 9.75 s


Test the speed of loading one MUR-SST granule. 

In [8]:
%%time

d=h5py.File(s3sys.open('s3://'+fns[0],'rb'))
print((d.keys()))
d['analysed_sst'][:].shape

<KeysViewHDF5 ['analysed_sst', 'analysis_error', 'lat', 'lon', 'mask', 'sea_ice_fraction', 'time']>
CPU times: user 5.91 s, sys: 2.5 s, total: 8.42 s
Wall time: 17 s


(1, 17999, 36000)

Wrap the process_granule function with dask.delayed.

In [9]:
%%time
delayed_process_granule = delayed(downscaling_mursst)
# Process all granules in parallel using Dask
results = [delayed_process_granule(fn) for fn in fns]

CPU times: user 295 ms, sys: 43.6 ms, total: 338 ms
Wall time: 313 ms


In [10]:
%%time
da.compute(*results)

CancelledError: downscaling_mursst-8719f5e6-aa61-43e3-a96d-2400532555d2

2023-04-02 06:47:34,520 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


#### 