In [86]:
from NN_models import PINN_backbone
# from simulate import simulate_convergence, simulate_branching, simulate_ridge, simulate_merge, simulate_deflection
from metrics import compute_RMSE, compute_MAE
from utils import set_seed

# Global file for training configs
from configs import PATIENCE, MAX_NUM_EPOCHS, NUM_RUNS, LEARNING_RATE, WEIGHT_DECAY, BATCH_SIZE, N_SIDE, PINN_RESULTS_DIR, W_PINN_DIV_WEIGHT, PINN_LEARNING_RATE

import torch
from torch.func import vmap, jacfwd
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import os
import pandas as pd

### TIMING ###
import time
start_time = time.time()  # Start timing after imports

# Set seed for reproducibility
set_seed(42)

# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

model_name = "PINN"

#########################
### x_train & y_train ###
#########################

# Import all simulation functions
from simulate import (
    simulate_detailed_convergence,
    simulate_detailed_deflection,
    simulate_detailed_curve,
    simulate_detailed_ridges,
    simulate_detailed_branching,
)

# Define simulations as a dictionary with names as keys to function objects
simulations = {
    "convergence_dtl": simulate_detailed_convergence,
    #"deflection_dtl": simulate_detailed_deflection,
    #"curve_dtl": simulate_detailed_curve,
    #"ridges_dtl": simulate_detailed_ridges,
    #"branching_dtl": simulate_detailed_branching,
}

# Load training inputs, not weights_only = True
x_train = torch.load("data/sim_data/x_train_lines_discretised_0to1.pt", weights_only = False).float()

# Storage dictionaries
y_train_dict = {}

# Make y_train_dict: Iterate over all simulation functions
for sim_name, sim_func in simulations.items():

    # Generate training observations
    y_train = sim_func(x_train)
    y_train_dict[sim_name] = y_train  # Store training outputs

    # Print details
    print(f"=== {sim_name.upper()} ===")
    print(f"Training inputs shape: {x_train.shape}")
    print(f"Training observations shape: {y_train.shape}")
    print(f"Training inputs dtype: {x_train.dtype}")
    print()

#######################
### x_test & y_test ###
#######################

print("=== Generating test data ===")

# Choose discretisation that is good for simulations and also for quiver plotting
N_SIDE = N_SIDE

side_array = torch.linspace(start = 0.0, end = 1.0, steps = N_SIDE)
XX, YY = torch.meshgrid(side_array, side_array, indexing = "xy")
x_test_grid = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim = -1)
# long format
x_test = x_test_grid.reshape(-1, 2)

# Storage dictionaries
y_test_dict = {}

# Make y_test_dict: Iterate over all simulation functions
for sim_name, sim_func in simulations.items():

    # Generate test observations
    y_test = sim_func(x_test)
    y_test_dict[sim_name] = y_test  # Store test outputs

    # Print details
    print(f"=== {sim_name.upper()} ===")
    print(f"Test inputs shape: {x_test.shape}")
    print(f"Test observations shape: {y_test.shape}")
    print(f"Test inputs dtype: {x_test.dtype}")
    print()

#####################
### Training loop ###
#####################

# Early stopping parameters
PATIENCE = PATIENCE  # Stop after 50 epochs with no improvement

MAX_NUM_EPOCHS = MAX_NUM_EPOCHS # 2000

# Number of training runs for mean and std of metrics
NUM_RUNS = 1 # 10

# higher lr for PINN
LEARNING_RATE = PINN_LEARNING_RATE
WEIGHT_DECAY = WEIGHT_DECAY

BATCH_SIZE = BATCH_SIZE

w = W_PINN_DIV_WEIGHT

# Ensure the results folder exists
RESULTS_DIR = PINN_RESULTS_DIR # Change this to "results" for full training

os.makedirs(RESULTS_DIR, exist_ok = True)

### LOOP OVER SIMULATIONS ###
for sim_name, sim_func in simulations.items():
    print(f"\nTraining for {sim_name.upper()}...")

    # Store metrics for the current simulation
    simulation_results = []

    # x_train, x_test stays the same but select y_train
    y_train = y_train_dict[sim_name]
    # select the correct y_test (PREVIOUS ERROR)
    y_test = y_test_dict[sim_name]

    ### LOOP OVER RUNS ###
    for run in range(NUM_RUNS):
        print(f"\n--- Training Run {run + 1}/{NUM_RUNS} ---")

        # Convert to DataLoader for batching
        dataset = TensorDataset(x_train, y_train)
        dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

        # Initialise fresh model
        # we seeded so this is reproducible
        PINN_model = PINN_backbone().to(device)
        PINN_model.train()

        # Define loss function (e.g., MSE for regression)
        criterion = torch.nn.MSELoss()

        # Define optimizer (e.g., AdamW)
        optimizer = optim.AdamW(PINN_model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)

        # Initialise tensor to store losses
        epoch_train_losses = torch.zeros(MAX_NUM_EPOCHS)
        epoch_train_rmse_losses = torch.zeros(MAX_NUM_EPOCHS)
        epoch_test_losses = torch.zeros(MAX_NUM_EPOCHS)
        epoch_test_rmse_losses = torch.zeros(MAX_NUM_EPOCHS)

        # Early stopping variables
        best_loss = float('inf')
        epochs_no_improve = 0

        ### LOOP OVER EPOCHS ###
        print("\nStart Training")
        for epoch in range(MAX_NUM_EPOCHS):

            epoch_train_loss = 0.0  # Accumulate batch losses
            epoch_train_rmse_loss = 0.0

            for batch in dataloader:
                PINN_model.train()

                x_batch, y_batch = batch
                # put on GPU
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                # inplace
                x_batch.requires_grad_()

                # Forward pass
                y_pred = vmap(PINN_model)(x_batch)
                # torch.Size([32 (batch_dim), 2 (out_dim), 2 (in_dim)])
                batch_divergence = vmap(jacfwd(PINN_model))(x_batch)
                # sum: f1/x1 + f2/x2, square to account for negative
                batch_divergence_loss = torch.square(torch.diagonal(batch_divergence, dim1 = -2, dim2 = -1).sum())

                # Compute loss (RMSE for same units as data) + divergence loss
                loss = (1 - w) * torch.sqrt(criterion(y_pred, y_batch)) + w * batch_divergence_loss
                epoch_train_loss += loss.item()
                epoch_train_rmse_loss += torch.sqrt(criterion(y_pred, y_batch)).item()

                # Backpropagation
                # AFTER
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            ### END BATCH LOOP ###
            PINN_model.eval()

            # Test loss outside of batch loop - once per epoch
            y_test_pred = vmap(PINN_model)(x_test.to(device))
            # compute just once
            epoch_test_rmse_loss = torch.sqrt(criterion(y_test_pred, y_test.to(device))).item()
            epoch_test_loss = (1 - w) * torch.sqrt(criterion(y_test_pred, y_test.to(device))) + w * torch.square(torch.diagonal(vmap(jacfwd(PINN_model))(x_test.to(device)), dim1 = -2, dim2 = -1).sum()).item()

            # Compute average loss for the epoch and store
            epoch_train_losses[epoch] = epoch_train_loss / len(dataloader)
            epoch_train_rmse_losses[epoch] = epoch_train_rmse_loss / len(dataloader)
            epoch_test_losses[epoch] = epoch_test_loss # old calc once per epoch over all
            epoch_test_rmse_losses[epoch] = epoch_test_rmse_loss

            print(f"{sim_name} {model_name} Run {run + 1}/{NUM_RUNS}, Epoch {epoch + 1}/{MAX_NUM_EPOCHS}, Training Loss (RMSE): {epoch_train_losses[epoch]:.4f}")

            # Early stopping check
            if epoch_train_losses[epoch] < best_loss:
                best_loss = epoch_train_losses[epoch]
                epochs_no_improve = 0  # Reset counter
                best_model_state = PINN_model.state_dict()  # Save best model
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= PATIENCE:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                break

        ### END EPOCH LOOP ###
        # Load the best model before stopping
        PINN_model.load_state_dict(best_model_state)
        print(f"Run {run + 1}/{NUM_RUNS}, Training of {model_name} complete for {sim_name.upper()}. Restored best model.")

        ################
        ### EVALUATE ###
        ################

        # Evaluate the trained model for each run
        PINN_model.eval()

        y_train_PINN_predicted = vmap(PINN_model)(x_train.to(device)).detach()
        y_test_PINN_predicted = vmap(PINN_model)(x_test.to(device)).detach()

        # Compute Divergence (convert tensor to float)
        PINN_train_div = torch.diagonal(vmap(jacfwd(PINN_model))(x_train.to(device)), dim1 = -2, dim2 = -1).detach().sum().item()
        PINN_test_div = torch.diagonal(vmap(jacfwd(PINN_model))(x_test.to(device)), dim1 = -2, dim2 = -1).detach().sum().item()

Using device: cuda

=== CONVERGENCE_DTL ===
Training inputs shape: torch.Size([196, 2])
Training observations shape: torch.Size([196, 2])
Training inputs dtype: torch.float32

=== Generating test data ===
=== CONVERGENCE_DTL ===
Test inputs shape: torch.Size([400, 2])
Test observations shape: torch.Size([400, 2])
Test inputs dtype: torch.float32


Training for CONVERGENCE_DTL...

--- Training Run 1/1 ---

Start Training
convergence_dtl PINN Run 1/1, Epoch 1/2000, Training Loss (RMSE): 0.6746
convergence_dtl PINN Run 1/1, Epoch 2/2000, Training Loss (RMSE): 0.4724
convergence_dtl PINN Run 1/1, Epoch 3/2000, Training Loss (RMSE): 0.5190
convergence_dtl PINN Run 1/1, Epoch 4/2000, Training Loss (RMSE): 0.4592
convergence_dtl PINN Run 1/1, Epoch 5/2000, Training Loss (RMSE): 0.4411
convergence_dtl PINN Run 1/1, Epoch 6/2000, Training Loss (RMSE): 0.3808
convergence_dtl PINN Run 1/1, Epoch 7/2000, Training Loss (RMSE): 0.4188
convergence_dtl PINN Run 1/1, Epoch 8/2000, Training Loss (RMSE):

In [87]:
print(f"PINN divergence (train): {PINN_train_div:.4f}")

PINN divergence (train): -0.3034


In [92]:
torch.diagonal(vmap(jacfwd(PINN_model))(x_train.to(device)), dim1 = -2, dim2 = -1).detach().sum(dim = -1).abs().sum()

tensor(2.0124, device='cuda:0')

In [101]:
x_train_grad = x_train.to(device).requires_grad_()
y_train_PINN_predicted = vmap(PINN_model)(x_train_grad)

In [93]:
# autograd div train
u_indicator_train, v_indicator_train = torch.zeros_like(y_train_PINN_predicted), torch.zeros_like(y_train_PINN_predicted)
u_indicator_train[:, 0] = 1.0 # output column u selected
v_indicator_train[:, 1] = 1.0 # output column v selected

In [102]:
GP_train_div = (torch.autograd.grad(
            outputs = y_train_PINN_predicted,
            inputs = x_train_grad,
            grad_outputs = u_indicator_train,
            create_graph = True
        )[0][:, 0] + torch.autograd.grad(
            outputs = y_train_PINN_predicted,
            inputs = x_train_grad,
            grad_outputs = v_indicator_train,
            create_graph = True
        )[0][:, 1]).abs().sum().item() # v with respect to y

In [103]:
GP_train_div

2.012362003326416

In [None]:
# functional div test
jac_autograd_test = torch.autograd.functional.jacobian(apply_GP, 
                                        x_test.to(device))
jac_autograd_test = torch.einsum("bobi -> boi", jac_autograd_test) # batch out batch in
GP_test_div = torch.diagonal(jac_autograd_test, dim1 = -2, dim2 = -1).sum().item()

In [29]:
print(GP_test_div)

-12.526342391967773


In [30]:
jac_autograd_test.shape
jac_autograd_test_per_item.shape
jac_autograd_test_per_item

tensor([[-4.7544e-01,  5.3905e-01],
        [-3.6514e-01,  3.5680e-01],
        [-1.8622e-01,  1.2825e-01],
        [ 4.1965e-02, -1.2784e-01],
        [ 2.9171e-01, -3.8863e-01],
        [ 5.3086e-01, -6.2973e-01],
        [ 7.2783e-01, -8.2798e-01],
        [ 8.5671e-01, -9.6408e-01],
        [ 9.0161e-01, -1.0248e+00],
        [ 8.5936e-01, -1.0044e+00],
        [ 7.4000e-01, -9.0492e-01],
        [ 5.6497e-01, -7.3582e-01],
        [ 3.6322e-01, -5.1233e-01],
        [ 1.6616e-01, -2.5351e-01],
        [ 2.2838e-03,  2.0085e-02],
        [-1.0754e-01,  2.8869e-01],
        [-1.5314e-01,  5.3522e-01],
        [-1.3611e-01,  7.4656e-01],
        [-6.8718e-02,  9.1426e-01],
        [ 2.8924e-02,  1.0346e+00],
        [-4.9582e-01,  5.4583e-01],
        [-4.1540e-01,  3.9960e-01],
        [-2.5884e-01,  2.0148e-01],
        [-4.2740e-02, -3.2286e-02],
        [ 2.0676e-01, -2.8075e-01],
        [ 4.5766e-01, -5.2047e-01],
        [ 6.7709e-01, -7.2820e-01],
        [ 8.3655e-01, -8.835

In [31]:
mean_pred_test

autograd = torch.autograd.grad(
    outputs = mean_pred_test.sum(),
    inputs = x_test_grad,
)[0]

autograd.sum()

tensor(800.2223, device='cuda:0')

In [78]:
u_indicator_test, v_indicator_test = torch.zeros_like(mean_pred_test), torch.zeros_like(mean_pred_test)
u_indicator_test[:, 0] = 1.0 # output column u selected
v_indicator_test[:, 1] = 1.0 # output column v selected

In [None]:
autograd_u = torch.autograd.grad(
    outputs = mean_pred_test,
    inputs = x_test_grad,
    grad_outputs = u_indicator_test,
    create_graph = True
)[0][:, 0] # u with respect to x

autograd_v = torch.autograd.grad(
    outputs = mean_pred_test,
    inputs = x_test_grad,
    grad_outputs = v_indicator,
    create_graph = True
)[0][:, 1] # v with respect to y

In [79]:
divergence = torch.autograd.grad(
    outputs = mean_pred_test,
    inputs = x_test_grad,
    grad_outputs = u_indicator_test,
    create_graph = True
)[0][:, 0] + torch.autograd.grad(
    outputs = mean_pred_test,
    inputs = x_test_grad,
    grad_outputs = v_indicator_test,
    create_graph = True
)[0][:, 1] # v with respect to y

In [83]:
divergence.abs().sum().item()

66.52590942382812

In [75]:
divergence.sum().item()

-12.526344299316406

In [67]:
combined = autograd_u[:, 0] + autograd_v[:, 1]
combined.sum()

tensor(-12.5263, device='cuda:0', grad_fn=<SumBackward0>)

In [56]:
autograd

tensor([[-4.7544e-01,  5.2412e-01],
        [-3.6514e-01,  4.5890e-01],
        [-1.8622e-01,  3.7656e-01],
        [ 4.1965e-02,  2.8763e-01],
        [ 2.9171e-01,  2.0410e-01],
        [ 5.3086e-01,  1.3799e-01],
        [ 7.2783e-01,  9.9891e-02],
        [ 8.5671e-01,  9.7717e-02],
        [ 9.0161e-01,  1.3569e-01],
        [ 8.5936e-01,  2.1391e-01],
        [ 7.4000e-01,  3.2843e-01],
        [ 5.6497e-01,  4.7182e-01],
        [ 3.6322e-01,  6.3414e-01],
        [ 1.6616e-01,  8.0415e-01],
        [ 2.2838e-03,  9.7053e-01],
        [-1.0754e-01,  1.1230e+00],
        [-1.5314e-01,  1.2533e+00],
        [-1.3611e-01,  1.3554e+00],
        [-6.8718e-02,  1.4260e+00],
        [ 2.8924e-02,  1.4640e+00],
        [-4.9582e-01,  1.2426e+00],
        [-4.1540e-01,  1.2385e+00],
        [-2.5884e-01,  1.1988e+00],
        [-4.2740e-02,  1.1309e+00],
        [ 2.0676e-01,  1.0449e+00],
        [ 4.5766e-01,  9.5280e-01],
        [ 6.7709e-01,  8.6669e-01],
        [ 8.3655e-01,  7.976

In [40]:
divergence = 0.0
for i in range(mean_pred_test.shape[-1]):  # loop over output dimensions (e.g., u and v)
    grad_outputs = torch.zeros_like(mean_pred_test)
    grad_outputs[:, i] = 1.0  # pick out derivative w.r.t. output i

    grad = torch.autograd.grad(
        outputs = mean_pred_test,
        inputs = x_test_grad,
        grad_outputs = grad_outputs,
        create_graph = True
    )[0]  # shape (batch, input_dim)

    divergence += grad[:, i]  # take derivative of output i w.r.t. input i

GP_test_div_manual = divergence.sum().item()

In [42]:
divergence

tensor([ 6.3601e-02, -8.3437e-03, -5.7976e-02, -8.5873e-02, -9.6913e-02,
        -9.8862e-02, -1.0014e-01, -1.0737e-01, -1.2319e-01, -1.4502e-01,
        -1.6492e-01, -1.7085e-01, -1.4911e-01, -8.7354e-02,  2.2368e-02,
         1.8115e-01,  3.8208e-01,  6.1045e-01,  8.4554e-01,  1.0635e+00,
         5.0009e-02, -1.5807e-02, -5.7362e-02, -7.5025e-02, -7.3992e-02,
        -6.2808e-02, -5.1117e-02, -4.7046e-02, -5.4862e-02, -7.3443e-02,
        -9.5995e-02, -1.1119e-01, -1.0548e-01, -6.6133e-02,  1.5629e-02,
         1.4210e-01,  3.0776e-01,  4.9938e-01,  6.9772e-01,  8.8029e-01,
         4.9696e-02, -9.1718e-03, -4.2615e-02, -5.0822e-02, -3.9174e-02,
        -1.6823e-02,  5.5739e-03,  1.8605e-02,  1.6584e-02, -7.4810e-04,
        -2.7729e-02, -5.3777e-02, -6.5623e-02, -5.0288e-02,  1.7172e-03,
         9.3819e-02,  2.2190e-01,  3.7437e-01,  5.3361e-01,  6.7877e-01,
         5.5814e-02,  3.7631e-03, -2.2331e-02, -2.2447e-02, -2.0576e-03,
         2.9247e-02,  6.0052e-02,  7.9869e-02,  8.1

In [41]:
GP_test_div_manual

-12.526344299316406