In [None]:
import datetime
import inspect
import time

import matplotlib.pyplot as plt
import torch
import yaml
from hy2dl.modelzoo import get_model
from hy2dl.modelzoo.baseconceptualmodel import BaseConceptualModel
from hy2dl.utils.config import Config
from hy2dl.utils.utils import set_random_seed
from utilities.data import DataHandler
from utilities.postprocessing import Postprocessor
from utilities.training import Trainer

### Part 1. Setup

In [None]:
# Load configuration
config = yaml.safe_load(open("files/camels_gb.yml"))

# Modify configuration
config["experiment_name"] = "CAMELS-GB_Hybrid-Custom-Model_NSE"
config["path_save_folder"] = "results/run_Hybrid"
config["model"] = "hybrid"
config["conceptual_model"] = "shm"
config["dynamic_parameterization_conceptual_model"] = [
    "dd", "f_thr", "sumax", "beta", "perc", "kf", "ki", "kb"
]

# Convert into 'Config' object
config = Config(config)
config.init_experiment()
config.dump()

### Part 2. Load data

In [None]:
# Get data
handler_data = DataHandler(config)
handler_data.load_data()

basin_ids = handler_data.get_basin_ids()

loader_training = handler_data.get_loader("training")
loader_validation = handler_data.get_loader("validation")
dataloaders = {
    "training": loader_training,
    "validation": loader_validation
}

### Part 3. Model

In [None]:
# Initialize model
set_random_seed(cfg=config)
model = get_model(config).to(config.device)

In [None]:
# Check conceptual model parameter ranges and types
print(model.conceptual_model.parameter_ranges)
print(model.conceptual_model.parameter_type)

In [None]:
class my_model(BaseConceptualModel):

    def __init__(self, cfg: Config):
        super(my_model, self).__init__()
        self.cfg = cfg
        self.n_conceptual_models = cfg.num_conceptual_models
        self.parameter_type = self._map_parameter_type(cfg=cfg)

    def forward(
        self,
        x_conceptual: dict[str, torch.Tensor],
        parameters: dict[str, torch.Tensor],
        initial_states: dict[str, torch.Tensor] | None = None,
    ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
        
        # initialize structures to store the information
        states, out = self._initialize_information(conceptual_inputs=x_conceptual)

        # initialize constants
        batch_size, seq_length = x_conceptual["precipitation"].shape
        device = x_conceptual["precipitation"].device

        if initial_states is None:  # if we did not specify initial states it takes the default values
            su = torch.full(
                (batch_size, self.n_conceptual_models),
                self._initial_states["su"],
                dtype=torch.float32,
                device=device,
            )

        else:  # we specify the initial states
            su = initial_states["su"]

        # run hydrological model for each time step
        for j in range(seq_length):
            # Broadcast tensor to consider multiple conceptual models running in parallel
            p = torch.tile(
                x_conceptual["precipitation"][:, j].unsqueeze(1),
                (1, self.n_conceptual_models),
            )
            et = torch.tile(x_conceptual["pet"][:, j].unsqueeze(1), (1, self.n_conceptual_models))

            # 1 bucket reservoir ------------------
            su = su + p  # [mm]
            ret = et * parameters["ET_aux"][:, j, :]  # [mm]
            su = torch.maximum(torch.tensor(0.0, requires_grad=True, dtype=torch.float32), su - ret)  # [mm]
            qi_out = su * parameters["ku"][:, j, :]  # [mm]
            su = su - qi_out  # [mm]

            # states
            states["su"][:, j, :] = su

            # discharge
            out[:, j, 0] = torch.mean(qi_out, dim=1)  # [mm]

        # last states
        final_states = self._get_final_states(states=states)

        return {
            "y_hat": out,
            "parameters": parameters,
            "internal_states": states,
            "final_states": final_states,
        }

    @property
    def _initial_states(self) -> dict[str, float]:
        return getattr(self, '_initial_states_values', {"su": 0.001})

    @_initial_states.setter
    def _initial_states(self, value: dict[str, float]):
        self._initial_states_values = value

    @property
    def parameter_ranges(self) -> dict[str, tuple[float, float]]:
        return getattr(self, "parameter_ranges_values", {"ku": (0.002, 1.0), "ET_aux": (0.0, 1.5)})
        
    @parameter_ranges.setter
    def parameter_ranges(self, value: dict[str, tuple[float, float]]):
        self.parameter_ranges_values = value

config = config._cfg
config["conceptual_model"] = "custom_model"
config["dynamic_parameterization_conceptual_model"] = ["ku", "ET_aux"]

config = Config(config)
config.dump()

model.conceptual_model = my_model(cfg=config)

In [None]:
# Check conceptual model parameter ranges and types
print(model.conceptual_model.parameter_ranges)
print(model.conceptual_model.parameter_type)

In [None]:
# Load model
model.load_state_dict(
    torch.load(config.path_save_folder / "model/model_epoch_04.pt")
)

### Part 4. Training

In [None]:
# Start trainer
handler_training = Trainer(config, dataloaders, model)

# Get list of learning rates
num_epochs = config.epochs

lrs = list(range(1, num_epochs + 1))
lrs = [config.learning_rate[max(k for k in config.learning_rate if k <= num)] for num in lrs]

In [None]:
# Start training and report
config.logger.info("Starting training")
config.logger.info(f"{'':^5} | {'':^8} | {'Trainining':^30} | {'Validation':^30} |")
config.logger.info(f"{'Epoch':^5} | {'LR':^8} | {'Loss':^8} | {'NSE':^8} | {'Time':^8} | {'Loss':^8} | {'NSE':^8} | {'Time':^8} |")

time_training = time.time()
for epoch in range(num_epochs):
    # Set learning rate
    handler_training.optimizer.update_optimizer_lr(epoch=(epoch + 1))

    # Train
    loss_train, nse_train, time_train = handler_training.run_epoch("training")
    if (epoch + 1) % config.validate_every != 0:
        config.logger.info(f"{epoch + 1:^5} | {lrs[epoch]:^8.1e} | {loss_train:^8.4f} | {nse_train:^8.4f} | {time_train:^8} | {'':^8} | {'':^8} | {'':^8} |")
        continue
    
    # Validate
    loss_val, nse_val, time_val = handler_training.run_epoch("validation")
    config.logger.info(f"{epoch + 1:^5} | {lrs[epoch]:^8.1e} | {loss_train:^8.4f} | {nse_train:^8.4f} | {time_train:^8} | {loss_val:^8.4f} | {nse_val:^8.4f} | {time_val:^8} |")

# Finish training
time_training = str(datetime.timedelta(seconds=int(time.time() - time_training)))
config.logger.info("Run completed successfully")
config.logger.info(f"Total run time: {time_training}")

### Part 5. Postprocess model

In [None]:
handler_postprocessing = Postprocessor(config)
results = handler_postprocessing.postprocess(model)

In [None]:
results

In [None]:
basin_id = "73014"

fig, ax = plt.subplots(figsize=(30, 5))
results.sel(basin=basin_id, last_n=365).y_obs.plot(ax=ax, label="Observed", color="tab:blue")
results.sel(basin=basin_id, last_n=365).y_hat.plot(ax=ax, label="Predicted", color="tab:orange")
plt.show()