In [1]:
import os 
from datetime import datetime

import numpy as np
import jax.numpy as jnp
import jax
import optax

import dask

import matplotlib.pyplot as plt

from water_balance_jax import wbm_jax, construct_Kpet_vec
from initial_params import initial_params, constants
from param_bounds import params_lower, params_upper

In [2]:
#####################
#### Directories ####
#####################
project_data_path = "/storage/group/pches/default/users/dcl5300/wbm_soilM_crop_uc_lafferty-etal-2024-tbd_DATA"

In [3]:
# Parameter names
param_names = [
"awCap_scalar", "wiltingp_scalar", \
"alpha_claycoef", "alpha_sandcoef", "alpha_siltcoef", \
"betaHBV_claycoef", "betaHBV_sandcoef", "betaHBV_siltcoef", "betaHBV_elevcoef", \
"GS_start_corn", "GS_end_corn", "L_ini_corn", "L_dev_corn", "L_mid_corn", "Kc_ini_corn", "Kc_mid_corn", "Kc_end_corn", "K_min_corn", "K_max_corn", \
"GS_start_cotton", "GS_end_cotton", "L_ini_cotton", "L_dev_cotton", "L_mid_cotton", "Kc_ini_cotton", "Kc_mid_cotton", "Kc_end_cotton", "K_min_cotton", "K_max_cotton", \
"GS_start_rice", "GS_end_rice", "L_ini_rice", "L_dev_rice", "L_mid_rice", "Kc_ini_rice", "Kc_mid_rice", "Kc_end_rice", "K_min_rice", "K_max_rice", \
"GS_start_sorghum", "GS_end_sorghum", "L_ini_sorghum", "L_dev_sorghum", "L_mid_sorghum", "Kc_ini_sorghum", "Kc_mid_sorghum", "Kc_end_sorghum", "K_min_sorghum", "K_max_sorghum",\
"GS_start_soybeans", "GS_end_soybeans", "L_ini_soybeans", "L_dev_soybeans", "L_mid_soybeans", "Kc_ini_soybeans", "Kc_mid_soybeans", "Kc_end_soybeans", "K_min_soybeans", "K_max_soybeans", \
"GS_start_wheat", "GS_end_wheat", "L_ini_wheat", "L_dev_wheat", "L_mid_wheat", "Kc_ini_wheat", "Kc_mid_wheat", "Kc_end_wheat", "K_min_wheat", "K_max_wheat"
]

In [4]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="10GiB",
    walltime="00:30:00"
)
cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.162:37507,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


### Preliminaries

In [5]:
def read_inputs(subset_name, obs_name, remove_nans):
    ######################
    # Read obs
    obs = np.load(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.npy')

    ######################
    # Read and extract inputs
    npz = np.load(f"{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/inputs.npz")

    # Meteo forcing
    tas = npz['tas']
    prcp = npz['prcp']

    # LAI
    lai = npz['lai']

    # Soil properties
    awCap = npz['awCap']
    wiltingp = npz['wiltingp']
    clayfrac = npz['clayfrac']
    sandfrac = npz['sandfrac']
    siltfrac = npz['siltfrac']

    # Land use
    corn = npz['corn']
    cotton = npz['cotton']
    rice = npz['rice']
    sorghum = npz['sorghum']
    soybeans = npz['soybeans']
    durum_wheat = npz['durum_wheat']
    spring_wheat = npz['spring_wheat']
    winter_wheat = npz['winter_wheat']
    wheat = durum_wheat + spring_wheat + winter_wheat
    
    cropland_other = npz['cropland_other']
    water = npz['water']
    evergreen_needleleaf = npz['evergreen_needleleaf']
    evergreen_broadleaf = npz['evergreen_broadleaf']
    deciduous_needleleaf = npz['deciduous_needleleaf']
    deciduous_broadleaf = npz['deciduous_broadleaf']
    mixed_forest = npz['mixed_forest']
    woodland = npz['woodland']
    wooded_grassland = npz['wooded_grassland']
    closed_shurbland = npz['closed_shurbland']
    open_shrubland = npz['open_shrubland']
    grassland = npz['grassland']
    barren = npz['barren']
    urban = npz['urban']
    
    all_other = cropland_other + water + evergreen_needleleaf + evergreen_broadleaf + deciduous_needleleaf + deciduous_broadleaf + mixed_forest + woodland + wooded_grassland + closed_shurbland + open_shrubland + grassland + barren + urban
    
    # Geophysical
    elev_std = npz['elev_std']
    
    lats = npz['lats']
    lons = npz['lons']
    
    # Initial conditions
    Ws_init = npz['soilMoist_init']

    ##########################
    # Prepare inputs for vmap:
    # spatial dimensions need to be collapsed and first
    # NaN gridpoints need to be removed
    nx = tas.shape[0]
    ny = tas.shape[1]
    nt = tas.shape[2]

    assert nt % 365 == 0
    nyrs = int(nt / 365)

    ## Obs
    ys = obs.reshape(nx * ny, nt)
    nan_inds_obs = jnp.isnan(ys).any(axis=1)

    ## Forcing: all days
    tas_in = tas.reshape(nx * ny, nt)
    prcp_in = prcp.reshape(nx * ny, nt)

    x_forcing_nt = jnp.stack([tas_in, prcp_in], axis=1)
    nan_inds_forcing_nt = jnp.isnan(x_forcing_nt).any(axis=(1,2))

    ## Forcing: yearly
    lai_in = lai.reshape(nx * ny, 365)
    x_forcing_nyrs = lai_in
    nan_inds_forcing_nyrs = jnp.isnan(x_forcing_nyrs).any(axis=1)

    ## Maps
    awCap_in = awCap.reshape(nx * ny)
    wiltingp_in = wiltingp.reshape(nx * ny)

    Ws_init_in = Ws_init.reshape(nx * ny)

    clayfrac_in = clayfrac.reshape(nx * ny)
    sandfrac_in = sandfrac.reshape(nx * ny)
    siltfrac_in = siltfrac.reshape(nx * ny)

    lats_in = np.tile(lats, nx)
    elev_std_in = elev_std.reshape(nx * ny)

    corn_in = corn.reshape(nx * ny)
    cotton_in = cotton.reshape(nx * ny)
    rice_in = rice.reshape(nx * ny)
    sorghum_in = sorghum.reshape(nx * ny)
    soybeans_in = soybeans.reshape(nx * ny)
    wheat_in = wheat.reshape(nx * ny)

    all_other_in = all_other.reshape(nx * ny)

    x_maps = jnp.stack([awCap_in, wiltingp_in, 
                        Ws_init_in, 
                        clayfrac_in, sandfrac_in, siltfrac_in, 
                        lats_in, elev_std_in,
                        corn_in, cotton_in, rice_in, sorghum_in, soybeans_in, wheat_in],
                       axis=1)
    nan_inds_maps = jnp.isnan(x_maps).any(axis=1)

    # Remove NaNs if desired
    if remove_nans:
        nan_inds = nan_inds_obs + nan_inds_forcing_nt + nan_inds_forcing_nyrs + nan_inds_maps
        ys = ys[~nan_inds]
        x_forcing_nt = x_forcing_nt[~nan_inds]
        x_forcing_nyrs = x_forcing_nyrs[~nan_inds]
        x_maps = x_maps[~nan_inds]

    # Return
    return ys, x_forcing_nt, x_forcing_nyrs, x_maps

In [6]:
def make_prediction(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps):
    # Read inputs
    tas, prcp = x_forcing_nt
    lai = x_forcing_nyrs
    
    awCap, wiltingp, \
    Ws_init, \
    clayfrac, sandfrac, siltfrac, \
    lats, elev_std, \
    corn, cotton, rice, sorghum, soybeans, wheat \
    = x_maps

    # Define all constants
    Ts, Tm, Wi_init, Sp_init = constants 
    
    # Define all params
    awCap_scalar, wiltingp_scalar, \
    alpha_claycoef, alpha_sandcoef, alpha_siltcoef, \
    betaHBV_claycoef, betaHBV_sandcoef, betaHBV_siltcoef, betaHBV_elevcoef, \
    GS_start_corn, GS_end_corn, L_ini_corn, L_dev_corn, L_mid_corn, Kc_ini_corn, Kc_mid_corn, Kc_end_corn, K_min_corn, K_max_corn, \
    GS_start_cotton, GS_end_cotton, L_ini_cotton, L_dev_cotton, L_mid_cotton, Kc_ini_cotton, Kc_mid_cotton, Kc_end_cotton, K_min_cotton, K_max_cotton, \
    GS_start_rice, GS_end_rice, L_ini_rice, L_dev_rice, L_mid_rice, Kc_ini_rice, Kc_mid_rice, Kc_end_rice, K_min_rice, K_max_rice,  \
    GS_start_sorghum, GS_end_sorghum, L_ini_sorghum, L_dev_sorghum, L_mid_sorghum, Kc_ini_sorghum, Kc_mid_sorghum, Kc_end_sorghum, K_min_sorghum, K_max_sorghum, \
    GS_start_soybeans, GS_end_soybeans, L_ini_soybeans, L_dev_soybeans, L_mid_soybeans, Kc_ini_soybeans, Kc_mid_soybeans, Kc_end_soybeans, K_min_soybeans, K_max_soybeans, \
    GS_start_wheat, GS_end_wheat, L_ini_wheat, L_dev_wheat, L_mid_wheat, Kc_ini_wheat, Kc_mid_wheat, Kc_end_wheat, K_min_wheat, K_max_wheat \
    = jnp.exp(theta)

    # Construct Kpet as weighted average
    Kpet_corn = construct_Kpet_vec(GS_start_corn, GS_end_corn, L_ini_corn, L_dev_corn, L_mid_corn, 1. - (L_ini_corn + L_dev_corn + L_mid_corn), Kc_ini_corn, Kc_mid_corn, Kc_end_corn, K_min_corn, K_max_corn, lai)
    Kpet_cotton = construct_Kpet_vec(GS_start_cotton, GS_end_cotton, L_ini_cotton, L_dev_cotton, L_mid_cotton, 1. - (L_ini_cotton + L_dev_cotton + L_mid_cotton), Kc_ini_cotton, Kc_mid_cotton, Kc_end_cotton, K_min_cotton, K_max_cotton, lai)
    Kpet_rice = construct_Kpet_vec(GS_start_rice, GS_end_rice, L_ini_rice, L_dev_rice, L_mid_rice, 1. - (L_ini_rice + L_dev_rice + L_mid_rice), Kc_ini_rice, Kc_mid_rice, Kc_end_rice, K_min_rice, K_max_rice, lai)
    Kpet_sorghum = construct_Kpet_vec(GS_start_sorghum, GS_end_sorghum, L_ini_sorghum, L_dev_sorghum, L_mid_sorghum, 1. - (L_ini_sorghum + L_dev_sorghum + L_mid_sorghum), Kc_ini_sorghum, Kc_mid_sorghum, Kc_end_sorghum, K_min_sorghum, K_max_sorghum, lai)
    Kpet_soybeans = construct_Kpet_vec(GS_start_soybeans, GS_end_soybeans, L_ini_soybeans, L_dev_soybeans, L_mid_soybeans, 1. - (L_ini_soybeans + L_dev_soybeans + L_mid_soybeans), Kc_ini_soybeans, Kc_mid_soybeans, Kc_end_soybeans, K_min_soybeans, K_max_soybeans, lai)
    Kpet_wheat = construct_Kpet_vec(GS_start_wheat, GS_end_wheat, L_ini_wheat, L_dev_wheat, L_mid_wheat, 1. - (L_ini_wheat + L_dev_wheat + L_mid_wheat), Kc_ini_wheat, Kc_mid_wheat, Kc_end_wheat, K_min_wheat, K_max_wheat, lai)

    other = 1. - (corn + cotton + rice + sorghum + soybeans + wheat)
    weights = jnp.array([corn, cotton, rice, sorghum, soybeans, wheat, other])
    Kpets = jnp.array([Kpet_corn, Kpet_cotton, Kpet_rice, Kpet_sorghum, Kpet_soybeans, Kpet_wheat, jnp.ones(365)])
    Kpet = jnp.average(Kpets, weights = weights, axis=0)
    
    # params that WBM sees
    awCap_scaled = awCap * awCap_scalar
    wiltingp_scaled = wiltingp * wiltingp_scalar
    alpha = 1.0 + (alpha_claycoef * clayfrac) + (alpha_sandcoef * sandfrac) + (alpha_siltcoef * siltfrac)
    betaHBV = 1.0 + (betaHBV_claycoef * clayfrac) + (betaHBV_sandcoef * sandfrac) + (betaHBV_siltcoef * siltfrac) + (betaHBV_elevcoef * elev_std)
    
    params = (Ts, Tm, wiltingp_scaled, awCap_scaled, alpha, betaHBV)
    
    # Make prediction
    prediction = wbm_jax(
        tas,
        prcp, 
        Kpet,
        Ws_init,
        Wi_init,
        Sp_init,
        lai,
        lats,
        params
    )

    return prediction

In [7]:
def train_and_store(subset_name, obs_name, _error_fn, error_fn_name, n_epochs,
                    batch_size = 2**7,
                    opt = 'adam',
                    learning_rate = 1e-2,
                    val_frac = 0.2,
                    reg_const = 0.01,
                    initial_params = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper):
    #############################################
    # Loss function with correct error metric
    ############################################
    # Prediction loss
    def prediction_loss(theta, constants,
                        x_forcing_nt, x_forcing_nyrs, x_maps, ys):
        
        prediction = make_prediction(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps)
        
        return _error_fn(prediction, ys)
    
    # Regularization loss
    def reg_loss(theta, initial_params, params_lower, params_upper):
        
        return jnp.nansum((theta - initial_params)**2 / ((theta - params_lower) * (params_upper - theta)))
    
    # Total loss
    def loss_fn(theta, reg_const, initial_params, params_lower, params_upper, constants,
                x_forcing_nt, x_forcing_nyrs, x_maps, ys):
        
        return prediction_loss(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps, ys) + \
                reg_const * reg_loss(theta, initial_params, params_lower, params_upper)
    
    # jit and vmap it
    pred_loss_value = jax.jit(jax.vmap(prediction_loss, in_axes=(None, None, 0, 0, 0, 0), out_axes=0))
    loss_value_and_grad = jax.jit(jax.vmap(jax.value_and_grad(loss_fn), in_axes=(None, None, None, None, None, None, 0, 0, 0, 0), out_axes=0))

    ###########################
    # Setup
    ###########################
    # Read data
    ys, x_forcing_nt, x_forcing_nyrs, x_maps = read_inputs(subset_name, obs_name, True)
    N = ys.shape[0]
    
    # Get train/val split over space
    N_val = int(N * val_frac)
    N_train = N - N_val
    
    train_idx = np.random.choice(N, N_train, replace=False)
    ys_train, x_forcing_nt_train, x_forcing_nyrs_train, x_maps_train = ys[train_idx], x_forcing_nt[train_idx], x_forcing_nyrs[train_idx], x_maps[train_idx]

    if N_val > 0:
        val_idx = np.array([n for n in np.arange(N) if n not in train_idx])
        ys_val, x_forcing_nt_val, x_forcing_nyrs_val, x_maps_val = ys[val_idx], x_forcing_nt[val_idx], x_forcing_nyrs[val_idx], x_maps[val_idx]
    
    # Define mini-batch hyper-parameters
    # n_minibatches = 1 + N // batch_size
    n_minibatches = 1 + N_train // batch_size

    # Initial parameters
    # theta = np.random.uniform(low=params_lower, high=params_upper)
    # theta = np.random.normal(loc=initial_params, scale=abs(initial_params/10.))
    theta = initial_params

    # Optimizer
    if opt == 'adam':
        adam = optax.adam(learning_rate=learning_rate)
        opt_fn = adam.update
        opt_state = adam.init(theta)
    elif opt == 'sgd':
        learning_rate = 1e-5
        opt_state = None
        def sgd(gradients, state):
            return -learning_rate * gradients, state
        opt_fn = sgd

    # Loss
    train_loss_out = np.empty(n_epochs + 1)
    pred_loss_out = np.empty(n_epochs + 1)
    reg_loss_out = np.empty(n_epochs + 1)
    val_loss_out = np.empty(n_epochs + 1)

    # Where to store results
    datetime_str = datetime.now().strftime('%Y%m%d-%H%M')
    training_name = f"{error_fn_name}_{str(n_epochs)}epochs_{str(batch_size)}batchsize_{opt}-opt_{str(val_frac)}val_{str(reg_const)}reg_{datetime_str}"

    out_file_path = f"{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/{training_name}.txt"
    f = open(out_file_path, "w")
    f.write(f"epoch metric train_loss pred_loss reg_loss val_loss {' '.join(param_names)}\n")

    # initial results
    pred_loss_init = jnp.mean(pred_loss_value(theta,
                                              constants,
                                              x_forcing_nt_train,
                                              x_forcing_nyrs_train,
                                              x_maps_train,
                                              ys_train))
    if N_val > 0:
        val_loss_init = jnp.mean(pred_loss_value(theta,
                                                 constants,
                                                 x_forcing_nt_val,
                                                 x_forcing_nyrs_val,
                                                 x_maps_val,
                                                 ys_val))
    else:
        val_loss_init = np.nan
        
    reg_loss_init = reg_loss(theta, initial_params, params_lower, params_upper)
    print(f"Epoch 0 pred loss: {pred_loss_init:.4f}, reg_loss: {reg_loss_init:.4f}, val loss: {val_loss_init:.4f}")
    
    ###########################
    # Training loop
    ###########################
    for epoch in range(n_epochs + 1):
        # Shuffle indices
        shuffled_inds = np.random.permutation(N_train)
    
        # Generate a mini-batch
        minibatch_inds = [shuffled_inds[(i*batch_size):((i + 1)*batch_size)] for i in range(n_minibatches)]

        # For batch loss
        batch_loss = [None] * n_minibatches

        for idx, inds in enumerate(minibatch_inds):
            # Calculate gradient of loss function, update parameters
            loss, grads = loss_value_and_grad(theta, reg_const, initial_params, params_lower, params_upper, constants,
                                              x_forcing_nt_train[inds],
                                              x_forcing_nyrs_train[inds],
                                              x_maps_train[inds],
                                              ys_train[inds])
            updates, opt_state = opt_fn(jnp.nanmean(grads, axis=0), opt_state)
            theta = optax.apply_updates(theta, updates)
            batch_loss[idx] = loss
            # Break if theta steps outside bounds
            if (theta < params_lower).any() or (theta > params_upper).any():
                print('Found invalid parameter')
                f.close()
                os.remove(out_file_path)
                return None

        # Save all losses
        train_loss_out[epoch] = jnp.nanmean(jnp.array([item for row in batch_loss for item in row]))
        reg_loss_out[epoch] = reg_loss(theta, initial_params, params_lower, params_upper)
        pred_loss_out[epoch] = train_loss_out[epoch] - (reg_const * reg_loss_out[epoch])
        if N_val > 0:
            val_loss_out[epoch] = jnp.mean(pred_loss_value(theta,
                                                          constants,
                                                          x_forcing_nt_val,
                                                          x_forcing_nyrs_val,
                                                          x_maps_val,
                                                          ys_val))
        else:
            val_loss_out[epoch] = jnp.nan
        
        # Write every epoch
        theta_str = [str(param) for param in theta]
        f.write(f"{str(epoch + 1)} {error_fn_name} {train_loss_out[epoch]:.4f} {pred_loss_out[epoch]:.4f} {reg_loss_out[epoch]:.4f} {val_loss_out[epoch]:.4f} {' '.join(theta_str)}\n")
        # Print every 5
        if epoch % 5 == 0:
            print(f"Epoch {str(epoch + 1)} total loss: {train_loss_out[epoch]:.4f}, pred loss: {pred_loss_out[epoch]:.4f}, reg_loss: {reg_loss_out[epoch]:.4f}, val loss: {val_loss_out[epoch]:.4f}")

    f.close()
    return train_loss_out, pred_loss_out, reg_loss_out, val_loss_out

# Fitting

In [8]:
# Length of timeseries needed for quantile RMSE
N = 2555

# Define all error functions
# RMSE
_rmse = lambda prediction, ys: jnp.sqrt(jnp.nanmean((prediction-ys)**2))

# MSE
_mse = lambda prediction, ys: jnp.nanmean((prediction-ys)**2)

# KGE
def _kge(prediction, ys):
    corr = jnp.nanmean((prediction - jnp.nanmean(prediction))*(ys - jnp.nanmean(ys))) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
    mean_ratio = jnp.nanmean(prediction) / jnp.nanmean(ys)
    std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
    kge = 1 - jnp.sqrt((corr - 1)**2 + (mean_ratio - 1)**2 + (std_ratio - 1)**2)
    return -kge 

# q0-25 RMSE
qmax = 0.25
size = round(N * qmax)
def _q25rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmax)
    inds = jnp.where(ys <= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q)**2))

# q75-100 RMSE
qmin = 0.75
size = round(N * qmin)
def _q75rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmin)
    inds = jnp.where(ys >= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q)**2))

_error_fns = [_rmse, _mse, _kge, _q25rmse, _q75rmse]
error_fn_names = ['rmse', 'mse', 'kge', 'q0-25rmse', 'q75-100rmse']

### SMAP

In [10]:
# Info
subset_name = 'centralUS'
obs_name = 'SMAP'

In [None]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    # Hyperparameter adjustments
    if error_fn_name == 'kge':
        reg_const = 0.001
    else:
        reg_const = 0.01

    if error_fn_name == 'mse':
        learning_rate = 1e-3
    else:
        learning_rate = 1e-2
        
    for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
        delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                     obs_name = obs_name,
                                                     _error_fn = _error_fn,
                                                     error_fn_name = error_fn_name,
                                                     batch_size = batch_size,
                                                     reg_const = reg_const,
                                                     learning_rate = learning_rate,
                                                     n_epochs = 30,
                                                     val_frac = 0.0))

out = dask.compute(*delayed)

### VIC

In [9]:
# Info
subset_name = 'centralUS'
obs_name = 'VIC'

# needed for quantile RMSE
N = np.load(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.npy').shape[-1]

In [16]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    # Hyperparameter adjustments
    if error_fn_name == 'kge':
        reg_const = 0.001
    else:
        reg_const = 0.01

    if error_fn_name == 'mse':
        learning_rate = 1e-3
    else:
        learning_rate = 1e-2
        
    for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
        delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                     obs_name = obs_name,
                                                     _error_fn = _error_fn,
                                                     error_fn_name = error_fn_name,
                                                     batch_size = batch_size,
                                                     reg_const = reg_const,
                                                     learning_rate = learning_rate,
                                                     n_epochs = 30,
                                                     val_frac = 0.0))

out = dask.compute(*delayed)

CPU times: user 4min 41s, sys: 21.2 s, total: 5min 2s
Wall time: 56min 33s


### MOSAIC

In [18]:
# Info
subset_name = 'centralUS'
obs_name = 'MOSAIC'

# needed for quantile RMSE
N = np.load(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.npy').shape[-1]

In [19]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    # Hyperparameter adjustments
    if error_fn_name == 'kge':
        reg_const = 0.001
    else:
        reg_const = 0.01

    if error_fn_name == 'mse':
        learning_rate = 1e-3
    else:
        learning_rate = 1e-2
        
    for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
        delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                     obs_name = obs_name,
                                                     _error_fn = _error_fn,
                                                     error_fn_name = error_fn_name,
                                                     batch_size = batch_size,
                                                     reg_const = reg_const,
                                                     learning_rate = learning_rate,
                                                     n_epochs = 30,
                                                     val_frac = 0.0))

out = dask.compute(*delayed)

CPU times: user 4min 23s, sys: 20.8 s, total: 4min 43s
Wall time: 56min 43s


### NOAH

In [20]:
# Info
subset_name = 'centralUS'
obs_name = 'NOAH'

# needed for quantile RMSE
N = np.load(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.npy').shape[-1]

In [None]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    # Hyperparameter adjustments
    if error_fn_name == 'kge':
        reg_const = 0.001
    else:
        reg_const = 0.01

    if error_fn_name == 'mse':
        learning_rate = 1e-3
    else:
        learning_rate = 1e-2
        
    for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
        delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                     obs_name = obs_name,
                                                     _error_fn = _error_fn,
                                                     error_fn_name = error_fn_name,
                                                     batch_size = batch_size,
                                                     reg_const = reg_const,
                                                     learning_rate = learning_rate,
                                                     n_epochs = 30,
                                                     val_frac = 0.0))

out = dask.compute(*delayed)