# Efficient Hierarchical Data Access Patterns Using Xarray DataTrees 

### Authors:
- Chris Battisto
- Eni Awowale

### Overview:
This notebook will demonstrate how to use `xarray.DataTree` with a dataset derived from the [_GPM IMERG Final Precipitation L3 Half Hourly 0.1 degree x 0.1 degree V07 (GPM_3IMERGHH_07)_](https://disc.gsfc.nasa.gov/datasets/GPM_3IMERGHH_07/summary) and [_MERRA-2 tavg1_2d_flx_Nx: 2d,1-Hourly,Time-Averaged,Single-Level,Assimilation,Surface Flux Diagnostics V5.12.4 (M2T1NXFLX)_](https://disc.gsfc.nasa.gov/datasets/M2T1NXFLX/summary) products.

### Import Libraries

In [None]:
import xarray as xr
import netCDF4 as nc4
import numpy as np

### Method A: Spatially Average Using DataTrees

First, open the dataset using `open_datatree`:

In [None]:
src = xr.open_datatree('precipitation.nc4')
print(src)

Then, we can perform a spatial average without having to recursively loop through each group!

In [None]:
dt_mean = src.mean(dim='time')
print(dt_mean)

This results in a plot of the 10-hour average precipitation rate of Hurricane Ida:

In [None]:
dt_mean['observed/precipitation'].plot()

### Method B (the old way): Spatially Average Using `netcdf4-python`

First, open the original data file, then create a new file that will contain the subsetted arrays

In [None]:
# Open original file
src = nc4.Dataset('precipitation.nc4', 'r')

# Create new file
dst = nc4.Dataset('precipitation_mean.nc4', 'w', format='NETCDF4')

print(dst)

Since the dataset is grouped/subgrouped, we will need to create functions that recursively traverse each subgroup, extract all their attributes and metadata, and return the new, subsetted dataset containing all of the inherited attributes. This also requires knowing which dimensions are time (unlimited units) or not, requiring added logic to each function. Then, we can perform a spatial average.

In [None]:
def copy_dims(src_group, dst_group):
    for dim_name, dim in src_group.dimensions.items():
        if dim_name == 'time':
            # Skip time dimension entirely to avoid duplicates
            continue
        if dim_name in dst_group.dimensions:
            continue
        if dim.isunlimited():
            dst_group.createDimension(dim_name, None)
        else:
            dst_group.createDimension(dim_name, len(dim))


def copy_vars(src_group, dst_group):
    for var_name, var in src_group.variables.items():
        dims = var.dimensions

        # Skip time coordinate variable
        if var_name == 'time':
            continue

        # Copy lat/lon coordinate variables as-is
        if var_name in ['lat', 'lon']:
            new_var = dst_group.createVariable(var_name, var.datatype, dims)
            new_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})
            new_var[:] = var[:]
            continue

        # For other variables that have time dimension: average over time and remove it
        if 'time' in dims:
            new_dims = tuple(d for d in dims if d != 'time')
            new_var = dst_group.createVariable(var_name, var.datatype, new_dims)
            new_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})

            data = var[:]
            time_axis = dims.index('time')
            mean_data = data.mean(axis=time_axis)

            new_var[:] = mean_data
        else:
            # Variables without time dimension copy as-is
            new_var = dst_group.createVariable(var_name, var.datatype, dims)
            new_var.setncatts({k: var.getncattr(k) for k in var.ncattrs()})
            new_var[:] = var[:]


def recursive_copy(src_group, dst_group):
    dst_group.setncatts({attr: src_group.getncattr(attr) for attr in src_group.ncattrs()})
    copy_dims(src_group, dst_group)
    copy_vars(src_group, dst_group)

    for name, subgrp in src_group.groups.items():
        dst_subgrp = dst_group.createGroup(name)
        recursive_copy(subgrp, dst_subgrp)


# Recursively copy starting at the root level
recursive_copy(src, dst)

# Close files
src.close()
dst.close()

Let's open our new dataset using DataTree and plot:

In [None]:
xr.open_datatree('precipitation_mean.nc4')['observed/precipitation'].plot()