# Aggregating WRF-Hydro modeling application outputs to HUC12s: 2-Dimensional variables
**Author:** Kevin Sampson; NCAR

#### Note from the Author:
This notebook is intended to process the zonal (spatial) statistics between NWM Retrospective outputs and a set of gridded 'zones', which can be any spatial unit such as counties, states, HUCs, etc. Those inputs must already be resolved on the intended NWM grid (LSM - 1km, or routing - 250m) and optionally subset to any spatial subset of the NWM retrospective data (i.e. the grids must match exactly). This script assumes all 'zone' datasets are written in typical GIS fashion from north to south. If an LSM grid is requested, the zone dataset will be flipped south-to-north in this script. 

## Background
WRF-Hydro modeling application outputs should already have been processed from hourly to monthly summaries on the native WRF-Hydro/NWM grids. We will use these monthly datasets to process the zonal statistics at the HUC12 scale. This notebook aggregates the 2-Dimensional WRF-Hydro modeling application outputs and CONUS404-BA outputs.  

## Processing Environment
This workflow leverages dask and requires 150 GB of allocated memory. 
The python environment used is a conda environment 'wrfhydro_huc12_agg' here: 

/path/to/repo/hytest/dataset_processing/tutorials/niwaa_wrfhydro_monthly_huc12_agg/02_Spatial_Aggregation/wrfhydro_huc12_agg.yml


### Imports

In [None]:
# --- Import Modules --- #

# Import Python Core Modules
import sys
import os
import time
import math
import tracemalloc
import datetime
from pathlib import Path
import logging

# Some environment variables important to dask
os.environ["MALLOC_TRIM_THRESHOLD_"] = "0"
os.environ["DASK_DISTRIBUTED__SCHEDULER__ACTIVE_MEMORY_MANAGER__START"] = "True"
os.environ["DASK_DISTRIBUTED__SCHEDULER__WORKER_SATURATION"] = "1.2"
if 'DASK_ROOT_CONFIG' in os.environ:
    del os.environ['DASK_ROOT_CONFIG']
import dask
from dask.distributed import Client, progress, LocalCluster, performance_report
from dask_jobqueue import SLURMCluster
from dask.diagnostics import ProgressBar
import dask.array as da


# Import Additional Modules
import numpy as np
import xarray as xr
import pandas as pd
import zarr
import flox.xarray

# Import functions from local repository
sys.path.append(r'/path/to/repo/hytest/dataset_processing/tutorials/niwaa_wrfhydro_monthly_huc12_agg/02_Spatial_Aggregation/')
from usgs_common import *

from rich.console import Console
from rich import pretty

# Rich library
pretty.install()
con = Console(record=False, force_jupyter=False)
con.width = 200

import warnings
warnings.filterwarnings('ignore', message=r'.*Sending large graph of size.*')

tic = time.time()
con.print(f'Process initiated at {time.ctime()}')
# --- End Import Modules --- #

## Define the input files and other relevant local variables

In [None]:
NWM_type = 'LDASOUT'

# Variable to process - list form, from LDASOUT, and LDASIN
variables = ['deltaACCET',
             'deltaACSNOW',
             'deltaSNEQV',
             'deltaSOILM',
             'deltaUGDRNOFF',
             'deltaSOILM_depthmean',
             'avgSNEQV',
             'avgSOILM',
             'avgSOILM_depthmean',
             'avgSOILM_wltadj_depthmean',
             'avgSOILSAT',
             'avgSOILSAT_wltadj_top1',
             'totPRECIP',
             'avgT2D']

# Give a name to the zone dataset, which will be the name of the zone variable
zone_name = 'WBDHU12'

# Perform temporal subset on inputs?
temporal_subset = True

# Choose the temporal range, if temporal_subset is true
time_subset_bounds = slice('2011-10-01', '2012-09-30')     # Test Hytest batch (1 year)

## Define the output files and other relevant variables to outputs

In [None]:
# Output directory to save 2D aggregation results
outDir = Path('/path/to/outputs/agg_out')
con.print(f'outDir exists: {outDir.is_dir()}')

# Basename for output files - extension will be applied later
output_pattern = 'CONUS_HUC12_2D_20111001_20120930'

# Other variables to help with the file output naming convention
write_CSV = True
write_NC = True

# Apply a landmask to the weight grid so that water cells are not considered in the spatial statistics? 
# Only applies to LSM grid variables
landmask_results = True

# Variables that will be normalized to the full land area (not landmasked land area)
non_landmask_vars = ['Precip', 'landmask']

# Add variables that we want to process spatial stats for
addVars = ['total_gridded_area'] + non_landmask_vars    # For all other processing
#addVars = ['total_gridded_area']                        # For the soil moisture top layer variables

# Calculate percent soil saturation as a derived output variable
pct_sat = False

### Handle the processing of input variables if the source is raw NWM

Use the NWM_type to define the input Zarr store, and any other processing requirements (unit conversion, time resampling, etc.)

In [None]:
# We will construct a list of files. They must all contain the same time and other dimensions in order to be concatenated using open_mfdataset
convert_to_mm = False

# Specify the directory where the precip (LDASIN, clim_*.nc) files are stored. This can be different in some cases than the othe files
clim_dir = Path('/path/to/temporal/aggregations/output/monthly')
con.print(f'clim_dir exists: {clim_dir.is_dir()}')

# Specify the directory where the LDASOUT (water_*.nc) files are stored.
land_dir = Path('/path/to/temporal/aggregations/output/monthly')
con.print(f'land_dir exists: {land_dir.is_dir()}')

# Add a second set of variables from a different set of files
file_in = get_files_wildcard(land_dir, 
                             file_pattern='water_*.nc', 
                             recursive=False)

# Obtain list of files from wildcard
file_in2 = get_files_wildcard(clim_dir, 
                             file_pattern='clim_*.nc', 
                             recursive=False)

if len(file_in) != len(file_in2):
    con.print('[orange_red1]WARNING[/]: The number of files in clim_dir and land_dir are not the same.')

### Spin up a Dask Cluster

Spin up a slurm cluster to parallelize the aggregation process. The scheduler for the dask cluster requires 150 GB of memory to allocate jobs. 

In [None]:
%%time

try:
    project='impd'
    #project = os.environ['SLURM_JOB_ACCOUNT']
except KeyError:
    logging.error("SLURM_JOB_ACCOUNT is not set in the active environment. Are you on the login node? You should not be running this there.")
    raise

cluster = SLURMCluster(job_name='dask_niwaa',
                       account=project,
                       processes=1, 
                       cores=1, 
                       memory='10GB', 
                       interface='ib0',
                       walltime='01:00:00',      
                       shared_temp_directory='/home/lstaub/tmp',
                       #job_extra={'hint': 'multithread'},
                       #scheduler_options = {'dashboard_address': ':32939'}
                      )

cluster.adapt(minimum=10, maximum=30)
con.print(cluster.job_script())

client = Client(cluster)

ood_dashboard_link = f"https://hw-ood.cr.usgs.gov/node/{os.environ['JUPYTER_SERVER_NODE']}/{os.environ['JUPYTER_SERVER_PORT']}/proxy/{client.dashboard_link.split(':')[2]}"
con.print(f'Dask Dashboard for OnDemand is available at: {ood_dashboard_link}')

con.print("The 'cluster' object can be used to adjust cluster behavior.  i.e. 'cluster.adapt(minimum=10)'")
con.print("The 'client' object can be used to directly interact with the cluster.  i.e. 'client.submit(func)' ")
con.print(f"The link to view the client dashboard is:\n>  {client.dashboard_link}")

In [None]:
client

In [None]:
# client.close()
# cluster.close()

#### Open the input file and read some useful information

In [None]:
%%time
# NOTE: 2025-03-27 PAN - added parallel=True to open_mfdataset

def extract_dates(in_paths=[], format_str='%Y%m'):
    """
    This function will take an input path and extract a date object from the filename. 
    Assumes that the filename ends with "_{datestring}.nc" (default = YYYYMM)
    """
    dt_strings = [os.path.basename(in_path).split('.nc')[0].split('_')[1] for in_path in in_paths] 
    dt_obj = pd.to_datetime(dt_strings, format=format_str)
    return dt_obj

# Open the selected dataset(s), dropping variables as necessary
drop_vars = [var_in for var_in in xr.open_dataset(file_in[0]).variables if var_in not in variables+[time_coord]]

if len(file_in2) > 1:
    drop_vars += [var_in for var_in in xr.open_dataset(file_in2[0]).variables if var_in not in variables+[time_coord]]

drop_vars = list(set(drop_vars)) # Eliminate redundancy
con.print(f'Dropping {drop_vars} from input file.')

# Only use this method if datasets are coming from multiple directories or file types
with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    # This is a little complicated because we will be building multiple datasets
    ds_list = [xr.open_mfdataset(in_list, 
                                 combine='nested', 
                                 decode_cf=False, 
                                 concat_dim='time',
                                 chunks='auto',
                                 parallel=True,
                                 drop_variables=drop_vars) for in_list in [file_in, file_in2] if len(in_list) > 0]

    datetimes = [extract_dates(in_list) for in_list in [file_in, file_in2] if len(in_list)>0]
    ds_list = [ds.assign_coords(time=datetimes_in) for ds, datetimes_in in zip(ds_list, datetimes)]
    ds = xr.merge(ds_list)
    del ds_list, datetimes
    
# Perform temporal subset, or not
if temporal_subset:
    ds = ds.loc[{time_coord:time_subset_bounds}]
    
# Obtain and print information about the input file
ds, timesteps, x_chunk_sizes, y_chunk_sizes, time_chunk_sizes = report_structure(ds, variable=list(ds.data_vars.keys())[0])
ds

#### Obtain the spatial aggregation array

In [None]:
%%time

# Choose a method for spatial aggregation
raster_zones = True
spatial_weights = False

# Use a 2D grid of zone IDs to perform spatial aggregation.
# This is a representation of the zones on the same grid as the analysis data.
if raster_zones:
    
    # Sort out resolution and input files
    if NWM_type == 'RTOUT':
        zone_raster = r'/caldera/hovenweep/projects/usgs/water/impd/hytest/niwaa_wrfhydro_monthly_huc12_aggregations_sample_data/HUC12_grids/HUC12s_on_250m_grid.tif'
        LSM_grid = False
    elif NWM_type == 'LDASOUT':
        zone_raster = r'/caldera/hovenweep/projects/usgs/water/impd/hytest/niwaa_wrfhydro_monthly_huc12_aggregations_sample_data/HUC12_grids/HUC12s_on_1000m_grid.tif'
        LSM_grid = True
    print('Using raster grid of zones for spatial aggregation: {0}'.format(zone_raster))
    
    # Data value to define nodata in the zone raster (anywhere that a zone does not exist).
    zone_nodata = 0

    # Read in the raster that defines the zones
    zone_arr, zone_ndv = return_raster_array(zone_raster)

    # Flip the raster if necessary - easier than flipping each input array from the model data
    if LSM_grid:
        zone_arr = zone_arr[flip_dim(['y', 'x'], DimToFlip='y')]

    # Replace nodata values with np.nan, which requires converting to floating point.    
    zone_arr = zone_arr.astype('float')    
    zone_arr[zone_arr==zone_nodata] = np.nan

    # Obtain unique values
    zone_uniques = np.unique(zone_arr)
    zones_unique = zone_uniques[zone_uniques!=np.nan]
    print('{0} zones found in the input dataset'.format(zones_unique.shape[0]-1))
    del zone_uniques, zones_unique
    
    # Add zones to the Xarray DataSet object
    zones = xr.DataArray(zone_arr, dims=("y", "x"), name=zone_name)
    #ds[zone_name] = zones.fillna(-1).astype(int)   # workaround flox bug
    ds[zone_name] = zones.fillna(-1).astype(np.int64)   # workaround flox bug
    del zones
    
    # Obtain landmask grid
    if landmask_results and NWM_type == 'LDASOUT':
        print('  Masking zone grid to LSM LANDMASK variable')
        landmask = xr.open_dataset(geogrid)['LANDMASK'].squeeze()
        zone_masked = zone_arr.copy()
        zone_masked[landmask==0] = np.nan
        masked_zone_name = '{0}_masked'.format(zone_name)
        zones_ma = xr.DataArray(zone_masked, dims=("y", "x"), name=masked_zone_name)
        
        # Filling NaN areas (water or ocean) with -1 removes it from that HUC.
        #ds[masked_zone_name] = zones_ma.fillna(-1).astype(int)   # workaround flox bug
        ds[masked_zone_name] = zones_ma.fillna(-1).astype(np.int64)   # workaround flox bug
        
        # Save the landmask (1s and 0s)
        landmask_da = xr.DataArray(landmask, dims=("y", "x"), name='landmask')
        ds['landmask'] = landmask_da.fillna(0).astype(int)   # workaround flox bug
        del landmask, zones_ma
    
        # Obtain unique values
        zone_uniques = np.unique(zone_masked)
        zones_unique = zone_uniques[zone_uniques!=np.nan]
        print('{0} zones found in the input dataset after land-masking'.format(zones_unique.shape[0]-1))
        del zone_uniques, zones_unique, zone_masked
        
    del zone_arr
    
# Use a 1D array of pixel weights to perform spatial aggregation
### NOT YET WORKING!
elif spatial_weights:
    sw_file =r'/caldera/hovenweep/projects/usgs/water/impd/hytest/niwaa_wrfhydro_monthly_huc12_aggregations_sample_data/static_niwaa_wrf_hydro_files/WRFHydro_spatialweights_CONUS_250m_NIWAAv1.0.nc'
    print('Using pre-computed NWM-style spatial weight file for spatial aggregation: {0}'.format(sw_file))
    
    # If the raster used to create spatial weights was created in GIS, then it will start with 0,0 in UL corner. 
    # To flip to south_north, select flip_raster==True
    flip_raster = True
    
    # Open the spatial weight file
    sw_ds = xr.open_dataset(sw_file)

    # Subset the spatial weight file to just one zone
    sw_ds = sw_ds.drop(['overlaps', 'polyid', 'regridweight'])
    sw_ds.load()
    
    display(sw_ds)

    # For now, flox need an integer for the zone IDs
    sw_ds['IDmask'] = sw_ds['IDmask'].astype(np.int64)
    sw_ds = sw_ds.rename({'IDmask':zone_name})

    # Obtain indexer arrays and alter the indices to 'flip' the y dimension if requested.
    indexer_i = sw_ds['i_index'].astype(int).data
    if flip_raster:
        indexer_j = LSM_grid_size_y - sw_ds['j_index'].astype(int).data
    else:
        indexer_j = sw_ds['j_index'].astype(int).data
        
    # Add the spatial weight variables to the dataset
    ds = xr.merge([ds, sw_ds])

In [None]:
ds

In [None]:
con.print(f'{timesteps=}')

## Iterate over time, processing the zonal statistics

### Perform 2D Groupby operation

This codeblock will execute the 2D groupby (zonal statistic) operation using the `flox` method `xarray_reduce` or `groupby_reduce`.

#### Method of operation

For some datasets there may be a memory limitation that will cause individual workers to pause once they reach 80% memory utilization. Thus, we have to carefully select the size of chunks to process. Currently, we use the existing chunk size in the input Zarr store, establishing our iteration strategy on how many time-chunks from the input we can process at once. Keep in mind that the full 2D dataset will be used at each timestep, so only the time chunk will be considered. The `time_chunk_factor` is used to multiply the time-chunk to determine the number of timesteps processed at each iteration. Keep in mind that processing times appear to scale linearly, so this may not be an important factor.

Currently, for certain variables, we calculate the sum over a third dimension, such as soil_layers_stag for the `SOIL_M` variable. 

Currently, the statistical operations provided in the `numpy_groupies` python library are supported:
* `sum`, `nansum`
* `prod`, `nanprod`
* `mean`, `nanmean`
* `var`, `nanvar`
* `std`, `nanstd`
* `min`, `nanmin`
* `max`, `nanmax`
* `first`, `nanfirst`
* `last`, `nanlast`
* `argmax`, `nanargmax`
* `argmin`, `nanargmin`

An output CSV is issued for each iteration and each statistic requested.

Other configurations are set to assist in the chunking of the data. A variable `time_chunk_factor` is used to calculate how many timestep chunks to use for each iteration. One CSV file is written out per iteration, per statistic calculated (currently `mean` and `max` are supported).

In [None]:
%%time
con.print(f'Process initiated at {time.ctime()}')
        
# Output to file
with performance_report(filename=os.path.join(outDir, "dask-report_2D_2.html")):  
    # Determine how many time chunks we can process at once
    time_chunks = [timesteps]    # To process all times at once, provide nested list containing all timesteps
    #time_chunk_size = 2  
    #time_chunks = [timesteps[i:i + time_chunk_size] for i in range(0, len(timesteps), time_chunk_size)]
    con.print(f'There will be {len(time_chunks)} iterations over time.')

    # Iterate over variables
    datetime_strings = []
    con.print(f'There will be up to {len(addVars + variables)} variables processed.')
    for varnum, variable in enumerate(addVars + variables):
        tic1 = time.time()
        #if variable not in ds:
        #    print('Skipping variable {0}'.format(variable))
        #    continue
        con.print(f'Processing variable [bold]{variable}[/]')

        # Set the appropriate zone mask
        if variable in non_landmask_vars + ['Precip']:
            # Use full basin zone array for spatial aggregation. No land-masking
            con.print(f'  Using full basin mask for variable {variable}')
            zone_da = ds[zone_name]

            # Special case where we re-use a variable to produce a secondary result
            if variable == 'Precip':
                da = ds['totPRECIP']
                da.name = variable
        else:
            con.print('  Using land/water mask to remove water cells from analysis')
            # Use land-masked zone array for spatial aggregation
            zone_da = ds[masked_zone_name]

        # Subset the variable to a DataArray
        if variable in ds:
            da = ds[variable]

        # Special case to gather gridded area considered for each basin
        elif variable == 'total_gridded_area':
            # Make an array of ones to collect the total gridded area for each basin
            da = xr.ones_like(ds['landmask'])
            da.name = variable

        # Initialize list to store temporary partial DataArrays
        outputs = []

        # Iterate over time-chunks and process zonal statistics
        for n, time_chunk in enumerate(time_chunks):
            # Interpret times as strings - for later input to CSV files as a time index
            datetime_strings += [pd.to_datetime(time_chunk).strftime('%Y%m%d%H')]

            # Subset in time if necessary
            if 'time' in da.dims:
                data = da.loc[dict(time=slice(time_chunk[0], time_chunk[-1]))]
            else:
                data = da

            # Handle total soil moisture depth
            if NWM_type == 'LDASOUT' and variable in ['SOIL_M','deltaSOILM','avgSOILM']:
                con.print('\tConverting soil moisture value to total water depth (mm) in soil column.')

                # For Soil Moisture, apply weights to soil depths to get total volume (in mm) in soil column.
                soil_dict = dict(soil_weights=('soil_layers_stag', [0,1,2,3]))
                weights = xr.DataArray(soil_depths_mm, dims=('soil_layers_stag',), coords=soil_dict)

                # Multiply by depth and sum the values over depth dimension
                data = (data * weights).sum(dim='soil_layers_stag')
                data.name = variable  # reset the dataarray name

            # Apply groupby operation
            if raster_zones:
                if variable == 'total_gridded_area':
                    flox_function = 'sum'
                else:
                    flox_function = 'mean'
                    
                con.print(f'\t[{varnum}]    Calculating zonal {flox_function}.')
                output = run_flox(data, zone_da, flox_function=flox_function, n=n)
            elif spatial_weights:
                # Convert from 2D to 1D array using indexer_j and indexer_i
                flox_function = 'sum'
                con.print(f'\t[{varnum}]    Calculating spatially weighted value {flox_function}.')
                output = run_flox(data.data[indexer_j, indexer_i] * ds['weight'], 
                                  zone_da, 
                                  flox_function=flox_function, 
                                  n=n)
                
            if variable not in non_landmask_vars+['Precip']:
                output = output.rename({masked_zone_name:zone_name})
            outputs.append(output)
            del data
        con.print(f'\t[{varnum}] Spatial aggregation step completed in {time.time()-tic1:3.2f} seconds.')   # .format(varnum, time.time()-tic1))

        # Merge all outputs together
        output = xr.merge(outputs)

        # Re-arrange dimensions so that time is the fastest varying dimension
        if 'time' in output.dims:
            output = output[[zone_name, time_coord, variable]]

        #if varnum == 0:
        if not 'out_ds' in locals():
            out_ds = output
        else:
            out_ds[variable] = output[variable]
        con.print(f'\t[{varnum}] Iteration completed in {time.time()-tic1:3.2f} seconds.')  # .format(varnum, time.time()-tic1))
    out_ds

In [None]:
out_ds

In [None]:
print(output)

### Remove unecessary attributes

In [None]:
# Eliminate any unecessary variable attributes (such as spatial metadata)
for variable in out_ds.data_vars:
    if 'grid_mapping' in out_ds[variable].attrs:
        del out_ds[variable].attrs['grid_mapping']
    if 'esri_pe_string' in out_ds[variable].attrs:
        del out_ds[variable].attrs['esri_pe_string']
    if 'proj4' in out_ds[variable].attrs:
        del out_ds[variable].attrs['proj4']
    if variable == 'landmask':
        out_ds[variable].attrs = {'description':'Fraction of gridded land area in each HUC12'}
    if variable == 'total_gridded_area':
        out_ds[variable].attrs = {'description':'Number of 1km grid cells for HUC12. Equivalend to square kilometers. Based on grid association of each HUC12'}
        
# Now eliminate unnecessary global attributes
if 'grid_mapping' in out_ds.attrs:
    del out_ds.attrs['grid_mapping']
if 'units' in out_ds.attrs:
    del out_ds.attrs['units']  
if 'esri_pe_string' in out_ds.attrs:
    del out_ds.attrs['esri_pe_string'] 
if 'long_name' in out_ds.attrs:
    del out_ds.attrs['long_name'] 
if '_FillValue' in out_ds.attrs:
    del out_ds.attrs['_FillValue'] 
if 'missing_value' in out_ds.attrs:
    del out_ds.attrs['missing_value'] 
out_ds

In [None]:
out_ds = out_ds.where(out_ds[zone_name]!=-1, drop=True)
out_ds

### Output to disk

In [None]:
%%time

# Read into memory before writing to disk?
out_ds.compute()

# Write output file (CSV)
if write_CSV:
    tic1 = time.time()
    out_file = os.path.join(outDir, output_pattern+'_2.csv')
    print('  Writing output to {0}'.format(out_file))
    if os.path.exists(out_file):
        tic1 = time.time()
        df_in = pd.read_csv(out_file)
        df_out = pd.concat([df_in, out_ds.to_dataframe()])
        df_out.to_csv(out_file)
        print('\t      Output file written in {0:3.2f} seconds.'.format(time.time()-tic1))
    else:
        write_csv(out_ds, out_file, columns=output[zone_name], index=[datetime_strings])
    print('\tExport to CSV completed in {0:3.2f} seconds.'.format(time.time()-tic1))
    
# Write output file (netCDF)
if write_NC:
    tic1 = time.time()
    out_file = os.path.join(outDir, output_pattern+'_2.nc')
    if os.path.exists(out_file):
        in_ds = xr.open_dataset(out_file).load()
        out_ds2 = xr.merge([in_ds, out_ds.transpose()])
        in_ds.close()
        del in_ds
        print('  Writing output to {0}'.format(out_file))
        out_ds2.to_netcdf(out_file, mode='w', format="NETCDF4", compute=True)
        del out_ds2
    else:
        print('  Writing output to {0}'.format(out_file))
        out_ds.transpose().to_netcdf(out_file, mode='w', format="NETCDF4", compute=True)
    print('\tExport to netCDF completed in {0:3.2f} seconds.'.format(time.time()-tic1))

## Spin Down the Cluster and Close datasets
##### After we are done, we can spin down our cluster

In [None]:
# Close the Dask cluster
client.close()
cluster.close()

In [None]:
# Close dataset
ds.close()
print('Process completed in {0: 3.2f} seconds.'.format(time.time()-tic))