In [None]:
from mm_interface import master
from core.utils.randomseed_config import randomseed_config


import warnings
warnings.filterwarnings("ignore")



# Setting globals, dirs, randomseed
device, dtype = master.set_globals()
out_dir = master.set_platform_dirs()
randomseed_config(0)


# Setting dictionaries to separately manage each diff model's attributes.
models = {'dPLHBV_dyn': None,'SACSMA_snow':None, 'marrmot_PRMS':None}  # 'HBV':None, 'hbvhy': None, 'SACSMA_snow':None, 'SACSMA':None,
args_list = {'dPLHBV_dyn': hbvhyArgs_d,'SACSMA_snow':sacsmaSnowArgs, 'marrmot_PRMS':prmsArgs}   # 'hbvhy': hbvhyArgs, 'HBV' : hbvArgs, 'SACSMA_snow':None, 'SACSMA': sacsmaArgs,
ENSEMBLE_TYPE = 'max'  # 'median', 'avg', 'max', 'softmax'

# Load test observations and predictions from a prior run.
pred_path = os.path.join(out_dir, 'hydro_models', '671_sites_dp', 'merged_test_preds_obs', 'HBV_dynamic_SACSMA_snow_marrmot_PRMS', 'preds_multim_E50_B25_R365_BT365_H256_tr1980_1995_n16.npy')
obs_path = os.path.join(out_dir, 'hydro_models', '671_sites_dp', 'merged_test_preds_obs', 'HBV_dynamic_SACSMA_snow_marrmot_PRMS', 'obs_multim_E50_B25_R365_BT365_H256_tr1980_1995_n16.npy')
preds = np.load(pred_path, allow_pickle=True).item()
obs = np.load(obs_path, allow_pickle=True).item()

model_output = preds
y_obs = obs
model_output['dPLHBV_dyn'] = model_output.pop('HBV_dynamic')
y_obs['dPLHBV_dyn'] = y_obs.pop('HBV_dynamic')

# Initialize
flow_preds = []
flow_obs = None
obs_trig = False

# Concatenate individual model predictions, and observation data.
for i, mod in enumerate(args_list):
    args = args_list[mod]
    mod_out = model_output[mod]
    y_ob = y_obs[mod]

    print(mod)

    if mod in ['HBV', 'SACSMA', 'SACSMA_snow', 'marrmot_PRMS']:
        # Hydro models are tested in batches, so we concatenate them and select
        # the desired flow.
        # Note: modified HBV already has this preparation done during testing.

        # Get flow predictions and swap axes to get shape [basins, days]
        pred = np.swapaxes(torch.cat([d["flow_sim"] for d in mod_out], dim=1).squeeze().numpy(), 0, 1)

        if obs_trig == False:
            # dPLHBV uses GAGES while the other hydro models use CAMELS data. This means small
            # e-5 variation in observation data between the two. This is averaged if both models
            # are used, but to avoid double-counting data from multiply hydro models, use a trigger.
            obs = np.swapaxes(y_ob[:, :, args["target"].index("00060_Mean")].numpy(), 0, 1)
            obs_trig = True
            dup = False
        else:
            dup = True

    elif mod in ['dPLHBV_dyn']:
        pred = mod_out[:,:,0][:,365:] # Set dim2 = 0 to get streamflow Qr
        obs = y_ob.squeeze()[:,365:]
        dup = False

    else:
        raise ValueError(f"Unsupported model type in `models`.")

    if i == 0:
        tmp_pred = pred
        tmp_obs = obs
    elif i == 1:
        tmp_pred = np.stack((tmp_pred, pred), axis=2)
        if not dup:
            # Avoid double-counting GAGES obs.
            tmp_obs = np.stack((tmp_obs, obs), axis=2)
    else:
        # Combine outputs of >3 models.
        tmp_pred = np.concatenate((tmp_pred,np.expand_dims(pred, 2)), axis=2)
        if not dup:
            # Avoid double-counting GAGES obs.
            tmp_obs = np.concatenate((tmp_obs,np.expand_dims(obs, 2)), axis=2)

preds = tmp_pred
obs = tmp_obs

# Merge observation data.
if len(obs.shape) == 3:
    comp_obs = np.mean(obs, axis = 2)
elif len(obs.shape) == 2:
    comp_obs = obs
else:
    raise ValueError("Error reading prediction data: incorrect formatting.")
