In [1]:
import numpy as np
import jax.numpy as jnp
import dask
from src.train import train_and_store
from src.read_inputs import read_inputs
from utils.initial_params import initial_params
from utils.param_bounds import params_lower, params_upper
from utils.param_names import param_names

In [2]:
############
### Dask ###
############
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="20GiB",
    walltime="03:00:00"
)
cluster.scale(jobs=30)  # ask for jobs

from dask.distributed import Client
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.6.0.161:43247,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Fitting

In [3]:
# Length of timeseries needed for quantile RMSE
N = 2555

##################################
# Define all error functions
##################################

########## RMSE
_rmse = lambda prediction, ys: jnp.sqrt(jnp.nanmean((prediction - ys) ** 2))

########## MSE
_mse = lambda prediction, ys: jnp.nanmean((prediction - ys) ** 2)


########## KGE
def _kge(prediction, ys):
    corr = jnp.nanmean(
        (prediction - jnp.nanmean(prediction)) * (ys - jnp.nanmean(ys))
    ) / (jnp.nanstd(prediction) * jnp.nanstd(ys))
    mean_ratio = jnp.nanmean(prediction) / jnp.nanmean(ys)
    std_ratio = jnp.nanstd(prediction) / jnp.nanstd(ys)
    kge = 1 - jnp.sqrt((corr - 1) ** 2 + (mean_ratio - 1) ** 2 + (std_ratio - 1) ** 2)
    return -kge


########## q0-25 RMSE
qmax = 0.25
size = round(N * qmax)

def _q25rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmax)
    inds = jnp.where(ys <= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


########## q75-100 RMSE
qmin = 0.75
size = round(N * qmin)

def _q75rmse(prediction, ys):
    thresh = jnp.quantile(ys, qmin)
    inds = jnp.where(ys >= thresh, size=size)
    prediction_q = prediction[inds]
    ys_q = ys[inds]
    return jnp.sqrt(jnp.nanmean((prediction_q - ys_q) ** 2))


_error_fns = [_rmse, _mse, _kge, _q25rmse, _q75rmse]
error_fn_names = ["rmse", "mse", "kge", "q0-25rmse", "q75-100rmse"]

## CONUS

In [4]:
subset_name = "CONUS"

### SMAP

In [5]:
obs_name = "SMAP"

# Get dimensions
ys, _, _, _ = read_inputs(subset_name, obs_name, True)
Nspace = ys.shape[0]
Ntime = ys.shape[1]

# Get validation indices
val_frac = 0.2
val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [6]:
%%time
# Parallelize with dask delayed
delayed = []
    
# Random starting paramters
n_random_starts = 5

for _ in range(n_random_starts):
    initial_params = np.random.uniform(params_lower, params_upper)
    # Loop through loss functions
    for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
        # Loop through validation splits
        for val_inds in val_inds_all:
            
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01

            if error_fn_name == "mse":
                learning_rate = 1e-4
            else:
                learning_rate = 1e-3

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names,
                    val_inds = val_inds,
                    reg_const = reg_const,
                    learning_rate = learning_rate
                )
            )

# Compute
_ = dask.compute(*delayed)

Task exception was never retrieved
future: <Task finished name='Task-3555' coro=<Client._gather.<locals>.wait() done, defined at /storage/home/dcl5300/miniforge3/envs/climate-stack-mamba-2023-12/lib/python3.11/site-packages/distributed/client.py:2208> exception=AllExit()>
Traceback (most recent call last):
  File "/storage/home/dcl5300/miniforge3/envs/climate-stack-mamba-2023-12/lib/python3.11/site-packages/distributed/client.py", line 2217, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-3562' coro=<Client._gather.<locals>.wait() done, defined at /storage/home/dcl5300/miniforge3/envs/climate-stack-mamba-2023-12/lib/python3.11/site-packages/distributed/client.py:2208> exception=AllExit()>
Traceback (most recent call last):
  File "/storage/home/dcl5300/miniforge3/envs/climate-stack-mamba-2023-12/lib/python3.11/site-packages/distributed/client.py", line 2217, in wait
    raise AllExit()
distributed.client.AllExi

KeyboardInterrupt: 

In [11]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 71.3068 pred loss: 71.3068, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 37.0238, pred loss: 36.9882, reg_loss: 35.5890, val loss: nan
Epoch 2 total loss: 30.4496, pred loss: 30.4099, reg_loss: 39.7576, val loss: nan
Epoch 3 total loss: 28.5626, pred loss: 28.5236, reg_loss: 39.0326, val loss: nan
Epoch 4 total loss: 26.8609, pred loss: 26.8195, reg_loss: 41.3449, val loss: nan
Epoch 5 total loss: 25.7375, pred loss: 25.6930, reg_loss: 44.5299, val loss: nan
Epoch 6 total loss: 25.2514, pred loss: 25.1778, reg_loss: 73.6590, val loss: nan
Epoch 7 total loss: 25.1462, pred loss: 25.0709, reg_loss: 75.3208, val loss: nan
Epoch 8 total loss: 25.1283, pred loss: 25.0542, reg_loss: 74.0245, val loss: nan
Epoch 9 total loss: 25.1189, pred loss: 25.0447, reg_loss: 74.1748, val loss: nan
Epoch 10 total loss: 25.1099, pred loss: 25.0406, reg_loss: 69.3078, val loss: nan
CPU times: user 18min 48s, sys: 6.25 s, total: 18min 54s
Wall time: 18min 59s


### VIC

In [10]:
subset_name = "CONUS"

In [11]:
obs_name = "VIC"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [12]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 43.2307 pred loss: 43.2307, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 25.3468, pred loss: 25.3160, reg_loss: 30.8212, val loss: nan
Epoch 2 total loss: 22.6578, pred loss: 22.6208, reg_loss: 37.0498, val loss: nan
Epoch 3 total loss: 22.4348, pred loss: 22.3977, reg_loss: 37.0558, val loss: nan
Epoch 4 total loss: 22.3127, pred loss: 22.2765, reg_loss: 36.2086, val loss: nan
Epoch 5 total loss: 22.2429, pred loss: 22.2045, reg_loss: 38.4276, val loss: nan
Epoch 6 total loss: 22.1966, pred loss: 22.1554, reg_loss: 41.2137, val loss: nan
Epoch 7 total loss: 22.1644, pred loss: 22.1240, reg_loss: 40.4062, val loss: nan
Epoch 8 total loss: 22.1411, pred loss: 22.1000, reg_loss: 41.1042, val loss: nan
Epoch 9 total loss: 22.1241, pred loss: 22.0840, reg_loss: 40.0695, val loss: nan
Epoch 10 total loss: 22.1085, pred loss: 22.0646, reg_loss: 43.9690, val loss: nan
CPU times: user 20min 50s, sys: 9.7 s, total: 21min
Wall time: 21min 12s


### NOAH

In [13]:
obs_name = "NOAH"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [14]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 57.7624 pred loss: 57.7624, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 35.7701, pred loss: 35.7443, reg_loss: 25.7774, val loss: nan
Epoch 2 total loss: 27.6216, pred loss: 27.5778, reg_loss: 43.7611, val loss: nan
Epoch 3 total loss: 26.1965, pred loss: 26.1273, reg_loss: 69.2153, val loss: nan
Epoch 4 total loss: 25.5663, pred loss: 25.4814, reg_loss: 84.8873, val loss: nan
Epoch 5 total loss: 25.3603, pred loss: 25.2688, reg_loss: 91.5227, val loss: nan
Epoch 6 total loss: 25.2955, pred loss: 25.1995, reg_loss: 95.9753, val loss: nan
Epoch 7 total loss: 25.2676, pred loss: 25.1634, reg_loss: 104.1650, val loss: nan
Epoch 8 total loss: 25.2554, pred loss: 25.1532, reg_loss: 102.1412, val loss: nan
Epoch 9 total loss: 25.2472, pred loss: 25.1439, reg_loss: 103.2125, val loss: nan
Epoch 10 total loss: 25.2393, pred loss: 25.1354, reg_loss: 103.8676, val loss: nan
CPU times: user 22min 26s, sys: 10.1 s, total: 22min 36s
Wall time: 22min 46s


### MOSAIC

In [15]:
obs_name = "MOSAIC"

# # Get dimensions
# ys, _, _, _ = read_inputs(subset_name, obs_name, True)
# Nspace = ys.shape[0]
# Ntime = ys.shape[1]

# # Get validation indices
# val_frac = 0.2
# val_inds_all = np.array_split(np.random.permutation(Nspace), 1/val_frac)

In [16]:
%%time

_error_fn = _rmse
error_fn_name = "rmse"

(
    train_loss_out,
    pred_loss_out,
    reg_loss_out,
    val_loss_out,
    theta,
) = train_and_store(
    subset_name=subset_name,
    obs_name=obs_name,
    _error_fn=_error_fn,
    error_fn_name=error_fn_name,
    initial_theta=initial_params,
    params_lower=params_lower,
    params_upper=params_upper,
    param_names=param_names,
    n_epochs_max = 10
)

Epoch 0 train loss: 43.0967 pred loss: 43.0967, reg_loss: 0.0000, val loss: nan
Epoch 1 total loss: 38.0224, pred loss: 37.9420, reg_loss: 80.3598, val loss: nan
Epoch 2 total loss: 36.7350, pred loss: 36.6415, reg_loss: 93.4429, val loss: nan
Epoch 3 total loss: 36.2260, pred loss: 36.1430, reg_loss: 82.9285, val loss: nan
Epoch 4 total loss: 35.2730, pred loss: 35.2043, reg_loss: 68.6413, val loss: nan
Epoch 5 total loss: 33.6450, pred loss: 33.5916, reg_loss: 53.3503, val loss: nan
Epoch 6 total loss: 31.9941, pred loss: 31.9335, reg_loss: 60.5987, val loss: nan
Epoch 7 total loss: 31.1447, pred loss: 31.0626, reg_loss: 82.1357, val loss: nan
Epoch 8 total loss: 30.9197, pred loss: 30.8284, reg_loss: 91.3139, val loss: nan
Epoch 9 total loss: 30.8252, pred loss: 30.7315, reg_loss: 93.6736, val loss: nan
Epoch 10 total loss: 30.7783, pred loss: 30.6808, reg_loss: 97.4418, val loss: nan
CPU times: user 22min 44s, sys: 9.99 s, total: 22min 54s
Wall time: 23min 4s


In [24]:
%%time
# Parallelize with dask delayed
delayed = []

for _error_fn, error_fn_name in zip(_error_fns, error_fn_names):
    for _ in range(10):
        for val_inds in val_inds_all:
            # Hyperparameter adjustments to improve stability
            if error_fn_name == "kge":
                reg_const = 0.001
            else:
                reg_const = 0.01

            if error_fn_name == "mse":
                learning_rate = 1e-4
            else:
                learning_rate = 1e-3

            # Append delayed
            delayed.append(
                dask.delayed(train_and_store)(
                    subset_name = subset_name,
                    obs_name = obs_name,
                    _error_fn = _error_fn,
                    error_fn_name = error_fn_name,
                    initial_theta = initial_params,
                    params_lower = params_lower,
                    params_upper = params_upper,
                    param_names = param_names
                    val_inds = val_inds,
                    reg_const = reg_const,
                    learning_rate = learning_rate
                )
            )

# Compute
_ = dask.compute(*delayed)

CPU times: user 6min 41s, sys: 27.7 s, total: 7min 8s
Wall time: 51min 12s
