# Imports

In [236]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Get the current working directory
current_dir = Path.cwd()

# Make sure code directory is in path,
# Add the parent directory of your project to the Python path
project_dir = str(current_dir.parent.parent.parent)
print(project_dir)
sys.path.append(project_dir)

from src.thn_run import (
    _load_cfg_and_ds,
    get_basin_interpolators,
    get_calibration_dataset,
)

from src.modelzoo_concept import get_concept_model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/ame805/torchHydroNodes


# Constants

In [237]:
cfg_file = "examples/config_run_calibrate_test.yml"
input_vars = ['prcp', 'tmean', 'dayl']
output_vars = ['obs_runoff']

HIDDEN_LAYERS = [64, 64, 64, 64, 64]

LR = 1e-3
EPOCHS = 50

# Classes and functions

In [238]:
class HydrologyPINN(nn.Module):
    """
    A PyTorch neural network model for hydrology-based physics-informed neural networks (PINNs).
    """

    def __init__(self, input_size, output_size, 
                 hidden_layers=HIDDEN_LAYERS, params_bounds=None):
        """
        Initialize the neural network with variable hidden layers.
        
        Args:
            input_size (int): The number of input features.
            output_size (int): The number of output features.
            hidden_layers (list of int): A list where each element represents the number of neurons in each hidden layer.
            params_bounds (list of tuples): A list of tuples where each tuple represents the lower and upper bounds for each parameter
        """
        super(HydrologyPINN, self).__init__()
        
        # Set up the neural network layers
        layers = [nn.Linear(input_size, hidden_layers[0]), nn.Tanh()]
        
        # Add hidden layers
        for i in range(1, len(hidden_layers)):
            layers.append(nn.Linear(hidden_layers[i-1], hidden_layers[i]))
            layers.append(nn.Tanh())
        
        # Add output layer
        layers.append(nn.Linear(hidden_layers[-1], output_size))
        self.network = nn.Sequential(*layers)
        
        # Now initialize parameters with the same order guaranteed
        self.params = nn.ParameterDict(OrderedDict({
            name: nn.Parameter(
                torch.Tensor(1).uniform_(bounds[0], bounds[1]), requires_grad=True
            ) for name, bounds in params_bounds.items()
        }))
        self.params_bounds = params_bounds  # Store bounds for clamping

    def forward(self, x):
        return self.network(x)

    def get_clamped_params(self):
        """
        Return parameters clamped within their bounds.
        """
        return {name: torch.clamp(param, *self.params_bounds[name]) for name, param in self.params.items()}

    
# Define the data loss and physics-based loss functions
def data_loss(predicted, observed):

    return nn.MSELoss()(predicted, observed)

def physics_loss(model, basin, predicted_params, observed):

    # Calculate terms based on your differential equations (requires details of ds/dt, etc.)
    simulated = model.run(basin, basin_params=predicted_params, use_grad=True)
    physics_penalty = simulated[-1] - observed

    return torch.mean(physics_penalty ** 2)

# Combine losses for the PINNs approach
def pinn_loss(predicted, observed, predicted_params, model, basin):

    return data_loss(predicted, observed) + physics_loss(model, basin, predicted_params, observed)


def plot_results(observed, predicted, epoch):
    plt.plot(observed, label='Observed', color='blue')
    plt.plot(predicted.detach().cpu().numpy(), label='Predicted', 
             color='red', linestyle='--')
    plt.legend()
    plt.xlabel('Time Step')
    plt.ylabel('Values')
    plt.title(f'Observed vs. Predicted (epoch {epoch})')
    plt.savefig('pins_simulation.png')
    plt.clf()

# # Initialize the plot for observed vs predicted
# def init_plot(observed):
#     plt.ion()
#     fig, ax = plt.subplots(figsize=(10, 5))
#     ax.plot(observed, label='Observed', color='blue')
#     pred_line, = ax.plot([], label='Predicted', color='orange')
#     ax.legend()
#     ax.set_xlabel('Time Step')
#     ax.set_ylabel('Values')
#     ax.set_title('Observed vs. Predicted')
#     return fig, ax, pred_line

# # Update the plot with new predictions
# def update_plot(pred_line, predicted, fig):
#     pred_line.set_ydata(predicted.detach().cpu().numpy())  # Update the y-data for predicted
#     pred_line.set_xdata(range(len(predicted)))  # Update the x-data in case of changes in length
#     plt.draw()
#     plt.pause(0.1)  # Pause briefly to update the plot
#     clear_output(wait=True)

# Load and prepare data

In [239]:
cfg, dataset = _load_cfg_and_ds(Path(project_dir) / cfg_file, model='conceptual')

# Get the basin interpolators
interpolators = get_basin_interpolators(dataset, cfg, project_dir)

-- Loading the config file and the dataset
-- Using device: cuda:0 --
Setting seed for reproducibility: 1111
-- Loading basin dynamics into xarray data set.
  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [00:00<00:00, 10.42it/s]


# Loop over basins

In [None]:
for basin in tqdm(dataset.basins, disable=cfg.disable_pbar, file=sys.stdout):
    # # Skip basins that have already been processed
    # if basin in processed_basins:
    #     print(f"Basin {basin} already processed, skipping.")
    #     continue

    ds_calib, time_idx0 = get_calibration_dataset(cfg, dataset, basin)
    
    time_idx0 = 0
    model_concept = get_concept_model(cfg, ds_calib, interpolators, time_idx0,
                                        dataset.scaler, odesmethod=cfg.odesmethod)
    
    # Get the input and output sizes
    input_size = len(input_vars)
    output_size = len(output_vars)
    
    # Get the parameter bounds
    params_bounds = OrderedDict(model_concept.cfg.params_bounds)
    
    # Crethe PINN model
    model_pinn = HydrologyPINN(input_size, output_size, 
                               hidden_layers=HIDDEN_LAYERS, params_bounds=params_bounds)
    optimizer = optim.Adam(model_pinn.parameters(), lr=LR)

    # DS to DF
    df_calib = ds_calib.to_dataframe()

    # Get the observed data
    observed_data = torch.tensor(df_calib[output_vars].values, dtype=torch.float32)

    # Train the model
    for epoch in range(EPOCHS):

        model_pinn.train()
        optimizer.zero_grad()

        # Forward pass: Simulate data through the network
        input_data = torch.tensor(df_calib[input_vars].values, dtype=torch.float32)

        predicted_data = model_pinn(input_data)

        # Update the plot with new predicted values
        plot_results(observed_data, predicted_data, epoch+1)

        # Clamp the parameters within their bounds
        predicted_params = model_pinn.get_clamped_params()

        # Extract the parameters from the model
        basin_params = [predicted_params[param] for param in predicted_params.keys()]

        # Calculate the loss
        loss = pinn_loss(predicted_data, observed_data, basin_params, model_concept, basin)

        # Backward pass: Compute the gradient of the loss with respect to model parameters
        loss.backward()
        optimizer.step()

        # Convert basin_params to list if it's a PyTorch tensor, then extract and round values
        params_list = (
            [round(param.item(), 4) for param in basin_params] if isinstance(basin_params, torch.Tensor) else
            [round(param.item(), 4) if isinstance(param, torch.Tensor) else round(param, 4) for param in basin_params]
        )

        # Print the result
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}, params: {params_list}")

Epoch 1/50, Loss: 302.20135498046875, params: [986.0253, 1380.3936, 0.2572, 2083.1235, 132.0493, 9.5271, 7.589, -5.8381]
Epoch 2/50, Loss: 266.2840881347656, params: [986.0253, 1380.3936, 0.2472, 2083.1335, 132.0393, 9.5171, 7.599, -5.8281]
Epoch 3/50, Loss: 300.8600769042969, params: [986.0253, 1380.3936, 0.2396, 2083.1262, 132.0318, 9.5244, 7.5916, -5.8354]
Epoch 4/50, Loss: 315.8741760253906, params: [986.0253, 1380.3936, 0.2338, 2083.1206, 132.026, 9.5302, 7.5859, -5.8412]
Epoch 5/50, Loss: 324.9161682128906, params: [986.0253, 1380.3936, 0.2292, 2083.116, 132.0213, 9.5349, 7.5812, -5.8459]
Epoch 6/50, Loss: 284.9862365722656, params: [986.0253, 1380.3936, 0.2254, 2083.1121, 132.0173, 9.5389, 7.5772, -5.8498]
Epoch 7/50, Loss: 331.5143127441406, params: [986.0253, 1380.3936, 0.2221, 2083.1086, 132.0139, 9.5423, 7.5738, -5.8532]
Epoch 8/50, Loss: 316.1024169921875, params: [986.0253, 1380.3936, 0.2192, 2083.1057, 132.0109, 9.5453, 7.5708, -5.8562]
Epoch 9/50, Loss: 315.001953125, pa