# Building and Optimising Neural Network Surrogates with $\texttt{gwbonsai}$

Author: Lucy M Thomas

Email: lmthomas@caltech.edu

Date: 3rd June 2025

This notebook was created for a tutorial called 'Public Code Packages to Visualise and Optimise Gravitational Wave Surrogate Models'. 

The tutorial is given as part of a workshop session on Scientific Machine Learning for Gravitational Wave Astronomy, ICERM, Brown University, on 3rd June 2025.

In [None]:
# Import useful packages
import numpy as np
import matplotlib.pyplot as plt
import sklearn

We consider a simple subspace of non-spinning binary black hole mergers ($\vec{\chi}_{1}=\vec{\chi}_{2}=0$). We restrict to mass ratios between $q=1$ and $q=2$.

Let's build a fit for the time-domain amplitude of the $(2,2)$-mode.

## Create and Align the Training Data

First, let's generate some training data for our fit. We'll use the $\texttt{NRSur7dq4}$ model to generate this training data.

In [None]:
# Let's first check that the installation of gwsurrogate was successful, and then download the surrogate model.
import gwsurrogate as gws
sur = gws.LoadSurrogate('NRSur7dq4') # This download might take a minute or two, depending on your internet connection speed.

Now we'll define functions to generate training data.

In [None]:
# define a simplified interface to the NRSur7dq4 model
def NRSur7dq4_22_nonspinning(q, dt=0.1):
    """ Simplified inferface to NRSur7dq4 to get the (2,2) mode amplitude for nonspinning systems.

      INPUT
      =====
      q  -- mass ratio
      dt -- timestep size, Units of M"""

    chiA  = [0.0, 0.0, 0.0]        # dimensionless spin of the heavier BH
    chiB  = [0.0, 0.0, 0.0]        # dimensionless spin of the lighter BH
    f_low = 0.0065               # initial frequency in units of 1/M
    f_ref = f_low                  # reference frequecny (1/M) spins defined at

    times, h, dyn = sur(q, chiA, chiB, dt=dt, f_low=f_low, f_ref=f_ref)

    return times, np.abs(h[(2,2)])

def training_set_generator(N,dt=0.1,verbose=False):
    """Generate N training samples from q in [1,2]"""
    qs = np.linspace(1.0,2.0,N)
    training_data = []
    for q in qs:
        t,h = NRSur7dq4_22_nonspinning(q,dt=dt)
        training_data.append(h)
        if verbose:
            print('length of h is %i'%len(h))
    return qs, training_data

We'll start with 11 points between $q=1$ and $q=2$, and 11 points between $\chi_{1z}\in[0.,0.8]$, and start by visualising the training data.

In [None]:
num_train_samples = 11

q_train, train_data = training_set_generator(num_train_samples, verbose=True)

In [None]:
for i in range(num_train_samples):
    times = 0.1 * np.arange(len(train_data[i]))
    plt.plot(times,train_data[i], label='q=%.2f'%q_train[i])
plt.xlabel('time (M)')
plt.ylabel('$|h_{22}|$ (M)')
plt.legend()

In [None]:
def common_time_grid(training_data,dt=0.1):
    """
    INPUT
    =====
    training_data: set of training waveforms 
    
    OUTPUT
    ======
    training data as a numpy array, padding with zeros as 
    necessary such that all waveforms are of the same length"""
    
    longest_waveform = 0
    for h in training_data:
        length = len(h)
        if length > longest_waveform:
            longest_waveform = length
            
    print("longest waveform size = %i"%longest_waveform)
        
    padded_training_data = []
    for h in training_data:
        nZeros = longest_waveform - len(h)
        h_pad = np.append(h, np.zeros(nZeros))
        padded_training_data.append(h_pad)
        
    times = np.arange(longest_waveform)*dt
    
    padded_training_data = np.vstack(padded_training_data).transpose()
    
    return times, padded_training_data

times, train_data = common_time_grid(train_data)

In [None]:
def get_peak(t, h):
  """Get argument and values of t and h at maximum value of |h| on a discrete grid. """
  arg = np.argmax(np.abs(h))
  return [arg, t[arg], h[arg]]

def get_peaks(t,training_set):
    """ Find the index of each waveform's peak in the entire training set. """
    time_peak_arg = []
    for i in range(num_train_samples):
        [arg, t_peak, h_peak] = get_peak(times,training_set[:,i]) # i^th training sample
        time_peak_arg.append(arg)
        print("Waveform %i with t_peak = %f"%(i,t_peak))
    print(time_peak_arg)
    return time_peak_arg

def align_peaks(times, training_set):
    """ Peak align a set of waveforms. The shortest waveform is used as the reference
    one."""
    
    time_peak_arg = get_peaks(times,training_set)
    
    min_arg = min(time_peak_arg)
    aligned_training_set=[]
    for i in range(num_train_samples):
        offset = time_peak_arg[i] - min_arg
        print("offset value of %i"%offset)
        h_aligned = training_set[offset:,i]
        aligned_training_set.append(h_aligned)
        
    t, training_data_aligned = common_time_grid(aligned_training_set)
    return training_data_aligned

train_data_aligned = align_peaks(times, train_data)

In [None]:
for i in range(num_train_samples):
    plt.plot(times, train_data_aligned[:,i], label='q=%.2f'%q_train[i])
plt.xlabel('time (M)')
plt.ylabel('$|h_{22}|$ (M)')
plt.legend()

For the actual model, we'll probably need more than 11 training points, so let's generate 401. This might take a few seconds.

In [None]:
num_train_samples = 401
q, data = training_set_generator(num_train_samples, verbose=False)
times, data = common_time_grid(data)
data_aligned = align_peaks(times, data)

We'll find it useful to split this data set into a training, validation and test set for when we come to train our model.

In [None]:
q_train, q_validation, train_data, validation_data = sklearn.model_selection.train_test_split(q, data_aligned.T, random_state=0,test_size=0.4)
q_test, q_validation, test_data, validation_data = sklearn.model_selection.train_test_split(q_validation, validation_data, random_state=0,test_size=0.5)
print('Number of training samples: %i'%len(q_train))
print('Number of validation samples: %i'%len(q_validation))
print('Number of test samples: %i'%len(q_test))
q_train = q_train.reshape(q_train.shape[0], 1)
q_validation = q_validation.reshape(q_validation.shape[0], 1)
q_test = q_test.reshape(q_test.shape[0], 1)


## Create the Bases and Empirical Interpolant

In [None]:
import rompy as rp

In [None]:
integration = rp.Integration([times[0], times[-1]], num=len(times), rule='trapezoidal')
rb = rp.ReducedBasis(integration)
rb.make(train_data, 0, 1e-10, verbose=True)
# We could try reducing/increasing this tolerance to see how it affects the reduced basis size.
print("Reduced basis dimension: %i"%rb.size)
eim = rp.EmpiricalInterpolant(rb.basis, verbose=True)
print("Empirical interpolant completed.")
print("relative compression ratio: %.2f"%(float((num_train_samples*len(times))/(rb.size*eim.size))))

In [None]:
eim.make_data(validation_data)
val_eim_data = eim.data
eim.make_data(test_data)
test_eim_data = eim.data
eim.make_data(train_data)
train_eim_data = eim.data
train_eim_data = train_eim_data.reshape(train_eim_data.shape[1], train_eim_data.shape[0])
val_eim_data = val_eim_data.reshape(val_eim_data.shape[1], val_eim_data.shape[0])
test_eim_data = test_eim_data.reshape(test_eim_data.shape[1], test_eim_data.shape[0])


## Create A Simple Neural Network Fit

We'll make a simple, sequential, fully connected multi-layer perceptron (MLP) to model the parametric fits of the EIM nodes across the parameter space of $q\in[1,2]$.

In [None]:
import sys
sys.path.append('/Users/lucythomas/ResearchProjects/gwbonsai/')
import gwbonsai as gb
from keras.models import Sequential
from keras.layers import InputLayer, Dense, Dropout
from keras.optimizers import Adam
import tensorflow
tensorflow.random.set_seed(123456)
from contextlib import redirect_stdout

We'll compile a simple MLP to do the fit, and train it to see how it looks.

In [None]:
file_prefix = '/Users/lucythomas/ResearchProjects/ICERM/' # Where to save the model
input_shape = 1 # We have only one input, the mass ratio q
output_shape = 9 # The number of EIM nodes
num_hidden_layers = 4 # Number of hidden layers in the neural network
nodes_per_layer = 10 # Number of nodes per hidden layer
activation = 'relu' # Activation function for the hidden layers
learning_rate = 1e-3 # Learning rate for the optimizer


In [None]:
model = Sequential()
model.add(InputLayer(shape=(input_shape,)))

for layer in range(num_hidden_layers):
    model.add(Dense(nodes_per_layer, activation=activation))
    
model.add(Dense(output_shape, activation='linear'))

model.compile(
    # Optimization algorithm, specify learning rate
    optimizer=Adam(learning_rate=learning_rate),
    # Loss function for a binary classifier
    loss='mean_squared_error',
    # Diagnostic quantities
    #metrics=['mean_squared_error']
    )
# Saving summary of compiled model to model_summary.txt.
with open(file_prefix+'model_summary.txt', 'w') as f:
    with redirect_stdout(f):
        model.summary()

In [None]:
history = model.fit(q_train, train_eim_data,epochs=100, batch_size=16, validation_data=(q_validation, val_eim_data), verbose=1)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

Training history looks okay, a small amount of fitting after ~10 epochs but not too bad. Let's evaluate the surrogate model on the test set and plot the results. We'll plot a random one of the test waveform amplitudes to get an idea of how we're doing.

In [None]:
nodes = model.predict(q_test)
predicted_test_data = np.dot(nodes, eim.B)


In [None]:
plt.plot(times, predicted_test_data[26,:], label='q=%.2f (predicted)'%q_test[26])
plt.plot(times, test_data[26,:], label='q=%.2f (true)'%q_test[26])
plt.xlabel('time (M)')
plt.ylabel('$|h_{22}|$ (M)')
plt.legend();

Gross, looks like our model seriously needs some work! But how do we achieve a better fit? 

Is it with more layers? More training data? A different activaton function? A bit of everything?

How do we know (without lots of prior experience of this problem) which hyperparameters to change and which to leave?

## Motivation for Systematic Optimisation

In [None]:
from PIL import Image
from IPython.display import display
img = Image.open('mentimeter_screenshot.png')
display(img)

Let's look back at the metimeter from the end of Melissa's session yesterday.

>'tuning ml is painful'...

>'ml is trial and error'...

>'architecture is important'...

What if there was a way we could try and simplify all this? To optimise the network and its training data in a more systematic way, to try and make sure we're getting the best surrogate we possibly can with a neural network?

Enter: `gwbonsai` (Building and Optimising Neural Surrogates for Astrophysical Inference). 

It's a helper package that provides routines for systematically optimising the training of neural network gravitatioanl wave surrogate models. It's available at this Git repo [link](https://github.com/lucymthomas/gwbonsai/tree/main). 

The package is stll being actively developed, so if you're a person who would benefit from this and want to see features that don't exist yet, please let me know! 

It leverages the power of `Optuna` for optimising the hyperparameters of the model, and then optimises the amount and distribution of training data required to ahcieve a good fit.

Practically, we found that for most surrogate-building applications, optimising all the architecture (not to mention the training dataset) all at once was prohibively expensive. Therefore we split the problem up into sections:
    
1. Optimise functional hyperparameter (network non-linear behaviour)
2. Optimise network size and shape parameters (avoids overfitting)
3. Optimise training dataset size and distribtuion (ensures good coverage of the parameter space)


In the remainder of this tutorial, we will go over these three steps in details for our (2,2)-mode amplitude model, and hopefully achieve a more convincing surrogate at the end!

## A Quick Aside: How Does the Hyperparameter Optimisation Work?

The default hyperparameter optimisation used in `Optuna` is called a `TPE sampler`, which stands for `Tree-structured Parzen Estimator`.
The mathematical idea behind the TPE sampler is based on Bayesian optimisation, and simultaneously models the distribution of 'good' points in hyperspace, and 'bad' ones.


Let $x$ be the value of a hyperparameter

Let $y$ be the value of the loss function that hyperparameter achieves

We are trying to find the value of $x$ across the space that minimises $y$

TPE key idea:

model $P(x|y)$ rather than $P(y|x)$

Throughout optimisation, we model two distributions:

$P(x | y > y*) = P(x | bad)$

$P(x | y =< y*) = P(x | good)$

New samples are proposed by maximising the expected improvement:

$EI = P(x | y =< y*) / P(x | y > y*)$


In [None]:
img = Image.open('TPE1.png')
display(img)

In [None]:
img = Image.open('TPE2.png')
display(img)

## Optimising Functional Hyperparameters

In [None]:
functional_options_dict = {
    'activation': ['relu', 'tanh', 'sigmoid'],
    'weight_init': ['glorot_uniform'], #['glorot_uniform', 'HeUniform', 'HeNormal'],
    'learning_rate': [1e-4, 1e-3, 1e-2],
    'optimiser': ['Adam','Adamax','Nadam','Ftrl', 'SGD', 'RMSprop'],
    'normalisation': [0] #[0, 1], # Whether to use batch normalization
}

fixed_dict = {
    'nodes_per_layer': 10,
    'num_hidden_layers': 4,
    'dropout': 1.0, # 1 for dropout layer, 0 for no dropout
    'dropout_rate': 0.2, # dropout rate
    'batch_size':16, #[8, 16, 32]
    'num_epochs': 100,
}
input_dim = 1
output_dim = 9

In [None]:
from optuna.samplers import TPESampler
import optuna
from functools import partial
sampler = TPESampler(seed=123456)
from gwbonsai.optimise_hyper.optimise_functional_tensorflow import functional_objective

objective_partial = partial(functional_objective, input_dim=input_dim, output_dim=output_dim, functional_options_dict=functional_options_dict, fixed_dict=fixed_dict, x_train=q_train, train_eim_data=train_eim_data, x_validation=q_validation, val_eim_data=val_eim_data, x_test=q_test, test_eim_data=test_eim_data, eim=eim)
study = optuna.create_study(direction='minimize',sampler=sampler)
study.optimize(objective_partial, n_trials=50)

print("Number of finished trials: ", len(study.trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
        
best_params = trial.params
np.save('best_functional_hyper.npy', best_params)


In [None]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_optimization_history

plot_optimization_history(study)

In [None]:
plot_contour(study)

There are lots of additional great visualisation tools available in Optuna if you want to dive deeper into , for more information see their [documentation page](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/005_visualization.html)

## Optimising Size and Shape Parameters

In [None]:
# Load the best functional hyperparameters obtained from functional_optimisation_7d.py
functional_best_params = np.load('best_functional_hyper.npy', allow_pickle=True)
functional_best_params
fixed_dict = functional_best_params.item()
fixed_dict['num_epochs'] = 100

shape_options_dict = {
    'num_hidden_layers': [2,4,6,8,10], # Number of hidden layers in the neural network
    'nodes_per_layer': [8,10,15,20,50], # Number of nodes per hidden layer
    'dropout': [0.0,1.0], # 1 for dropout layer, 0 for no dropout
    'dropout_rate': np.linspace(0.1, 0.5, 5).tolist(), # dropout rate
    'batch_size': [8] # Batch size for training
}

input_shape = 1
output_shape = 9

In [None]:
sampler = TPESampler(seed=123456)
from gwbonsai.optimise_hyper.optimise_size_shape_tensorflow import shape_objective

objective_partial = partial(shape_objective, input_dim=input_dim, output_dim=output_dim, shape_options_dict=shape_options_dict, fixed_dict=fixed_dict, x_train=q_train, train_eim_data=train_eim_data, x_validation=q_validation, val_eim_data=val_eim_data, x_test=q_test, test_eim_data=test_eim_data, eim=eim)
study = optuna.create_study(direction='minimize',sampler=sampler)
study.optimize(objective_partial, n_trials=100)

print("Number of finished trials: ", len(study.trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
        
best_params = trial.params
np.save('best_shape_hyper.npy', best_params)


In [None]:
from optuna.visualization import plot_contour
from optuna.visualization import plot_optimization_history

plot_optimization_history(study)

In [None]:
plot_contour(study)

## Optimising Training Dataset

New we've obtained our optimal hyperparameters for the network:

In [None]:
best_functional = np.load('best_functional_hyper.npy', allow_pickle=True)
best_shape = np.load('best_shape_hyper.npy', allow_pickle=True)
best = best_functional | best_shape
best

We will split the original training set into a smaller training set, and the rest will be a holdout set. The smaller training set will iteratively grow larger, as the worst performing points from the holdout set are added to it.

In [None]:
q_holdout, q_train_0, holdout_eim_data, train_eim_data_0 = sklearn.model_selection.train_test_split(q_train, train_eim_data, random_state=0,test_size=0.1)
print('Number of training samples: %i'%len(q_train_0))
print('Number of holdout samples: %i'%len(q_holdout))

We create a list with the sizes of training datasets we wish to have over the course of the procedure- these list values must add up to (less than) the length of q_train!

In [None]:
num_iterations = 10
iteration_size = len(q_train) // num_iterations
append_sizes = [iteration_size] * num_iterations

assert np.sum(append_sizes) == 240, f"Sum is {np.sum(append_sizes)}, expected 240"



Sert up dataframes to store our results

In [None]:
import pandas as pd
df_train = pd.DataFrame(train_eim_data, columns=['eim_0', 'eim_1', 'eim_2', 'eim_3', 'eim_4', 'eim_5', 'eim_6', 'eim_7', 'eim_8'])
df_train['q']=q_train.flatten()
df_train

In [None]:
df_test = pd.DataFrame(test_eim_data, columns=['eim_0', 'eim_1', 'eim_2', 'eim_3', 'eim_4', 'eim_5', 'eim_6', 'eim_7', 'eim_8'])
df_test['q']=q_test.flatten()
df_test

In [None]:
input_cols = ['q']
output_cols = ['eim_0', 'eim_1', 'eim_2', 'eim_3', 'eim_4', 'eim_5', 'eim_6', 'eim_7', 'eim_8']

In [None]:
first_iteration = np.full(len(q_train), np.nan)
first_iteration[:append_sizes[0]] = 0
df_train['first_training_iteration'] = first_iteration


In [None]:
from gwbonsai.optimise_data.optimise_data import train_iteration

In [None]:
df_train, df_test = train_iteration(append_sizes, best, df_train, df_test, input_cols, output_cols, 100, eim)

In [None]:
df_train

In [None]:
# Sample data
data1 = df_test['mean_error_0']
data2 = df_test['mean_error_1']
data3 = df_test['mean_error_2']
data4 = df_test['mean_error_3']
data5 = df_test['mean_error_4']
data6 = df_test['mean_error_5']
data7 = df_test['mean_error_6']
data8 = df_test['mean_error_7']
data9 = df_test['mean_error_8']
data10 = df_test['mean_error_9']

data = [data1, data2, data3, data4, data5, data6, data7, data8, data9, data10]
positions = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] # Custom positions for each violin

# Create the plot
fig, ax = plt.subplots(figsize=(15, 6))

# Create violin plots
ax.violinplot(data, positions=positions, showmeans=True)

# Set the x-axis ticks and labels
ax.set_xticks(positions)
ax.set_xticklabels(['Mean error 1', 'Mean error 2', 'Mean error 3', 'Mean error 4', 'Mean error 5', 'Mean error 6', 'Mean error 7', 'Mean error 8', 'Mean error 9', 'Mean error 10'])

# Add labels and title
ax.set_xlabel('Training Data Sets')
ax.set_ylabel('Mean Squared Error')

ax.set_yscale('log')
# Show the plot
plt.show()

Where was the training data distributed? Where were extra points added?


In [None]:
# Sample data
data1 = df_test[df_train['mean_error_0'].notna()]['q']
data2 = df_test[df_train['mean_error_1'].notna()]['q']
data3 = df_test[df_train['mean_error_2'].notna()]['q']
data4 = df_test[df_train['mean_error_3'].notna()]['q']
data5 = df_test[df_train['mean_error_4'].notna()]['q']
data6 = df_test[df_train['mean_error_5'].notna()]['q']
data7 = df_test[df_train['mean_error_6'].notna()]['q']
data8 = df_test[df_train['mean_error_7'].notna()]['q']
data9 = df_test[df_train['mean_error_8'].notna()]['q']
data10 = df_test[df_train['mean_error_9'].notna()]['q']

# Create the plot
fig, ax = plt.subplots(figsize=(15, 6))

# Create violin plots
ax.hist(data1, label='Iteration 0', histtype='step', bins=20, density=True)
ax.hist(data2, label='Iteration 1', histtype='step', bins=20, density=True)
ax.hist(data3, label='Iteration 2', histtype='step', bins=20, density=True)
ax.hist(data4, label='Iteration 3', histtype='step', bins=20, density=True)
ax.hist(data5, label='Iteration 4', histtype='step', bins=20, density=True)
ax.hist(data6, label='Iteration 5', histtype='step', bins=20, density=True)
ax.hist(data7, label='Iteration 6', histtype='step', bins=20, density=True)
ax.hist(data8, label='Iteration 7', histtype='step', bins=20, density=True)
ax.hist(data9, label='Iteration 8', histtype='step', bins=20, density=True)
ax.hist(data10, label='Iteration 9', histtype='step', bins=20, density=True)


# Add labels and title
ax.set_xlabel('q')
ax.set_ylabel('Density')

# Show the plot
plt.show()

## Questions/Notes

How does our final fit compare to the original one at the beginning? Test it out!

Things we did today:

    - Optimised the functional hyperparameters of the neural network
    - Optimised the shape hyperparameters of the neural network
    - Optimised the size and distribution of the training data
Things we did not do:

    - Changing the size of the reduced basis as our iterative training data set grows
    - Optimising the size of the reduced basis for the accuracy we want to achieve
    - Use more complicated neural netwrk architectures or features (learning rate schedulers, other kinds of architectures, hourglass architectures, etc.)
    - Play around with data scaling (normalisation, standardisation, etc.)

What else could we have done to improve our fits?