# Imports

In [131]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim

# Get the current working directory
current_dir = Path.cwd()

# Make sure code directory is in path,
# Add the parent directory of your project to the Python path
project_dir = str(current_dir.parent.parent.parent)
print(project_dir)
sys.path.append(project_dir)

from src.thn_run import (
    _load_cfg_and_ds,
    get_basin_interpolators,
    get_calibration_dataset,
)

from src.modelzoo_concept import get_concept_model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/ame805/torchHydroNodes


# Constants

In [132]:
cfg_file = "examples/config_run_calibrate_test.yml"
input_vars = ['prcp', 'tmean', 'dayl']
output_vars = ['obs_runoff']

HIDDEN_LAYERS = [64, 64]

LR = 1e-3
EPOCHS = 50

# Classes and functions

In [133]:
class HydrologyPINN(nn.Module):
    """
    A PyTorch neural network model for hydrology-based physics-informed neural networks (PINNs).
    """

    def __init__(self, input_size, output_size, 
                 hidden_layers=HIDDEN_LAYERS, params_bounds=None):
        """
        Initialize the neural network with variable hidden layers.
        
        Args:
            input_size (int): The number of input features.
            output_size (int): The number of output features.
            hidden_layers (list of int): A list where each element represents the number of neurons in each hidden layer.
            params_bounds (list of tuples): A list of tuples where each tuple represents the lower and upper bounds for each parameter
        """
        super(HydrologyPINN, self).__init__()
        
        # Set up the neural network layers
        layers = [nn.Linear(input_size, hidden_layers[0]), nn.Tanh()]
        
        # Add hidden layers
        for i in range(1, len(hidden_layers)):
            layers.append(nn.Linear(hidden_layers[i-1], hidden_layers[i]))
            layers.append(nn.Tanh())
        
        # Add output layer
        layers.append(nn.Linear(hidden_layers[-1], output_size))
        self.network = nn.Sequential(*layers)
        
        # Now initialize parameters with the same order guaranteed
        self.params = nn.ParameterDict(OrderedDict({
            name: nn.Parameter(
                torch.Tensor(1).uniform_(bounds[0], bounds[1]), requires_grad=True
            ) for name, bounds in params_bounds.items()
        }))
        self.params_bounds = params_bounds  # Store bounds for clamping

    def forward(self, x):
        return self.network(x)

    def get_clamped_params(self):
        """
        Return parameters clamped within their bounds.
        """
        return {name: torch.clamp(param, *self.params_bounds[name]) for name, param in self.params.items()}

    
# Define the data loss and physics-based loss functions
def data_loss(predicted, observed):

    return nn.MSELoss()(predicted, observed)

def physics_loss(model, basin, predicted_params, observed):

    # Calculate terms based on your differential equations (requires details of ds/dt, etc.)
    simulated = model.run(basin, basin_params=predicted_params, use_grad=True)
    physics_penalty = simulated[-1] - observed

    return torch.mean(physics_penalty ** 2)

# Combine losses for the PINNs approach
def pinn_loss(predicted, observed, predicted_params, model, basin):

    return data_loss(predicted, observed) + physics_loss(model, basin, predicted_params, observed)

# Load and prepare data

In [134]:
cfg, dataset = _load_cfg_and_ds(Path(project_dir) / cfg_file, model='conceptual')

# Get the basin interpolators
interpolators = get_basin_interpolators(dataset, cfg, project_dir)

-- Loading the config file and the dataset
-- Using device: cuda:0 --
Setting seed for reproducibility: 1111
-- Loading basin dynamics into xarray data set.
100%|██████████| 4/4 [00:00<00:00, 12.09it/s]


# Loop over basins

In [None]:
for basin in tqdm(dataset.basins, disable=cfg.disable_pbar, file=sys.stdout):
    # # Skip basins that have already been processed
    # if basin in processed_basins:
    #     print(f"Basin {basin} already processed, skipping.")
    #     continue

    ds_calib, time_idx0 = get_calibration_dataset(cfg, dataset, basin)
    
    time_idx0 = 0
    model_concept = get_concept_model(cfg, ds_calib, interpolators, time_idx0,
                                        dataset.scaler, odesmethod=cfg.odesmethod)
    
    # Get the input and output sizes
    input_size = len(input_vars)
    output_size = len(output_vars)
    
    # Get the parameter bounds
    params_bounds = OrderedDict(model_concept.cfg.params_bounds)
    
    # Crethe PINN model
    model_pinn = HydrologyPINN(input_size, output_size, 
                               hidden_layers=HIDDEN_LAYERS, params_bounds=params_bounds)
    optimizer = optim.Adam(model_pinn.parameters(), lr=LR)

    # DS to DF
    df_calib = ds_calib.to_dataframe()

    # Train the model
    for epoch in range(EPOCHS):

        model_pinn.train()
        optimizer.zero_grad()

        # Forward pass: Simulate data through the network
        input_data = torch.tensor(df_calib[input_vars].values, dtype=torch.float32)
        observed_data = torch.tensor(df_calib[output_vars].values, dtype=torch.float32)

        predicted_data = model_pinn(input_data)
        predicted_params = model_pinn.get_clamped_params()

        basin_params = [predicted_params[param] for param in predicted_params.keys()]

        # Calculate the loss
        loss = pinn_loss(predicted_data, observed_data, basin_params, model_concept, basin)

        # Backward pass: Compute the gradient of the loss with respect to model parameters
        loss.backward()
        optimizer.step()

        # Print the loss
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}")

Epoch 1/50, Loss: 969.2459106445312


  


Epoch 2/50, Loss: 968.2750854492188


  


Epoch 3/50, Loss: 967.3612670898438


  


Epoch 4/50, Loss: 966.5269775390625


  


Epoch 5/50, Loss: 965.4683837890625


  


Epoch 6/50, Loss: 964.280029296875


  


Epoch 7/50, Loss: 963.1788330078125


  


Epoch 8/50, Loss: 962.1548461914062


  


Epoch 9/50, Loss: 961.211181640625


  


Epoch 10/50, Loss: 960.3460693359375


  


Epoch 11/50, Loss: 959.5568237304688


  


Epoch 12/50, Loss: 958.8379516601562


  


Epoch 13/50, Loss: 958.1835327148438


  
