In [3]:
import sys
package_path = '/data/lgl5139/hydro_multimodel/dPLHydro_multimodel'
sys.path.append(package_path)


from core.data.dataset_loading import get_data_dict
dataset_dict, _ = get_data_dict(model.config, train=False)


In [36]:
def batch_sample(config: Dict, dataset_dictionary: Dict[str, torch.Tensor], 
                     i_s: int, i_e: int) -> Dict[str, torch.Tensor]:
    """
    Take sample of data for testing batch.
    """
    dataset_sample = {}
    for key, value in dataset_dictionary.items():
        if value.ndim == 3:
            # TODO: I don't think we actually need this.
            # Remove the warmup period for all except airtemp_memory and hydro inputs.
            if key in ['airT_mem_temp_model', 'x_hydro_model', 'inputs_nn_scaled']:
                warm_up = 0
            else:
                warm_up = config['warm_up']
            dataset_sample[key] = value[warm_up:, i_s:i_e, :] #.to(config['device'])
        elif value.ndim == 2:
            dataset_sample[key] = value[i_s:i_e, :] #.to(config['device'])
        else:
            raise ValueError(f"Incorrect input dimensions. {key} array must have 2 or 3 dimensions.")
    return dataset_sample


In [38]:
from core.data import take_sample_test

batch_sample(model.config, dataset_dict, 0,25)['c_nn'].shape

(25, 35)

In [64]:
"""
This is a testing script for running a dPL, physics-informed machine learning
model BMI that is NextGen framework and NOAA OWP operation-ready.

Note:
- The current setup only passes CAMELS (671 basins) data to the BMI. For
    different datasets, `.set_value()` mappings must be modeified to the respective
    forcing + attribute key values.
"""
import os
from ruamel.yaml import YAML
import logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

from core.data.dataset_loading import get_data_dict

log = logging.getLogger(__name__)



################## Initialize the BMI ##################
# Path to BMI config.
config_path = '/data/lgl5139/hydro_multimodel/dPLHydro_multimodel/models/bmi/bmi_config.yaml' #"bmi_config.yaml"

# Create instance of BMI model.
log.info("Creating dPLHydro BMI model instance")
model = BMIdPLHydroModel()

# [CONTROL FUNCTION] Initialize the BMI.
log.info(f"INITIALIZING BMI")
model.initialize(bmi_cfg_filepath=config_path)


################## Get test data ##################
log.info(f"Collecting attribute and forcing data")

# TODO: Adapt this PMI data loader to be more-BMI friendly, less a function iceberg.
# dataset_dict, _ = get_data_dict(model.config, train=False)

# Fixing typo in CAMELS dataset: 'geol_porostiy'.
# (Written into config somewhere inside get_data_dict...)
var_c_nn = model.config['observations']['var_c_nn']
if 'geol_porostiy' in var_c_nn:
    model.config['observations']['var_c_nn'][var_c_nn.index('geol_porostiy')] = 'geol_porosity'


################## Forward model for 1 or multiple timesteps ##################
# n_timesteps = dataset_dict['inputs_nn_scaled'].shape[0]
n_timesteps = 1  # debug
tlim = 366  # debug
n_basins = 25


log.info(f"BEGIN BMI FORWARD: {n_timesteps} timesteps...")

# TODO: write a timestep handler/translator so we can pull out
# forcings/attributes for the specific timesteps we want streamflow predictions for.

# Loop through and return streamflow at each timestep.
for t in range(n_timesteps):
    # NOTE: for each timestep in this loop, the data assignments below are of
    # arrays of basins. e.g., forcings['key'].shape = (1, # basins)

    ################## Map forcings + attributes into BMI ##################
    # Set NN forcings...
    for i, var in enumerate(model.config['observations']['var_t_nn']):
        standard_name = model._var_name_map_short_first[var]
        print("outside", standard_name, dataset_dict['inputs_nn_scaled'][t:tlim, :n_basins, i].shape)
        model.set_value(standard_name, dataset_dict['inputs_nn_scaled'][t:tlim, :n_basins, i], model='nn')
    n_forc = i
    
    # Set NN attributes...
    for i, var in enumerate(model.config['observations']['var_c_nn']):
        standard_name = model._var_name_map_short_first[var]
        model.set_value(standard_name, dataset_dict['inputs_nn_scaled'][t:tlim, :n_basins, n_forc + i + 1], model='nn') 

    # Set physics model forcings...
    for i, var in enumerate(model.config['observations']['var_t_hydro_model']):
        standard_name = model._var_name_map_short_first[var]
        model.set_value(standard_name, dataset_dict['x_hydro_model'][t:tlim, :n_basins, i], model='pm') 

    # Set physics model attributes...
    for i, var in enumerate(model.config['observations']['var_c_hydro_model']):
        standard_name = model._var_name_map_short_first[var]
        # NOTE: These attributes don't have a time dimension...
        model.set_value(standard_name, dataset_dict['c_hydro_model'][:n_basins, i], model='pm') 

    # print(model._values)

    # [CONTROL FUNCTION] Update the model at all basins for one timestep.
    model.update()
    # print(f"Streamflow at time {model.t} is {model.streamflow_cms}")

[2024-07-22 10:09:05][__main__][INFO] - Creating dPLHydro BMI model instance
[2024-07-22 10:09:05][__main__][INFO] - INITIALIZING BMI
[2024-07-22 10:09:05][__main__][INFO] - Collecting attribute and forcing data
[2024-07-22 10:09:05][__main__][INFO] - BEGIN BMI FORWARD: 1 timesteps...


outside atmosphere_water__liquid_equivalent_precipitation_rate (366, 25)
outside land_surface_air__temperature (366, 25)
outside land_surface_water__potential_evaporation_volume_flux (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


In [67]:
"""
This is a testing script for running a dPL, physics-informed machine learning
model BMI that is NextGen framework and NOAA OWP operation-ready.

Note:
- The current setup only passes CAMELS (671 basins) data to the BMI. For
    different datasets, `.set_value()` mappings must be modeified to the respective
    forcing + attribute key values.
"""
import os
import numpy as np
from tqdm import tqdm
from ruamel.yaml import YAML
import logging
logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(name)s][%(levelname)s] - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

from core.data.dataset_loading import get_data_dict

log = logging.getLogger(__name__)


################## Initialize the BMI ##################
# Path to BMI config.
config_path = '/data/lgl5139/hydro_multimodel/dPLHydro_multimodel/models/bmi/bmi_config.yaml' #"bmi_config.yaml"

# Create instance of BMI model.
log.info("Creating dPLHydro BMI model instance")
model = BMIdPLHydroModel()

# [CONTROL FUNCTION] Initialize the BMI.
log.info(f"INITIALIZING BMI")
model.initialize(bmi_cfg_filepath=config_path)


################## Get test data ##################
log.info(f"Collecting attribute and forcing data")

# Fixing typo in CAMELS dataset: 'geol_porostiy'.
# (Written into config somewhere inside get_data_dict...)
var_c_nn = model.config['observations']['var_c_nn']
if 'geol_porostiy' in var_c_nn:
    model.config['observations']['var_c_nn'][var_c_nn.index('geol_porostiy')] = 'geol_porosity'


################## Forward model for 1 or multiple timesteps ##################
# n_timesteps = dataset_dict['inputs_nn_scaled'].shape[0]
# debugging ----- #
n_timesteps = 2  # debug
tlim = 367  # debug
n_basins = 25
# --------------- #

log.info(f"BEGIN BMI FORWARD: {n_timesteps} timesteps...")

rho = model.config['rho']  # For routing

# Loop through and return streamflow at each timestep.
for t in range(n_timesteps):
    # NOTE: for each timestep in this loop, the data assignments below are of
    # arrays of basins. e.g., forcings['key'].shape = (rho + 1, # basins).
    # NOTE: MHPI models use a warmup period and routing in their forward pass,
    # so we cannot simply pass one timestep to these, but rather warmup or
    # rho + 1 timesteps up to the step we want to predict.
    # TODO: Check inefficiency cost of setting an extra rho timesteps of data
    # in the BMI for each timestep prediction. If too much, we pass all available
    # data into BMI; sounds from Jonathan that this should be fine.

    batched_preds_list = []
    ngrid = dataset_dict['inputs_nn_scaled'].shape[1]
    iS = np.arange(0, ngrid, model.config['batch_basins'])
    iE = np.append(iS[1:], ngrid)

    for i in tqdm(range(len(iS)), leave=False, dynamic_ncols=True):
        batch_sample_dict = batch_sample(model.config, dataset_dict, iS[i], iE[i])


        ################## Map forcings + attributes into BMI ##################
        # Set NN forcings...
        for i, var in enumerate(model.config['observations']['var_t_nn']):
            standard_name = model._var_name_map_short_first[var]
            model.set_value(standard_name, dataset_dict['inputs_nn_scaled'][t:rho + t + 1, :n_basins, i], model='nn')
        n_forc = i
        
        # Set NN attributes...
        for i, var in enumerate(model.config['observations']['var_c_nn']):
            standard_name = model._var_name_map_short_first[var]
            model.set_value(standard_name, dataset_dict['inputs_nn_scaled'][t:rho + t + 1, :n_basins, n_forc + i + 1], model='nn') 

        # Set physics model forcings...
        for i, var in enumerate(model.config['observations']['var_t_hydro_model']):
            standard_name = model._var_name_map_short_first[var]
            model.set_value(standard_name, dataset_dict['x_hydro_model'][t:rho + t + 1, :n_basins, i], model='pm') 

        # Set physics model attributes...
        for i, var in enumerate(model.config['observations']['var_c_hydro_model']):
            standard_name = model._var_name_map_short_first[var]
            # NOTE: These don't have a time dimension.
            model.set_value(standard_name, dataset_dict['c_hydro_model'][:n_basins, i], model='pm') 

        # [CONTROL FUNCTION] Update the model at all basins for one timestep.
        model.update()

        print(f"Streamflow at time {t} is {np.average((model.streamflow_cms).cpu())}")
        print(model.preds['HBV']['flow_sim'])

[2024-07-22 10:22:19][__main__][INFO] - Creating dPLHydro BMI model instance
[2024-07-22 10:22:19][__main__][INFO] - INITIALIZING BMI
[2024-07-22 10:22:19][__main__][INFO] - Collecting attribute and forcing data
[2024-07-22 10:22:19][__main__][INFO] - BEGIN BMI FORWARD: 2 timesteps...
  0%|          | 0/27 [00:00<?, ?it/s]

xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


  4%|▎         | 1/27 [00:00<00:10,  2.52it/s]

Streamflow at time 0 is 0.2967322766780853
tensor([[[0.3797],
         [0.8162],
         [0.4463],
         [0.2316],
         [0.1695],
         [0.1500],
         [0.1231],
         [0.0883],
         [0.2711],
         [0.3231],
         [0.2240],
         [0.6297],
         [0.3755],
         [0.4125],
         [0.1153],
         [0.3766],
         [0.2217],
         [0.3283],
         [0.1741],
         [0.1816],
         [0.3951],
         [0.2239],
         [0.3702],
         [0.1874],
         [0.2038]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


  7%|▋         | 2/27 [00:00<00:09,  2.54it/s]

Streamflow at time 0 is 0.33953049778938293
tensor([[[0.4071],
         [0.8722],
         [0.4850],
         [0.2638],
         [0.2065],
         [0.1873],
         [0.1593],
         [0.1246],
         [0.2845],
         [0.4032],
         [0.2523],
         [0.7263],
         [0.4322],
         [0.4635],
         [0.1321],
         [0.3927],
         [0.2540],
         [0.3327],
         [0.1852],
         [0.1899],
         [0.5389],
         [0.3158],
         [0.3851],
         [0.2644],
         [0.2296]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 11%|█         | 3/27 [00:01<00:09,  2.54it/s]

Streamflow at time 0 is 0.32227107882499695
tensor([[[0.4264],
         [0.8108],
         [0.4889],
         [0.2424],
         [0.1907],
         [0.1570],
         [0.1387],
         [0.1103],
         [0.2928],
         [0.3958],
         [0.2719],
         [0.5657],
         [0.3694],
         [0.4055],
         [0.1190],
         [0.3502],
         [0.2423],
         [0.3957],
         [0.2607],
         [0.1935],
         [0.4848],
         [0.2735],
         [0.4282],
         [0.2268],
         [0.2158]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 15%|█▍        | 4/27 [00:01<00:09,  2.55it/s]

Streamflow at time 0 is 0.2787761688232422
tensor([[[0.3130],
         [0.7174],
         [0.3839],
         [0.2273],
         [0.1763],
         [0.1361],
         [0.0879],
         [0.0987],
         [0.2406],
         [0.3473],
         [0.1659],
         [0.5422],
         [0.3721],
         [0.4076],
         [0.1127],
         [0.3066],
         [0.2203],
         [0.3256],
         [0.1751],
         [0.1812],
         [0.4042],
         [0.2488],
         [0.3441],
         [0.2143],
         [0.2201]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 19%|█▊        | 5/27 [00:01<00:08,  2.56it/s]

Streamflow at time 0 is 0.29672399163246155
tensor([[[0.4316],
         [0.6776],
         [0.4869],
         [0.2414],
         [0.1948],
         [0.1660],
         [0.1084],
         [0.1228],
         [0.2846],
         [0.3293],
         [0.2231],
         [0.5651],
         [0.4365],
         [0.4798],
         [0.1078],
         [0.2690],
         [0.2054],
         [0.3047],
         [0.1914],
         [0.1693],
         [0.3359],
         [0.2736],
         [0.3597],
         [0.2245],
         [0.2291]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 22%|██▏       | 6/27 [00:02<00:08,  2.58it/s]

Streamflow at time 0 is 0.28316232562065125
tensor([[[0.3901],
         [0.7903],
         [0.4396],
         [0.2308],
         [0.2003],
         [0.1345],
         [0.0640],
         [0.0877],
         [0.2486],
         [0.3507],
         [0.1848],
         [0.6416],
         [0.3655],
         [0.3820],
         [0.1042],
         [0.3119],
         [0.1955],
         [0.2954],
         [0.1687],
         [0.1760],
         [0.3725],
         [0.2394],
         [0.3063],
         [0.1988],
         [0.1998]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 26%|██▌       | 7/27 [00:02<00:07,  2.58it/s]

Streamflow at time 0 is 0.32889634370803833
tensor([[[0.4139],
         [0.7557],
         [0.5165],
         [0.2576],
         [0.2295],
         [0.1930],
         [0.1027],
         [0.1119],
         [0.3234],
         [0.3974],
         [0.1896],
         [0.5995],
         [0.4197],
         [0.4581],
         [0.1531],
         [0.4139],
         [0.3155],
         [0.3898],
         [0.2442],
         [0.2284],
         [0.3735],
         [0.2670],
         [0.4014],
         [0.2331],
         [0.2341]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 30%|██▉       | 8/27 [00:03<00:07,  2.58it/s]

Streamflow at time 0 is 0.29761114716529846
tensor([[[0.3854],
         [0.7190],
         [0.4343],
         [0.2435],
         [0.1976],
         [0.1446],
         [0.1169],
         [0.1082],
         [0.2855],
         [0.3703],
         [0.1761],
         [0.5671],
         [0.4133],
         [0.4638],
         [0.1359],
         [0.2600],
         [0.2524],
         [0.3800],
         [0.1571],
         [0.1983],
         [0.3893],
         [0.2449],
         [0.3517],
         [0.2286],
         [0.2165]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 33%|███▎      | 9/27 [00:03<00:07,  2.46it/s]

Streamflow at time 0 is 0.28921642899513245
tensor([[[0.4246],
         [0.6840],
         [0.4478],
         [0.2255],
         [0.1971],
         [0.1420],
         [0.1161],
         [0.0920],
         [0.2849],
         [0.3267],
         [0.1910],
         [0.5657],
         [0.3969],
         [0.4307],
         [0.1133],
         [0.2606],
         [0.2622],
         [0.3321],
         [0.1759],
         [0.1855],
         [0.2999],
         [0.2200],
         [0.3843],
         [0.2286],
         [0.2430]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 37%|███▋      | 10/27 [00:04<00:07,  2.38it/s]

Streamflow at time 0 is 0.3287491202354431
tensor([[[0.3894],
         [0.7792],
         [0.4665],
         [0.2581],
         [0.2290],
         [0.1652],
         [0.1329],
         [0.1165],
         [0.3389],
         [0.3930],
         [0.2621],
         [0.6017],
         [0.4714],
         [0.5191],
         [0.1273],
         [0.3831],
         [0.3354],
         [0.3322],
         [0.1962],
         [0.2209],
         [0.4288],
         [0.2821],
         [0.3509],
         [0.2214],
         [0.2173]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 41%|████      | 11/27 [00:04<00:06,  2.43it/s]

Streamflow at time 0 is 0.29971399903297424
tensor([[[0.3794],
         [0.6616],
         [0.3911],
         [0.2473],
         [0.2157],
         [0.1750],
         [0.1724],
         [0.1157],
         [0.3126],
         [0.3430],
         [0.2178],
         [0.5496],
         [0.4020],
         [0.4182],
         [0.1292],
         [0.3002],
         [0.2879],
         [0.3534],
         [0.1790],
         [0.2031],
         [0.3705],
         [0.2333],
         [0.3689],
         [0.2282],
         [0.2376]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 44%|████▍     | 12/27 [00:04<00:06,  2.47it/s]

Streamflow at time 0 is 0.3206193447113037
tensor([[[0.4424],
         [0.7513],
         [0.4825],
         [0.2428],
         [0.1826],
         [0.1582],
         [0.1352],
         [0.0962],
         [0.2670],
         [0.3245],
         [0.2420],
         [0.6415],
         [0.4206],
         [0.4466],
         [0.1226],
         [0.3548],
         [0.2752],
         [0.3806],
         [0.2480],
         [0.1933],
         [0.4233],
         [0.2944],
         [0.4184],
         [0.2363],
         [0.2353]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 48%|████▊     | 13/27 [00:05<00:05,  2.50it/s]

Streamflow at time 0 is 0.264851450920105
tensor([[[0.4351],
         [0.6914],
         [0.4632],
         [0.2231],
         [0.1497],
         [0.1696],
         [0.1106],
         [0.1002],
         [0.2659],
         [0.3261],
         [0.1566],
         [0.5613],
         [0.3088],
         [0.3491],
         [0.0994],
         [0.1983],
         [0.2125],
         [0.2565],
         [0.1591],
         [0.1835],
         [0.2703],
         [0.2030],
         [0.3154],
         [0.1946],
         [0.2180]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 52%|█████▏    | 14/27 [00:05<00:05,  2.52it/s]

Streamflow at time 0 is 0.30065950751304626
tensor([[[0.4614],
         [0.7580],
         [0.5296],
         [0.2489],
         [0.1894],
         [0.1784],
         [0.0861],
         [0.0886],
         [0.2740],
         [0.3536],
         [0.1888],
         [0.6386],
         [0.3890],
         [0.4164],
         [0.1288],
         [0.2941],
         [0.2744],
         [0.2816],
         [0.1794],
         [0.1876],
         [0.3426],
         [0.2560],
         [0.3189],
         [0.2221],
         [0.2302]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 56%|█████▌    | 15/27 [00:05<00:04,  2.53it/s]

Streamflow at time 0 is 0.3579704165458679
tensor([[[0.3825],
         [0.7966],
         [0.4848],
         [0.2463],
         [0.2424],
         [0.2107],
         [0.1773],
         [0.1292],
         [0.4282],
         [0.3991],
         [0.2962],
         [0.6554],
         [0.5075],
         [0.5321],
         [0.1346],
         [0.5118],
         [0.3455],
         [0.2984],
         [0.1893],
         [0.2037],
         [0.4654],
         [0.3228],
         [0.4488],
         [0.2849],
         [0.2558]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 59%|█████▉    | 16/27 [00:06<00:04,  2.54it/s]

Streamflow at time 0 is 0.3218288719654083
tensor([[[0.3710],
         [0.8159],
         [0.4192],
         [0.2666],
         [0.2086],
         [0.1458],
         [0.1130],
         [0.0987],
         [0.3347],
         [0.3846],
         [0.2374],
         [0.6405],
         [0.4253],
         [0.4576],
         [0.1151],
         [0.3613],
         [0.2960],
         [0.4492],
         [0.2315],
         [0.2000],
         [0.3853],
         [0.2776],
         [0.3547],
         [0.2224],
         [0.2337]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 63%|██████▎   | 17/27 [00:06<00:03,  2.54it/s]

Streamflow at time 0 is 0.3240063786506653
tensor([[[0.3596],
         [0.8298],
         [0.4609],
         [0.2259],
         [0.1935],
         [0.1358],
         [0.0937],
         [0.1215],
         [0.3269],
         [0.3827],
         [0.2685],
         [0.6610],
         [0.4553],
         [0.4827],
         [0.1195],
         [0.3983],
         [0.2307],
         [0.3000],
         [0.1719],
         [0.2068],
         [0.4612],
         [0.3061],
         [0.4077],
         [0.2428],
         [0.2575]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 67%|██████▋   | 18/27 [00:07<00:03,  2.55it/s]

Streamflow at time 0 is 0.30810219049453735
tensor([[[0.4239],
         [0.8429],
         [0.4845],
         [0.2201],
         [0.1732],
         [0.1436],
         [0.1087],
         [0.1274],
         [0.2761],
         [0.3909],
         [0.2453],
         [0.6029],
         [0.4108],
         [0.4312],
         [0.1167],
         [0.3496],
         [0.2172],
         [0.3189],
         [0.1553],
         [0.1696],
         [0.4626],
         [0.2361],
         [0.3321],
         [0.2254],
         [0.2375]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 70%|███████   | 19/27 [00:07<00:03,  2.56it/s]

Streamflow at time 0 is 0.3099818825721741
tensor([[[0.4193],
         [0.7513],
         [0.4742],
         [0.2325],
         [0.2073],
         [0.1135],
         [0.1134],
         [0.1180],
         [0.2865],
         [0.3717],
         [0.2431],
         [0.6083],
         [0.4139],
         [0.4271],
         [0.1142],
         [0.3681],
         [0.2401],
         [0.3395],
         [0.1721],
         [0.1812],
         [0.4340],
         [0.2774],
         [0.3651],
         [0.2272],
         [0.2505]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 74%|███████▍  | 20/27 [00:07<00:02,  2.56it/s]

Streamflow at time 0 is 0.3005250096321106
tensor([[[0.3146],
         [0.6837],
         [0.3972],
         [0.2255],
         [0.1779],
         [0.1693],
         [0.1001],
         [0.1186],
         [0.2279],
         [0.3804],
         [0.1826],
         [0.5611],
         [0.4213],
         [0.4476],
         [0.1263],
         [0.3835],
         [0.2843],
         [0.4372],
         [0.2320],
         [0.1925],
         [0.3749],
         [0.2486],
         [0.3837],
         [0.2091],
         [0.2333]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 78%|███████▊  | 21/27 [00:08<00:02,  2.57it/s]

Streamflow at time 0 is 0.30321064591407776
tensor([[[0.4471],
         [0.6595],
         [0.4523],
         [0.2398],
         [0.1720],
         [0.1935],
         [0.1305],
         [0.1071],
         [0.2446],
         [0.3416],
         [0.2089],
         [0.5672],
         [0.3987],
         [0.4274],
         [0.1175],
         [0.3347],
         [0.2831],
         [0.3593],
         [0.2527],
         [0.2030],
         [0.3568],
         [0.2560],
         [0.3800],
         [0.2321],
         [0.2149]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 81%|████████▏ | 22/27 [00:08<00:01,  2.57it/s]

Streamflow at time 0 is 0.3366590142250061
tensor([[[0.4102],
         [0.8092],
         [0.4798],
         [0.2288],
         [0.2013],
         [0.1752],
         [0.1673],
         [0.0864],
         [0.3200],
         [0.4325],
         [0.2935],
         [0.6397],
         [0.4583],
         [0.4680],
         [0.1273],
         [0.4224],
         [0.2905],
         [0.3728],
         [0.2410],
         [0.1974],
         [0.4747],
         [0.2900],
         [0.3916],
         [0.2236],
         [0.2154]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 85%|████████▌ | 23/27 [00:09<00:01,  2.57it/s]

Streamflow at time 0 is 0.31178703904151917
tensor([[[0.3889],
         [0.7982],
         [0.4614],
         [0.2372],
         [0.1810],
         [0.1474],
         [0.1403],
         [0.1028],
         [0.2525],
         [0.3599],
         [0.2690],
         [0.7285],
         [0.3744],
         [0.4022],
         [0.1043],
         [0.3407],
         [0.2519],
         [0.4080],
         [0.2678],
         [0.1757],
         [0.3973],
         [0.2442],
         [0.3462],
         [0.1919],
         [0.2232]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 89%|████████▉ | 24/27 [00:09<00:01,  2.56it/s]

Streamflow at time 0 is 0.3034447729587555
tensor([[[0.3804],
         [0.7367],
         [0.4951],
         [0.2347],
         [0.2053],
         [0.1550],
         [0.0993],
         [0.1026],
         [0.2761],
         [0.3907],
         [0.1963],
         [0.5490],
         [0.3988],
         [0.4225],
         [0.1264],
         [0.2938],
         [0.2274],
         [0.3375],
         [0.2034],
         [0.2036],
         [0.3583],
         [0.2804],
         [0.4006],
         [0.2391],
         [0.2730]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 93%|█████████▎| 25/27 [00:09<00:00,  2.55it/s]

Streamflow at time 0 is 0.30617886781692505
tensor([[[0.4185],
         [0.7335],
         [0.4968],
         [0.2345],
         [0.1933],
         [0.1388],
         [0.1297],
         [0.1410],
         [0.3057],
         [0.3831],
         [0.2247],
         [0.5642],
         [0.4040],
         [0.4293],
         [0.1141],
         [0.3069],
         [0.2174],
         [0.3411],
         [0.1876],
         [0.1925],
         [0.3731],
         [0.2686],
         [0.3849],
         [0.2322],
         [0.2388]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 96%|█████████▋| 26/27 [00:10<00:00,  2.56it/s]

Streamflow at time 0 is 0.29703181982040405
tensor([[[0.3587],
         [0.7651],
         [0.4360],
         [0.2283],
         [0.1897],
         [0.1438],
         [0.0552],
         [0.1129],
         [0.2635],
         [0.3473],
         [0.2134],
         [0.5987],
         [0.3838],
         [0.4204],
         [0.1157],
         [0.3622],
         [0.2538],
         [0.3249],
         [0.2013],
         [0.1922],
         [0.3581],
         [0.2571],
         [0.3727],
         [0.2358],
         [0.2354]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


                                               

Streamflow at time 0 is 0.2975090444087982
tensor([[[0.3240],
         [0.8644],
         [0.3893],
         [0.2506],
         [0.2026],
         [0.1850],
         [0.1357],
         [0.1099],
         [0.2939],
         [0.3658],
         [0.1979],
         [0.5628],
         [0.3578],
         [0.3837],
         [0.1261],
         [0.4801],
         [0.2505],
         [0.2505],
         [0.1740],
         [0.1806],
         [0.3824],
         [0.2514],
         [0.3319],
         [0.1947],
         [0.1919]]], device='cuda:0')


  0%|          | 0/27 [00:00<?, ?it/s]

xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


  4%|▎         | 1/27 [00:00<00:10,  2.55it/s]

Streamflow at time 1 is 0.3024483025074005
tensor([[[0.3956],
         [0.6190],
         [0.4042],
         [0.2385],
         [0.1902],
         [0.1666],
         [0.1110],
         [0.1258],
         [0.2810],
         [0.3061],
         [0.1720],
         [0.6995],
         [0.4544],
         [0.4959],
         [0.1111],
         [0.2866],
         [0.2743],
         [0.3981],
         [0.1704],
         [0.1978],
         [0.3657],
         [0.2359],
         [0.3735],
         [0.2214],
         [0.2666]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


  7%|▋         | 2/27 [00:00<00:10,  2.42it/s]

Streamflow at time 1 is 0.31004226207733154
tensor([[[0.3690],
         [0.6729],
         [0.4408],
         [0.2235],
         [0.1946],
         [0.1889],
         [0.1335],
         [0.0991],
         [0.2977],
         [0.3159],
         [0.2085],
         [0.6817],
         [0.4137],
         [0.4640],
         [0.1285],
         [0.3489],
         [0.3004],
         [0.3754],
         [0.1948],
         [0.1991],
         [0.3991],
         [0.2313],
         [0.3431],
         [0.2455],
         [0.2814]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 11%|█         | 3/27 [00:01<00:09,  2.47it/s]

Streamflow at time 1 is 0.2887735366821289
tensor([[[0.3431],
         [0.6848],
         [0.4122],
         [0.2133],
         [0.1656],
         [0.1331],
         [0.0853],
         [0.1003],
         [0.2373],
         [0.3201],
         [0.2374],
         [0.7145],
         [0.4355],
         [0.4817],
         [0.1060],
         [0.2720],
         [0.2320],
         [0.2973],
         [0.1618],
         [0.1795],
         [0.4037],
         [0.2400],
         [0.3076],
         [0.2079],
         [0.2473]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 15%|█▍        | 4/27 [00:01<00:09,  2.50it/s]

Streamflow at time 1 is 0.30714160203933716
tensor([[[0.3694],
         [0.6468],
         [0.4341],
         [0.2466],
         [0.2117],
         [0.1803],
         [0.1064],
         [0.1155],
         [0.3246],
         [0.3287],
         [0.2009],
         [0.7073],
         [0.4234],
         [0.4720],
         [0.1359],
         [0.2815],
         [0.3137],
         [0.3538],
         [0.2163],
         [0.1964],
         [0.3351],
         [0.2346],
         [0.3288],
         [0.2526],
         [0.2623]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 19%|█▊        | 5/27 [00:01<00:08,  2.52it/s]

Streamflow at time 1 is 0.33069029450416565
tensor([[[0.4477],
         [0.7327],
         [0.4959],
         [0.2114],
         [0.1474],
         [0.1656],
         [0.0756],
         [0.0932],
         [0.2597],
         [0.4043],
         [0.2357],
         [0.7470],
         [0.4755],
         [0.5085],
         [0.1234],
         [0.3773],
         [0.2697],
         [0.4041],
         [0.1904],
         [0.2002],
         [0.4133],
         [0.2885],
         [0.4629],
         [0.2652],
         [0.2722]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 22%|██▏       | 6/27 [00:02<00:08,  2.53it/s]

Streamflow at time 1 is 0.3361317813396454
tensor([[[0.3754],
         [0.7256],
         [0.4348],
         [0.2505],
         [0.2250],
         [0.1710],
         [0.1770],
         [0.1271],
         [0.3638],
         [0.3906],
         [0.2255],
         [0.7031],
         [0.4646],
         [0.4858],
         [0.1409],
         [0.3862],
         [0.2806],
         [0.3706],
         [0.2088],
         [0.2229],
         [0.4478],
         [0.2764],
         [0.3571],
         [0.2658],
         [0.3267]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 26%|██▌       | 7/27 [00:02<00:07,  2.54it/s]

Streamflow at time 1 is 0.3081676959991455
tensor([[[0.3553],
         [0.7281],
         [0.4564],
         [0.2249],
         [0.1898],
         [0.1225],
         [0.0740],
         [0.0825],
         [0.2522],
         [0.3871],
         [0.1755],
         [0.8051],
         [0.4351],
         [0.4862],
         [0.1017],
         [0.3409],
         [0.2250],
         [0.2891],
         [0.1838],
         [0.1504],
         [0.4662],
         [0.2788],
         [0.3742],
         [0.2546],
         [0.2649]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 30%|██▉       | 8/27 [00:03<00:07,  2.53it/s]

Streamflow at time 1 is 0.2943371832370758
tensor([[[0.3639],
         [0.7456],
         [0.4328],
         [0.2136],
         [0.1619],
         [0.1471],
         [0.1052],
         [0.0890],
         [0.2399],
         [0.3440],
         [0.1848],
         [0.7158],
         [0.4152],
         [0.4599],
         [0.0896],
         [0.3011],
         [0.2541],
         [0.3331],
         [0.1560],
         [0.1825],
         [0.3804],
         [0.2294],
         [0.3353],
         [0.2131],
         [0.2653]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 33%|███▎      | 9/27 [00:03<00:07,  2.54it/s]

Streamflow at time 1 is 0.32358813285827637
tensor([[[0.3383],
         [0.7408],
         [0.4202],
         [0.2411],
         [0.2299],
         [0.1735],
         [0.1082],
         [0.1290],
         [0.3006],
         [0.3528],
         [0.2441],
         [0.7865],
         [0.4516],
         [0.4954],
         [0.1160],
         [0.3338],
         [0.2753],
         [0.3066],
         [0.2110],
         [0.1959],
         [0.4096],
         [0.2819],
         [0.3918],
         [0.2772],
         [0.2787]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 37%|███▋      | 10/27 [00:03<00:06,  2.55it/s]

Streamflow at time 1 is 0.3011869490146637
tensor([[[0.4233],
         [0.7072],
         [0.4716],
         [0.2501],
         [0.1849],
         [0.1370],
         [0.1033],
         [0.0989],
         [0.2553],
         [0.3473],
         [0.1954],
         [0.6878],
         [0.3744],
         [0.4232],
         [0.1253],
         [0.3251],
         [0.2547],
         [0.3011],
         [0.1928],
         [0.1952],
         [0.3673],
         [0.2479],
         [0.3659],
         [0.2303],
         [0.2643]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 41%|████      | 11/27 [00:04<00:06,  2.56it/s]

Streamflow at time 1 is 0.33731552958488464
tensor([[[0.4389],
         [0.8594],
         [0.4948],
         [0.2630],
         [0.2399],
         [0.1289],
         [0.1261],
         [0.1126],
         [0.2955],
         [0.2992],
         [0.2990],
         [0.7703],
         [0.4181],
         [0.4741],
         [0.1191],
         [0.4898],
         [0.2446],
         [0.2929],
         [0.2017],
         [0.1974],
         [0.4897],
         [0.2698],
         [0.3602],
         [0.2495],
         [0.2984]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 44%|████▍     | 12/27 [00:04<00:05,  2.56it/s]

Streamflow at time 1 is 0.3253725469112396
tensor([[[0.3770],
         [0.7547],
         [0.4755],
         [0.2225],
         [0.1718],
         [0.1526],
         [0.1360],
         [0.1097],
         [0.2490],
         [0.3114],
         [0.2208],
         [0.7887],
         [0.4537],
         [0.5313],
         [0.0895],
         [0.5398],
         [0.2395],
         [0.3000],
         [0.2220],
         [0.1735],
         [0.5226],
         [0.2605],
         [0.3548],
         [0.2247],
         [0.2527]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 48%|████▊     | 13/27 [00:05<00:05,  2.56it/s]

Streamflow at time 1 is 0.3175594210624695
tensor([[[0.3985],
         [0.7325],
         [0.4416],
         [0.2338],
         [0.1916],
         [0.2130],
         [0.1184],
         [0.1011],
         [0.2694],
         [0.3795],
         [0.2010],
         [0.6878],
         [0.4264],
         [0.4792],
         [0.1378],
         [0.3685],
         [0.2776],
         [0.3054],
         [0.2271],
         [0.1894],
         [0.3866],
         [0.2658],
         [0.3422],
         [0.2855],
         [0.2792]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 52%|█████▏    | 14/27 [00:05<00:05,  2.47it/s]

Streamflow at time 1 is 0.31179600954055786
tensor([[[0.3763],
         [0.8189],
         [0.4691],
         [0.2108],
         [0.1807],
         [0.1415],
         [0.0581],
         [0.0820],
         [0.2709],
         [0.3199],
         [0.2159],
         [0.8103],
         [0.4648],
         [0.5258],
         [0.1250],
         [0.3716],
         [0.2313],
         [0.2557],
         [0.1465],
         [0.1800],
         [0.4926],
         [0.2387],
         [0.3401],
         [0.2150],
         [0.2536]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 56%|█████▌    | 15/27 [00:05<00:04,  2.50it/s]

Streamflow at time 1 is 0.3247033357620239
tensor([[[0.4447],
         [0.7840],
         [0.4906],
         [0.2300],
         [0.1694],
         [0.1445],
         [0.0811],
         [0.0851],
         [0.3333],
         [0.3685],
         [0.2462],
         [0.8174],
         [0.4844],
         [0.5609],
         [0.1189],
         [0.3966],
         [0.2430],
         [0.3062],
         [0.1738],
         [0.1757],
         [0.3952],
         [0.2273],
         [0.3400],
         [0.2354],
         [0.2653]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 59%|█████▉    | 16/27 [00:06<00:04,  2.40it/s]

Streamflow at time 1 is 0.3193935453891754
tensor([[[0.4394],
         [0.6858],
         [0.5141],
         [0.2328],
         [0.1793],
         [0.1867],
         [0.1310],
         [0.1066],
         [0.2494],
         [0.3465],
         [0.1796],
         [0.6927],
         [0.5045],
         [0.5565],
         [0.1134],
         [0.2809],
         [0.2406],
         [0.4231],
         [0.1764],
         [0.1901],
         [0.4223],
         [0.2524],
         [0.3612],
         [0.2486],
         [0.2705]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 63%|██████▎   | 17/27 [00:06<00:04,  2.42it/s]

Streamflow at time 1 is 0.29916733503341675
tensor([[[0.3726],
         [0.7267],
         [0.4457],
         [0.2296],
         [0.1955],
         [0.1492],
         [0.0970],
         [0.1091],
         [0.2655],
         [0.2709],
         [0.1995],
         [0.7120],
         [0.4315],
         [0.4732],
         [0.1225],
         [0.2924],
         [0.2118],
         [0.3491],
         [0.1643],
         [0.1842],
         [0.3218],
         [0.2529],
         [0.3759],
         [0.2554],
         [0.2708]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 67%|██████▋   | 18/27 [00:07<00:03,  2.41it/s]

Streamflow at time 1 is 0.2939625084400177
tensor([[[0.3346],
         [0.5876],
         [0.3751],
         [0.2140],
         [0.1770],
         [0.1553],
         [0.0990],
         [0.0874],
         [0.3071],
         [0.3022],
         [0.2004],
         [0.7165],
         [0.4291],
         [0.4717],
         [0.1201],
         [0.3155],
         [0.2771],
         [0.3796],
         [0.1850],
         [0.1834],
         [0.2800],
         [0.2650],
         [0.3796],
         [0.2490],
         [0.2576]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 70%|███████   | 19/27 [00:07<00:03,  2.37it/s]

Streamflow at time 1 is 0.3159227669239044
tensor([[[0.3326],
         [0.8314],
         [0.4265],
         [0.2226],
         [0.1837],
         [0.1293],
         [0.1243],
         [0.1046],
         [0.2358],
         [0.3906],
         [0.2447],
         [0.7623],
         [0.3990],
         [0.4458],
         [0.1245],
         [0.3604],
         [0.2522],
         [0.3171],
         [0.2307],
         [0.2026],
         [0.4334],
         [0.2646],
         [0.3662],
         [0.2414],
         [0.2716]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 74%|███████▍  | 20/27 [00:08<00:02,  2.43it/s]

Streamflow at time 1 is 0.3066652715206146
tensor([[[0.3498],
         [0.7395],
         [0.4180],
         [0.2151],
         [0.1816],
         [0.1500],
         [0.1059],
         [0.1108],
         [0.2964],
         [0.3238],
         [0.2444],
         [0.7779],
         [0.4563],
         [0.5200],
         [0.1170],
         [0.3526],
         [0.2111],
         [0.3008],
         [0.1878],
         [0.1733],
         [0.3919],
         [0.2407],
         [0.3348],
         [0.2195],
         [0.2476]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 78%|███████▊  | 21/27 [00:08<00:02,  2.47it/s]

Streamflow at time 1 is 0.3263903856277466
tensor([[[0.3431],
         [0.7356],
         [0.4103],
         [0.2310],
         [0.2081],
         [0.1421],
         [0.0823],
         [0.1133],
         [0.3449],
         [0.3795],
         [0.2487],
         [0.7372],
         [0.4768],
         [0.5316],
         [0.1287],
         [0.3834],
         [0.2481],
         [0.3260],
         [0.2108],
         [0.2079],
         [0.3870],
         [0.2736],
         [0.4305],
         [0.2895],
         [0.2896]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 81%|████████▏ | 22/27 [00:08<00:02,  2.49it/s]

Streamflow at time 1 is 0.2740177512168884
tensor([[[0.3211],
         [0.6505],
         [0.4043],
         [0.2004],
         [0.1642],
         [0.1030],
         [0.0594],
         [0.0622],
         [0.2335],
         [0.3390],
         [0.1636],
         [0.7218],
         [0.4291],
         [0.5106],
         [0.0910],
         [0.2708],
         [0.1763],
         [0.2891],
         [0.2037],
         [0.1398],
         [0.3014],
         [0.2323],
         [0.3309],
         [0.2141],
         [0.2384]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 85%|████████▌ | 23/27 [00:09<00:01,  2.51it/s]

Streamflow at time 1 is 0.30772799253463745
tensor([[[0.3711],
         [0.6285],
         [0.4141],
         [0.2276],
         [0.1912],
         [0.1422],
         [0.1022],
         [0.1163],
         [0.2900],
         [0.3141],
         [0.1736],
         [0.7040],
         [0.4352],
         [0.4934],
         [0.1255],
         [0.4546],
         [0.2678],
         [0.3723],
         [0.2118],
         [0.2078],
         [0.3527],
         [0.2490],
         [0.3775],
         [0.2327],
         [0.2379]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 89%|████████▉ | 24/27 [00:09<00:01,  2.52it/s]

Streamflow at time 1 is 0.30345606803894043
tensor([[[0.4370],
         [0.7198],
         [0.5003],
         [0.2464],
         [0.2134],
         [0.1381],
         [0.0897],
         [0.1062],
         [0.2747],
         [0.3359],
         [0.1797],
         [0.6626],
         [0.4228],
         [0.5040],
         [0.1337],
         [0.2791],
         [0.2449],
         [0.3124],
         [0.1728],
         [0.1948],
         [0.3456],
         [0.2333],
         [0.3304],
         [0.2402],
         [0.2686]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 93%|█████████▎| 25/27 [00:10<00:00,  2.25it/s]

Streamflow at time 1 is 0.2809123992919922
tensor([[[0.3036],
         [0.6717],
         [0.3717],
         [0.2123],
         [0.1678],
         [0.1993],
         [0.0996],
         [0.1004],
         [0.2481],
         [0.3089],
         [0.1652],
         [0.7174],
         [0.3965],
         [0.4318],
         [0.1059],
         [0.3414],
         [0.2623],
         [0.2706],
         [0.1745],
         [0.1692],
         [0.2996],
         [0.2094],
         [0.3164],
         [0.2283],
         [0.2509]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


 96%|█████████▋| 26/27 [00:10<00:00,  2.26it/s]

Streamflow at time 1 is 0.29723691940307617
tensor([[[0.4226],
         [0.6864],
         [0.5107],
         [0.2266],
         [0.1712],
         [0.1534],
         [0.1108],
         [0.1007],
         [0.2667],
         [0.3587],
         [0.1818],
         [0.6629],
         [0.4213],
         [0.4786],
         [0.1063],
         [0.2639],
         [0.2386],
         [0.3495],
         [0.2010],
         [0.1840],
         [0.3035],
         [0.2221],
         [0.3306],
         [0.2329],
         [0.2462]]], device='cuda:0')
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
xnn (366, 25, 3) (366, 25)
(366, 25, 38)


                                               

Streamflow at time 1 is 0.3084976077079773
tensor([[[0.3882],
         [0.7290],
         [0.4345],
         [0.2205],
         [0.1868],
         [0.1320],
         [0.1322],
         [0.0896],
         [0.2534],
         [0.3461],
         [0.2386],
         [0.8004],
         [0.4384],
         [0.4888],
         [0.1065],
         [0.3317],
         [0.2284],
         [0.3269],
         [0.1609],
         [0.2006],
         [0.4092],
         [0.2311],
         [0.3442],
         [0.2383],
         [0.2562]]], device='cuda:0')




In [13]:
model.preds['HBV'].keys()

dict_keys(['flow_sim', 'srflow', 'ssflow', 'gwflow', 'AET_hydro', 'PET_hydro', 'flow_sim_no_rout', 'srflow_no_rout', 'ssflow_no_rout', 'gwflow_no_rout', 'recharge', 'excs', 'evapfactor', 'tosoil', 'percolation', 'BFI_sim'])

In [23]:
model.preds['HBV']['flow_sim'].squeeze()

tensor([0.3243, 0.6412, 0.3652, 0.2053, 0.1810, 0.1370, 0.0779, 0.1165, 0.2484,
        0.2881, 0.1966, 0.8492, 0.4286, 0.4866, 0.1089, 0.2628, 0.2306, 0.2687,
        0.1672, 0.1659, 0.3424, 0.1890, 0.2894, 0.2326, 0.3005],
       device='cuda:0')

In [62]:
"""
BMI wrapper for interfacing dPL hydrology models with NOAA OWP NextGen framework.
"""
# Need this to get external packages like conf.config.
import sys
package_path = '/data/lgl5139/hydro_multimodel/dPLHydro_multimodel'
sys.path.append(package_path)

import os
import logging
from pathlib import Path
from typing import Optional, Any, Dict, Union

import numpy as np
import yaml
from ruamel.yaml import YAML
import torch
import time

from bmipy import Bmi
from conf.config import Config
from models.model_handler import ModelHandler
from omegaconf import DictConfig, OmegaConf
from pydantic import ValidationError
from core.data import take_sample_test

log = logging.getLogger(__name__)



class BMIdPLHydroModel(Bmi):
    """
    Run forward with BMI for a trained differentiable hydrology model.
    """
    def __init__(self):
        """
        Create a dPLHydro model BMI ready for initialization.
        """
        super(BMIdPLHydroModel, self).__init__()
        start_time = time.time()

        self._model = None
        self._initialized = False

        self._start_time = 0.0
        self._values = {}
        self._nn_values = {}
        self._pm_values = {}
        self._end_time = np.finfo(float).max
        self.var_array_lengths = 1

        self.bmi_process_time = 0


        # Required, static attributes of the model
        _att_map = {
        'model_name':         "Differentiable Parameter Learning Hydrology BMI",
        'version':            '1.2',
        'author_name':        'MHPI',
        'time_units':         'days',
        }
        
        # Input forcing/attribute CSDMS Standard Names.
        self._input_var_names = [
            ############## Forcings ##############
            'atmosphere_water__liquid_equivalent_precipitation_rate',
            'land_surface_air__temperature',
            'land_surface_air__max_of_temperature',  # custom name
            'land_surface_air__min_of_temperature',  # custom name
            'land_surface_water__potential_evaporation_volume_flux',  # check name,
            ############## Attributes ##############
            'atmosphere_water__daily_mean_of_liquid_equivalent_precipitation_rate',
            'land_surface_water__daily_mean_of_potential_evaporation_flux',
            'p_seasonality',  # custom name
            'atmosphere_water__precipitation_falling_as_snow_fraction',
            'ratio__mean_potential_evapotranspiration__mean_precipitation',
            'atmosphere_water__frequency_of_high_precipitation_events',
            'atmosphere_water__mean_duration_of_high_precipitation_events',
            'atmosphere_water__precipitation_frequency',
            'atmosphere_water__low_precipitation_duration',
            'basin__mean_of_elevation',
            'basin__mean_of_slope',
            'basin__area',
            'land_vegetation__forest_area_fraction',
            'land_vegetation__max_monthly_mean_of_leaf-area_index',
            'land_vegetation__diff_max_min_monthly_mean_of_leaf-area_index',
            'land_vegetation__max_monthly_mean_of_green_vegetation_fraction',
            'land_vegetation__diff__max_min_monthly_mean_of_green_vegetation_fraction',
            'region_state_land~covered__area_fraction',  # custom name
            'region_state_land~covered__area',  # custom name
            'root__depth',  # custom name
            'soil_bedrock_top__depth__pelletier',
            'soil_bedrock_top__depth__statsgo',
            'soil__porosity',
            'soil__saturated_hydraulic_conductivity',
            'maximum_water_content',
            'soil_sand__volume_fraction',
            'soil_silt__volume_fraction', 
            'soil_clay__volume_fraction',
            'geol_1st_class',  # custom name
            'geol_1st_class__fraction',  # custom name
            'geol_2nd_class',  # custom name
            'geol_2nd_class__fraction',  # custom name
            'basin__carbonate_rocks_area_fraction',
            'soil_active-layer__porosity',  # check name
            'bedrock__permeability'
        ]

        # Output variable names (CSDMS standard names).
        # TODO: Find CSDMS names for the other ouput vars.
        self._output_var_names = [
            'land_surface_water__runoff_volume_flux',
            'srflow',
            'ssflow',
            'gwflow',
            'AET_hydro',
            'PET_hydro',
            'flow_sim_no_rout',
            'srflow_no_rout',
            'ssflow_no_rout',
            'gwflow_no_rout',
            'excs',
            'evapfactor',
            'tosoil',
            'percolation',
            'BFI_sim'
        ]

        # Map CSDMS Standard Names to the model's internal variable names (For CAMELS).
        self._var_name_units_map = {
            ############## Forcings ##############
            'atmosphere_water__liquid_equivalent_precipitation_rate':['prcp(mm/day)', 'mm d-1'],
            'land_surface_air__temperature':['tmean(C)','degC'],
            'land_surface_air__max_of_temperature':['tmax(C)', 'degC'],  # custom name
            'land_surface_air__min_of_temperature':['tmin(C)', 'degC'],  # custom name
            'land_surface_water__potential_evaporation_volume_flux':['PET_hargreaves(mm/day)', 'mm d-1'],  # check name
            ############## Attributes ##############
            'atmosphere_water__daily_mean_of_liquid_equivalent_precipitation_rate':['p_mean','mm d-1'],
            'land_surface_water__daily_mean_of_potential_evaporation_flux':['pet_mean','mm d-1'],
            'p_seasonality':['p_seasonality', '-'],  # custom name
            'atmosphere_water__precipitation_falling_as_snow_fraction':['frac_snow','-'],
            'ratio__mean_potential_evapotranspiration__mean_precipitation':['aridity','-'],
            'atmosphere_water__frequency_of_high_precipitation_events':['high_prec_freq','d yr-1'],
            'atmosphere_water__mean_duration_of_high_precipitation_events':['high_prec_dur','d'],
            'atmosphere_water__precipitation_frequency':['low_prec_freq','d yr-1'],
            'atmosphere_water__low_precipitation_duration':['low_prec_dur','d'],
            'basin__mean_of_elevation':['elev_mean','m'],
            'basin__mean_of_slope':['slope_mean','m km-1'],
            'basin__area':['area_gages2','km2'],
            'land_vegetation__forest_area_fraction':['frac_forest','-'],
            'land_vegetation__max_monthly_mean_of_leaf-area_index':['lai_max','-'],
            'land_vegetation__diff_max_min_monthly_mean_of_leaf-area_index':['lai_diff','-'],
            'land_vegetation__max_monthly_mean_of_green_vegetation_fraction':['gvf_max','-'],
            'land_vegetation__diff__max_min_monthly_mean_of_green_vegetation_fraction':['gvf_diff','-'],
            'region_state_land~covered__area_fraction':['dom_land_cover_frac', 'percent'],  # custom name
            'region_state_land~covered__area':['dom_land_cover', '-'],  # custom name
            'root__depth':['root_depth_50', '-'],  # custom name
            'soil_bedrock_top__depth__pelletier':['soil_depth_pelletier','m'],
            'soil_bedrock_top__depth__statsgo':['soil_depth_statsgo','m'],
            'soil__porosity':['soil_porosity','-'],
            'soil__saturated_hydraulic_conductivity':['soil_conductivity','cm hr-1'],
            'maximum_water_content':['max_water_content','m'],
            'soil_sand__volume_fraction':['sand_frac','percent'],
            'soil_silt__volume_fraction':['silt_frac','percent'], 
            'soil_clay__volume_fraction':['clay_frac','percent'],
            'geol_1st_class':['geol_1st_class', '-'],  # custom name
            'geol_1st_class__fraction':['glim_1st_class_frac', '-'],  # custom name
            'geol_2nd_class':['geol_2nd_class', '-'],  # custom name
            'geol_2nd_class__fraction':['glim_2nd_class_frac', '-'],  # custom name
            'basin__carbonate_rocks_area_fraction':['carbonate_rocks_frac','-'],
            'soil_active-layer__porosity':['geol_porosity', '-'],  # check name
            'bedrock__permeability':['geol_permeability','m2'],
            'drainage__area':['DRAIN_SQKM', 'km2'],  # custom name
            'land_surface__latitude':['lat','degrees'],
            ############## Outputs ##############
            # TODO: Find csdms names.
            'land_surface_water__runoff_volume_flux':['flow_sim','m3 s-1'],
            'srflow':['srflow','m3 s-1'],
            'ssflow':['ssflow','m3 s-1'],
            'gwflow':['gwflow','m3 s-1'],
            'AET_hydro':['AET_hydro','m3 s-1'],
            'PET_hydro':['PET_hydro','m3 s-1'],
            'flow_sim_no_rout':['flow_sim_no_rout','m3 s-1'],
            'srflow_no_rout':['srflow_no_rout','m3 s-1'],
            'ssflow_no_rout':['ssflow_no_rout','m3 s-1'],
            'gwflow_no_rout':['gwflow_no_rout','m3 s-1'],
            'excs':['excs','-'],
            'evapfactor':['evapfactor','-'],
            'tosoil':['tosoil','m3 s-1'],
            'percolation':['percolation','-'],
            'BFI_sim':['BFI_sim','-']
        }
        
        # Keep running total of BMI runtime.
        self.bmi_process_time += time.time() - start_time
    
    def initialize(self, bmi_cfg_filepath: Optional[str] = None) -> None:
        """
        (BMI Control function) Initialize the dPLHydro model.

        Parameters
        ----------
        bmi_cfg_filepath : str, optional
            Path to the BMI configuration file.
        """
        start_time = time.time()

        # Read in BMI configurations.
        self.initialize_config(bmi_cfg_filepath)
        
        # Make lookup tables (Peckham et al.).
        self._var_name_map_long_first = {
            long_name:self._var_name_units_map[long_name][0] for \
            long_name in self._var_name_units_map.keys()
            }
        self._var_name_map_short_first = {
            self._var_name_units_map[long_name][0]:long_name for \
            long_name in self._var_name_units_map.keys()}
        self._var_units_map = {
            long_name:self._var_name_units_map[long_name][1] for \
            long_name in self._var_name_units_map.keys()
        }

        # Initialize inputs and outputs.
        for var in self.config['observations']['var_t_nn'] + self.config['observations']['var_c_nn']:
            standard_name = self._var_name_map_short_first[var]
            self._nn_values[standard_name] = []
            # setattr(self, var, 0)

        for var in self.config['observations']['var_t_hydro_model'] + self.config['observations']['var_c_hydro_model']:
            standard_name = self._var_name_map_short_first[var]
            self._pm_values[standard_name] = []
            # setattr(self, var, 0)


        # Set a simulation start time and gettimestep size.
        self.current_time = self._start_time
        self._time_step_size = self.config['time_step_delta']

        # Load a trained model.
        self._model = ModelHandler(self.config).to(self.config['device'])
        self._initialized = True

        # Intialize dataset (NOTE: move this externally).
        # self._get_data_dict()

        # Keep running total of BMI runtime.
        self.bmi_process_time += time.time() - start_time




    def _get_data_dict(self) -> None:
        """
        Construct data dictionary from BMI input data.

        iS, iE: arrays of start and end pairs of basin indicies for batching.
        """
        dataset_dict, self.config = self._values_to_dict()



    def update_frac(self, time_frac: float) -> None:
        """
        Update model by a fraction of a time step.
        
        Parameters
        ----------
        time_frac : float
            Fraction fo a time step.
        """
        if self.verbose:
            print("Warning: This model is trained to make predictions on one day timesteps.")
        time_step = self.get_time_step()
        self._time_step_size = self._time_step_size * time_frac
        self.update()
        self._time_step_size = time_step

    def update_until(self, end_time: float) -> None:
        """
        (BMI Control function) Update model until a particular time.
        Note: Models should be trained standalone with dPLHydro_PMI first before forward predictions with this BMI.

        Parameters
        ----------
        end_time : float
            Time to run model until.
        """
        n_steps = (end_time - self.get_current_time()) / self.get_time_step()

        for _ in range(int(n_steps)):
            self.update()
        self.update_frac(n_steps - int(n_steps))

    def finalize(self) -> None:
        """
        (BMI Control function) Finalize model.
        """
        # TODO: Force destruction of ESMF and other objects when testing is done
        # to save space.

        self._model = None

    def array_to_tensor(self) -> None:
        """
        Converts input values into Torch tensor object to be read by model. 
        """  
        raise NotImplementedError("array_to_tensor")
    
    def tensor_to_array(self) -> None:
        """
        Converts model output Torch tensor into date + gradient arrays to be
        passed out of BMI for backpropagation, loss, optimizer tuning.
        """  
        raise NotImplementedError("tensor_to_array")
    
    def get_tensor_slice(self):
        """
        Get tensor of input data for a single timestep.
        """
        # sample_dict = take_sample_test(self.bmi_config, self.dataset_dict)
        # self.input_tensor = torch.Tensor()
    
        raise NotImplementedError("get_tensor_slice")

    # ------------------ Finished up to here ------------------
    # ---------------------------------------------------------
    def get_var_type(self, var_name):
        """
        Data type of variable.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.

        Returns
        -------
        str
            Data type.
        """
        return str(self.get_value_ptr(var_name).dtype)

    def get_var_units(self, var_name):
        """Get units of variable.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.

        Returns
        -------
        str
            Variable units.
        """
        return self._var_units[var_name]

    def get_var_nbytes(self, var_name):
        """Get units of variable.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.

        Returns
        -------
        int
            Size of data array in bytes.
        """
        return self.get_value_ptr(var_name).nbytes

    def get_var_itemsize(self, name):
        return np.dtype(self.get_var_type(name)).itemsize

    def get_var_location(self, name):
        return self._var_loc[name]

    def get_var_grid(self, var_name):
        """Grid id for a variable.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.

        Returns
        -------
        int
            Grid id.
        """
        # for grid_id, var_name_list in self._grids.items():
        #     if var_name in var_name_list:
        #         return grid_id
        raise NotImplementedError("get_var_grid")

    def get_grid_rank(self, grid_id):
        """Rank of grid.

        Parameters
        ----------
        grid_id : int
            Identifier of a grid.

        Returns
        -------
        int
            Rank of grid.
        """
        # return len(self._model.shape)
        raise NotImplementedError("get_grid_rank")


    def get_grid_size(self, grid_id):
        """Size of grid.

        Parameters
        ----------
        grid_id : int
            Identifier of a grid.

        Returns
        -------
        int
            Size of grid.
        """
        # return int(np.prod(self._model.shape))
        raise NotImplementedError("get_grid_size")


    def get_value_ptr(self, var_name: str, model:str) -> np.ndarray:
        """Reference to values.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.

        Returns
        -------
        array_like
            Value array.
        """
        if model == 'nn':
            if var_name not in self._nn_values.keys():
                raise ValueError(f"No known variable in BMI model: {var_name}")
            return self._nn_values[var_name]

        elif model == 'pm':
            if var_name not in self._pm_values.keys():
                raise ValueError(f"No known variable in BMI model: {var_name}")
            return self._pm_values[var_name]
        
        else:
            raise ValueError("Valid model type (nn or pm) must be specified.")

    def get_value(self, var_name, dest):
        """Copy of values.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.
        dest : ndarray
            A numpy array into which to place the values.

        Returns
        -------
        array_like
            Copy of values.
        """
        dest[:] = self.get_value_ptr(var_name).flatten()
        return dest

    def get_value_at_indices(self, var_name, dest, indices):
        """Get values at particular indices.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.
        dest : ndarray
            A numpy array into which to place the values.
        indices : array_like
            Array of indices.

        Returns
        -------
        array_like
            Values at indices.
        """
        dest[:] = self.get_value_ptr(var_name).take(indices)
        return dest

    def set_value(self, var_name, values: np.ndarray, model:str):
        """Set model values.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.
        values : array_like
            Array of new values.
        """
        if not isinstance(values, (np.ndarray, list, tuple)):
            values = np.array([values])

        val = self.get_value_ptr(var_name, model=model)

        # val = values.reshape(val.shape)
        val[:] = values

    def set_value_at_indices(self, name, inds, src):
        """Set model values at particular indices.

        Parameters
        ----------
        var_name : str
            Name of variable as CSDMS Standard Name.
        src : array_like
            Array of new values.
        indices : array_like
            Array of indices.
        """
        val = self.get_value_ptr(name)
        val.flat[inds] = src

    def get_component_name(self):
        """Name of the component."""
        return self._name

    def get_input_item_count(self):
        """Get names of input variables."""
        return len(self._input_var_names)

    def get_output_item_count(self):
        """Get names of output variables."""
        return len(self._output_var_names)

    def get_input_var_names(self):
        """Get names of input variables."""
        return self._input_var_names

    def get_output_var_names(self):
        """Get names of output variables."""
        return self._output_var_names

    def get_grid_shape(self, grid_id, shape):
        """Number of rows and columns of uniform rectilinear grid."""
        # var_name = self._grids[grid_id][0]
        # shape[:] = self.get_value_ptr(var_name).shape
        # return shape
        raise NotImplementedError("get_grid_shape")

    def get_grid_spacing(self, grid_id, spacing):
        """Spacing of rows and columns of uniform rectilinear grid."""
        # spacing[:] = self._model.spacing
        # return spacing
        raise NotImplementedError("get_grid_spacing")

    def get_grid_origin(self, grid_id, origin):
        """Origin of uniform rectilinear grid."""
        # origin[:] = self._model.origin
        # return origin
        raise NotImplementedError("get_grid_origin")

    def get_grid_type(self, grid_id):
        """Type of grid."""
        # return self._grid_type[grid_id]
        raise NotImplementedError("get_grid_type")

    def get_start_time(self):
        """Start time of model."""
        return self._start_time

    def get_end_time(self):
        """End time of model."""
        return self._end_time

    def get_current_time(self):
        return self._current_time

    def get_time_step(self):
        return self._time_step_size

    def get_time_units(self):
        return self._time_units

    def get_grid_edge_count(self, grid):
        raise NotImplementedError("get_grid_edge_count")

    def get_grid_edge_nodes(self, grid, edge_nodes):
        raise NotImplementedError("get_grid_edge_nodes")

    def get_grid_face_count(self, grid):
        raise NotImplementedError("get_grid_face_count")

    def get_grid_face_nodes(self, grid, face_nodes):
        raise NotImplementedError("get_grid_face_nodes")

    def get_grid_node_count(self, grid):
        """Number of grid nodes.

        Parameters
        ----------
        grid : int
            Identifier of a grid.

        Returns
        -------
        int
            Size of grid.
        """
        # return self.get_grid_size(grid)
        raise NotImplementedError("get_grid_node_count")

        

    def get_grid_nodes_per_face(self, grid, nodes_per_face):
        raise NotImplementedError("get_grid_nodes_per_face")

    def get_grid_face_edges(self, grid, face_edges):
        raise NotImplementedError("get_grid_face_edges")

    def get_grid_x(self, grid, x):
        raise NotImplementedError("get_grid_x")

    def get_grid_y(self, grid, y):
        raise NotImplementedError("get_grid_y")

    def get_grid_z(self, grid, z):
        raise NotImplementedError("get_grid_z")

    def initialize_config(self, config_path: str) -> Dict:
        """
        Check that config_path is valid path and convert config into a
        dictionary object.
        """
        config_path = Path(config_path).resolve()
        
        if not config_path:
            raise RuntimeError("No BMI configuration path provided.")
        elif not config_path.is_file():
            raise RuntimeError(f"BMI configuration not found at path {config_path}.")
        else:
            with config_path.open('r') as f:
                self.config = yaml.safe_load(f)
    

        # USE BELOW FOR HYDRA + OMEGACONF:
        # try:
        #     config_dict: Union[Dict[str, Any], Any] = OmegaConf.to_container(
        #         cfg, resolve=True
        #     )
        #     config = Config(**config_dict)
        # except ValidationError as e:
        #     log.exception(e)
        #     raise e
        # return config, config_dict

    def take_sample_test(self, config: Dict, dataset_dictionary: Dict[str, torch.Tensor], 
                        i_s: int, i_e: int) -> Dict[str, torch.Tensor]:
        """
        Take sample of data for testing batch.
        """
        dataset_sample = {}
        for key, value in dataset_dictionary.items():
            if value.ndim == 3:
                # TODO: I don't think we actually need this.
                # Remove the warmup period for all except airtemp_memory and hydro inputs.
                if key in ['airT_mem_temp_model', 'x_hydro_model', 'inputs_nn_scaled']:
                    warm_up = 0
                else:
                    warm_up = config['warm_up']
                dataset_sample[key] = value[warm_up:, i_s:i_e, :].to(config['device'])
            elif value.ndim == 2:
                dataset_sample[key] = value[i_s:i_e, :].to(config['device'])
            else:
                raise ValueError(f"Incorrect input dimensions. {key} array must have 2 or 3 dimensions.")
        return dataset_sample


    def _values_to_dict(self) -> None:
        """
        Take CSDMS Standard Name-mapped forcings + attributes and construct data
        dictionary for NN and physics model.
        """
        # n_basins = self.config['batch_basins']
        n_basins = 25
        rho = self.config['rho']

        # Initialize dict arrays
        x_nn = np.zeros((rho + 1, n_basins, len(self.config['observations']['var_t_nn'])))
        c_nn = np.zeros((rho + 1, n_basins, len(self.config['observations']['var_c_nn'])))
        x_hydro_model = np.zeros((rho + 1, n_basins, len(self.config['observations']['var_t_hydro_model'])))
        c_hydro_model = np.zeros((n_basins, len(self.config['observations']['var_c_hydro_model'])))

        for i, var in enumerate(self.config['observations']['var_t_nn']):
            standard_name = self._var_name_map_short_first[var]
            # NOTE: Using _values is a bit hacky. Should use get_values I think.    
            print("xnn", x_nn.shape, np.array([self._nn_values[standard_name]]).squeeze().shape)
            x_nn[:, :, i] = np.array([self._nn_values[standard_name]]).squeeze()
        
        for i, var in enumerate(self.config['observations']['var_c_nn']):
            standard_name = self._var_name_map_short_first[var]
            c_nn[:, :, i] = np.array([self._nn_values[standard_name]])

        for i, var in enumerate(self.config['observations']['var_t_hydro_model']):
            standard_name = self._var_name_map_short_first[var]
            x_hydro_model[:, :, i] = np.array([self._pm_values[standard_name]])

        for i, var in enumerate(self.config['observations']['var_c_hydro_model']):
            standard_name = self._var_name_map_short_first[var]
            c_hydro_model[:, i] = np.array([self._pm_values[standard_name]])
        
        self.dataset_dict = {
            'inputs_nn_scaled': np.concatenate((x_nn, c_nn), axis=2), #[np.newaxis,:,:],
            'x_hydro_model': x_hydro_model, #[np.newaxis,:,:],
            'c_hydro_model': c_hydro_model
        }
        print(self.dataset_dict['inputs_nn_scaled'].shape)

        # Convert to torch tensors:
        for key in self.dataset_dict.keys():
            if type(self.dataset_dict[key]) == np.ndarray:
                self.dataset_dict[key] = torch.from_numpy(self.dataset_dict[key]).float() #.to(self.config['device'])



    def update(self) -> None:
        """
        (BMI Control function) Advance model state by one time step.

        Note: Models should be trained standalone with dPLHydro_PMI first before forward predictions with this BMI.
        """
        start_time = time.time()

        self.current_time += self._time_step_size 
        
        self._values_to_dict()

        ngrid = self.dataset_dict['inputs_nn_scaled'].shape[1]
        iS = np.arange(0, ngrid, self.config['batch_basins'])
        iE = np.append(iS[1:], ngrid)
        self.dataset_sample = take_sample_test(self.config,
                                          self.dataset_dict,
                                          iS[0],
                                          iE[0]
                                          )

        # Predictions
        self.preds = self._model.forward(self.dataset_sample, eval=True)
        
        # Scale and check output.
        self.scale_output()

    def scale_output(self) -> None:
        """
        Scale and return more meaningful output from wrapped model.
        """
        models = self.config['hydro_models'][0]

        # TODO: still have to finish finding and undoing scaling applied before
        # model run. (See some checks used in bmi_lstm.py.)

        # Strip unnecessary time and variable dims. This gives 1D array of flow
        # at each basin.
        # TODO: setup properly for multiple models later.
        self.streamflow_cms = self.preds[models]['flow_sim'].squeeze()


In [None]:
model._model.forward(sample, eval=True)