# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import pandas as pd
import random

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Get the current working directory
current_dir = Path.cwd()

# Make sure code directory is in path,
# Add the parent directory of your project to the Python path
project_dir = str(current_dir.parent.parent.parent)
print(project_dir)
sys.path.append(project_dir)

from src.thn_run import (
    _load_cfg_and_ds,
    get_basin_interpolators,
    get_calibration_dataset,
)

from src.modelzoo_concept import get_concept_model

/home/ame805/torchHydroNodes


# Constants

In [2]:
cfg_file = "examples/config_run_calibrate_test.yml"
input_vars = ['prcp', 'tmean', 'dayl']
output_vars = ['obs_runoff']

# HIDDEN_LAYERS = [64, 64, 64, 64, 64]
# HIDDEN_LAYERS = [32, 32, 32, 32, 32]
HIDDEN_LAYERS = [128, 128, 128, 128, 128]

# LR = 1e-3
LR = {
    "initial": 0.01,
    "decay": 0.5,
    "decay_step_fraction": 2,
}

EPOCHS = 400
COLL_FRACT = 0.6 # fraction of time series to use as collocaiton points

# Classes and functions

In [3]:
class HydrologyPINN(nn.Module):
    """
    A PyTorch neural network model for hydrology-based physics-informed neural networks (PINNs).
    """

    def __init__(self, input_size, output_size, 
                 hidden_layers=HIDDEN_LAYERS, 
                 params_bounds=None, scaler=None):
        """
        Initialize the neural network with variable hidden layers.
        
        Args:
            input_size (int): The number of input features.
            output_size (int): The number of output features.
            hidden_layers (list of int): A list where each element represents the number of neurons in each hidden layer.
            params_bounds (list of tuples): A list of tuples where each tuple represents the lower and upper bounds for each parameter
        """
        super(HydrologyPINN, self).__init__()

        # Save the scaler
        self.scaler = scaler
        
        # Set up the neural network layers
        layers = [nn.Linear(input_size, hidden_layers[0]), nn.Tanh()]
        
        # Add hidden layers
        for i in range(1, len(hidden_layers)):
            layers.append(nn.Linear(hidden_layers[i-1], hidden_layers[i]))
            layers.append(nn.Tanh())
        
        # Add output layer
        layers.append(nn.Linear(hidden_layers[-1], output_size))
        self.network = nn.Sequential(*layers)
        
        # Now initialize parameters with the same order guaranteed
        self.params = nn.ParameterDict(OrderedDict({
            name: nn.Parameter(
                torch.Tensor(1).uniform_(bounds[0], bounds[1]), requires_grad=True
            ) for name, bounds in params_bounds.items()
        }))
        self.params_bounds = params_bounds  # Store bounds for clamping

    def forward(self, x):
        # Scale the input if a scaler is provided
        if self.scaler is not None:
            x = torch.tensor(self.scaler.transform(x.cpu().numpy()), dtype=x.dtype, device=x.device)
        return self.network(x)

    def get_clamped_params(self):
        """
        Return parameters clamped within their bounds.
        """
        return {name: torch.clamp(param, *self.params_bounds[name]) for name, param in self.params.items()}
 
def setup_optimizer_and_scheduler(model, lr=LR, epochs=EPOCHS):
    """
    Set up the optimizer and scheduler for the model.
    
    Args:
        model (nn.Module): The PyTorch model to optimize.
        lr (float or dict): The initial learning rate for the optimizer or a dictionary with scheduler parameters.
        epochs (int): Total number of training epochs, used to set the scheduler's step size.
    
    Returns:
        optimizer (torch.optim.Optimizer): The optimizer for the model.
        scheduler (torch.optim.lr_scheduler._LRScheduler or None): The learning rate scheduler, if applicable.
    """
    
    # Check if lr is a scalar (float or int)
    if isinstance(lr, (float, int)):
        # Optimizer with scalar learning rate
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = None  # No scheduler for scalar lr

    else:
        # Optimizer with initial learning rate from lr dictionary
        optimizer = optim.Adam(model.parameters(), lr=lr["initial"])

        # Scheduler with decay based on epochs and decay_step_fraction
        step_size = max(1, epochs // lr["decay_step_fraction"])  # Ensure step_size is at least 1
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=lr["decay"])

    return optimizer, scheduler

# Define the data loss and physics-based loss functions
def data_loss(predicted, observed):

    return nn.MSELoss()(predicted, observed)

def physics_loss(model, basin, predicted_params, observed, collocation_indices):
    """
    Calculate the physics-based loss on collocation points.
    
    Args:
        model: The ODE model for simulation.
        basin: The current basin being processed.
        predicted_params: Parameters predicted by the PINN.
        observed (torch.Tensor): Observed data tensor.
        collocation_indices (torch.Tensor): Indices for collocation points.
    
    Returns:
        torch.Tensor: The calculated physics loss.
    """
    # Run the model simulation and get the full time series
    simulated = model.run(basin, basin_params=predicted_params, use_grad=True)
    runoff_sim = simulated[-1]  # Extract the runoff component from the model output
    
    # Subset observed and simulated data using collocation indices
    observed_colloc = torch.index_select(observed, 0, collocation_indices)
    runoff_sim_colloc = torch.index_select(runoff_sim, 0, collocation_indices)

    # Calculate the physics-based penalty for the collocation points
    physics_penalty = runoff_sim_colloc - observed_colloc

    return torch.mean(physics_penalty ** 2), runoff_sim

def pinn_loss(predicted, observed, predicted_params, model, basin, epoch, collocation_indices):
    dataloss = data_loss(predicted, observed)
    physicsloss, simulated_ode = physics_loss(model, basin, predicted_params, observed, collocation_indices)
    icloss = nn.MSELoss()(predicted[0], observed[0])

    print(f'DataLoss: {dataloss:.3e}, PhysicsLoss: {physicsloss:.3e}, IC Loss: {icloss:.3e}')
    plot_results(observed, predicted, simulated_ode, basin, epoch, period='train') 

    return dataloss + physicsloss + icloss

def plot_results(observed, predicted, simulated_ode, basin, epoch, period='train'):

    nse_nn  = nse_loss(predicted, observed)
    nse_ode = nse_loss(simulated_ode, observed)

    fig, ax = plt.subplots(figsize=(12, 6))

    plt.plot(observed, label='Observed', color='blue')
    plt.plot(predicted.detach().cpu().numpy(), label=f'Predicted_NN  (NSE: {nse_nn:.3f})',
             color='red', linestyle='--')
    # plt.plot(simulated_ode.detach().cpu().numpy(), label=f'Simulated_ODE (NSE: {nse_ode:.3f})',
    #          color='green', linestyle=':')
    plt.scatter(range(len(simulated_ode)), simulated_ode.detach().cpu().numpy(), 
                    label=f'Simulated_ODE (NSE: {nse_ode:.3f})', color='green', 
                    marker='o', s=10)
    
    # Plot a marker on the first value of simulated ODE
    plt.scatter(0, simulated_ode.detach().cpu().numpy()[0], color='k', marker='X', s=50)
    
    plt.legend()
    plt.xlabel('Time Step')
    plt.ylabel('Values')
    plt.title(f'Observed vs. Predicted | {period} | epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'{basin}_pins_simulation.png')
    plt.clf()

def nse_loss(predicted, observed):

    return 1 - torch.sum((predicted - observed) ** 2) / (torch.sum((observed - torch.mean(observed)) ** 2) + torch.finfo(torch.float32).eps)

def get_collocation_indices(total_length, coll_fraction=0.6, seed=42):
    """
    Generate random indices for collocation points based on a fraction of the total time series,
    ensuring that the points are distributed over the entire domain.
    
    Args:
        total_length (int): Total length of the time series.
        coll_fraction (float): Fraction of time series to use as collocation points.
        seed (int): Random seed for reproducibility.
        
    Returns:
        torch.Tensor: Tensor containing the random indices for collocation points.
    """
    # Set the random seed for reproducibility
    random.seed(seed)
    
    # Calculate the number of collocation points based on the fraction
    num_collocation_points = int(total_length * coll_fraction)
    
    # Randomly sample indices across the full range of the time series
    collocation_indices = random.sample(range(total_length), num_collocation_points)
    
    # Sort indices to maintain order in time (optional)
    collocation_indices.sort()
    
    print(f'Number of collocation points: {num_collocation_points}')
    print(f'Max index: {max(collocation_indices)}')
    print(f'Min index: {min(collocation_indices)}')

    return torch.tensor(collocation_indices, dtype=torch.long)

##################################################



# Load and prepare data

In [4]:
cfg, dataset = _load_cfg_and_ds(Path(project_dir) / cfg_file, model='conceptual')

# Get the basin interpolators
interpolators = get_basin_interpolators(dataset, cfg, project_dir)

-- Loading the config file and the dataset
-- Using device: cuda:0 --
Setting seed for reproducibility: 1111
-- Loading basin dynamics into xarray data set.
100%|██████████| 4/4 [00:00<00:00, 11.92it/s]


# Loop over basins

In [None]:
for basin in tqdm(dataset.basins, disable=cfg.disable_pbar, file=sys.stdout):
    # # Skip basins that have already been processed
    # if basin in processed_basins:
    #     print(f"Basin {basin} already processed, skipping.")
    #     continue

    ds_calib, time_idx0 = get_calibration_dataset(cfg, dataset, basin)
    
    time_idx0 = 0
    model_concept = get_concept_model(cfg, ds_calib, interpolators, time_idx0,
                                        dataset.scaler, odesmethod=cfg.odesmethod)
    
    # Get the input and output sizes
    input_size = len(input_vars)
    output_size = len(output_vars)
    
    # Get the parameter bounds
    params_bounds = OrderedDict(model_concept.cfg.params_bounds)

    # DS to DF
    df_calib = ds_calib.to_dataframe()

    # Get the input and output data
    input_data = torch.tensor(df_calib[input_vars].values, dtype=torch.float32)
    observed_data = torch.tensor(df_calib[output_vars].values, dtype=torch.float32)
    # Scale the input data
    scaler = StandardScaler().fit(input_data)
    # input_data = scaler.fit_transform(input_data)
    # # To tensor
    # input_data = torch.tensor(input_data, dtype=cfg.precision['torch'])

    # Crethe PINN model
    model_pinn = HydrologyPINN(input_size, output_size, 
                               hidden_layers=HIDDEN_LAYERS, 
                               params_bounds=params_bounds,
                               scaler=scaler)
    
    # optimizer = optim.Adam(model_pinn.parameters(), lr=LR)
    optimizer, scheduler = setup_optimizer_and_scheduler(model_pinn, lr=LR, epochs=EPOCHS)

   # Generate collocation indices based on the length of the time series for this basin
    total_length = len(ds_calib.obs_runoff)
    print(f"Total length of time series: {total_length}")
    collocation_indices = get_collocation_indices(total_length, COLL_FRACT, 
                                                  seed=model_concept.cfg.seed)

    # Train the model
    for epoch in range(EPOCHS):

        model_pinn.train()
        optimizer.zero_grad()

        # Forward pass: Simulate data through the network
        predicted_data = model_pinn(input_data)

        # # Update the plot with new predicted values
        # plot_results(observed_data, predicted_data, basin, epoch+1)

        # Clamp the parameters within their bounds
        predicted_params = model_pinn.get_clamped_params()

        # Extract the parameters from the model
        basin_params = [predicted_params[param] for param in predicted_params.keys()]

        # Calculate the loss
        loss = pinn_loss(predicted_data, observed_data, basin_params, 
                         model_concept, basin, epoch+1, collocation_indices)

        # Backward pass: Compute the gradient of the loss with respect to model parameters
        loss.backward()
        optimizer.step()

        # Update the learning rate if a scheduler is provided
        if scheduler is not None:
            scheduler.step()
            
        # Convert basin_params to list if it's a PyTorch tensor, then extract and round values
        params_list = (
            [round(param.item(), 4) for param in basin_params] if isinstance(basin_params, torch.Tensor) else
            [round(param.item(), 4) if isinstance(param, torch.Tensor) else round(param, 4) for param in basin_params]
        )

        # Print the result
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}, params: {params_list}")

    # Save the model
    torch.save(model_pinn.state_dict(), f"{basin}_pinn_model.pth")
    print(f"Model saved for basin {basin}")

    # Save the scaler
    torch.save(scaler, f"{basin}_scaler.pth")
    print(f"Scaler saved for basin {basin}")

    # Save the parameters as csv
    params_df = pd.DataFrame(params_list, columns=['value'])
    params_df.to_csv(f"{basin}_params.csv", index=False)

Total length of time series: 1827
Number of collocation points: 1096
Max index: 1826
Min index: 0
DataLoss: 6.906e+00, PhysicsLoss: 2.436e+01, IC Loss: 1.852e-04
Epoch 1/400, Loss: 31.263368606567383, params: [421.1609, 1217.3712, 0.2192, 1329.0988, 168.825, 1.8139, 4.6648, -5.4678]
DataLoss: 4.111e+00, PhysicsLoss: 2.876e+01, IC Loss: 1.904e+00
Epoch 2/400, Loss: 34.77211380004883, params: [421.1609, 1217.3712, 0.2292, 1329.0887, 168.835, 1.8239, 4.6548, -5.4778]
DataLoss: 1.092e+01, PhysicsLoss: 2.387e+01, IC Loss: 3.076e-01
Epoch 3/400, Loss: 35.102359771728516, params: [421.1609, 1217.3712, 0.2219, 1329.0808, 168.842, 1.8217, 4.6546, -5.486]
DataLoss: 4.538e+00, PhysicsLoss: 2.350e+01, IC Loss: 3.862e+00
