# Training a PINN for PINNSim

The following notebook describes the basic elements of the workflow of training a PINN. The following structure mimics `pinnsim.learning_functions.workflow`, however, some sections of the workflow are simplified here for illustration purposes. At the end of this notebook, we explain in brief how `pinnsim.learning_functions.workflow` can be used on a high level, for the details we refer to the function itself.

### Imports

In [None]:
import matplotlib.pyplot as plt
import torch
from pinnsim import LEARNING_DATA_PATH
from pinnsim.configurations.hyperparameter_configs import (
    convert_sweep_config_to_run_config,
    default_hyperparameter_setup,
)
from pinnsim.learning_functions.loss_normed_state import LossNormedState
from pinnsim.learning_functions.setup_functions import (
    setup_dataset,
    setup_nn_model,
    setup_optimiser,
    setup_schedulers,
)

## Setup

We begin by defining a config file that contains all relevant information, e.g., the generator name (`generator_name`) or the number of neurons per hidden layer (`hidden_layer_size`). Ideally, all variations to the training setup can be adjusted in the config file as this allows us to easily modify and track the training setups.

In [None]:
sweep_config = default_hyperparameter_setup()
config = convert_sweep_config_to_run_config(sweep_config=sweep_config)

print(config.generator_name)
print(config.hidden_layer_size)

### Datasets

Based on this config file, we load the datasets and the component model, i.e., the generator model. We will use `dataset_training` and `dataset_collocation` to optimise the NN parameters, while `dataset_validation` is used to track the training process and select the best performing model. After the training `dataset_testing` is used to asses the performance in an unbiased way. As the dataset are simulated, we want to avoid to re-simulate the datasets for every training run. Therefore, we once simulate large datasets, store them in `LEARNING_DATA_PATH` (specified in `pinnsim_utils.__init__.py`) and subsequently only sample from those. For more details see `setup_dataset`.

In [None]:
(
    dataset_training,
    dataset_validation,
    dataset_testing,
    dataset_collocation,
    component_model,
) = setup_dataset(config=config, data_path=LEARNING_DATA_PATH)

The datasets are constructed similar to datasets from `torch.utils.data` (see `pinnsim.dataset_functions.dataset_object.py`). This allows indexing the dataset as shown below. The returned tuple corresponds to 
- time $t$
- initial_state $x_0$
- control_input $u$
- voltage_parametrisation $\Xi$
- result_state $\hat{x}(t)$

namely, all values that we need in the learning.

In [None]:
dataset_training[1]

### NN model

For the setup of the NN model, we need a few config parameters (number and size of the hidden layers and optionally a seed to control the intialisation of the NN's parameters). Additionally, we provide the component model as we will use it for the calculation of the physics loss $\mathcal{L}_c$. Optionally, the training dataset can be used to set an internal normalisation of the NN.

We aim to approximate the integration operation
$\begin{align}
\hat{x}(t) = x_0 + \int_0^t f(x, u, \Xi) dt
\end{align}$
with the NN, hence the basic input-output structure the model should follow is $[t, x_0, u, \Xi] \mapsto \hat{x}$. Calling `nn_model.forward()` with the corresponding inputs will predict the value $\hat{x}$. Within the function `pinnsim.learning_functions.dynamical_system_NN.py`, we apply a few adjustments to a simple feed-forward neural network in order to help the learning process. In any case, the above mapping is always maintained.

In [None]:
nn_model = setup_nn_model(
    config=config,
    power_system_model=component_model,
    training_dataset=dataset_training,
)

### Optimiser, scheduler, and loss function

The remaining elements that we need to set up are
- an optimiser, we use a L-BFGS optimiser as implemented in `torch`.
- optionally schedulers for the learning rate and the weighting factor $\alpha$ between the data $\mathcal{L}_x$ and the physics loss $\mathcal{L}_c$. The total loss is calculated as $\mathcal{L} = \mathcal{L}_x + \alpha \mathcal{L}_c$. In our experience, a too large value for $\alpha$ can lead to problems during the first epochs of the training.
- a loss function. The state $\hat{x}$ has not the same units across all its variables, hence we define a scaling and compute the loss in the scaled/normed state space. These factors are component dependent.

In [None]:
optimiser = setup_optimiser(nn_model=nn_model, config=config)

learning_rate_scheduler, loss_weight_scheduler= setup_schedulers(
    nn_model=nn_model, 
    optimiser=optimiser, 
    config=config   
)

loss_function = LossNormedState(component_model=component_model)

## Training and evaluation function

Next, we define the functions that constitute an epoch in the training process, namely a training step and an evaluation step.

### Training step

After unpacking the training dataset with simulated points (`dataset`) and the dataset with collocation points (`dataset_collocation`), we define a closure function that will be supplied to the optimiser. 

For the data points, we simply call `nn_model.forward()` and then apply the specified (scaled) loss function. 

For the collocation points, the function `nn_model.forward_lhs_rhs()` provides the
- state prediction $\hat{x}$
- the temporal derivative of the state prediction `d_dt_state_prediction_c` $\frac{d}{dt} \hat{x}$
- the update function $f$ evaluated with the state prediction `f_prediction_c` $f(\hat{x}, u, \Xi)$ 

The physics loss $\mathcal{L}_c$ is calculated as $\Vert \frac{d}{dt} \hat{x} - f(\hat{x}, u, \Xi) \Vert^2$ and added to the data loss $\mathcal{L}_x$ to form the total loss $\mathcal{L} = \mathcal{L}_x + \alpha \mathcal{L}_c$.

As the loss values can become very small, we use a factor `loss_multiplier` to avoid being affected by the internal tolerance settings of the L-BFGS optimiser.

In [None]:

def train_epoch(
    dataset,
    dataset_collocation,
    nn_model,
    loss_function,
    optimiser,
):
    nn_model.train()
    time, state_initial, control_input, voltage_parametrisation, state_result = dataset
    (
        time_c,
        state_initial_c,
        control_input_c,
        voltage_parametrisation_c,
        _,
    ) = dataset_collocation
    loss_multiplier = torch.tensor(nn_model.epochs_total + 1)

    def closure():
        optimiser.zero_grad()

        state_prediction = nn_model.forward(
            time=time,
            state_initial=state_initial,
            control_input=control_input,
            voltage_parametrisation=voltage_parametrisation,
        )

        loss_prediction = loss_function(inputs=state_prediction, targets=state_result)

        state_prediction_c, d_dt_state_prediction_c, f_prediction_c = nn_model.forward_lhs_rhs(
            time=time_c,
            state_initial=state_initial_c,
            control_input=control_input_c,
            voltage_parametrisation=voltage_parametrisation_c,
        )
        
        loss_physics = loss_function(
            inputs=d_dt_state_prediction_c, targets=f_prediction_c
        )


        loss = (
            loss_prediction + nn_model.physics_regulariser * loss_physics
        ) * loss_multiplier
        loss.backward()

        return loss

    loss = optimiser.step(closure)

    return loss / loss_multiplier

### Evaluation step

The evaluation function simply evaluates the loss for the validation dataset. The NN model should be set to the evaluation mode `nn_model.eval()` in case that dropout or batch normalisation functions are used. The line `with torch.no_grad()` indicates, that no gradients are needed of the following function evaluation (unlike in the training); this speeds up the computation. 

In [None]:
def evaluate_model(
    dataset,
    nn_model,
    loss_function,
):
    nn_model.eval()

    with torch.no_grad():
        (
            time,
            state_initial,
            control_input,
            voltage_parametrisation,
            state_result,
        ) = dataset

        (
            state_prediction,
            d_dt_state_prediction,
            f_state_prediction,
        ) = nn_model.forward_lhs_rhs(
            time=time,
            state_initial=state_initial,
            control_input=control_input,
            voltage_parametrisation=voltage_parametrisation,
        )
    
    loss_prediction = loss_function(inputs=state_prediction, targets=state_result)
    loss_physics = loss_function(inputs=d_dt_state_prediction, targets=f_state_prediction)

    return loss_prediction, loss_physics


## Training loop

The training becomes a simple loop over `train_epoch` and `evaluate_model`. Optionally, we can add logging and/or printing functions and scheduled adjustment to certain parameters, e.g., the learning rate or the loss weighting parameter. If the training process is logged, it can be visualised afterwards. 

In [None]:
logging_list = list()
maximum_epochs = 20

print("Epoch | Loss training | Loss prediction val| loss physics val ")

while nn_model.epochs_total < maximum_epochs:
    loss = train_epoch(
                dataset=dataset_training[:],
                dataset_collocation=dataset_collocation[:],
                nn_model=nn_model,
                loss_function=loss_function,
                optimiser=optimiser,
            )
    
    loss_prediction, loss_physics = evaluate_model(
        dataset=dataset_validation[:],
        nn_model=nn_model,
        loss_function=loss_function,
    )

    learning_rate_scheduler.step()
    loss_weight_scheduler()
    print(
        f" {nn_model.epochs_total + 1:04} |"
        f"      {loss:.2e} |"
        f"           {loss_prediction:.2e} |"
        f"         {loss_physics:.2e}"
    )
    logging_list.append(torch.stack([torch.tensor(nn_model.epochs_total + 1), loss, loss_prediction, loss_physics]))
    nn_model.epochs_total += 1
    

In [None]:
results_log = torch.vstack(logging_list).detach()

plt.plot(results_log[:, 0], results_log[:, 1], label="Training loss")
plt.plot(results_log[:, 0], results_log[:, 2], label="Validation loss data")
plt.plot(results_log[:, 0], results_log[:, 3], label="Validation loss physics")
plt.loglog()
plt.legend()
plt.xlim([1, maximum_epochs])
plt.xlabel("Epoch [-]")
plt.ylabel("Loss [-]")
plt.show()

### Model testing

Finally, the model can be tested, e.g., evaluating the test loss. In `pinnsim.learning_functions.testing_functions.py`, we also compute the maximum and mean absolute error of each state to improve the interpretability of the learning outcome.

In [None]:
def test_model(   
    dataset,
    nn_model,
    loss_function
):
    nn_model.eval()

    with torch.no_grad():
        (
            time,
            state_initial,
            control_input,
            voltage_parametrisation,
            state_result,
        ) = dataset

        state_prediction = nn_model.predict(
            time=time,
            state_initial=state_initial,
            control_input=control_input,
            voltage_parametrisation=voltage_parametrisation,
        )

    loss_testing = loss_function(inputs=state_prediction, targets=state_result)
    print(f"Test loss {loss_testing:.2e}")

    
test_model(
            dataset=dataset_testing[:],
            nn_model=nn_model,
            loss_function=loss_function,
        )

## The complete workflow

In `pinnsim.learning_functions`, we add more complexity to the above described setup and in `workflow.py`, there is an implementation of the entire workflow, such that it can be called from a single config file as shown below. An important aspect, that is added, is the logging functionality using [WandB](https://wandb.ai) and an automatic saving of the best model based on the validation data loss.

In [None]:
from pinnsim.learning_functions.workflow import train
sweep_config = default_hyperparameter_setup()
run_config = convert_sweep_config_to_run_config(sweep_config=sweep_config)

After defining the `run_config` file the entire process can be called by the following line. The notebook `training_sweeps` shows how multiple models can easily be trained using this setup.

In [None]:
train(config=run_config)