In [1]:
import os

import numpy as np
import jax.numpy as jnp

import dask

from train import train_and_store
from read_inputs import read_hindcast_inputs

from utils.initial_params import initial_params
from utils.param_bounds import params_lower, params_upper
from utils.param_names import param_names
from utils.global_paths import project_data_path

In [2]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="15GiB",
    walltime="04:00:00",
)
cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://146.186.150.12:41789,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Fitting

In [4]:
# Length of timeseries needed for quantile RMSE
N = 2555

##################################
# Define all error functions
##################################

########## RMSE
def _rmse(prediction, ys):
    return jnp.sqrt(jnp.nanmean((prediction - ys) ** 2))

########## ubRMSE
def _ubrmse(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    return jnp.sqrt(jnp.nanmean((prediction_cent - ys_cent) ** 2))

########## MSE
def _mse(prediction, ys):
    return jnp.nanmean((prediction - ys) ** 2)

########## ubMSE
def _ubmse(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    return jnp.nanmean((prediction_cent - ys_cent) ** 2)

########## MAE
def _mae(prediction, ys):
    return jnp.nanmean(jnp.abs(prediction - ys))

########## ubMAE
def _ubmae(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    return jnp.nanmean(jnp.abs(prediction_cent - ys_cent))

########## KGE
def _kge(prediction, ys):
    corr = jnp.nanmean(
        (prediction - jnp.nanmean(prediction)) * (ys - jnp.nanmean(ys))
    ) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
    mean_ratio = jnp.nanmean(prediction) / jnp.nanmean(ys)
    std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
    kge = 1 - jnp.sqrt(
        (corr - 1) ** 2 + (mean_ratio - 1) ** 2 + (std_ratio - 1) ** 2
    )
    return -kge 

############ Nash-Sutcliffe efficiency
def _nse(prediction, ys):
    nse = 1 - jnp.sum((ys - prediction) ** 2) / jnp.sum(
        (ys - jnp.mean(ys)) ** 2
    )
    return -nse

######### outer50RMSE
def _outer50rmse(prediction, ys):
    size = round(N * 0.5)
    q25 = jnp.quantile(ys, 0.25)
    q75 = jnp.quantile(ys, 0.75)
    inds = jnp.where((ys <= q25) | (ys >= q75), size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


######### outer50ubRMSE
def _outer50ubrmse(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    size = round(N * 0.5)
    q25 = jnp.quantile(ys_cent, 0.25)
    q75 = jnp.quantile(ys_cent, 0.75)
    inds = jnp.where((ys_cent <= q25) | (ys_cent >= q75), size=size)
    prediction_q = prediction_cent[inds]
    ys_q = ys_cent[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))

######### outer20RMSE
def _outer20rmse(prediction, ys):
    size = round(N * 0.2)
    q10 = jnp.quantile(ys, 0.1)
    q90 = jnp.quantile(ys, 0.9)
    inds = jnp.where((ys <= q10) | (ys >= q90), size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))

######### outer20ubRMSE
def _outer20ubrmse(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    size = round(N * 0.2)
    q10 = jnp.quantile(ys_cent, 0.1)
    q90 = jnp.quantile(ys_cent, 0.9)
    inds = jnp.where((ys_cent <= q10) | (ys_cent >= q90), size=size)
    prediction_q = prediction_cent[inds]
    ys_q = ys_cent[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


######### Weighted RMSE where upper 10% of points are weighted 9 times more
def _rmse_upper10weighted(prediction, ys):
    q90 = jnp.quantile(ys, 0.9)
    weights = jnp.where(ys >= q90, 9, 1)
    weighted_rmse = jnp.sqrt(
        jnp.average((prediction - ys) ** 2, weights=weights)
    )
    return weighted_rmse

######### Weighted ubRMSE where upper 10% of points are weighted 9 times more
def _ubrmse_upper10weighted(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    q90 = jnp.quantile(ys_cent, 0.9)
    weights = jnp.where(ys_cent >= q90, 9, 1)
    weighted_rmse = jnp.sqrt(
        jnp.average((prediction_cent - ys_cent) ** 2, weights=weights)
    )
    return weighted_rmse
    
######### Weighted RMSE where lower 10% of points are weighted 9 times more
def _rmse_lower10weighted(prediction, ys):
    q10 = jnp.quantile(ys, 0.1)
    weights = jnp.where(ys <= q10, 9, 1)
    weighted_rmse = jnp.sqrt(
        jnp.average((prediction - ys) ** 2, weights=weights)
    )
    return weighted_rmse

######### Weighted ubRMSE where upper 10% of points are weighted 9 times more
def _ubrmse_lower10weighted(prediction, ys):
    prediction_cent = prediction - jnp.mean(prediction)
    ys_cent = ys - jnp.mean(ys)
    q10 = jnp.quantile(ys_cent, 0.1)
    weights = jnp.where(ys_cent <= q10, 9, 1)
    weighted_rmse = jnp.sqrt(
        jnp.average((prediction_cent - ys_cent) ** 2, weights=weights)
    )
    return weighted_rmse


_error_fns_all = [
    _rmse,
    _ubrmse,
    _mse,
    _ubmse,
    _mae,
    _ubmae,
    _kge,
    _nse,
    _outer50rmse,
    _outer50ubrmse,
    _outer20rmse,
    _outer20ubrmse,
    _rmse_upper10weighted,
    _ubrmse_upper10weighted,
    _rmse_lower10weighted,
    _ubrmse_lower10weighted,
]
error_fn_names_all = [
    "rmse",
    "ubrmse",
    "mse",
    "ubmse",
    "mae",
    "ubmae",
    "kge",
    "nse",
    "outer50rmse",
    "outer50ubrmse",
    "outer20rmse",
    "outer20ubrmse",
    "rmse-upper10",
    "ubrmse-upper10",
    "rmse-lower10",
    "ubrmse-lower10",
]

_error_fns_todo = [
    _ubmse,
    _ubmae,
    _outer50ubrmse,
    _outer20ubrmse,
    _ubrmse_upper10weighted,
    _ubrmse_lower10weighted,
]
error_fn_names_todo = [
    "ubmse",
    "ubmae",
    "outer50ubrmse",
    "outer20ubrmse",
    "ubrmse-upper10",
    "ubrmse-lower10",
]


######################
# OLD 
######################

# ########## q0-25 RMSE
# qmax = 0.25
# size = round(N * qmax)

# def _q25rmse(prediction, ys):
#     thresh = jnp.quantile(ys, qmax)
#     inds = jnp.where(ys <= thresh, size=size)
#     prediction_q = prediction[inds]
#     ys_q = ys[inds]
#     return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


# ########## q75-100 RMSE
# qmin = 0.75
# size = round(N * qmin)

# def _q75rmse(prediction, ys):
#     thresh = jnp.quantile(ys, qmin)
#     inds = jnp.where(ys >= thresh, size=size)
#     prediction_q = prediction[inds]
#     ys_q = ys[inds]
#     return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))

# ########## Anomaly correlation
# def _anomaly_corr(prediction, ys):
#     prediction_cent = prediction - jnp.mean(prediction)
#     ys_cent = ys - jnp.mean(ys)
#     anom_corr = jnp.nanmean(
#         (prediction_cent - jnp.nanmean(prediction_cent))
#         * (ys_cent - jnp.nanmean(ys_cent))
#     ) / (jnp.nanstd(prediction_cent) * jnp.nanstd(ys_cent))
#     return -anom_corr


# ############ Taylor skill score
# def _taylor_skill(prediction, ys):
#     corr = jnp.nanmean(
#         (prediction - jnp.nanmean(prediction)) * (ys - jnp.nanmean(ys))
#     ) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
#     std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
#     R0 = 1.0
#     taylor = 4 * (1 + corr) / (((std_ratio + 1 / std_ratio) ** 2) * (1 + R0))
#     return -taylor

In [5]:
def make_or_read_split(subset_name, obs_name, val_frac=0.2):
    # Read/perform train-test split
    train_test_file = f"{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/split_{str(val_frac)}.npz"

    if os.path.exists(train_test_file):
        npz = np.load(train_test_file)
        val_inds_all = [npz[key] for key in npz.keys()]
    else:
        ys, _, _, _ = read_hindcast_inputs(subset_name, obs_name, True)
        Nspace = ys.shape[0]

        # Get validation indices
        val_inds_all = np.array_split(np.random.permutation(Nspace), 1 / val_frac)

        out_dict = dict([(str(i), val_inds_all[i]) for i in range(len(val_inds_all))])
        np.savez(train_test_file, **out_dict)

    return val_inds_all

## eCONUS

In [6]:
subset_name = "eCONUS"

In [7]:
%%time
# Parallelize with dask delayed
delayed = []

# Random starting paramters
n_random_starts = 4

# Loop through obs
for obs_name in ["SMAP", "VIC", "NOAH", "MOSAIC"]:
    # Get val inds
    # val_inds_all = make_or_read_split(subset_name, obs_name)
    val_inds = []
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns_todo, error_fn_names_todo):
        # # Loop through validation splits
        # for val_inds in val_inds_all:
        # Generate starting params
        for idr in range(n_random_starts + 1):
            if idr == 0:
                initial_theta = initial_params
            else:
                initial_theta = np.random.uniform(params_lower, params_upper)

            # Hyperparameter adjustments to improve stability
            if error_fn_name in ["kge", "nse"]:
                reg_const = 0.001
            elif error_fn_name == "mse":
                reg_const = 0.1
            else:
                reg_const = 0.01

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name=subset_name,
                    obs_name=obs_name,
                    _error_fn=_error_fn,
                    error_fn_name=error_fn_name,
                    initial_theta=initial_theta,
                    params_lower=params_lower,
                    params_upper=params_upper,
                    param_names=param_names,
                    val_inds=val_inds,
                    reg_const=reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 3min 48s, sys: 21.7 s, total: 4min 10s
Wall time: 2h 39min 16s


In [7]:
%%time
# Parallelize with dask delayed
delayed = []

# Random starting paramters
n_random_starts = 4

# Loop through obs
for obs_name in ["SMAP", "VIC", "NOAH", "MOSAIC"]:
    # Get val inds
    # val_inds_all = make_or_read_split(subset_name, obs_name)
    val_inds = []
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # # Loop through validation splits
        # for val_inds in val_inds_all:
        # Generate starting params
        for idr in range(n_random_starts + 1):
            if idr == 0:
                initial_theta = initial_params
            else:
                initial_theta = np.random.uniform(params_lower, params_upper)

            # Hyperparameter adjustments to improve stability
            if error_fn_name in ["kge", "nse", "ac", "taylor"]:
                reg_const = 0.001
            elif error_fn_name == "mse":
                reg_const = 0.1
            else:
                reg_const = 0.01

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name=subset_name,
                    obs_name=obs_name,
                    _error_fn=_error_fn,
                    error_fn_name=error_fn_name,
                    initial_theta=initial_theta,
                    params_lower=params_lower,
                    params_upper=params_upper,
                    param_names=param_names,
                    val_inds=val_inds,
                    reg_const=reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 13min 13s, sys: 53.1 s, total: 14min 6s
Wall time: 3h 40min 48s


### Old

In [8]:
obs_name = "VIC"

# Read/perform train-test split
val_frac = 0.2
train_test_file = f"{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/split_{str(val_frac)}.npz"

if os.path.exists(train_test_file):
    npz = np.load(train_test_file)
    val_inds_all = [npz[key] for key in npz.keys()]
else:
    ys, _, _, _ = read_hindcast_inputs(subset_name, obs_name, True)
    Nspace = ys.shape[0]
    Ntime = ys.shape[1]

    # Get validation indices
    val_inds_all = np.array_split(np.random.permutation(Nspace), 1 / val_frac)

    out_dict = dict([(str(i), val_inds_all[i]) for i in range(len(val_inds_all))])
    np.savez(train_test_file, **out_dict)

In [9]:
%%time
# Parallelize with dask delayed
delayed = []

# Random starting paramters
n_random_starts = 5

for _ in range(n_random_starts):
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # Loop through validation splits
        for val_inds in val_inds_all:
            # Generate starting params
            initial_params = np.random.uniform(params_lower, params_upper)
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01
            if error_fn_name == "mse":
                learning_rate = 1e-3
            else:
                learning_rate = 1e-2

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name=subset_name,
                    obs_name=obs_name,
                    _error_fn=_error_fn,
                    error_fn_name=error_fn_name,
                    initial_theta=initial_params,
                    params_lower=params_lower,
                    params_upper=params_upper,
                    param_names=param_names,
                    val_inds=val_inds,
                    learning_rate=learning_rate,
                    reg_const=reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 2min 15s, sys: 13.9 s, total: 2min 29s
Wall time: 1h 44min 58s
