# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import pandas as pd
import random

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Get the current working directory
current_dir = Path.cwd()

# Make sure code directory is in path,
# Add the parent directory of your project to the Python path
project_dir = str(current_dir.parent.parent.parent)
print(project_dir)
sys.path.append(project_dir)

from src.thn_run import (
    _load_cfg_and_ds,
    get_basin_interpolators,
    get_calibration_dataset,
)

from src.modelzoo_concept import get_concept_model

from utils_hydroPINNs import *

/home/ame805/torchHydroNodes


# Constants

In [None]:
cfg_file = "examples/config_run_calibrate_test.yml"
input_vars = ['prcp', 'tmean', 'dayl']
output_vars = ['obs_runoff']

HIDDEN_LAYERS = [64, 64, 64, 64, 64]
# HIDDEN_LAYERS = [32, 32, 32, 32, 32]
# HIDDEN_LAYERS = [128, 128, 128, 128, 128]

# LR = 1e-3
LR = {
    "initial": 0.001,            # Default learning rate for other parameters
    # "params_lr": 0.1,            # Higher learning rate for specific parameters in model.params
    "decay": 0.5,                 # Learning rate decay factor
    "decay_step_fraction": 4     # Decay step fraction (used to calculate step_size)
}

EPOCHS = 100000
COLL_FRACT = 0.6 # fraction of time series to use as collocaiton points

# Load and prepare data

In [3]:
cfg, dataset = _load_cfg_and_ds(Path(project_dir) / cfg_file, model='conceptual')

# Get the basin interpolators
interpolators = get_basin_interpolators(dataset, cfg, project_dir)

-- Loading the config file and the dataset
-- Using device: cuda:0 --
Setting seed for reproducibility: 1111
-- Loading basin dynamics into xarray data set.
100%|██████████| 2/2 [00:00<00:00, 12.04it/s]




# Loop over basins

In [4]:
for basin in tqdm(dataset.basins, disable=cfg.disable_pbar, file=sys.stdout):
    # # Skip basins that have already been processed
    # if basin in processed_basins:
    #     print(f"Basin {basin} already processed, skipping.")
    #     continue

    ds_calib, time_idx0 = get_calibration_dataset(cfg, dataset, basin)
    
    time_idx0 = 0
    model_concept = get_concept_model(cfg, ds_calib, interpolators, time_idx0,
                                        dataset.scaler, odesmethod=cfg.odesmethod)
    
    # Get the input and output sizes
    input_size = len(input_vars)
    output_size = len(output_vars)
    
    # Get the parameter bounds
    params_bounds = OrderedDict(model_concept.cfg.params_bounds)

    # DS to DF
    df_calib = ds_calib.to_dataframe()

    # Get the input and output data
    input_data = torch.tensor(df_calib[input_vars].values, dtype=torch.float32)
    observed_data = torch.tensor(df_calib[output_vars].values, dtype=torch.float32)
    # Scale the input data
    scaler = StandardScaler().fit(input_data)
    # input_data = scaler.fit_transform(input_data)
    # # To tensor
    # input_data = torch.tensor(input_data, dtype=cfg.precision['torch'])

    # Crethe PINN model
    # model_pinn = HydrologyPINN_v0(input_size, output_size, 
    #                            hidden_layers=HIDDEN_LAYERS, 
    #                            params_bounds=params_bounds,
    #                            scaler=scaler)
    model_pinn = HydrologyPINN_v1(input_size, hidden_layers=HIDDEN_LAYERS, 
                               params_bounds=params_bounds,
                               scaler=scaler)
    
    # optimizer = optim.Adam(model_pinn.parameters(), lr=LR)
    optimizer, scheduler = setup_optimizer_and_scheduler(model_pinn, lr=LR, epochs=EPOCHS)

   # Generate collocation indices based on the length of the time series for this basin
    total_length = len(ds_calib.obs_runoff)
    collocation_indices = get_collocation_indices(total_length, COLL_FRACT, 
                                                  seed=model_concept.cfg.seed)

    # Train the model
    for epoch in range(EPOCHS):

        model_pinn.train()
        optimizer.zero_grad()

        # Forward pass: Simulate data through the network
        predicted_data = model_pinn(input_data)

        # # Update the plot with new predicted values
        # plot_results(observed_data, predicted_data, basin, epoch+1)

        # Clamp the parameters within their bounds
        predicted_params = model_pinn.get_clamped_params()

        # print(f"Predicted parameters: {predicted_params}")

        # Extract the parameters from the model
        basin_params = [predicted_params[param] for param in predicted_params.keys()]

        # Calculate the loss
        loss = pinn_loss(predicted_data, observed_data, basin_params, 
                         model_concept, basin, epoch+1, collocation_indices)

        # Backward pass: Compute the gradient of the loss with respect to model parameters
        loss.backward()

        # # for param in model_pinn.parameters():
        # #     if param.grad is not None:
        # #         param.grad *= 1e3  # Scale the gradient by 10 (adjust as needed)

        # # Check gradients
        # print("Gradients for q_bucket:", predicted_data[-1].grad)
        if predicted_data[1].grad is None:
            print("Gradients for S1:", predicted_data[1].grad)

        # Clip gradients to stabilize training
        torch.nn.utils.clip_grad_norm_(model_pinn.parameters(), max_norm=5.0)

        # # Check gradients for the specific parameters in model_pinn.params
        # for name, param in model_pinn.params.items():
        #     print(f"Gradient for {name}: {param.grad}, requires_grad: {param.requires_grad}")

        # # Check if weights are updated
        # for name, param in model_pinn.named_parameters():
        #     if param.requires_grad:
        #         print(f"Gradient for {name}: {param.grad}")
        # aux = input("Press Enter to continue...")

        # Step the optimizer to update parameters
        optimizer.step()

        # Update the learning rate if a scheduler is provided
        if scheduler is not None:
            scheduler.step()
            
        # Convert basin_params to list if it's a PyTorch tensor, then extract and round values
        params_list = (
            [round(param.item(), 4) for param in basin_params] if isinstance(basin_params, torch.Tensor) else
            [round(param.item(), 4) if isinstance(param, torch.Tensor) else round(param, 4) for param in basin_params]
        )

        # Print the result
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}, params: {params_list}")

    # Save the model
    torch.save(model_pinn.state_dict(), f"{basin}_pinn_model.pth")
    print(f"Model saved for basin {basin}")

    # Save the scaler
    torch.save(scaler, f"{basin}_scaler.pth")
    print(f"Scaler saved for basin {basin}")

    # Save the parameters as csv
    params_df = pd.DataFrame(params_list, columns=['value'])
    params_df.to_csv(f"{basin}_params.csv", index=False)

smax: Parameter containing:
tensor([1709.4611])
s1: tensor([[0.],
        [0.]], grad_fn=<SliceBackward0>) tensor([[0.],
        [0.]], grad_fn=<SliceBackward0>)
smax - s1 tensor([[1709.4611],
        [1709.4611]], grad_fn=<SliceBackward0>) tensor([[1709.4611],
        [1709.4611]], grad_fn=<SliceBackward0>)
result: tensor([[3.7027e-13],
        [3.7027e-13]], grad_fn=<SliceBackward0>) tensor([[3.7027e-13],
        [3.7027e-13]], grad_fn=<SliceBackward0>)
Epoch 1/100000, Loss: 6.357026100158691, params: [0.0, 1303.0043, 0.0167, 1709.4611, 18.47, 2.6745, 0.1757, -2.093]
smax: Parameter containing:
tensor([1709.4611])
s1: tensor([[0.],
        [0.]], grad_fn=<SliceBackward0>) tensor([[0.],
        [0.]], grad_fn=<SliceBackward0>)
smax - s1 tensor([[1709.4611],
        [1709.4611]], grad_fn=<SliceBackward0>) tensor([[1709.4611],
        [1709.4611]], grad_fn=<SliceBackward0>)
result: tensor([[3.7027e-13],
        [3.7027e-13]], grad_fn=<SliceBackward0>) tensor([[3.7027e-13],
        [3.70

KeyboardInterrupt: 

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>