In [7]:
import os 
from datetime import datetime

import numpy as np
import jax.numpy as jnp
import jax
import optax

import dask

import matplotlib.pyplot as plt

from water_balance_jax import wbm_jax, construct_Kpet_vec
from initial_params import initial_params, initial_params_vic, constants
from param_bounds import params_lower, params_upper, params_vic_lower, params_vic_upper
from read_inputs import read_inputs, read_inputs
from prediction import make_prediction, make_prediction_vic
from global_paths import project_data_path

In [5]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="5GiB",
    walltime="01:00:00"
)
cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.156:39569,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


### Preliminaries

In [8]:
def train_and_store(subset_name,
                    obs_name,
                    _error_fn,
                    error_fn_name,
                    param_names,
                    n_epochs = 30,
                    batch_size = 2**5,
                    opt = 'adam',
                    learning_rate = 1e-3,
                    val_inds = None,
                    reg_const = 0.001,
                    initial_params = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper):
    #############################################
    # Loss function with correct error metric
    ############################################
    # Prediction loss
    if obs_name == "VIC":
        def prediction_loss(theta, constants,
                            x_forcing_nt, x_forcing_nyrs, x_maps, ys):
            prediction = make_prediction_vic(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps)
        return _error_fn(prediction, ys)
    else:
        def prediction_loss(theta, constants,
                            x_forcing_nt, x_forcing_nyrs, x_maps, ys):
            prediction = make_prediction(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps)
            return _error_fn(prediction, ys)
    
    # Regularization loss
    def reg_loss(theta, initial_params, params_lower, params_upper):
        return jnp.nansum((theta - initial_params)**2 / ((theta - params_lower) * (params_upper - theta)))
    
    # Total loss
    def loss_fn(theta, reg_const, initial_params, params_lower, params_upper, constants,
                x_forcing_nt, x_forcing_nyrs, x_maps, ys):
        return prediction_loss(theta, constants, x_forcing_nt, x_forcing_nyrs, x_maps, ys) + \
                reg_const * reg_loss(theta, initial_params, params_lower, params_upper)
    
    # jit and vmap it
    pred_loss_value = jax.jit(jax.vmap(prediction_loss, in_axes=(None, None, 0, 0, 0, 0), out_axes=0))
    loss_value_and_grad = jax.jit(jax.vmap(jax.value_and_grad(loss_fn), in_axes=(None, None, None, None, None, None, 0, 0, 0, 0), out_axes=0))

    ###########################
    # Setup
    ###########################
    # Read data
    ys, x_forcing_nt, x_forcing_nyrs, x_maps = read_inputs(subset_name, obs_name, True)
    N = ys.shape[0]
    
    # Get train/val split over space
    if len(val_inds) > 0:
        ys_val, x_forcing_nt_val, x_forcing_nyrs_val, x_maps_val = ys[val_inds], x_forcing_nt[val_inds], x_forcing_nyrs[val_inds], x_maps[val_inds]
        train_inds = np.array([n for n in np.arange(N) if n not in val_inds])
        ys_train, x_forcing_nt_train, x_forcing_nyrs_train, x_maps_train = ys[train_inds], x_forcing_nt[train_inds], x_forcing_nyrs[train_inds], x_maps[train_inds]
    else:
        ys_train, x_forcing_nt_train, x_forcing_nyrs_train, x_maps_train = ys, x_forcing_nt, x_forcing_nyrs, x_maps
    
    # Define mini-batch hyper-parameters
    N_train = ys_train.shape[0]
    n_minibatches = 1 + N_train // batch_size

    # Initial parameters
    theta = np.random.uniform(low=params_lower, high=params_upper)
    # theta = initial_params

    # Optimizer
    if opt == 'adam':
        adam = optax.adam(learning_rate=learning_rate)
        opt_fn = adam.update
        opt_state = adam.init(theta)
    elif opt == 'sgd':
        learning_rate = 1e-5
        opt_state = None
        def sgd(gradients, state):
            return -learning_rate * gradients, state
        opt_fn = sgd

    # Loss
    train_loss_out = np.empty(n_epochs + 1)
    pred_loss_out = np.empty(n_epochs + 1)
    reg_loss_out = np.empty(n_epochs + 1)
    val_loss_out = np.empty(n_epochs + 1)

    # Where to store results
    datetime_str = datetime.now().strftime('%Y%m%d-%H%M')
    random_str = str(abs(theta[0])).replace('.','')[:5] # used to discern different starting values
    training_name = f"{error_fn_name}_{str(n_epochs)}epochs_{str(batch_size)}batchsize_{str(val_inds[0])}val_{str(reg_const)}reg_{random_str}r"

    out_file_path = f"{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/{training_name}.txt"
    f = open(out_file_path, "w")
    f.write(f"epoch metric train_loss pred_loss reg_loss val_loss {' '.join(param_names)}\n")

    # initial results
    pred_loss_init = jnp.mean(pred_loss_value(theta,
                                              constants,
                                              x_forcing_nt_train,
                                              x_forcing_nyrs_train,
                                              x_maps_train,
                                              ys_train))
    if len(val_inds) > 0:
        val_loss_init = jnp.mean(pred_loss_value(theta,
                                                 constants,
                                                 x_forcing_nt_val,
                                                 x_forcing_nyrs_val,
                                                 x_maps_val,
                                                 ys_val))
    else:
        val_loss_init = np.nan
        
    reg_loss_init = reg_loss(theta, initial_params, params_lower, params_upper)
    print(f"Epoch 0 pred loss: {pred_loss_init:.4f}, reg_loss: {reg_loss_init:.4f}, val loss: {val_loss_init:.4f}")
    
    ###########################
    # Training loop
    ###########################
    invalid_theta_count = 0 
    
    for epoch in range(n_epochs + 1):
        # Shuffle indices
        shuffled_inds = np.random.permutation(N_train)
    
        # Generate a mini-batch
        minibatch_inds = [shuffled_inds[(i*batch_size):((i + 1)*batch_size)] for i in range(n_minibatches)]

        # For batch loss
        batch_loss = [None] * n_minibatches

        for idx, inds in enumerate(minibatch_inds):
            # Calculate gradient of loss function, update parameters
            loss, grads = loss_value_and_grad(theta, reg_const, initial_params, params_lower, params_upper, constants,
                                              x_forcing_nt_train[inds],
                                              x_forcing_nyrs_train[inds],
                                              x_maps_train[inds],
                                              ys_train[inds])
            updates, opt_state = opt_fn(jnp.nanmean(grads, axis=0), opt_state)
            theta = optax.apply_updates(theta, updates)
            batch_loss[idx] = loss
            # Break if theta steps outside bounds
            if (theta < params_lower).any() or (theta > params_upper).any():
                print('Found invalid parameter... re-initializaing')
                invalid_theta_count += 1
                if invalid_theta_count > 5:
                    f.close()
                    # os.remove(out_file_path)
                    return train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta
                else:
                    theta = np.random.uniform(low=params_lower, high=params_upper)

        # Save all losses
        train_loss_out[epoch] = jnp.nanmean(jnp.array([item for row in batch_loss for item in row]))
        reg_loss_out[epoch] = reg_loss(theta, initial_params, params_lower, params_upper)
        pred_loss_out[epoch] = train_loss_out[epoch] - (reg_const * reg_loss_out[epoch])
        if len(val_inds) > 0:
            val_loss_out[epoch] = jnp.mean(pred_loss_value(theta,
                                                          constants,
                                                          x_forcing_nt_val,
                                                          x_forcing_nyrs_val,
                                                          x_maps_val,
                                                          ys_val))
        else:
            val_loss_out[epoch] = jnp.nan
        
        # Write every epoch
        theta_str = [str(param) for param in theta]
        f.write(f"{str(epoch + 1)} {error_fn_name} {train_loss_out[epoch]:.4f} {pred_loss_out[epoch]:.4f} {reg_loss_out[epoch]:.4f} {val_loss_out[epoch]:.4f} {' '.join(theta_str)}\n")
        # Print every 5
        if epoch % 5 == 0:
            print(f"Epoch {str(epoch + 1)} total loss: {train_loss_out[epoch]:.4f}, pred loss: {pred_loss_out[epoch]:.4f}, reg_loss: {reg_loss_out[epoch]:.4f}, val loss: {val_loss_out[epoch]:.4f}")

    f.close()
    return train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta

# Fitting

In [9]:
# Length of timeseries needed for quantile RMSE
N = 2555

# Define all error functions
# RMSE
_rmse = lambda prediction, ys: jnp.sqrt(jnp.nanmean((prediction-ys)**2))

# MSE
_mse = lambda prediction, ys: jnp.nanmean((prediction-ys)**2)

# KGE
def _kge(prediction, ys):
    corr = jnp.nanmean((prediction - jnp.nanmean(prediction))*(ys - jnp.nanmean(ys))) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
    mean_ratio = jnp.nanmean(prediction) / jnp.nanmean(ys)
    std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
    kge = 1 - jnp.sqrt((corr - 1)**2 + (mean_ratio - 1)**2 + (std_ratio - 1)**2)
    return -kge 

# q0-25 RMSE
qmax = 0.25
size = round(N * qmax)
def _q25rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmax)
    inds = jnp.where(ys <= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q)**2))

# q75-100 RMSE
qmin = 0.75
size = round(N * qmin)
def _q75rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmin)
    inds = jnp.where(ys >= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q)**2))

_error_fns = [_rmse, _mse, _kge, _q25rmse, _q75rmse]
error_fn_names = ['rmse', 'mse', 'kge', 'q0-25rmse', 'q75-100rmse']

### NOAH

In [10]:
# Info
subset_name = 'centralUS'
obs_name = 'NOAH'

In [11]:
# Get dimensions
ys, _, _, _ = read_inputs(subset_name, obs_name, True)
Nspace = ys.shape[0]
Ntime = ys.shape[1]

# Get validation indices
val_frac = 0.2
val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

KeyError: 'rootDepth is not a file in the archive'

In [24]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    for _ in range(4):
        for val_inds in val_inds_all:
            delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                         obs_name = obs_name,
                                                         _error_fn = _error_fn,
                                                         error_fn_name = error_fn_name,
                                                         val_inds = val_inds))

_ = dask.compute(*delayed)

CPU times: user 6min 41s, sys: 27.7 s, total: 7min 8s
Wall time: 51min 12s


### VIC

In [13]:
# Info
subset_name = 'centralUS'
obs_name = 'VIC'

In [14]:
# soil class scalar using log, random start, scaled
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 151.2948, reg_loss: 210.8976, val loss: nan
Epoch 1 total loss: 145.2396, pred loss: 144.4348, reg_loss: 80.4860, val loss: nan
Epoch 6 total loss: 69.4841, pred loss: 69.0368, reg_loss: 44.7317, val loss: nan
Epoch 11 total loss: 43.8797, pred loss: 43.6616, reg_loss: 21.8073, val loss: nan
Epoch 16 total loss: 33.9699, pred loss: 33.7160, reg_loss: 25.3886, val loss: nan
Epoch 21 total loss: 29.3166, pred loss: 29.0769, reg_loss: 23.9760, val loss: nan
Epoch 26 total loss: 26.9986, pred loss: 26.7692, reg_loss: 22.9425, val loss: nan
Epoch 31 total loss: 26.6741, pred loss: 26.4396, reg_loss: 23.4555, val loss: nan


In [12]:
# soil class scalar using log, random start
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 193.0402, reg_loss: 275.4830, val loss: nan
Epoch 1 total loss: 188.6235, pred loss: 187.8289, reg_loss: 79.4585, val loss: nan
Epoch 6 total loss: 126.1107, pred loss: 125.7390, reg_loss: 37.1648, val loss: nan
Epoch 11 total loss: 72.8450, pred loss: 72.5604, reg_loss: 28.4649, val loss: nan
Epoch 16 total loss: 46.6197, pred loss: 46.3994, reg_loss: 22.0373, val loss: nan
Epoch 21 total loss: 37.7458, pred loss: 37.4790, reg_loss: 26.6730, val loss: nan
Epoch 26 total loss: 33.6993, pred loss: 33.4204, reg_loss: 27.8836, val loss: nan
Epoch 31 total loss: 32.5981, pred loss: 32.3313, reg_loss: 26.6761, val loss: nan


In [17]:
# NOAH + regression on soil contents using log, random start
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 135.3471, reg_loss: 619.0754, val loss: nan
Epoch 1 total loss: 117.6131, pred loss: 116.4815, reg_loss: 113.1661, val loss: nan
Epoch 6 total loss: 51.9875, pred loss: 51.4211, reg_loss: 56.6422, val loss: nan
Epoch 11 total loss: 42.1469, pred loss: 41.7205, reg_loss: 42.6382, val loss: nan
Epoch 16 total loss: 38.3026, pred loss: 37.9357, reg_loss: 36.6916, val loss: nan
Epoch 21 total loss: 36.0634, pred loss: 35.7291, reg_loss: 33.4250, val loss: nan
Epoch 26 total loss: 34.8380, pred loss: 34.5312, reg_loss: 30.6831, val loss: nan
Epoch 31 total loss: 34.2670, pred loss: 33.9823, reg_loss: 28.4617, val loss: nan


In [20]:
# NOAH, log, random start
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 123.8417, reg_loss: 195.7297, val loss: nan
Epoch 1 total loss: 112.4669, pred loss: 111.8620, reg_loss: 60.4960, val loss: nan
Epoch 6 total loss: 64.0041, pred loss: 63.6166, reg_loss: 38.7476, val loss: nan
Epoch 11 total loss: 48.7567, pred loss: 48.4685, reg_loss: 28.8285, val loss: nan
Epoch 16 total loss: 44.0785, pred loss: 43.8326, reg_loss: 24.5905, val loss: nan
Epoch 21 total loss: 41.5831, pred loss: 41.3319, reg_loss: 25.1146, val loss: nan
Epoch 26 total loss: 39.6683, pred loss: 39.4182, reg_loss: 25.0046, val loss: nan
Epoch 31 total loss: 38.1857, pred loss: 37.9466, reg_loss: 23.9035, val loss: nan


In [15]:
# NOAH + regression on soil contents using log
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 139.6269, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 115.5750, pred loss: 115.4117, reg_loss: 16.3227, val loss: nan
Epoch 6 total loss: 39.4394, pred loss: 39.2534, reg_loss: 18.6035, val loss: nan
Epoch 11 total loss: 36.1339, pred loss: 35.9304, reg_loss: 20.3449, val loss: nan
Epoch 16 total loss: 34.4622, pred loss: 34.2524, reg_loss: 20.9795, val loss: nan
Epoch 21 total loss: 33.8736, pred loss: 33.6661, reg_loss: 20.7497, val loss: nan
Epoch 26 total loss: 33.6492, pred loss: 33.4455, reg_loss: 20.3705, val loss: nan
Epoch 31 total loss: 33.4985, pred loss: 33.2909, reg_loss: 20.7576, val loss: nan


In [12]:
# Regression on soil contents plus intercept using log, random start
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 398.7293, reg_loss: 149.0950, val loss: nan
Epoch 1 total loss: 357.0240, pred loss: 356.3212, reg_loss: 70.2793, val loss: nan
Epoch 6 total loss: 109.7697, pred loss: 109.0889, reg_loss: 68.0762, val loss: nan
Epoch 11 total loss: 55.9770, pred loss: 55.7513, reg_loss: 22.5671, val loss: nan
Epoch 16 total loss: 46.4427, pred loss: 46.3075, reg_loss: 13.5284, val loss: nan
Epoch 21 total loss: 41.7153, pred loss: 41.5609, reg_loss: 15.4350, val loss: nan
Epoch 26 total loss: 39.8752, pred loss: 39.7038, reg_loss: 17.1428, val loss: nan
Epoch 31 total loss: 39.2561, pred loss: 39.0825, reg_loss: 17.3580, val loss: nan


In [27]:
# Regression on soil contents plus intercept using log
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 56.6325, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 49.9408, pred loss: 49.8898, reg_loss: 5.1048, val loss: nan
Epoch 6 total loss: 40.4870, pred loss: 40.3236, reg_loss: 16.3395, val loss: nan
Epoch 11 total loss: 39.0435, pred loss: 38.8532, reg_loss: 19.0305, val loss: nan
Epoch 16 total loss: 38.5324, pred loss: 38.3486, reg_loss: 18.3766, val loss: nan
Epoch 21 total loss: 38.1424, pred loss: 37.9535, reg_loss: 18.8949, val loss: nan
Epoch 26 total loss: 37.8405, pred loss: 37.6562, reg_loss: 18.4250, val loss: nan
Epoch 31 total loss: 37.6220, pred loss: 37.4392, reg_loss: 18.2792, val loss: nan


In [14]:
# Regression on soil contents using log
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 56.2360, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 49.7308, pred loss: 49.6789, reg_loss: 5.1875, val loss: nan
Epoch 6 total loss: 40.0404, pred loss: 39.8642, reg_loss: 17.6196, val loss: nan
Epoch 11 total loss: 39.1196, pred loss: 38.9322, reg_loss: 18.7309, val loss: nan
Epoch 16 total loss: 38.4627, pred loss: 38.2776, reg_loss: 18.5110, val loss: nan
Epoch 21 total loss: 37.9999, pred loss: 37.8117, reg_loss: 18.8236, val loss: nan
Epoch 26 total loss: 37.7233, pred loss: 37.5467, reg_loss: 17.6600, val loss: nan
Epoch 31 total loss: 37.5629, pred loss: 37.3785, reg_loss: 18.4447, val loss: nan


In [15]:
# Regression on soil contents using non-log
_error_fn = _rmse
error_fn_name = 'rmse'

train_loss_out, pred_loss_out, reg_loss_out, val_loss_out, theta = train_and_store(subset_name = subset_name,
                                                                                   obs_name = obs_name,
                                                                                   _error_fn = _error_fn,
                                                                                   error_fn_name = error_fn_name,
                                                                                   learning_rate = 0.001,
                                                                                   n_epochs = 30,
                                                                                   val_frac = 0.0)

Epoch 0 pred loss: 56.2178, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 54.7293, pred loss: 54.6346, reg_loss: 9.4750, val loss: nan
Epoch 6 total loss: 46.0005, pred loss: 45.8825, reg_loss: 11.7977, val loss: nan
Epoch 11 total loss: 45.2821, pred loss: 45.1470, reg_loss: 13.5068, val loss: nan
Epoch 16 total loss: 44.7553, pred loss: 44.6166, reg_loss: 13.8699, val loss: nan
Epoch 21 total loss: 44.3233, pred loss: 44.1600, reg_loss: 16.3280, val loss: nan
Epoch 26 total loss: 43.9569, pred loss: 43.7966, reg_loss: 16.0265, val loss: nan
Epoch 31 total loss: 43.6540, pred loss: 43.4892, reg_loss: 16.4812, val loss: nan


In [None]:
# %%time
# # Parallelize with dask delayed
# delayed = []

# for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
#     # Hyperparameter adjustments
#     if error_fn_name == 'kge':
#         reg_const = 0.001
#     else:
#         reg_const = 0.01

#     if error_fn_name == 'mse':
#         learning_rate = 1e-3
#     else:
#         learning_rate = 1e-2
        
#     for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
#         delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
#                                                      obs_name = obs_name,
#                                                      _error_fn = _error_fn,
#                                                      error_fn_name = error_fn_name,
#                                                      batch_size = batch_size,
#                                                      reg_const = reg_const,
#                                                      learning_rate = learning_rate,
#                                                      n_epochs = 30,
#                                                      val_frac = 0.0))

# out = dask.compute(*delayed)

### MOSAIC

In [18]:
# Info
subset_name = 'centralUS'
obs_name = 'MOSAIC'

# needed for quantile RMSE
N = np.load(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.npy').shape[-1]

In [19]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    # Hyperparameter adjustments
    if error_fn_name == 'kge':
        reg_const = 0.001
    else:
        reg_const = 0.01

    if error_fn_name == 'mse':
        learning_rate = 1e-3
    else:
        learning_rate = 1e-2
        
    for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
        delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                     obs_name = obs_name,
                                                     _error_fn = _error_fn,
                                                     error_fn_name = error_fn_name,
                                                     batch_size = batch_size,
                                                     reg_const = reg_const,
                                                     learning_rate = learning_rate,
                                                     n_epochs = 30,
                                                     val_frac = 0.0))

out = dask.compute(*delayed)

CPU times: user 4min 23s, sys: 20.8 s, total: 4min 43s
Wall time: 56min 43s


### NOAH

In [20]:
# Info
subset_name = 'centralUS'
obs_name = 'NOAH'

# needed for quantile RMSE
N = np.load(f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/{obs_name}_validation.npy').shape[-1]

In [None]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    # Hyperparameter adjustments
    if error_fn_name == 'kge':
        reg_const = 0.001
    else:
        reg_const = 0.01

    if error_fn_name == 'mse':
        learning_rate = 1e-3
    else:
        learning_rate = 1e-2
        
    for batch_size in [2**5, 2**6, 2**7, 2**8, 2**9]:
        delayed.append(dask.delayed(train_and_store)(subset_name = subset_name,
                                                     obs_name = obs_name,
                                                     _error_fn = _error_fn,
                                                     error_fn_name = error_fn_name,
                                                     batch_size = batch_size,
                                                     reg_const = reg_const,
                                                     learning_rate = learning_rate,
                                                     n_epochs = 30,
                                                     val_frac = 0.0))

out = dask.compute(*delayed)