In [1]:
# from jax import config
# config.update("jax_enable_x64", True)

In [1]:
# from jax import config
# config.update("jax_debug_nans", False)

In [1]:
import numpy as np
import jax.numpy as jnp
import dask
from src.train import train_and_store
from src.read_inputs import read_inputs
from utils.initial_params import initial_params
from utils.param_bounds import params_lower, params_upper
from utils.param_names import param_names

In [2]:
# ############
# ### Dask ###
# ############
# from dask_jobqueue import SLURMCluster

# cluster = SLURMCluster(
#     # account="pches",
#     account="open",
#     cores=1,
#     memory="5GiB",
#     walltime="01:00:00"
# )
# cluster.scale(jobs=30)  # ask for jobs

# from dask.distributed import Client
# client = Client(cluster)
# client

# Fitting

In [2]:
# Length of timeseries needed for quantile RMSE
N = 2555

##################################
# Define all error functions
##################################

########## RMSE
_rmse = lambda prediction, ys: jnp.sqrt(jnp.nanmean((prediction - ys) ** 2))

########## MSE
_mse = lambda prediction, ys: jnp.nanmean((prediction - ys) ** 2)


########## KGE
def _kge(prediction, ys):
    corr = jnp.nanmean(
        (prediction - jnp.nanmean(prediction)) * (ys - jnp.nanmean(ys))
    ) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
    mean_ratio = jnp.nanmean(prediction) / jnp.nanmean(ys)
    std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
    kge = 1 - jnp.sqrt((corr - 1) ** 2 + (mean_ratio - 1) ** 2 + (std_ratio - 1) ** 2)
    return -kge


########## q0-25 RMSE
qmax = 0.25
size = round(N * qmax)

def _q25rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmax)
    inds = jnp.where(ys <= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


########## q75-100 RMSE
qmin = 0.75
size = round(N * qmin)

def _q75rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmin)
    inds = jnp.where(ys >= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


_error_fns = [_rmse, _mse, _kge, _q25rmse, _q75rmse]
error_fn_names = ["rmse", "mse", "kge", "q0-25rmse", "q75-100rmse"]

## CONUS

In [3]:
subset_name = "CONUS"

### SMAP

In [4]:
obs_name = "SMAP"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [None]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

In [11]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 71.3068 pred loss: 71.3068, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 37.0238, pred loss: 36.9882, reg_loss: 35.5890, val loss: nan
Epoch 2 total loss: 30.4496, pred loss: 30.4099, reg_loss: 39.7576, val loss: nan
Epoch 3 total loss: 28.5626, pred loss: 28.5236, reg_loss: 39.0326, val loss: nan
Epoch 4 total loss: 26.8609, pred loss: 26.8195, reg_loss: 41.3449, val loss: nan
Epoch 5 total loss: 25.7375, pred loss: 25.6930, reg_loss: 44.5299, val loss: nan
Epoch 6 total loss: 25.2514, pred loss: 25.1778, reg_loss: 73.6590, val loss: nan
Epoch 7 total loss: 25.1462, pred loss: 25.0709, reg_loss: 75.3208, val loss: nan
Epoch 8 total loss: 25.1283, pred loss: 25.0542, reg_loss: 74.0245, val loss: nan
Epoch 9 total loss: 25.1189, pred loss: 25.0447, reg_loss: 74.1748, val loss: nan
Epoch 10 total loss: 25.1099, pred loss: 25.0406, reg_loss: 69.3078, val loss: nan
CPU times: user 18min 48s, sys: 6.25 s, total: 18min 54s
Wall time: 18min 59s


### VIC

In [4]:
subset_name = "CONUS"

In [5]:
obs_name = "VIC"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [6]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 43.2285 pred loss: 43.2285, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 31.2653, pred loss: 31.2455, reg_loss: 19.7406, val loss: nan
Epoch 2 total loss: 28.5166, pred loss: 28.4924, reg_loss: 24.2528, val loss: nan
Epoch 3 total loss: 27.4719, pred loss: 27.4466, reg_loss: 25.2189, val loss: nan
Epoch 4 total loss: 26.7883, pred loss: 26.7634, reg_loss: 24.9428, val loss: nan
Epoch 5 total loss: 26.4324, pred loss: 26.4068, reg_loss: 25.5514, val loss: nan
Epoch 6 total loss: 26.3011, pred loss: 26.2723, reg_loss: 28.7572, val loss: nan
Epoch 7 total loss: 26.2459, pred loss: 26.2111, reg_loss: 34.8074, val loss: nan
Epoch 8 total loss: 26.2234, pred loss: 26.1878, reg_loss: 35.5653, val loss: nan
Epoch 9 total loss: 26.1747, pred loss: 26.1417, reg_loss: 33.0146, val loss: nan
Epoch 10 total loss: 26.1460, pred loss: 26.1147, reg_loss: 31.2960, val loss: nan
CPU times: user 20min 11s, sys: 6.57 s, total: 20min 17s
Wall time: 20min 27s


### NOAH

In [12]:
obs_name = "NOAH"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [13]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 57.7611 pred loss: 57.7611, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 42.8357, pred loss: 42.8026, reg_loss: 33.0917, val loss: nan
Epoch 2 total loss: 38.4613, pred loss: 38.4211, reg_loss: 40.2449, val loss: nan
Epoch 3 total loss: 37.5037, pred loss: 37.4652, reg_loss: 38.4609, val loss: nan
Epoch 4 total loss: 37.0446, pred loss: 37.0048, reg_loss: 39.8278, val loss: nan
Epoch 5 total loss: 36.8838, pred loss: 36.8361, reg_loss: 47.6472, val loss: nan
Epoch 6 total loss: 36.8323, pred loss: 36.7863, reg_loss: 46.0656, val loss: nan
Epoch 7 total loss: 36.8175, pred loss: 36.7704, reg_loss: 47.1084, val loss: nan
Epoch 8 total loss: 36.8152, pred loss: 36.7659, reg_loss: 49.3062, val loss: nan
Epoch 9 total loss: 36.8129, pred loss: 36.7595, reg_loss: 53.3739, val loss: nan
Epoch 10 total loss: 36.8129, pred loss: 36.7631, reg_loss: 49.8330, val loss: nan
CPU times: user 19min 29s, sys: 7.89 s, total: 19min 37s
Wall time: 19min 53s


### MOSAIC

In [14]:
obs_name = "MOSAIC"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [15]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 43.0931 pred loss: 43.0931, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 39.6757, pred loss: 39.6027, reg_loss: 73.0110, val loss: nan
Epoch 2 total loss: 38.8521, pred loss: 38.7778, reg_loss: 74.3228, val loss: nan
Epoch 3 total loss: 38.0258, pred loss: 37.9529, reg_loss: 72.9220, val loss: nan
Epoch 4 total loss: 36.9927, pred loss: 36.9277, reg_loss: 65.0387, val loss: nan
Epoch 5 total loss: 35.8053, pred loss: 35.7488, reg_loss: 56.5669, val loss: nan
Epoch 6 total loss: 34.8588, pred loss: 34.7524, reg_loss: 106.4587, val loss: nan
Epoch 7 total loss: 34.4979, pred loss: 34.3991, reg_loss: 98.8139, val loss: nan
Epoch 8 total loss: 34.4250, pred loss: 34.3239, reg_loss: 101.1077, val loss: nan
Epoch 9 total loss: 34.3971, pred loss: 34.3001, reg_loss: 96.9636, val loss: nan
Epoch 10 total loss: 34.3896, pred loss: 34.2926, reg_loss: 97.0100, val loss: nan
CPU times: user 19min 37s, sys: 11.1 s, total: 19min 48s
Wall time: 24min 57s


In [24]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    for _ in range(10):
        for val_inds in val_inds_all:
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01

            if error_fn_name == "mse":
                learning_rate = 1e-4
            else:
                learning_rate = 1e-3

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names
                    val_inds = val_inds,
                    reg_const = reg_const,
                    learning_rate = learning_rate
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 6min 41s, sys: 27.7 s, total: 7min 8s
Wall time: 51min 12s
