In [1]:
import os

import numpy as np
import jax.numpy as jnp

import dask

from src.train import train_and_store
from src.read_inputs import read_hindcast_inputs

from utils.initial_params import initial_params
from utils.param_bounds import params_lower, params_upper
from utils.param_names import param_names
from utils.global_paths import project_data_path

In [22]:
obs = np.load(
        f"{project_data_path}/WBM/calibration/eCONUS/SMAP/SMAP_validation.npy"
    )

In [23]:
obs.shape

(317, 194, 2555)

In [41]:
ys = obs.reshape(317 * 194, 2555)
nan_inds_obs = np.isnan(ys).any(axis=1)

In [42]:
nan_inds_obs

array([ True,  True,  True, ...,  True,  True,  True])

In [46]:
nan_inds_obs_new = np.copy(nan_inds_obs)
nan_inds_obs_new[0] = False

In [47]:
nan_inds_obs

array([ True,  True,  True, ...,  True,  True,  True])

In [48]:
nan_inds_obs_new

array([False,  True,  True, ...,  True,  True,  True])

In [49]:
nan_inds_obs_new + nan_inds_obs

array([ True,  True,  True, ...,  True,  True,  True])

In [2]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="15GiB",
    walltime="01:00:00"
)
cluster.scale(jobs=25)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.158:46375,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Fitting

In [6]:
# Length of timeseries needed for quantile RMSE
N = 2555

##################################
# Define all error functions
##################################

########## RMSE
_rmse = lambda prediction, ys: jnp.sqrt(jnp.nanmean((prediction - ys) ** 2))

########## MSE
_mse = lambda prediction, ys: jnp.nanmean((prediction - ys) ** 2)

########## KGE
def _kge(prediction, ys):
    corr = jnp.nanmean(
        (prediction - jnp.nanmean(prediction)) * (ys - jnp.nanmean(ys))
    ) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
    mean_ratio = jnp.nanmean(prediction) / jnp.nanmean(ys)
    std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
    kge = 1 - jnp.sqrt((corr - 1) ** 2 + (mean_ratio - 1) ** 2 + (std_ratio - 1) ** 2)
    return -kge


######### hollowRMSE
size = round(N * 0.5)

def _hollowrmse(prediction, ys):
    q25 = jnp.quantile(ys, 0.25)
    q75 = jnp.quantile(ys, 0.75)
    inds = jnp.where((ys <= q25) | (ys >= q75), size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))

    
# _error_fns = [_rmse, _mse, _kge, _hollowrmse]
# error_fn_names = ["rmse", "mse", "kge", "hollow-rmse"]

_error_fns = [_mse]
error_fn_names = ["mse"]

# ########## q0-25 RMSE
# qmax = 0.25
# size = round(N * qmax)

# def _q25rmse(prediction, ys):
#     thresh = jnp.quantile(ys, qmax)
#     inds = jnp.where(ys <= thresh, size=size)
#     prediction_q = prediction[inds]
#     ys_q = ys[inds]
#     return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


# ########## q75-100 RMSE
# qmin = 0.75
# size = round(N * qmin)

# def _q75rmse(prediction, ys):
#     thresh = jnp.quantile(ys, qmin)
#     inds = jnp.where(ys >= thresh, size=size)
#     prediction_q = prediction[inds]
#     ys_q = ys[inds]
#     return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))

## eCONUS

In [4]:
subset_name = "eCONUS"

### SMAP

In [7]:
obs_name = "SMAP"

# Read/perform train-test split
val_frac = 0.2
train_test_file = f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/split_{str(val_frac)}.npz'   

if os.path.exists(train_test_file):
    npz = np.load(train_test_file)
    val_inds_all = [npz[key] for key in npz.keys()]
else:
    ys, _, _, _ = read_hindcast_inputs(subset_name, obs_name, True)
    Nspace = ys.shape[0]
    Ntime = ys.shape[1]

    # Get validation indices
    val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

    out_dict = dict([(str(i), val_inds_all[i]) for i in range(len(val_inds_all))])
    np.savez(train_test_file, **out_dict)

In [8]:
%%time
# Parallelize with dask delayed
delayed = []
    
# Random starting paramters
n_random_starts = 5

for _ in range(n_random_starts):
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # Loop through validation splits
        for val_inds in val_inds_all:
            # Generate starting params
            initial_params = np.random.uniform(params_lower, params_upper)
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            elif error_fn_name == 'mse':
                reg_const = 0.1
            else:
                reg_const = 0.01

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names,
                    val_inds = val_inds,
                    reg_const = reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 2min 56s, sys: 13.5 s, total: 3min 10s
Wall time: 36min 32s


### VIC

In [8]:
obs_name = "VIC"

# Read/perform train-test split
val_frac = 0.2
train_test_file = f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/split_{str(val_frac)}.npz'   

if os.path.exists(train_test_file):
    npz = np.load(train_test_file)
    val_inds_all = [npz[key] for key in npz.keys()]
else:
    ys, _, _, _ = read_hindcast_inputs(subset_name, obs_name, True)
    Nspace = ys.shape[0]
    Ntime = ys.shape[1]

    # Get validation indices
    val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

    out_dict = dict([(str(i), val_inds_all[i]) for i in range(len(val_inds_all))])
    np.savez(train_test_file, **out_dict)

In [9]:
%%time
# Parallelize with dask delayed
delayed = []
    
# Random starting paramters
n_random_starts = 5

for _ in range(n_random_starts):
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # Loop through validation splits
        for val_inds in val_inds_all:
            # Generate starting params
            initial_params = np.random.uniform(params_lower, params_upper)
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01
            if error_fn_name == 'mse':
                learning_rate = 1e-3
            else:
                learning_rate = 1e-2

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names,
                    val_inds = val_inds,
                    learning_rate = learning_rate,
                    reg_const = reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 2min 15s, sys: 13.9 s, total: 2min 29s
Wall time: 1h 44min 58s


### NOAH

In [10]:
obs_name = "NOAH"

# Read/perform train-test split
val_frac = 0.2
train_test_file = f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/split_{str(val_frac)}.npz'   

if os.path.exists(train_test_file):
    npz = np.load(train_test_file)
    val_inds_all = [npz[key] for key in npz.keys()]
else:
    ys, _, _, _ = read_hindcast_inputs(subset_name, obs_name, True)
    Nspace = ys.shape[0]
    Ntime = ys.shape[1]

    # Get validation indices
    val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

    out_dict = dict([(str(i), val_inds_all[i]) for i in range(len(val_inds_all))])
    np.savez(train_test_file, **out_dict)

In [11]:
%%time
# Parallelize with dask delayed
delayed = []
    
# Random starting paramters
n_random_starts = 5

for _ in range(n_random_starts):
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # Loop through validation splits
        for val_inds in val_inds_all:
            # Generate starting params
            initial_params = np.random.uniform(params_lower, params_upper)
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01
            if error_fn_name == 'mse':
                learning_rate = 1e-3
            else:
                learning_rate = 1e-2

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names,
                    val_inds = val_inds,
                    learning_rate = learning_rate,
                    reg_const = reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 2min 9s, sys: 13.2 s, total: 2min 22s
Wall time: 1h 41min 34s


### MOSAIC

In [12]:
obs_name = "MOSAIC"

# Read/perform train-test split
val_frac = 0.2
train_test_file = f'{project_data_path}/WBM/calibration/{subset_name}/{obs_name}/training_res/split_{str(val_frac)}.npz'   

if os.path.exists(train_test_file):
    npz = np.load(train_test_file)
    val_inds_all = [npz[key] for key in npz.keys()]
else:
    ys, _, _, _ = read_hindcast_inputs(subset_name, obs_name, True)
    Nspace = ys.shape[0]
    Ntime = ys.shape[1]

    # Get validation indices
    val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

    out_dict = dict([(str(i), val_inds_all[i]) for i in range(len(val_inds_all))])
    np.savez(train_test_file, **out_dict)

In [13]:
%%time
# Parallelize with dask delayed
delayed = []
    
# Random starting paramters
n_random_starts = 5

for _ in range(n_random_starts):
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # Loop through validation splits
        for val_inds in val_inds_all:
            # Generate starting params
            initial_params = np.random.uniform(params_lower, params_upper)
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01
            if error_fn_name == 'mse':
                learning_rate = 1e-3
            else:
                learning_rate = 1e-2

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names,
                    val_inds = val_inds,
                    learning_rate = learning_rate,
                    reg_const = reg_const,
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 1min 57s, sys: 13.2 s, total: 2min 10s
Wall time: 1h 40min 38s
