# Train a hybrid model - MLP/LSTM

In this tutorial, we demonstrate the steps to train a hybrid model using MLP or LSTM as the Neural Network component. The training data is derived from the pre-processed outputs generated after running the conceptual model. The Hybrid Model is trained to predict $Q$ (runoff) by matching the observed data.

**IMPORTANT**: While this tutorial uses a file listing 2 basins, it does not demonstrate a multi-basin model. Instead, a single-basin model is run for each basin individually.

**Before we start**

- This tutorial is rendered from a Jupyter notebook that is hosted on GitHub. If you'd like to run the code yourself, you can access the notebook and configuration files directly from the repository: [03-TrainHybridModel](https://github.com/jpcurbelo/torchHydroNodes/tree/master/tutorials/03-TrainHybridModel).

- To run this notebook locally, ensure you have completed the setup steps outlined in [Getting started](https://torchhydronodes.readthedocs.io/en/latest/usage/getting_started.html). These steps include setting up the environment, installing the required packages, and preparing the data files necessary for the tutorial.

- **Dependency on a Previous Tutorial**: Before running this tutorial, you must complete the [01-RunConceptModel Tutorial](https://torchhydronodes.readthedocs.io/en/latest/tutorials/run-concept-model.html). After completing it:

    1- Move the generated run folder to ``src/data``.

    2- Update the ``data_dir`` field in the ``config_run_train_mlp.yml``  ([here](https://github.com/jpcurbelo/torchHydroNodes/blob/master/tutorials/03-TrainHybridModel/config_run_train_mlp.yml)) (or ``config_run_train_lstm.yml``  ([here](https://github.com/jpcurbelo/torchHydroNodes/blob/master/tutorials/03-TrainHybridModel/config_run_train_lstm.yml))) file to point to this folder.

# Import packages

In [9]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path

# Dynamically set the project directory based on the notebook's location
notebook_dir = Path().resolve()
project_dir = str(notebook_dir.parent.parent)  # Adjust based on your project structure
sys.path.append(project_dir)

import os
import yaml

from src.thn_run import (
    _load_cfg_and_ds,
    get_basin_interpolators
)

from src.modelzoo_concept import get_concept_model
from src.modelzoo_nn import (
    get_nn_model,
    get_nn_pretrainer,
)
from src.modelzoo_hybrid import (
    get_hybrid_model,
    get_trainer,
)

# Constants

Feel free to run and explore both *nn_type = 'mlp'* and *nn_type = 'lstm'*

In [10]:
nn_type = 'mlp'  
# nn_type = 'lstm'

config_file = f'config_run_train_{nn_type}.yml'

# Load main config file

This step is essential when running multiple single-basin models. Refer to *src/scripts_paper/run_hybrid_trainer_single_all_mlp.py* for the implementation. A parallelized version of the code demonstrated in this tutorial is also available for more efficient execution.

In [11]:
# Load the MAIN configuration file
if Path(config_file).exists():
    with open(config_file, 'r') as f:
        cfg = yaml.safe_load(f)
else:
    raise FileNotFoundError(f'Configuration file {config_file} not found!')
    
# Read basin list
with open(cfg['basin_file'], 'r') as f:
    all_basins = [basin.strip() for basin in f.readlines()]

print(all_basins)

['01013500', '01022500', '01030500', '06431500']


# Train hybrid model for each basin

In [None]:
for basin in all_basins:

    # Temporary basin configuration file
    basin_file = f'temp_basin_{basin}_{nn_type}.txt'
    with open(basin_file, 'w') as f:
        f.write(basin)

    # Update the basin configuration file
    cfg['basin_file'] = basin_file

    # Create temporary configuration file config_file_temp_basin.yml
    config_file_temp = str(config_file).split('.')[0] + f'_temp_{nn_type}_{basin}.yml'
    with open(config_file_temp, 'w') as f:
        yaml.dump(cfg, f)

    # Load the configuration file and dataset
    cfg_run, dataset = _load_cfg_and_ds(
        Path(config_file_temp), model='hybrid')

    # Delete the basin_file and config_file_temp after training
    if os.path.isfile(basin_file):
        os.remove(basin_file)
    if os.path.isfile(config_file_temp):
        os.remove(config_file_temp)

    # Get the basin interpolators
    interpolators = get_basin_interpolators(dataset, cfg_run, project_dir)

    # Conceptual model
    time_idx0 = 0
    model_concept = get_concept_model(cfg_run, dataset.ds_train, 
                                      interpolators, time_idx0, 
                                      dataset.scaler)

    # Neural network model
    model_nn = get_nn_model(model_concept, dataset.ds_static)

    # Pretrainer
    pretrainer = get_nn_pretrainer(model_nn, dataset)

    # Pretrain the model
    pretrain_ok = pretrainer.train(loss=cfg_run.loss_pretrain, 
                                lr=cfg_run.lr_pretrain, 
                                epochs=cfg_run.epochs_pretrain,
                                disable_pbar=False,
                                any_log=False
    )
    
    # Train the hybrid model
    if pretrain_ok:
        # Build the hybrid model
        model_hybrid = get_hybrid_model(cfg_run, pretrainer, dataset)

        # Build the trainer 
        trainer = get_trainer(model_hybrid)

        # Train the model 
        trainer.train()
    else:
        print(f'Pretraining failed for basin {basin}')  


-- Loading the config file and the dataset
-- Using device: cpu --
Setting seed for reproducibility: 111
cfg.nn_model_dir is not defined - parameters MUST be defined in the config file
-- Loading basin dynamics into xarray data set.
  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 22.59it/s]
------------------------------------------------------------
-- Pretraining the neural network model -- (cpu)
------------------------------------------------------------
# Epoch 00001: 100%|██████████| 29/29 [00:00<00:00, 120.84it/s, Loss=1.0589e+00]
* Plotting basin 01013500: 100%|██████████| 1/1 [00:00<00:00,  1.18it/s]
# Epoch 00002: 100%|██████████| 29/29 [00:00<00:00, 152.34it/s, Loss=3.4281e-01]
# Epoch 00003: 100%|██████████| 29/29 [00:00<00:00, 143.18it/s, Loss=1.8930e-01]
# Epoch 00004: 100%|██████████| 29/29 [00:00<00:00, 160.47it/s, Loss=1.3642e-01]
# Epoch 00005: 100%|██████████| 29/29 [00:00<00:00, 167.77it/s, Loss=1.1274e-01]
# Epoch 00006: 100%|██████████| 29/29 [00:00<00:00, 170.74it/s, Loss=8.4502e-02]
# Epoch 00007: 100%|██████████| 29/29 [00:00<00:00, 177.40it/s, Loss=8.5674e-02]
# Epoch 00008: 100%|██████████| 29/29 [00:00<00:00, 172.84it/s, Loss=4.5157e-02]
# Epoch 00009: 100%|██████████| 29/29 [00:00<00:00, 166.84it/s

KeyboardInterrupt: 

You might want to explore the methods *evaluate* and *save_plots* in the class *BaseHybridModelTrainer* (*src/modelzoo_hybrid/basetrainer*)