In [1]:
import os
from glob import glob
from datetime import datetime
import numpy as np
import xarray as xr
import dask

from utils.global_paths import project_data_path, project_code_path

### pyWBM parameters

In [2]:
# Output location
output_dir = '/storage/group/pches/default/users/dcl5300/lafferty_etal_2025_pyWBM_soil_data/pyWBM_parameters/'

In [3]:
def tidy_params(ds):
    # Tidy types
    ds = ds.astype("float32")
    ds["lat"] = ds["lat"].astype(np.float32)
    ds["lon"] = ds["lon"].astype(np.float32)
    ds['obs_name'] = ds['obs_name'].astype(object)

    ds = ds.rename({'metric':'loss_metric'})
    ds['loss_metric'] = ds['loss_metric'].astype(object)
        
    # Add variables attrs
    if 'doy' in ds.coords:
        ds['doy'].attrs['long_name'] = 'Day of year'
        ds['doy'].attrs['calendar'] = 'noleap'
    
    ds['lat'].attrs['standard_name'] = "latitude"
    ds['lat'].attrs['units'] = "degrees_north"
    
    ds['lon'].attrs['standard_name'] = "longitude"
    ds['lon'].attrs['units'] = "degrees_east"

    ds['obs_name'].attrs['note'] = 'Calibration dataset'
    
    ds['loss_metric'].attrs['note'] = 'Calibration loss metric'
        
    # Add global attrs
    ds.attrs['creation_date'] = datetime.today().strftime('%Y-%m-%d')
    ds.attrs['title'] = 'pyWBM calibrated parameters; David Lafferty, Cornell University'
    ds.attrs['reference'] = 'Lafferty et al. 2025; DOI: (pending)'
    ds.attrs['WBM_reference'] = 'Grogan et al.: Water balance model (WBM) v.1.0.0: a scalable gridded global hydrologic model with water-tracking functionality, Geosci. Model Dev., 15, 7287–7323, https://doi.org/10.5194/gmd-15-7287-2022, 2022'
    ds.attrs['Conventions'] = 'CF-1.8'
    ds.attrs['contact'] = 'dcl257@cornell.edu'

    return ds

In [4]:
# Combined parameter dataset
ds = tidy_params(xr.open_dataset(f"{project_data_path}/WBM/calibration/eCONUS/param_maps/param_map.nc"))

ds['awCap'].attrs['long_name'] = 'Available water capcity in 1m soil depth'
ds['awCap'].attrs['units'] = 'mm'

ds['wiltingp'].attrs['long_name'] = 'Wilting point'
ds['wiltingp'].attrs['units'] = 'mm'

ds['alpha'].attrs['long_name'] = 'Evapotranspiration response to soil moisture drying'

ds['beta_R'].attrs['long_name'] = 'Runoff shape parameter'

# Store
ds.to_netcdf(f"{output_dir}/pyWBM_parameters.nc")

In [13]:
# Kpet
ds = tidy_params(xr.open_dataset(f"{project_data_path}/WBM/calibration/eCONUS/param_maps/Kpet.nc"))

ds['Kpet'].attrs['long_name'] = 'Potential evapotranspiration coefficients'
ds['Kpet'].attrs['units'] = ''

# Store
ds.to_netcdf(f"{output_dir}/pyWBM_Kpet_coefs.nc")

### Forcing

In [3]:
def reformat_forcing(gcm, member, ssp, overwrite = False):
    # Output directory
    output_dir = '/storage/group/pches/default/users/dcl5300/lafferty_etal_2025_pyWBM_soil_data/forcing'
    output_file = f"{output_dir}/{gcm}_{ssp}_{member}_2016-2100_LOCA2_v20220519_regridded.nc"

    # Check if done
    if (overwrite and os.path.exists(output_file)) or (not os.path.exists(output_file)):
        try:
            # Read
            ds = xr.open_dataset(
                f"{project_data_path}/projections/eCONUS/forcing/LOCA2/{gcm}_{member}_{ssp}.zarr/",
                engine="zarr",
                chunks=None,
            )
        
            # For LOCA2 info
            loca_path = '/storage/group/pches/default/public/LOCA2'
            ds_orig = xr.open_dataset(f"{loca_path}/{gcm}/0p0625deg/{member}/{ssp}/pr/pr.{gcm}.{ssp}.{member}.2015-2044.LOCA_16thdeg_v20220519.nc",
                                     chunks='auto').isel(time=0)
            
            # Tidy types
            ds = ds.astype("float32")
            ds["lat"] = ds["lat"].astype(np.float32)
            ds["lon"] = ds["lon"].astype(np.float32)
            
            # Convert calendar
            ds = ds.sel(time=slice('2016-01-01',None)) # same as soil outputs
            ds["time"] = ds.indexes["time"].normalize()
            ds = ds.convert_calendar(calendar="noleap", dim="time")
        
            # Add variables attrs
            ds['tas'].attrs['units'] = "degC"
            ds['tas'].attrs['standard_name'] = "air_temperature"
            ds['tas'].attrs['long_name'] = "Daily Average Near-Surface Air Temperature"
            ds['tas'].attrs['units_metadata'] = "temperature: on_scale"
        
            ds['pr'].attrs['units'] = "mm"
            ds['pr'].attrs['standard_name'] = "lwe_thickness_of_precipitation_amount"
            ds['pr'].attrs['long_name'] = "Precipitation"
    
            ds['time'].attrs['axis'] = 'T'
            ds['time'].attrs['standard_name'] = 'time'
        
            # Add global attrs
            ds.attrs['creation_date'] = datetime.today().strftime('%Y-%m-%d')
            ds.attrs['title'] = 'LOCA statistically downscaled climate model data, David W. Pierce, Scripps Institution of Oceanography; regridded by David Lafferty, Cornell University'
            ds.attrs['references'] = 'Lafferty et al. 2025; DOI: (pending)'
            ds.attrs['Conventions'] = 'CF-1.8'
            ds.attrs['history'] = 'LOCA2 outputs regridded using xESMF 0.8.7 to NLDAS-2 12.5km grid (https://ldas.gsfc.nasa.gov/nldas/specifications)'
            ds.attrs['contact'] = 'dcl257@cornell.edu'
        
            # Some from original LOCA attrs
            loca_attrs_keep = ['activity_id',
                           'frequency',
                           'institution', 
                           'institution_id',
                           'mip_era', 
                           'nominal_resolution', 
                           'parent_activity_id', 
                           'parent_source_id', 
                           'parent_variant_label', 
                           'product',
                           'realm',
                           'license',
                           'LOCA2_version']
        
            for attr in loca_attrs_keep:
                if attr in ds_orig.attrs.keys():
                    ds.attrs[attr] = ds_orig.attrs[attr]
    
            # Compress and save
            comp = dict(zlib=True, complevel=5)
            encoding = {var: comp for var in ds.data_vars}
        
            ds.to_netcdf(output_file, encoding=encoding, mode="w")
        except Exception as e:
            with open(f"{project_code_path}/scripts/logs/{gcm}_{ssp}_{member}_forcing_reformat.txt", "w") as f:
                f.write(str(e))
                return None

In [4]:
%%time
# Get unique combos (forcings from pyWBM ensemble only)
infos = glob(f"{project_data_path}/projections/eCONUS/out/LOCA2/*.nc")
infos = [path.split('/')[-1] for path in infos] # remove preceding path
infos = np.unique([f"{path.split('_')[0]}_{path.split('_')[1]}_{path.split('_')[2]}" for path in infos]) # drop calibration info

# Loop through all
for info in infos:
    # Get info
    gcm, member, ssp = info.split('_')
    # Reformat
    reformat_forcing(gcm, member, ssp)
    print(gcm, member, ssp)

ACCESS-CM2 r1i1p1f1 ssp245
ACCESS-CM2 r1i1p1f1 ssp370
ACCESS-ESM1-5 r1i1p1f1 ssp245
ACCESS-ESM1-5 r1i1p1f1 ssp370
AWI-CM-1-1-MR r1i1p1f1 ssp245
AWI-CM-1-1-MR r1i1p1f1 ssp370
BCC-CSM2-MR r1i1p1f1 ssp245
BCC-CSM2-MR r1i1p1f1 ssp370
CESM2-LENS r10i1p1f1 ssp370
CNRM-CM6-1 r1i1p1f2 ssp245
CNRM-CM6-1 r1i1p1f2 ssp370
CNRM-ESM2-1 r1i1p1f2 ssp245
CNRM-ESM2-1 r1i1p1f2 ssp370
CanESM5 r1i1p1f1 ssp245
CanESM5 r1i1p1f1 ssp370
EC-Earth3-Veg r1i1p1f1 ssp245
EC-Earth3-Veg r1i1p1f1 ssp370
EC-Earth3 r1i1p1f1 ssp245
EC-Earth3 r1i1p1f1 ssp370
FGOALS-g3 r1i1p1f1 ssp245
FGOALS-g3 r1i1p1f1 ssp370
GFDL-CM4 r1i1p1f1 ssp245
GFDL-ESM4 r1i1p1f1 ssp245
GFDL-ESM4 r1i1p1f1 ssp370
HadGEM3-GC31-LL r1i1p1f3 ssp245
INM-CM4-8 r1i1p1f1 ssp245
INM-CM4-8 r1i1p1f1 ssp370
INM-CM5-0 r1i1p1f1 ssp245
INM-CM5-0 r1i1p1f1 ssp370
IPSL-CM6A-LR r10i1p1f1 ssp370
KACE-1-0-G r1i1p1f1 ssp245
KACE-1-0-G r1i1p1f1 ssp370
MIROC6 r1i1p1f1 ssp245
MIROC6 r1i1p1f1 ssp370
MPI-ESM1-2-HR r10i1p1f1 ssp370
MPI-ESM1-2-LR r10i1p1f1 ssp245
MPI-ESM1-2-LR r

### Soil moisture

In [19]:
@dask.delayed
def reformat_soil_output(gcm, member, ssp, obs_name, loss_metric, overwrite = False):
    # Output directory
    output_dir = '/storage/group/pches/default/users/dcl5300/lafferty_etal_2025_pyWBM_soil_data/soil_moisture'
    output_file = f"{output_dir}/{gcm}_{ssp}_{member}_{obs_name}_{loss_metric}_2016-2100.nc"

    # Check if done
    if (overwrite and os.path.exists(output_file)) or (not os.path.exists(output_file)):
        try:
            # Read
            ds = xr.open_dataset(
                f"{project_data_path}/projections/eCONUS/out/LOCA2/{gcm}_{member}_{ssp}_{obs_name}_{loss_metric}.nc"
            )
    
            # For LOCA2 info
            loca_path = '/storage/group/pches/default/public/LOCA2'
            ds_orig = xr.open_dataset(f"{loca_path}/{gcm}/0p0625deg/{member}/{ssp}/pr/pr.{gcm}.{ssp}.{member}.2015-2044.LOCA_16thdeg_v20220519.nc",
                                     chunks='auto').isel(time=0)
            
            # Tidy types
            ds = ds.astype("float32")
            ds["lat"] = ds["lat"].astype(np.float32)
            ds["lon"] = ds["lon"].astype(np.float32)
        
            # Add variables attrs
            ds['soilMoist'].attrs['units'] = "mm"
            ds['soilMoist'].attrs['standard_name'] = "lwe_thickness_of_soil_moisture_content"
            ds['soilMoist'].attrs['long_name'] = "Daily Average Soil Moisture Content in 1m Depth"
            ds['soilMoist'].attrs['layer_depth'] = "1m"
    
            ds['time'].attrs['axis'] = 'T'
            ds['time'].attrs['standard_name'] = 'time'
    
            ds['lat'].attrs['standard_name'] = "latitude"
            ds['lat'].attrs['units'] = "degrees_north"
    
            ds['lon'].attrs['standard_name'] = "longitude"
            ds['lon'].attrs['units'] = "degrees_east"
        
            # Add global attrs
            ds.attrs['creation_date'] = datetime.today().strftime('%Y-%m-%d')
            ds.attrs['title'] = 'pyWBM simulation outputs, forced by LOCA2 downcscaled climate data; David Lafferty, Cornell University'
            ds.attrs['references'] = 'Lafferty et al. 2025; DOI: (pending)'
            ds.attrs['Conventions'] = 'CF-1.8'
            ds.attrs['contact'] = 'dcl257@cornell.edu'
    
            calibration_data = "SMAP L4 daily average (https://gmao.gsfc.nasa.gov/GMAO_products/SMAP_L4/)" if obs_name == "SMAP" else \
                                    "NLDAS-2 daily average (https://ldas.gsfc.nasa.gov/nldas)"
            ds.attrs['pyWBM_calibration_data'] = calibration_data
            ds.attrs['pyWBM_calibration_loss_function'] = loss_metric
            
            ds.attrs['frequency'] = "day"
            ds.attrs['LOCA2_version'] = "v20220519"
    
            # Some from original LOCA attrs
            loca_attrs_keep = ['activity_id',
                           'institution', 
                           'institution_id',
                           'mip_era', 
                           'nominal_resolution', 
                           'parent_activity_id', 
                           'parent_source_id', 
                           'parent_variant_label', 
                           'product',
                           'realm',
                           'license']
        
            for attr in loca_attrs_keep:
                if attr in ds_orig.attrs.keys():
                    ds.attrs[f"LOCA2_{attr}"] = ds_orig.attrs[attr]
        
            # Compress and save
            comp = dict(zlib=True, complevel=5)
            encoding = {var: comp for var in ds.data_vars}
        
            ds.to_netcdf(output_file, encoding=encoding, mode="w")
        except Exception as e:
            with open(f"{project_code_path}/scripts/logs/{gcm}_{ssp}_{member}_{obs_name}_{loss_metric}_soilM_reformat.txt", "w") as f:
                f.write(str(e))
                return None

In [20]:
# Dask
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="30GiB",
    walltime="01:00:00",
)
cluster.scale(jobs=10)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.8.142:39879,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [21]:
%%time
# Get unique combos
infos = glob(f"{project_data_path}/projections/eCONUS/out/LOCA2/*.nc")
infos = [path.split('/')[-1] for path in infos]

# Loop through all, parallelize with dask
delayed = []
for info in infos:
    # Get info
    gcm, member, ssp, obs_name, loss_metric = info.replace('.nc', '').split('_')
    # Reformat
    delayed.append(reformat_soil_output(gcm, member, ssp, obs_name, loss_metric))

_ = dask.compute(*delayed)

CPU times: user 47.7 s, sys: 3.18 s, total: 50.9 s
Wall time: 8min 44s
