### Adapting ensemble code of Kamlesh 2023
*Requires cuda

---

In [3]:
from config.read_configurations import config_hbv as hbvArgs
from config.read_configurations import config_prms as prmsArgs
from config.read_configurations import config_sacsma as sacsmaArgs
from config.read_configurations import config_sacsma_snow as sacsmaSnowArgs
from config.read_configurations import config_hbv_hydrodl as hbvhyArgs_d


import torch
import os
import platform
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy.stats
# from post import plot

from core.utils.randomseed_config import randomseed_config
from core.utils.master import create_output_dirs
from MODELS.loss_functions.get_loss_function import get_lossFun
from MODELS.test_dp_HBV_dynamic import test_dp_hbv
from core.data_processing.data_loading import loadData
from core.data_processing.normalization import transNorm
from core.utils.randomseed_config import randomseed_config
from core.data_processing.model import (
    take_sample_test,
    converting_flow_from_ft3_per_sec_to_mm_per_day
)

import warnings
warnings.filterwarnings("ignore")



# Set path to `hydro_multimodel_results` directory.
if platform.system() == 'Darwin':
    # For mac os
    out_dir = '/Users/leoglonz/Desktop/water/data/model_runs/hydro_multimodel_results'
    # Some operations are not yet working with MPS, so we might need to set some environment variables to use CPU fall instead
    # %env PYTORCH_ENABLE_MPS_FALLBACK=1

elif platform.system() == 'Windows':
    # For windows
    out_dir = 'D:\\data\\model_runs\\hydro_multimodel_results\\'

elif platform.system() == 'Linux':
    # For Colab
    out_dir = '/content/drive/MyDrive/Colab/data/model_runs/hydro_multimodel_results'

else:
    raise ValueError('Unsupported operating system.')


##-----## Multi-model Parameters ##-----##
##--------------------------------------##
# Setting dictionaries to separately manage each diff model's attributes.
models = {'dPLHBV_dyn': None, 'SACSMA_snow':None, 'marrmot_PRMS':None}  # 'HBV':None, 'hbvhy': None, 'SACSMA_snow':None, 'SACSMA':None,
args_list = {'dPLHBV_dyn': hbvhyArgs_d, 'SACSMA_snow':sacsmaSnowArgs, 'marrmot_PRMS':prmsArgs}   # 'hbvhy': hbvhyArgs, 'HBV' : hbvArgs, 'SACSMA_snow':None, 'SACSMA': sacsmaArgs,
ENSEMBLE_TYPE = 'max'  # 'median', 'avg', 'max', 'softmax'

# Load test observations and predictions from a prior run.
pred_path = os.path.join(out_dir, 'multimodels', '671_sites_dp', 'output', 'preds_671_dPLHBVd_SACSMASnow_PRMS.npy')
obs_path = os.path.join(out_dir, 'multimodels', '671_sites_dp', 'output', 'obs_671_dPLHBVd_SACSMASnow_PRMS.npy')
preds = np.load(pred_path, allow_pickle=True).item()
obs = np.load(obs_path, allow_pickle=True).item()

model_output = preds
y_obs = obs

# Initialize
flow_preds = []
flow_obs = None
obs_trig = False

# Concatenate individual model predictions, and observation data.
for i, mod in enumerate(args_list):
    args = args_list[mod]
    mod_out = model_output[mod]
    y_ob = y_obs[mod]

    print(mod)

    if mod in ['HBV', 'SACSMA', 'SACSMA_snow', 'marrmot_PRMS']:
        # Hydro models are tested in batches, so we concatenate them and select
        # the desired flow.
        # Note: modified HBV already has this preparation done during testing.

        # Get flow predictions and swap axes to get shape [basins, days]
        pred = np.swapaxes(torch.cat([d["flow_sim"] for d in mod_out], dim=1).squeeze().numpy(), 0, 1)

        if obs_trig == False:
            # dPLHBV uses GAGES while the other hydro models use CAMELS data. This means small 
            # e-5 variation in observation data between the two. This is averaged if both models
            # are used, but to avoid double-counting data from multiply hydro models, use a trigger.
            obs = np.swapaxes(y_ob[:, :, args["target"].index("00060_Mean")].numpy(), 0, 1)
            obs_trig = True
            dup = False
        else:
            dup = True

    elif mod in ['dPLHBV_dyn']:
        pred = mod_out[:,:,0][:,365:] # Set dim2 = 0 to get streamflow Qr
        obs = y_ob.squeeze()[:,365:]
        dup = False

    else:
        raise ValueError(f"Unsupported model type in `models`.")
    
    if i == 0:
        tmp_pred = pred
        tmp_obs = obs
    elif i == 1:
        tmp_pred = np.stack((tmp_pred, pred), axis=2)
        if not dup:
            # Avoid double-counting GAGES obs.
            tmp_obs = np.stack((tmp_obs, obs), axis=2)
    else:
        # Combine outputs of >3 models.
        tmp_pred = np.concatenate((tmp_pred,np.expand_dims(pred, 2)), axis=2)
        if not dup:
            # Avoid double-counting GAGES obs.
            tmp_obs = np.concatenate((tmp_obs,np.expand_dims(obs, 2)), axis=2)

preds = tmp_pred
obs = tmp_obs

dPLHBV_dyn
SACSMA_snow
marrmot_PRMS


In [4]:
preds

array([[[0.68206453, 0.04178718, 0.04005739],
        [0.6691796 , 0.15529171, 0.11451818],
        [0.66096807, 0.28792492, 0.20852403],
        ...,
        [0.8316542 , 1.36070156, 1.47097933],
        [0.9395818 , 1.60485446, 1.63182008],
        [1.1111125 , 1.95312655, 1.91992617]],

       [[1.0886465 , 0.30196935, 0.22478479],
        [1.025554  , 0.63441604, 0.54297   ],
        [0.95555204, 0.75878251, 0.72244602],
        ...,
        [0.65361506, 0.63856089, 0.87026596],
        [0.658928  , 0.66052788, 0.95191711],
        [0.6488367 , 0.62889987, 0.8994624 ]],

       [[0.45831633, 0.02072911, 0.02349203],
        [0.4823348 , 0.09775467, 0.08600431],
        [0.4811496 , 0.19010772, 0.15757091],
        ...,
        [0.5142777 , 0.79288632, 1.0482471 ],
        [0.5479347 , 0.84627336, 1.09326005],
        [0.57785803, 0.8880285 , 1.10479856]],

       ...,

       [[0.8393843 , 0.20238577, 0.0420501 ],
        [0.8366754 , 0.2631875 , 0.05698096],
        [0.8341537 , 0

In [5]:
obs

array([[[0.67010016, 0.67010128],
        [0.63534618, 0.63534725],
        [0.62448556, 0.62448663],
        ...,
        [0.69399352, 0.69399476],
        [0.81671851, 0.81671995],
        [0.93509925, 0.93510085]],

       [[0.9127725 , 0.91277152],
        [0.82746666, 0.82746577],
        [0.7506914 , 0.75069064],
        ...,
        [0.2687134 , 0.26871312],
        [0.30710103, 0.30710071],
        [0.3198969 , 0.31989658]],

       [[0.32544019, 0.32544002],
        [0.35272659, 0.3527264 ],
        [0.3620439 , 0.36204371],
        ...,
        [0.38267507, 0.38267487],
        [0.3986476 , 0.39864737],
        [0.41462012, 0.41461989]],

       ...,

       [[0.11595911, 0.11595792],
        [0.11595911, 0.11595792],
        [0.11595911, 0.11595792],
        ...,
        [0.19326518, 0.19326322],
        [0.18774332, 0.18774141],
        [0.18774332, 0.18774141]],

       [[0.02008291, 0.02008214],
        [0.02126426, 0.02126344],
        [0.02244561, 0.02244474],
        .