# Building a 1-dimensional surrogate

ICERM Workshop: Scientific Machine Learning for Gravitational Wave Astronomy

Date: 6/3/2025

Author: Scott Field

**Note**: In addition to the usual Python libraries, you will need rompy and gwsurrogate.

```bash
>>> conda create -n icerm python=3.11 "numpy<2.0"
>>> conda install -c conda-forge gwsurrogate
>>> conda install -c conda-forge jupyterlab
>>> pip install forked-rompy
```


# Session goal...

## Target audience

* Introductory tutorial showing how some of the methods referenced in Carl's talk can be implimented in practice.
* Assumes people broadly know about gravitational wave models. Maybe even dabbled in surrogate modelling.

## Problem statement

This notebook describes how to build a simple 1-dimensional surrogate model of your favorite 1-dimensional gravitational-wave model:

\begin{align*}
h(t, \theta, \phi; q) & = h_+(t, \theta, \phi; q) - \mathrm{i} h_x(t, \theta, \phi; q) \\
& = \sum_{\ell=2}^{\infty} \sum_{m=-\ell}^{\ell} h^{\ell m}(t;q) {}_{-2}Y_{\ell m} \left(\theta, \phi \right) \, ,
\end{align*}

Where $\theta$ and $\phi$ are angles for the direction of propagation away from the source, $q$ is the mass ratio and ${}_{-2}Y_{\ell m}$ are the harmonic functions.

We will build a surrogte for the (2,2)-mode

$$h^{22}(t;q)$$


Our building strategy will mix different methods according to:

* Align waveforms in time, but not phase
* SVD basis
* Empirical interpolant representation
* Splines to approximate the real/imaginary parts of h(T;q) for T fixed and q variying

While each part has been reported in the litterature (being a mixture of methods considered in https://arxiv.org/abs/1402.4146 and https://arxiv.org/abs/1308.3565), this combination hasn't been considered in a published paper before. This illustrates the flexability of surrogate building.

In [1]:
!pip install -q condacolab
import condacolab
condacolab.install()

⏬ Downloading https://github.com/jaimergp/miniforge/releases/download/24.11.2-1_colab/Miniforge3-colab-24.11.2-1_colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:21
🔁 Restarting kernel...


In [1]:
# After kernel restart, run again to ensure Conda is active
import condacolab
condacolab.check()

✨🍰✨ Everything looks OK!


In [None]:
!conda create -n icerm python=3.11 "numpy<2.0"
!conda install -c conda-forge gwsurrogate
!conda install -c conda-forge jupyterlab
!pip install forked-rompy

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
tk-8.6.13            | 3.1 MB    | :   1% 0.009974418635798569/1 [00:00<00:10, 10.71s/it][A[A[A[A

python-3.11.12       | 29.1 MB   | :   0% 0.0005363802244363621/1 [00:00<04:36, 276.31s/it]
numpy-1.26.4         | 7.7 MB    | :  38% 0.37578494127740397/1 [00:00<00:00,  2.17it/s] [A


libstdcxx-15.1.0     | 3.7 MB    | :  27% 0.26870338551976947/1 [00:00<00:00,  1.51it/s] [A[A[A



tk-8.6.13            | 3.1 MB    | :  82% 0.8228895374533819/1 [00:00<00:00,  4.67it/s]  [A[A[A[A

python-3.11.12       | 29.1 MB   | :   4% 0.038082995934981705/1 [00:00<00:05,  5.46s/it]  
numpy-1.26.4         | 7.7 MB    | :  70% 0.702819403686388/1 [00:00<00:00,  2.66it/s]  [A



tk-8.6.13            | 3.1 MB    | : 100% 1.0/1 [00:00<00:00,  4.67it/s]               [A[A[A[A

python-3.11.12       | 29.1 MB   | :  15% 0.1485773221688723/1 [00:00<00:01,  1.80s/it]  


libstdcxx-15.1.0     | 3.7 MB    | : 100% 1.0/1 [00:00<00:00

In [None]:
# import the usual suspects
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Step 0: Training data from the underlying model

* We need to collect training data on which to build the model
* For simplicity, our training data will come from the generically precessing binary black hole model **NRSur7dq4**
* **NRSur7dq4** is already fast, so a surrogate isn't needed (although it could be made faster still!). In a more realistic senario, training data would come from numerical relativity simulations.
* To get access to **NRSur7dq4** we will use the Python package [GWSurrogate](https://github.com/sxs-collaboration/gwsurrogate), which is maintained by members of the Simulating eXtreme Spacetimes (SXS) collaboration

In [None]:
import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

import gwsurrogate as gws
gw_model =  gws.LoadSurrogate('NRSur7dq4')

In [None]:
# define a simplified interface to the NRSur7dq4 model
def NRSur7dq4_22_nonspinning(q, dt=0.1):
    """ Simplified inferface to NRSur7dq4 to get the (2,2) mode for nonspinning systems.

      INPUT
      =====
      q  -- mass ratio
      dt -- timestep size, Units of M"""

    chiA  = [0.0, 0.0, 0.0]        # dimensionless spin of the heavier BH
    chiB  = [0.0, 0.0, 0.0]        # dimensionless spin of the lighter BH
    f_low = 0.0065                 # initial frequency of wave
    f_ref = f_low                  # reference frequency the spins are defined at

    times, h, dyn = gw_model(q, chiA, chiB, dt=dt, f_low=f_low, f_ref=f_ref)

    return times, h[(2,2)]


In [None]:
t, h = NRSur7dq4_22_nonspinning(2.0)

In [None]:
plt.plot(t,np.real(h),'r',label = "real")
plt.plot(t,np.imag(h),'k--',label = "imag")
plt.legend()

# Step 1: Poplulating the training set

There are potentially many reasonable ways to construct a training set defined on the training region

$$q \in [1, 2]$$

and we will use the easiest of all: uniformly spaced grid

In [None]:
# global settings that control the training set
dt = 0.1
train_samples = 50 # more waveforms, better models

In [None]:
def training_set_generator(N,verbose=False):
    """Generate N training samples from q in [1,2]"""
    qs = np.linspace(1.0,2.0,N)
    training_data = []
    for q in qs:
        t,h = NRSur7dq4_22_nonspinning(q,dt=dt)
        training_data.append(h)
        if verbose:
            print(f"The number of time samples of h(t;{q}) is {len(h)}")
    return qs, training_data

In [None]:
# generating training data
qs, training_data = training_set_generator(train_samples, verbose=True)

## Interlude I: preparing waveform data -- common durations

* For surrogate modeling to work, we need all of the waveforms to be of the same duration/length. From whats shown above, this is not the case!
* We need to modify our waveform data to fix this problem

In [None]:
def common_time_grid(training_data):
    """
    INPUT
    =====
    training_data: set of training waveforms

    OUTPUT
    ======
    training data as a numpy array, padding with zeros as
    necessary such that all waveforms are of the same length"""

    longest_waveform = 0
    for h in training_data:
        length = len(h)
        if length > longest_waveform:
            longest_waveform = length

    print(f"Maximum number of waveform time samples over training set = {longest_waveform}")

    padded_training_data = []
    for h in training_data:
        nZeros = longest_waveform - len(h)
        h_pad = np.append(h, np.zeros(nZeros))
        padded_training_data.append(h_pad)

    times = np.arange(longest_waveform)*dt

    padded_training_data = np.vstack(padded_training_data).transpose()

    return times, padded_training_data

In [None]:
times, training_data = common_time_grid(training_data)

In [None]:
plt.figure(1)
plt.plot(times,np.abs(training_data));

## Interlude II: preparing waveform data -- temporal alignment

* For efficient surrogate models (efficient =  admits a low-dimensional approximation space), the model should change as little as possible as the parameter is varied.
* Unfortunately, thats not the case here! Notice the waves peak at different times.
* A simulation-dependent time-shift will align the peaks.

**Goal**: Lets align all of the waveform peaks, using the shortest waveform as the reference one.

**Warning**: The code below finds the waveform peak on a discrete grid. For high accuracy models, better peak finding is needed.

In [None]:
def get_peak(t, h):
  """Get argument and values of t and h at maximum value of |h| on a discrete grid. """
  arg = np.argmax(np.abs(h))
  return [arg, t[arg], h[arg]]

def get_peaks(t,training_set):
    """ Find the index of each waveform's peak in the entire training set. """
    time_peak_arg = []
    for i in range(train_samples):
        [arg, t_peak, h_peak] = get_peak(times,training_set[:,i]) # i^th training sample
        time_peak_arg.append(arg)
        print("Waveform %i with t_peak = %f"%(i,t_peak))
    #print(time_peak_arg)
    return time_peak_arg

def align_peaks(times, training_set):
    """ Peak align a set of waveforms. The shortest waveform is used as the reference
    one."""

    time_peak_arg = get_peaks(times,training_set)

    min_arg = min(time_peak_arg)
    aligned_training_set=[]
    for i in range(train_samples):
        offset = time_peak_arg[i] - min_arg
        #print("offset value of %i"%offset)
        h_aligned = training_set[offset:,i]
        aligned_training_set.append(h_aligned)

    t, training_data_aligned = common_time_grid(aligned_training_set)
    return training_data_aligned

In [None]:
training_data_aligned = align_peaks(times, training_data)

In [None]:
training_data_aligned.shape

In [None]:
plt.figure(1)
plt.plot(times,np.abs(training_data_aligned))

plt.figure(2)
plt.plot(times,np.abs(training_data_aligned),'*')
plt.xlim([2189-20,2189+20])

# Step 2: Building a basis

## The snapshot matrix

The matrix **training_data_aligned** has 25117 rows (the time samples) and 50 columns (parameter samples). That is, each column is a complexified waveform $h^{22}(\vec{t},q)$ sampled at some value of the mass ratio $q$.

## Finding a compact basis for the set of waveforms

**Goal**: Find $n$ basis $u_i(t)$ such that (i) $n$ is small and (ii) the approximation

$$h(t;q) \approx \sum_{i=1}^n  c_i (q) u_i(t)$$

is accurate. Some popular options for linear dimensionality reduction (in GW modeling at least) are:

* **The greedy algorithm**: Find the $n$ most important waveforms (ie matrix columns).
* **Singular Value Decomposition** (SVD): Find the $n$ most important "modes" of the data (linear combinations over all columns)
* Here we will use the SVD, which is already coded up in numpy.

### Background on the SVD

You can think of the SVD as an algorithm that identifies the "most important" columns of a matrix. A column is "important" if the other columns of the matrix can be written as a linear combination of the important ones.

Suppose you have a 25117-by-50 matrix, ${\bf A}$, and you discover that all 50 columns can be written as a linear combination of just 2 columns (up to some accuracy threshold). That would be great -- you can express any training waveform as a sum over two basis functions. This is what the SVD provides, and how it can  be used to approximate a matrix ${\bf A}$.

### Numpy implementation

The numpy SVD routine will decompose the training data matrix as

$$\text{training\_data\_aligned} = U S V^*$$

For our purpose,
* The most important columns (the basis) are columns of the matrix U
* The matix $S$ is diagonal, whose entries are singluar values.
* The $i^{th}$ singular value, $s_i$, assigns a weight to how important the $i^{th}$ column is.

In [None]:
# first, lets check the claim that peak-aligned training sets are easier to approximate.
# From the above discussion, faster decaying singular values means the matrix/model is easier
# to compress
u, s_no_align, vh = np.linalg.svd(training_data, full_matrices=False)
u, s, vh          = np.linalg.svd(training_data_aligned, full_matrices=False)
plt.semilogy(range(train_samples),s,'r*',label='align')
plt.semilogy(range(train_samples),s_no_align,'b+',label='no align')
plt.ylabel('singluar value')
plt.xlabel('singluar value index k')
plt.legend(loc=3)

In [None]:
# Top 2 singular vectors -- these are the 2 most features according to the SVD
for i in range(2):
    plt.figure(i)
    plt.plot(times,np.real(u[:,i]), label= "real")
    plt.plot(times,np.imag(u[:,i]), label= "imag")
    plt.legend()

## Basis application example

We can immediately use the first $n$ basis to compute

$$h(t;q) \approx \sum_{i=1}^n  c_i (q) u_i(t)$$


where $u_i(t)$ is the $i^{th}$ basis vector and $c_i (q)$ are the coefficients of the representation of $h(t;q)$ in the approximation space span$\{u_i\}_{i=1}^n$.

Since $u_i$ are orthonormal, the coefficient $c_i (q)$ is found by taking the inner produce of $h(t;q)$ with $u_i$

In [None]:
# Decide on how many column vectors to use as the basis
basis_size = 7 # More vectors -> more accuracy but also more computational cost
basis_set  = u[:,0:basis_size]

In [None]:
# Select a waveform to approximate
h_i = training_data_aligned[:,10]

# compute its representation in the linear space defined by the span of the basis set
proj_coeffs = np.dot(basis_set.conjugate().transpose(), h_i)
h_approx = np.dot(basis_set, proj_coeffs.transpose())
h_approx = h_approx.transpose()

# plot the original waveform, its approximated version, and the error
plt.figure(1)
plt.plot(times,h_i,'blue',label='h')
plt.plot(times,h_approx,'r--',label='h with 7 basis')
plt.legend()

plt.figure(2)
wave_err = np.abs(h_i - h_approx)
plt.semilogy(times,wave_err)

# Step 3: An empirical interpolant representation

As an alternative to projecting onto the basis set, an empirical interpolant can be used. The main idea is to trade the $n$ peices of information, the projection coefficients $c_i$, for direct waveform evaluations.

Given $n$ basis, there (usually) exists $n$ times $\{T_i\}_{i=1}^n$ for which the set of numbers
$$\{ c_i (q) \}_{i=1}^n \Longleftrightarrow \{ h(T_i;q) \}_{i=1}^n$$
contains equivalent information in the sense that

$$ h (t;q) \approx \sum_{i=1}^n  c_i (q) u_i(t) \approx \sum_{i}^n h(T_i;q) B_i(t)$$

where $B_i(t)$ are just linear combinations of $u_i(t)$ (ie the approximation space, the span of the basis, is unchanged).

In [None]:
# Recommened to "pip install forked-rompy" (https://pypi.org/project/forked-rompy/)
# Original code: https://bitbucket.org/chadgalley/rompy/
import rompy as rp

In [None]:
eim = rp.EmpiricalInterpolant(basis_set.transpose(), verbose=True) # Note the transpose

In [None]:
# the most important pieces of information are the interpolation times T_i and the basis B_i
eim_indicies = eim.indices
B = eim.B

In [None]:
# Lets see what time the empirical inteprolation method has discovered as most important:
T_eim = times[eim_indicies]

plt.figure(1)
plt.plot(times,np.real(h_i),'k')
plt.plot(T_eim,np.zeros_like(T_eim),'r*')

plt.figure(2)
plt.plot(times,np.real(h_i),'k')
plt.plot(T_eim,np.zeros_like(T_eim),'r*')
plt.xlim([2189-50,2189+50])

## Empirical interpolation application example

Claim: To compute the waveform $h(t;q)$ at any time, we only need to know the waveform's value at the times $T_i$.

Lets check this...

In [None]:
# Select a waveform to approximate
h_i = training_data_aligned[:,10]

# compute its empirical interpolant representation
h_eim =  h_i[eim_indicies]
h_approx = np.dot(B.transpose(), h_eim)

# plot the original waveform, its empirically interpolated version, and the error
plt.figure(1)
plt.plot(times,h_i,'blue')
plt.plot(times,h_approx,'r--')

plt.figure(2)
plt.semilogy(times,np.abs(h_i - h_approx),label = 'Empirical interpolant')
plt.semilogy(T_eim,np.ones_like(T_eim)*1.e-12,'r*', label = r'T_{eim} nodes') # The empirical interpolant interpolates the data. So the errors will be zero at the EIM nodes
plt.semilogy(times,wave_err,label = 'SVD error')
plt.legend()

# Step 4: Parametric fits

We are almost ready to complete our surrogate model! If we could evaluate the few functions $h(T_i;q)$ at any q, then we would be done. So what remains is finding a way to predict their values.

Here one often uses some regression/interpolation method
- Polynomials
- Splines (This tutorial)
- Gaussian Process Regression
- Deep neural networks (Lucy Thomas' tutorial)
- Others?

In [None]:
# First, view the data we need to fit
h_training_eim = training_data_aligned[eim_indicies,:]

#for counter, i in enumerate(eim_indicies):
for counter, i in enumerate(eim_indicies[:2]):
    plt.figure(counter)
    plt.title("eim index %i at t[i] = %f"%(i,times[i]))
    plt.plot(qs,np.real(h_training_eim[counter,:]),'r',label='real' )
    plt.plot(qs,np.imag(h_training_eim[counter,:]),'k--',label='imag' )
    plt.legend()

## Interlude: dealing with hard-to-aproximate functions

Evidently, these functions have a significant amount of struture in $q$. There are a few strategies for dealing with this:

1) One could look for adventageous decompositions of the data, for example,

$$h^{22}(T_i;q) = A^{22}(T_i;q) \exp\left(-\mathrm{i} \phi^{22}(T_i;q) \right)$$

since the amplitude, $A$, and phase, $\phi$ are expected to be "borring" functions of $q$. **There will be examples of this later today**

2) One could align the phases at some reference time. This would happen early in the surrogate building process, say step 1 after temporal alignment

3) One could brute force it, by using splines with dense grids

... we will follow 3

In [None]:
from scipy.interpolate import splrep, splev

In [None]:
# lets first show how this works on a data piece
eim_indx = 1
h_eim_real_spline = splrep(qs, np.real(h_training_eim[eim_indx,:]),k=2) # degree 2
h_eim_imag_spline = splrep(qs, np.imag(h_training_eim[eim_indx,:]),k=2) # degree 2

# evaluate the splines on a dense reference grid
q_dense = np.linspace(min(qs),max(qs),300)

plt.figure(1)
plt.plot(qs,np.real(h_training_eim[eim_indx,:]),'b',label ='real, data')
plt.plot(q_dense,splev(q_dense,h_eim_real_spline),'r--',label ='real, spline')
plt.plot(qs,np.imag(h_training_eim[eim_indx,:]),'k',label ='imag, data')
plt.plot(q_dense,splev(q_dense,h_eim_imag_spline),'g--',label ='imag, spline')
plt.legend()

In [None]:
# Now find the spline interpolant for all of the coefficients
h_eim_real_spline = [splrep(qs, np.real(h_training_eim[i,:]),k=2) for i in range(len(eim_indicies))]
h_eim_imag_spline = [splrep(qs, np.imag(h_training_eim[i,:]),k=2) for i in range(len(eim_indicies))]

In [None]:
# The full surrogate can now be evaluated by evaluating the splines, and using the
# empirical interpolation representation

q=1.2 # mass ratio value to predict now model for

h_eim = np.array([splev(q, h_eim_real_spline[j])  \
             + 1.0j*splev(q,h_eim_imag_spline[j]) for j in range(len(eim_indicies))])
h_approx = np.dot(B.transpose(), h_eim)

# plot the original waveform, its empirically interpolated version, and the error
plt.figure(1)
plt.plot(times,np.real(h_approx),'k')
plt.plot(times,np.imag(h_approx),'r--')

# Step 5: Full surrogate

We're basically done.

Lets package up all the data to make it more user friendly. Then sanity test the model by comparing against the training data. Show its faster.

In [None]:
def surrogate(q):
    h_eim = np.array([splev(q, h_eim_real_spline[j])  \
             + 1.0j*splev(q,h_eim_imag_spline[j]) for j in range(len(eim_indicies))])
    h_approx = np.dot(B.transpose(), h_eim)
    return h_approx

In [None]:
h_surr = surrogate(q=1.5)
plt.figure(1)
plt.plot(times,np.real(h_surr),'k')
plt.plot(times,np.imag(h_surr),'r--')

In [None]:
# timing experiment -- NRSur7dq4 model from gwsurrogate
print(f"{train_samples} samples from NRSur7dq4 model from gwsurrogate...")
%time qs, training_data = training_set_generator(train_samples)

print(f"{train_samples} samples from surrogate model of...")
%time surrogate_evals = np.array([surrogate(q) for q in qs]).transpose()

## How to interpret timing experiments

**Note**: NRSur7dq4 is slower because its computing a lot more physics!

NRSur7dq4 is a fully precessing model with $\ell \leq 4$ modes.

The surrogate we just built is for the $(2,2)$ mode only, nonspinning

**Punchline**: you can typically making a model faster by removing accuracy/physics

In [None]:
# accuracy as relative error in max norm
# here we compare the model against the training data
h_error = np.abs(training_data_aligned - surrogate_evals)
h_inf =  h_error.max(axis=0) / np.abs(training_data_aligned).max(axis=0)

In [None]:
plt.plot(qs,h_inf)

# Going further

The surrogate model seems to be working well. Further directions to explore include...

* How do the errors compare against testing waveforms not used to train?
* What is the dominant source of model error? Do the errors get smaller as...
  * the number of basis used is increased?
  * The training set density is increased?
  * Different order splines are used?
  * The sampling rate dt is decreased?
* This notebook can be used for other 1-dimensional models, e.g. different slices of the precessing parameter space.
* Fairly straightforward extensions can be used for some higher-dimensional models too

# Appendix

## Short exercise 1

Explore the relationship between the singular value spectrum, s, and the waveform error.

Task: Select some waveform from your training set to approximate. Compute the approximation error vs # of SVD basis. Plot this error along with the singular values. What do you notice?

In [None]:
# Select a waveform to approximate
h_i = training_data_aligned[:,10]

errors = []
bss = range(1,100)
for bs in bss:
    basis_set_local  = u[:,0:bs]

    # compute its representation in the linear space defined by the span of the basis set
    proj_coeffs = np.dot(basis_set_local.conjugate().transpose(), h_i)
    h_approx = np.dot(basis_set_local, proj_coeffs.transpose())
    h_approx = h_approx.transpose()

    wave_err = np.abs(h_i - h_approx)
    errors.append(np.sum(wave_err))

# plot the original waveform, its compressed version, and the error
plt.semilogy(bss,errors,label='wave error')
plt.semilogy(range(train_samples),s,'r*',label='singular values')
plt.legend()