In [None]:
from visualise import visualise_v_stream, visualise_v_quiver
from NN_models import dfNN_for_vmap, PINN_backbone
from simulate import simulate_convergence, simulate_branching, simulate_ridge, simulate_merge, simulate_deflection
from metrics import compute_RMSE, compute_MAE
from utils import set_seed

import torch
import torch.nn as nn
from torch.func import vmap, jacfwd
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import os
import pandas as pd

%load_ext autoreload
%autoreload 2

# Full loop for dfNN

In [26]:
from NN_models import dfNN_for_vmap, PINN_backbone
from simulate import simulate_convergence, simulate_branching, simulate_ridge, simulate_merge, simulate_deflection
from metrics import compute_RMSE, compute_MAE
from utils import set_seed

import torch
from torch.func import vmap, jacfwd
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import os
import pandas as pd

# Set seed for reproducibility
set_seed(42)

model_name = "dfNN"

#########################
### x_train & y_train ###
#########################

# Import all simulation functions
from simulate import (
    simulate_convergence,
    simulate_branching,
    simulate_merge,
    simulate_deflection,
    simulate_ridge,
)

# Define simulations as a dictionary with names as keys to function objects
simulations = {
    "convergence": simulate_convergence,
    "branching": simulate_branching,
    "merge": simulate_merge,
    "deflection": simulate_deflection,
    "ridge": simulate_ridge,
}

# Load training inputs
x_train = torch.load("data/sim_data/x_train_lines_discretised_0to1.pt").float()

# Storage dictionaries
y_train_dict = {}

# Make y_train_dict: Iterate over all simulation functions
for name, sim_func in simulations.items():

    # Generate training observations
    y_train = sim_func(x_train)
    y_train_dict[name] = y_train  # Store training outputs

    # Print details
    print(f"=== {name.upper()} ===")
    print(f"Training inputs shape: {x_train.shape}")
    print(f"Training observations shape: {y_train.shape}")
    print(f"Training inputs dtype: {x_train.dtype}")
    print()

#######################
### x_test & y_test ###
#######################

print("=== Generating test data ===")

# Choose discretisation that is good for simulations and also for quiver plotting
N_SIDE = 20

side_array = torch.linspace(start = 0.0, end = 1.0, steps = N_SIDE)
XX, YY = torch.meshgrid(side_array, side_array, indexing = "xy")
x_test_grid = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim = -1)
# long format
x_test = x_test_grid.reshape(-1, 2)

# Storage dictionaries
y_test_dict = {}

# Make y_test_dict: Iterate over all simulation functions
for name, sim_func in simulations.items():

    # Generate test observations
    y_test = sim_func(x_test)
    y_test_dict[name] = y_test  # Store test outputs

    # Print details
    print(f"=== {name.upper()} ===")
    print(f"Test inputs shape: {x_test.shape}")
    print(f"Test observations shape: {y_test.shape}")
    print(f"Test inputs dtype: {x_test.dtype}")
    print()

    # visualise_v_quiver(y_test, x_test, title_string = name)

#####################
### Training loop ###
#####################

# Early stopping parameters
PATIENCE = 50  # Stop after 50 epochs with no improvement
max_num_epochs = 5 # 2000

# Number of training runs for mean and std of metrics
NUM_RUNS = 3 # 10

LEARNING_RATE = 0.0001
WEIGHT_DECAY = 1e-4

# Ensure the results folder exists
results_dir = "results"
os.makedirs(results_dir, exist_ok = True)

### LOOP OVER SIMULATIONS ###
for name, sim_func in simulations.items():
    print(f"\nTraining for {name.upper()}...")

    # Store metrics for the current simulation
    simulation_results = []

    # x_train is the same, select y_train
    y_train = y_train_dict[name]

    ### LOOP OVER RUNS ###
    for run in range(NUM_RUNS):
        print(f"\n--- Training Run {run + 1}/{NUM_RUNS} ---")

        # Convert to DataLoader for batching
        dataset = TensorDataset(x_train, y_train)
        dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

        # Initialise fresh model
        # we seeded so this is reproducible
        dfNN_model = dfNN_for_vmap()
        dfNN_model.train()

        # Define loss function (e.g., MSE for regression)
        criterion = torch.nn.MSELoss()

        # Define optimizer (e.g., AdamW)
        optimizer = optim.AdamW(dfNN_model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)

        # Initialise tensors to store losses
        epoch_train_losses = torch.zeros(max_num_epochs)
        epoch_test_losses = torch.zeros(max_num_epochs)

        # Early stopping variables
        best_loss = float('inf')
        epochs_no_improve = 0

        ### LOOP OVER EPOCHS ###
        print("\nStart Training")
        for epoch in range(max_num_epochs):

            epoch_train_loss = 0.0  # Accumulate batch losses within epoch
            epoch_test_loss = 0.0

            for batch in dataloader:
                x_batch, y_batch = batch
                x_batch.requires_grad_()

                # Forward pass
                y_pred = vmap(dfNN_model)(x_batch)

                # Compute loss (RMSE for same units as data)
                loss = torch.sqrt(criterion(y_pred, y_batch))
                epoch_train_loss += loss.item()

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Compute test loss for loss convergence plot
                y_test_pred = vmap(dfNN_model)(x_test)
                epoch_test_loss += torch.sqrt(criterion(y_test_pred, y_test)).item()

            # Compute average loss for the epoch
            avg_train_loss = epoch_train_loss / len(dataloader)
            avg_test_loss = epoch_test_loss / len(dataloader)

            epoch_train_losses[epoch] = avg_train_loss
            epoch_test_losses[epoch] = avg_test_loss

            print(f"Epoch {epoch+1}/{max_num_epochs}, Training Loss (RMSE): {avg_train_loss:.4f}")

            # Early stopping check
            if avg_train_loss < best_loss:
                best_loss = avg_train_loss
                epochs_no_improve = 0  # Reset counter
                best_model_state = dfNN_model.state_dict()  # Save best model
            else:
                epochs_no_improve += 1

            if epochs_no_improve >= PATIENCE:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

        # Load the best model before stopping
        dfNN_model.load_state_dict(best_model_state)
        print(f"Training of {model_name} complete for {name.upper()}. Restored best model.")

        ################
        ### EVALUATE ###
        ################

        # Evaluate the trained model
        dfNN_model.eval()

        y_train_dfNN_predicted = vmap(dfNN_model)(x_train).detach()
        y_test_dfNN_predicted = vmap(dfNN_model)(x_test).detach()

        # Only save things for one run
        if run == 0:
            #(1) Save predictions from first run so we can visualise them later
            torch.save(y_test_dfNN_predicted, f"{results_dir}/{name}_{model_name}_test_predictions.pt")

            #(2) Save loss over epochs
            df_losses = pd.DataFrame({
                'Epoch': list(range(epoch_train_losses.shape[0])), # pythonic
                'Train Loss RMSE': epoch_train_losses.tolist(), 
                'Test Loss RMSE': epoch_test_losses.tolist()
                })
            
            df_losses.to_csv(f"{results_dir}/{name}_{model_name}_losses_over_epochs.csv", index = False)

        # Compute Divergence (convert tensor to float)
        dfNN_train_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_train), dim1 = -2, dim2 = -1).detach().sum().item()
        dfNN_test_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_test), dim1 = -2, dim2 = -1).detach().sum().item()

        # Compute metrics (convert tensors to float)
        dfNN_train_RMSE = compute_RMSE(y_train, y_train_dfNN_predicted).item()
        dfNN_train_MAE = compute_MAE(y_train, y_train_dfNN_predicted).item()

        dfNN_test_RMSE = compute_RMSE(y_test, y_test_dfNN_predicted).item()
        dfNN_test_MAE = compute_MAE(y_test, y_test_dfNN_predicted).item()

        # Store results in list
        simulation_results.append([
            run + 1, dfNN_train_RMSE, dfNN_train_MAE, dfNN_train_div,
            dfNN_test_RMSE, dfNN_test_MAE, dfNN_test_div
        ])

    ### FINISH LOOP OVER RUNS ###
    # Convert results to a Pandas DataFrame
    df = pd.DataFrame(
        simulation_results, 
        columns = ["Run", "Train RMSE", "Train MAE", "Train Divergence",
                   "Test RMSE", "Test MAE", "Test Divergence"])

    # Compute mean and standard deviation for each metric
    mean_std_df = df.iloc[:, 1:].agg(["mean", "std"])  # Exclude "Run" column

    # Save results to CSV
    results_file = os.path.join(results_dir, f"{name}_{model_name}_metrics_per_run.csv")
    df.to_csv(results_file, index = False)
    print(f"\nResults saved to {results_file}")

    # Save mean and standard deviation to CSV
    mean_std_file = os.path.join(results_dir, f"{name}_{model_name}_metrics_summary.csv")
    mean_std_df.to_csv(mean_std_file)
    print(f"\nMean & Std saved to {mean_std_file}")

=== CONVERGENCE ===
Training inputs shape: torch.Size([196, 2])
Training observations shape: torch.Size([196, 2])
Training inputs dtype: torch.float32

=== BRANCHING ===
Training inputs shape: torch.Size([196, 2])
Training observations shape: torch.Size([196, 2])
Training inputs dtype: torch.float32

=== MERGE ===
Training inputs shape: torch.Size([196, 2])
Training observations shape: torch.Size([196, 2])
Training inputs dtype: torch.float32

=== DEFLECTION ===
Training inputs shape: torch.Size([196, 2])
Training observations shape: torch.Size([196, 2])
Training inputs dtype: torch.float32

=== RIDGE ===
Training inputs shape: torch.Size([196, 2])
Training observations shape: torch.Size([196, 2])
Training inputs dtype: torch.float32

=== Generating test data ===
=== CONVERGENCE ===
Test inputs shape: torch.Size([400, 2])
Test observations shape: torch.Size([400, 2])
Test inputs dtype: torch.float32

=== BRANCHING ===
Test inputs shape: torch.Size([400, 2])
Test observations shape: tor

In [24]:
len(epoch_train_losses.tolist())

5

In [25]:
len(list(range(epoch_train_losses.shape[0])))

5

# Load training data

In [None]:
from simulate import simulate_convergence, simulate_branching, simulate_ridge, simulate_merge, simulate_deflection

x_train = torch.load("data/sim_data/x_train_lines_discretised.pt").float()
y_train = simulate_convergence(x_train)

# small data
print(f"The shape of the training inputs is {x_train.shape}.")
print(f"The shape of the training observations is {y_train.shape}.")
print()
print(f"The dtype of the training inputs is {x_train.dtype}.")

In [None]:
# Generate rather dense grid for eval
side = torch.linspace(start = 0., end = 3., steps = 20)
XX, YY = torch.meshgrid(side, side, indexing = "xy")
x_test_grid = torch.cat([XX.unsqueeze(-1), YY.unsqueeze(-1)], dim = -1)
x_test = x_test_grid.reshape(-1, 2)

# Retrieve true
y_test = simulate_convergence(x_test)

# Early stopping

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim

# Convert to DataLoader for batching
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

# Initialise fresh model
dfNN_model = dfNN_for_vmap()
dfNN_model.train()

# Define loss function (e.g., MSE for regression)
criterion = nn.MSELoss()

# Define optimizer (e.g., AdamW)
optimizer = optim.AdamW(dfNN_model.parameters(), lr = 0.0001, weight_decay = 1e-4)
max_num_epochs = 1000 

# Initialise tensor to store losses
epoch_losses = torch.zeros(max_num_epochs)
epoch_test_losses = torch.zeros(max_num_epochs)

# Early stopping parameters
patience = 50  # Stop after 50 epochs with no improvement
best_loss = float('inf')
epochs_no_improve = 0

print("\nStart Training")
for epoch in range(max_num_epochs):

    epoch_loss = 0.0  # Accumulate batch losses
    epoch_test_loss = 0.0

    for batch in dataloader:
        x_batch, y_batch = batch
        x_batch.requires_grad_()

        y_pred = vmap(dfNN_model)(x_batch)

        # Compute loss (RMSE for same units as data)
        loss = torch.sqrt(criterion(y_pred, y_batch))
        epoch_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Test loss
        y_test_pred = vmap(dfNN_model)(x_test)
        test_loss = torch.sqrt(criterion(y_test_pred, y_test))
        epoch_test_loss += test_loss.item()

    # Compute average loss for the epoch
    avg_train_loss = epoch_loss / len(dataloader)
    avg_test_loss = epoch_test_loss / len(dataloader)

    epoch_losses[epoch] = avg_train_loss
    epoch_test_losses[epoch] = avg_test_loss

    print(f"Epoch {epoch+1}/{max_num_epochs}, Training Loss (RMSE): {avg_train_loss:.4f}")

    # Early stopping check
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        epochs_no_improve = 0  # Reset counter
        best_model_state = dfNN_model.state_dict()  # Save best model
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        break

# Load the best model before stopping
dfNN_model.load_state_dict(best_model_state)
print("Training complete. Restored best model.")

In [None]:
import os
import torch
import pandas as pd
import numpy as np

# Ensure the results folder exists
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

# Number of training runs
num_runs = 10

# Store all results
all_results = []

for run in range(num_runs):
    print(f"\n--- Training Run {run + 1}/{num_runs} ---")
    
    # Initialize model and set to training mode
    dfNN_model = dfNN_for_vmap()
    dfNN_model.train()

    # Define optimizer
    optimizer = torch.optim.AdamW(dfNN_model.parameters(), lr=0.0001, weight_decay=1e-4)

    # Training loop
    max_num_epochs = 1000
    for epoch in range(max_num_epochs):
        epoch_loss = 0.0
        for batch in dataloader:
            x_batch, y_batch = batch
            x_batch.requires_grad_()

            y_pred = vmap(dfNN_model)(x_batch)

            # Compute loss
            loss = torch.sqrt(criterion(y_pred, y_batch))
            epoch_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluate the trained model
    dfNN_model.eval()

    y_train_dfNN_predicted = vmap(dfNN_model)(x_train).detach()
    y_test_dfNN_predicted = vmap(dfNN_model)(x_test).detach()

    # Compute Divergence (convert tensor to float)
    dfNN_train_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_train), dim1=-2, dim2=-1).detach().sum().item()
    dfNN_test_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_test), dim1=-2, dim2=-1).detach().sum().item()

    # Compute metrics (convert tensors to float)
    dfNN_train_RMSE = compute_RMSE(y_train, y_train_dfNN_predicted).item()
    dfNN_train_MAE = compute_MAE(y_train, y_train_dfNN_predicted).item()

    dfNN_test_RMSE = compute_RMSE(y_test, y_test_dfNN_predicted).item()
    dfNN_test_MAE = compute_MAE(y_test, y_test_dfNN_predicted).item()

    # Store results in list
    all_results.append([
        run + 1, dfNN_train_RMSE, dfNN_train_MAE, dfNN_train_div,
        dfNN_test_RMSE, dfNN_test_MAE, dfNN_test_div
    ])

# Convert results to a Pandas DataFrame
df = pd.DataFrame(
    all_results, 
    columns=["Run", "Train RMSE", "Train MAE", "Train Divergence",
             "Test RMSE", "Test MAE", "Test Divergence"]
)

# Compute mean and standard deviation for each metric
mean_std_df = df.iloc[:, 1:].agg(["mean", "std"])  # Exclude "Run" column

# Save results to CSV
results_file = os.path.join(results_dir, "dfNN_performance.csv")
df.to_csv(results_file, index=False)
print(f"\nResults saved to {results_file}")

# Save mean and standard deviation to CSV
mean_std_file = os.path.join(results_dir, "dfNN_performance_summary.csv")
mean_std_df.to_csv(mean_std_file)
print(f"\nMean & Std saved to {mean_std_file}")

# Print summary
print("\nMean & Std of 10 Runs:")
print(mean_std_df)

In [None]:
import os
import json  # To save structured results

# Ensure the results folder exists
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

# Set model to evaluation mode
dfNN_model.eval()

# Pass through
y_train_dfNN_predicted = vmap(dfNN_model)(x_train).detach()
y_test_dfNN_predicted = vmap(dfNN_model)(x_test).detach()

# Compute Divergence (convert tensor to float)
dfNN_train_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_train), dim1=-2, dim2=-1).detach().sum().item()
dfNN_test_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_test), dim1=-2, dim2=-1).detach().sum().item()

# Compute metrics (convert tensors to float)
dfNN_train_RMSE = compute_RMSE(y_train, y_train_dfNN_predicted).item()
dfNN_train_MAE = compute_MAE(y_train, y_train_dfNN_predicted).item()

dfNN_test_RMSE = compute_RMSE(y_test, y_test_dfNN_predicted).item()
dfNN_test_MAE = compute_MAE(y_test, y_test_dfNN_predicted).item()

# Print results
print(f"dfNN Train RMSE: {dfNN_train_RMSE:.4f}")
print(f"dfNN Train MAE: {dfNN_train_MAE:.4f}")
print(f"dfNN Train Divergence: {dfNN_train_div:.4f}\n")

print(f"dfNN Test RMSE: {dfNN_test_RMSE:.4f}")
print(f"dfNN Test MAE: {dfNN_test_MAE:.4f}")
print(f"dfNN Test Divergence: {dfNN_test_div:.4f}")

# Convert results to a dictionary
results = {
    "train": {
        "RMSE": dfNN_train_RMSE,
        "MAE": dfNN_train_MAE,
        "Divergence": dfNN_train_div
    },
    "test": {
        "RMSE": dfNN_test_RMSE,
        "MAE": dfNN_test_MAE,
        "Divergence": dfNN_test_div
    }
}

# Save results to a JSON file
results_file = os.path.join(results_dir, "dfNN_performance.json")
with open(results_file, "w") as f:
    json.dump(results, f, indent=4)

print(f"\nResults saved to {results_file}")

In [None]:
# Convert to DataLoader for batching
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

# Initialise fresh model
dfNN_model = dfNN_for_vmap()
dfNN_model.train()

# Define loss function (e.g., MSE for regression)
criterion = nn.MSELoss()

# Define optimizer (e.g., Adam)
optimizer = optim.AdamW(dfNN_model.parameters(), lr = 0.0001, weight_decay = 1e-4)
num_epochs = 1000 # 600 are good

# Initialise tensor to store losses
epoch_losses = torch.zeros(num_epochs)
epoch_test_losses = torch.zeros(num_epochs)

print()
print("Start Training")
for epoch in range(num_epochs):

    epoch_loss = 0.0  # Accumulate batch losses
    epoch_test_loss = 0.0

    for batch in dataloader:
        x_batch, y_batch = batch
        x_batch.requires_grad_()

        y_pred = vmap(dfNN_model)(x_batch)

        # Compute loss (RMSE for same units as data)
        loss = torch.sqrt(criterion(y_pred, y_batch))
        epoch_loss += loss.item()
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Test loss
        y_test_pred = vmap(dfNN_model)(x_test)
        test_loss = torch.sqrt(criterion(y_test_pred, y_test))
        epoch_test_loss += test_loss.item()
    
    # Store the average loss for the epoch
    epoch_losses[epoch] = epoch_loss / len(dataloader)
    epoch_test_losses[epoch] = epoch_test_loss / len(dataloader)

    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss (RMSE): {loss.item():.4f}")

In [None]:
import matplotlib.pyplot as plt

# Plot results
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), epoch_losses, label = "Train Loss", color = "blue")
plt.plot(range(1, num_epochs + 1), epoch_test_losses, label = "Test Loss", color = "red")
plt.xlabel("Epochs")
plt.ylabel("RMSE Loss")
plt.title("dfNN Training & Test Loss Over Epochs")
plt.legend()
plt.grid()
plt.show()

In [None]:
dfNN_model.eval()
# Pass through
y_train_dfNN_predicted = vmap(dfNN_model)(x_train)
y_test_dfNN_predicted = vmap(dfNN_model)(x_test)

# Divergence
dfNN_train_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_train), dim1 = -2, dim2 = -1).detach().sum().item()
dfNN_test_div = torch.diagonal(vmap(jacfwd(dfNN_model))(x_test), dim1 = -2, dim2 = -1).detach().sum().item()

# Train
dfNN_train_RMSE = compute_RMSE(y_train, y_train_dfNN_predicted)
print(f"dfNN Train RMSE: {dfNN_train_RMSE:.4f}")
dfNN_train_MAE = compute_MAE(y_train , y_train_dfNN_predicted)
print(f"dfNN Train MAE: {dfNN_train_MAE:.4f}")
print(f"dfNN Train Divergence: {dfNN_train_div:.4f}")

# Test
print("")
dfNN_test_RMSE = compute_RMSE(y_test, y_test_dfNN_predicted)
print(f"dfNN Test RMSE: {dfNN_test_RMSE:.4f}")
dfNN_test_MAE = compute_MAE(y_test, y_test_dfNN_predicted)
print(f"dfNN Test MAE: {dfNN_test_MAE:.4f}")
print(f"dfNN Test Divergence: {dfNN_test_div:.4f}")

In [None]:
visualise_v_quiver(y_test_dfNN_predicted.detach(), x_test.detach(), title_string = "dfNN Predicted Convergence Field") # order is v, x

# PINN

In [None]:
# No test loss calculated - faster

# Convert to DataLoader for batching
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

# equal weighting in loss
w = 0.5

# Initialise fresh model
PINN_model = PINN_backbone()
PINN_model.train()

# Define loss function (e.g., MSE for regression)
criterion = nn.MSELoss()

# Define optimizer (e.g., Adam)
optimizer = optim.AdamW(PINN_model.parameters(), lr = 0.0001, weight_decay = 1e-4)
num_epochs = 1000

# Initialise tensor to store losses
epoch_train_losses = torch.zeros(num_epochs)
epoch_train_rmse_losses = torch.zeros(num_epochs)
epoch_test_losses = torch.zeros(num_epochs)
epoch_test_rmse_losses = torch.zeros(num_epochs)

print()
print("Start Training")
for epoch in range(num_epochs):

    epoch_train_loss = 0.0  # Accumulate batch losses
    epoch_train_rmse_loss = 0.0

    for batch in dataloader:
        x_batch, y_batch = batch
        x_batch.requires_grad_()

        y_pred = vmap(PINN_model)(x_batch)
        # torch.Size([32 (batch_dim), 2 (out_dim), 2 (in_dim)])
        batch_divergence = vmap(jacfwd(PINN_model))(x_batch)
        # sum: f1/x1 + f2/x2, square to account for negative
        batch_divergence_loss = torch.square(torch.diagonal(batch_divergence, dim1 = -2, dim2 = -1).sum())

        # Compute loss (RMSE for same units as data) + divergence loss
        loss = (1 - w) * torch.sqrt(criterion(y_pred, y_batch)) + w * batch_divergence_loss
        epoch_train_loss += loss.item()
        epoch_train_rmse_loss += torch.sqrt(criterion(y_pred, y_batch)).item()
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Store the average loss for the epoch
    epoch_train_losses[epoch] = epoch_train_loss / len(dataloader)
    epoch_train_rmse_losses[epoch] = epoch_train_rmse_loss / len(dataloader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss (RMSE + divergence loss): {loss.item():.4f}")

In [None]:
# Convert to DataLoader for batching
dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

# equal weighting in loss
w = 0.5

# Initialise fresh model
PINN_model = PINN_backbone()
PINN_model.train()

# Define loss function (e.g., MSE for regression)
criterion = nn.MSELoss()

# Define optimizer (e.g., Adam)
optimizer = optim.AdamW(PINN_model.parameters(), lr = 0.0001, weight_decay = 1e-4)
num_epochs = 1000

# Initialise tensor to store losses
epoch_train_losses = torch.zeros(num_epochs)
epoch_train_rmse_losses = torch.zeros(num_epochs)
epoch_test_losses = torch.zeros(num_epochs)
epoch_test_rmse_losses = torch.zeros(num_epochs)

print()
print("Start Training")
for epoch in range(num_epochs):

    epoch_train_loss = 0.0  # Accumulate batch losses
    epoch_train_rmse_loss = 0.0
    epoch_test_loss = 0.0
    epoch_test_rmse_loss = 0.0

    for batch in dataloader:
        x_batch, y_batch = batch
        x_batch.requires_grad_()

        y_pred = vmap(PINN_model)(x_batch)
        # torch.Size([32 (batch_dim), 2 (out_dim), 2 (in_dim)])
        batch_divergence = vmap(jacfwd(PINN_model))(x_batch)
        # sum: f1/x1 + f2/x2, square to account for negative
        batch_divergence_loss = torch.square(torch.diagonal(batch_divergence, dim1 = -2, dim2 = -1).sum())

        # Compute loss (RMSE for same units as data) + divergence loss
        loss = (1 - w) * torch.sqrt(criterion(y_pred, y_batch)) + w * batch_divergence_loss
        epoch_train_loss += loss.item()
        epoch_train_rmse_loss += torch.sqrt(criterion(y_pred, y_batch)).item()

        # Test loss
        y_test_pred = vmap(PINN_model)(x_test)
        epoch_test_rmse_loss += torch.sqrt(criterion(y_test_pred, y_test)).item()
        epoch_test_loss += (1 - w) * torch.sqrt(criterion(y_test_pred, y_test)) + w * torch.square(torch.diagonal(vmap(jacfwd(PINN_model))(x_test), dim1 = -2, dim2 = -1).sum()).item()
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Store the average loss for the epoch
    epoch_train_losses[epoch] = epoch_train_loss / len(dataloader)
    epoch_train_rmse_losses[epoch] = epoch_train_rmse_loss / len(dataloader)
    epoch_test_losses[epoch] = epoch_test_loss / len(dataloader)
    epoch_test_rmse_losses[epoch] = epoch_test_rmse_loss / len(dataloader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss (RMSE + divergence loss): {loss.item():.4f}")

In [None]:
import matplotlib.pyplot as plt

# Plot results
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), epoch_train_losses, label = "Train Loss", color = "blue")
plt.plot(range(1, num_epochs + 1), epoch_train_rmse_losses.detach(), label = "Train RMSE Loss", color = "lightblue")
# plt.plot(range(1, num_epochs + 1), epoch_test_losses.detach(), label = "Test Loss", color = "red")
plt.plot(range(1, num_epochs + 1), epoch_test_rmse_losses.detach(), label = "Test RMSE Loss", color = "pink")
plt.xlabel("Epochs")
plt.ylabel("RMSE Loss")
plt.title("PINN Training & Test Loss Over Epochs")
plt.legend()
plt.grid()
plt.show()

In [None]:
PINN_model.eval()
# Pass through
y_train_PINN_predicted = vmap(PINN_model)(x_train)
y_test_PINN_predicted = vmap(PINN_model)(x_test)

# Divergence
PINN_train_div = torch.diagonal(vmap(jacfwd(PINN_model))(x_train), dim1 = -2, dim2 = -1).detach().sum().item()
PINN_test_div = torch.diagonal(vmap(jacfwd(PINN_model))(x_test), dim1 = -2, dim2 = -1).detach().sum().item()
PINN_test_div_field = torch.diagonal(vmap(jacfwd(PINN_model))(x_test), dim1 = -2, dim2 = -1).sum(-1).detach()

# Train
PINN_train_RMSE = compute_RMSE(y_train, y_train_PINN_predicted)
print(f"PINN Train RMSE: {PINN_train_RMSE:.4f}")
PINN_train_MAE = compute_MAE(y_train , y_train_PINN_predicted)
print(f"PINN Train MAE: {PINN_train_MAE:.4f}")
print(f"PINN Train Divergence: {PINN_train_div:.4f}")

# Test
print("")
PINN_test_RMSE = compute_RMSE(y_test, y_test_PINN_predicted)
print(f"PINN Test RMSE: {PINN_test_RMSE:.4f}")
PINN_test_MAE = compute_MAE(y_test, y_test_PINN_predicted)
print(f"PINN Test MAE: {PINN_test_MAE:.4f}")
print(f"PINN Test Divergence: {PINN_test_div:.4f}")

In [None]:
visualise_v_quiver(
    y_test_PINN_predicted.detach(), 
    x_test.detach(), 
    PINN_test_div_field, 
    title_string = "PINN Predicted Convergence Field", 
    color_abs_max = 0.1) # order is v, x