# Argo Data Loading Workflow

Brief description of the notebook

### Imports

In [1]:
import xarray as xr
import numpy as np
import gsw
import argopy
from argopy import DataFetcher as ArgoDataFetcher
argo_loader = ArgoDataFetcher(src="gdac", ftp="/swot/SUM05/dbalwada/Argo_sync", progress=False)

I've synced a local copy of the entire Argo dataset that lives in the directory `/swot/SUM05/dbalwada/Argo_sync`. Instructions for how to do this can be found here: http://www.argodatamgt.org/Access-to-data/Argo-GDAC-synchronization-service

It's also possible to download from a remote server directly (like using `src="erddap"`), however I've found this to be inefficient and error-prone for large regions.

## Load Small Box (~5deg x 5deg)

This function loads the Argo data in a given lat/lon range and returns it as one dataset, interpolated to a common pressure grid. This is suitable for fairly small regions (a few degrees each direction).

In [None]:
def get_box(box, standard_grid=np.arange(0,2002,2)):
    """Takes lat, lon parameters and a pressure grid and returns an xr ds 
    with CT, SA, SIG0, SPICE, and sample_rate interpolated to that grid.

    box: (list) of the form [lon_min, lon_max, lat_min, lat_max, z_min, z_max]
    standard_grid: (list) pressure grid for interpolation
    """

    ds = argo_loader.region(box)
    print("loading points complete")

    ds = ds.to_xarray()
    print("to xarray complete")

    ds = ds.argo.teos10(["CT", "SA", "SIG0"])
    ds = ds.argo.point2profile()
    print("point to profile complete")

    ds_interp = get_ds_interp(ds, standard_grid)
    print("interpolation complete")

    ds_interp["SPICE"] = gsw.spiciness0(ds_interp.SA, ds_interp.CT).rename("SPICE")
    print("adding spice complete")
    
    if 'raw_attrs' in ds_interp.attrs:
        del ds_interp.attrs['raw_attrs']

    return ds_interp

In [None]:
def get_ds_interp(ds, standard_grid):
    """
    Iterates profile by profile through ds, interpolating each to the
    provided pressure grid. If the sampled profile does not reach the 
    extent of the pressure grid, it's filled in with NaNs.
    Returns an xr ds with variables interpolated to a standard pressure
    grid
    
    ds: (xr ds) with dimension of profiles (`N_PROF`)
    standard_grid: (list) pressure grid for interpretation
    """
    
    print('NEW INTERP FUNCTION')
    profs_interp = []
    interp_step = standard_grid[1] - standard_grid[0]
    
    for n in range(0, len(ds.N_PROF)):
        prof = ds.isel(N_PROF=n).expand_dims('N_PROF')
        depth_min = int(prof.PRES.min())
        depth_min = np.ceil(depth_min / 2) * 2
        depth_max = int(prof.PRES.max())
        depth_max = (depth_max // 2) * 2

        if not (np.all(np.diff(standard_grid) > 0) and np.all(standard_grid >= 0)):
            print(f"\tProfile {n} skipped due to invalid standard_grid values.")
            continue

        if depth_max > depth_min:
            dp = prof.PRES.diff('N_LEVELS')
            prof['sample_rate'] = dp
            
            try:
                prof_interp = prof.argo.interp_std_levels(np.arange(depth_min, depth_max, interp_step))
                prof_interp_reindexed = prof_interp.reindex({'PRES_INTERPOLATED': standard_grid}, method=None, fill_value=np.nan)
                profs_interp.append(prof_interp_reindexed)
            except ValueError as e:
                print(f"\tProfile {n} skipped due to interpolation error: {e}")
        
        elif depth_max > prof.PRES.max():
            print(f"\tProfile {n} has depth_max of {depth_max} but max PRES is {prof.PRES.max()}")
            
        elif depth_max <= depth_min:
            print(f"\tProfile {n} has invalid depth range: depth_min={depth_min}, depth_max={depth_max}")

    concat_n_prof = xr.concat(profs_interp, dim='N_PROF') if profs_interp else xr.Dataset()
    
    return concat_n_prof

In [None]:
box = [-30,-25,0,5,0,2000]
ds_box = get_box(box)

## Load Large Region (~entire basin)

This function loads the Argo data in a much larger lat/long range and saves the data into small boxes, at a filepath specified in the `get_box_delayed` function. After this is complete, look to the concat section below to concatonate the individual xr ds for each box into one ds for the whole reigon.

This workflow parallelizes the loading process by taking one large region and dividing it into subregions, and dividing each subregion into boxes. It iterates through subregions, taking one at a time and dividing it into boxes. Each box is then passed to one core, which loads all argo files in the given lat/lon range (using the `get_box` function from above), and saving the data as an xr ds in a netcdf file. The boxes are loaded in parallel using dask. Once one subregion is completed, it moves onto the next and repeates these steps until the whole region has been complete.

It make take some thought and practice to pick the `region_step` and `box_step` intervals to match the resouces available to your machine. For example, for efficiency, the number of boxes in each region should be slightly lower than the number of available cores. For a system with 72 cores, I would plan to have ~60 boxes in each subregion. It would also be ideal to make sure the number of boxes in each subregion is less than the number of workers initialized in the dask cluster.

### Initialize Dask Cluster

The number of workers, threads, and memory can be changed to correspond to a given machine. Or remove them to use default settings.

In [None]:
import dask
from dask.distributed import Client, LocalCluster
from dask.diagnostics import ProgressBar

cluster = LocalCluster(n_workers=64, threads_per_worker=1, memory_limit='60GiB')
client = Client(cluster)
print(cluster) #this prints URL for dask cluster

In [None]:
@dask.delayed
def get_box_delayed(*args, **kwargs):
    try:
        return prf.get_box(*args, **kwargs)
    except Exception as e:
        return type(e).__name__, str(e)

def get_box_dask(boxes_list):
    
    boxes_list = [tuple([box]) for box in boxes_list]
    tasks = [get_box_delayed(*args) for args in boxes_list]
    results = dask.compute(*tasks)
    errors=[]

    for n, result in enumerate(results):
        if isinstance(result, tuple) and isinstance(result[0], str):
            error_type, error_message = result
            print("Error in box {}: {} - {}".format([boxes_list[n][0][0],boxes_list[n][0][1],boxes_list[n][0][2],boxes_list[n][0][3]], error_type, error_message))
            errors.append([boxes_list[n][0][0],boxes_list[n][0][1],boxes_list[n][0][2],boxes_list[n][0][3], error_type, error_message])
        else:
            result.to_netcdf("/swot/SUM05/amf2288/sync-boxes/new_test/lon:({},{})_lat:({},{})_ds_z.nc".format(boxes_list[n][0][0],boxes_list[n][0][1],boxes_list[n][0][2],boxes_list[n][0][3]))
            print("Saved box {} of {}".format(n+1, len(results)))
            
    return errors 

In [None]:
def generate_grid(box, step):
    grid = []
    lon_min, lon_max, lat_min, lat_max = box[0],box[1],box[2],box[3]
    lat = lat_min
    while lat < lat_max:
        lon = lon_min
        while lon < lon_max:
            box_lat_max = min(lat + step, lat_max)
            box_lon_max = min(lon + step, lon_max)
            box = [lat, box_lat_max, lon, box_lon_max, box[4], box[5]]
            grid.append(box)
            lon += step
        lat += step
    return grid

In [None]:
def get_region(area, region_step, target_step):
    
    regions = generate_grid(area, region_step)
    print('-' * 50)
    print("Cluster: {}".format(cluster))
    print('-' * 50)
    print("THE REGIONS ARE {}".format(regions))
    print('-' * 50)
    
    errors_list = []
    
    for n,region in enumerate(regions):
        boxes = generate_grid(region, target_step)
        print('-' * 50)
        print("REGION #{} OUT OF {} IS: {}".format(n+1, len(regions), region))
        print('-' * 50)
        print("THE BOXES IN REGION #{} ARE {}".format(n+1,boxes))
        print('-' * 50)

        errors = get_box_dask(boxes)
        errors_list.append([errors])
        
        print('-' * 50)
        print("COMPLETED REGION #{} OUT OF {}".format(n+1,len(regions)))
        print('-' * 50)
        
    return errors_list

In [None]:
s_atl = [-75,25,-90,0,0,2000]
region_step = 40
box_step = 5
interp_step = 2
get_region(s_atl, region_step, box_step)

WHAT DOES `interp_step`/`region_step` DO???????

## Concat Region

This takes the individual boxes saved by the `get_region` function and concatonates them into one file, which is then saved. There are two functions, with the choice of saving the output as a netcdf file or as a zarr store. Once this is complete (and you've checked it has worked properly), probably good to delete the directory with individual netcdf files.

In [None]:
def concatenate_netcdf(input_dir: str, first_dim: str, second_dim: str, output_dir: str, output_file: str):
    input_path = Path(input_dir)
    netcdf_files = list(input_path.glob("*.nc"))
    
    # Create the output directory if it doesn't exist
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Open datasets lazily with dask
    datasets = [xr.open_dataset(str(file), chunks={}) for file in netcdf_files]
    
    # Concatenate along the first dimension
    combined_first_dim = xr.concat(datasets, dim=first_dim)
    
    # Rechunk the data to ensure uniform chunk sizes
    combined_rechunked = combined_first_dim.chunk({first_dim: 256, second_dim: 256})  # Adjust chunk sizes as needed
    
    # Save to NetCDF
    output_file_path = output_path / output_file
    with ProgressBar():
        combined_rechunked.to_netcdf(output_file_path, compute=True)

In [None]:
def concatenate_zarr(input_dir: str, first_dim: str, second_dim: str, output_dir: str, output_file: str):
    input_path = Path(input_dir)
    netcdf_files = list(input_path.glob("*.nc"))
    
    # Create the output directory if it doesn't exist
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Open datasets lazily with dask
    datasets = [xr.open_dataset(str(file), chunks={}) for file in netcdf_files]
    
    # Concatenate along the first dimension
    combined_first_dim = xr.concat(datasets, dim=first_dim)
    
    # Rechunk the data to ensure uniform chunk sizes
    combined_rechunked = combined_first_dim.chunk({first_dim: 256, second_dim: 256})  # Adjust chunk sizes as needed
    
    # Save to Zarr
    output_file_path = output_path / output_file
    with ProgressBar():
        combined_rechunked.to_zarr(output_file_path, compute=True)

In [None]:
input_directory = "/swot/SUM05/amf2288/sync-boxes/new_test"
output_directory = "/swot/SUM05/amf2288/sync-boxes"
output_netcdf = "new_test.nc"
output_zarr = "new_test.zarr"
first_dim = "N_PROF"
second_dim = "PRES_INTERPOLATED"

In [None]:
concatenate_netcdf(input_directory, first_dim, second_dim, output_directory, output_netcdf)

In [None]:
#concatenate_zarr(input_directory, first_dim, second_dim, output_directory, output_zarr)

## Examine Dataset

In [None]:
ds = open_dataset()