If you're running in a separate notebook (e.g., Google Colab), go through and un-comment the cells below as required. Also make sure to set the runtime before running the notebook.

In [None]:
# # If you're running on a shared cluster and want to limit the resources you take up:
# import os
# os.environ["OPENBLAS_NUM_THREADS"] = '1'
# os.environ["MKL_NUM_THREADS"] = '1'
# os.environ["VECLIB_MAXIMUM_THREADS"] = '1'
# os.environ["NUMEXPR_NUM_THREADS"] = '1'
# os.environ['OMP_NUM_THREADS'] = '1'
# os.environ['NPROC'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' # you can change to a GPU ID not in use

In [None]:
# !pip install numpy matplotlib corner h5ify
# !pip install wcosmo jax_tqdm equinox equinox optax flowjax

In [None]:
# # If you're running on CPU:
# !pip install jax numypro

# # If you're running on GPU
# !pip install -U 'jax[cuda12]'
# !pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# # Download software injections from LVK O3:
# !mkdir -p data
# !wget https://github.com/mdmould/ml-gw-pop/raw/refs/heads/main/data/vt.h5 -P data

## Neural posterior estimation for gravitational-wave population inference

In this notebook, we'll train a normalizing flow to learn the Bayesian posterior for an astrophysical population model from gravitational-wave catalogues using simulation-based inference. This will be a toy pedagogical example, but hopefully you can it as a base to build a more realistic inference pipeline. We'll focus on just binary black-hole mergers. We'll use [JAX](https://github.com/jax-ml/jax) as the main workhorse behind this notebook.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)

In [None]:
# check for GPU devices
jax.devices()

#### Population model

First, let's define the population model that we'll use to model the astrophysical distribution of sources. We'll only include source-frame primary masses, binary mass ratio, and merger redshift in our model.

- The primary masses will follow a simplified version of the [Power Law + Peak](https://arxiv.org/abs/1801.02699) model. It's parameters are:
  - alpha = power law slope,
  - mu = location of peak,
  - sigma = width of peak,
  - f = fraction of source in peak,
  - mmin = minimum primary mass,
  - mmax = maximum mass.
- The mass ratios will follow a simple power law. Parameters:
  - beta = power law slope.
- We'll assume that the merger rate evolves over comoving volume and source-frame time as a [power law in redshift](https://arxiv.org/abs/1805.10270). Parameters:
  - gamma = power law slope.

We'll use [wcosmo](https://github.com/ColmTalbot/wcosmo), which is a nice package for cosmological calculations in JAX.

In [None]:
import wcosmo
wcosmo.disable_units()

In [None]:
def truncated_powerlaw(x, alpha, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha
    norm = (xmax**(alpha + 1) - xmin**(alpha+1)) / (alpha + 1)
    return cut * shape / norm

def truncated_normal(x, mu, sigma, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = jax.scipy.stats.norm.pdf(x, mu, sigma)
    norm = (
        - jax.scipy.stats.norm.cdf(xmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax, mu, sigma)
    )
    return cut * shape / norm

In [None]:
def pdf_mass_1_source(x, parameters):
    mmin, mmax = 2, 100
    pl = truncated_powerlaw(x, parameters['alpha'], mmin, mmax)
    tn = truncated_normal(x, parameters['mu'], parameters['sigma'], mmin, mmax)
    return (1 - parameters['f']) * pl + parameters['f'] * tn

def pdf_mass_ratio(x, parameters):
    q_min, q_max = 0.01, 1
    return truncated_powerlaw(x, parameters['beta'], q_min, q_max)

def shape_redshift(x, parameters):
    return (1 + x)**parameters['gamma']

def pdf_redshift(x, parameters):
    zmax = 2
    fn = lambda z: (
        shape_redshift(z, parameters)
        * wcosmo.Planck15.differential_comoving_volume(z) * 4 * jnp.pi / 1e9
    )
    cut = (0 < x) * (x <= zmax)
    shape = fn(x)
    zz = jnp.linspace(0, zmax, 10_000)
    norm = jnp.trapezoid(fn(zz), zz)
    return cut * shape / norm

In [None]:
def density(data, parameters):
    return (
        pdf_mass_1_source(data['mass_1_source'], parameters)
        * pdf_mass_ratio(data['mass_ratio'], parameters)
        * pdf_redshift(data['redshift'], parameters)
    )

In [None]:
# the number of source parameters we include
dim_event = 3

Let's plot what the population models look like for some parameter values.

In [None]:
import matplotlib.pyplot as plt

In [None]:
parameters = dict(
    alpha = -3,
    mu = 35,
    sigma = 3,
    f = 0.05,
    beta = 1,
    gamma = 0,
)

In [None]:
m = jnp.linspace(2, 100, 1_000)
p = pdf_mass_1_source(m, parameters)

plt.plot(m, p)
plt.semilogy()
plt.xlabel('primary mass')
plt.ylabel('PDF')

print(jnp.trapezoid(p, m))

In [None]:
q = jnp.linspace(0, 1, 1_000)
p = pdf_mass_ratio(q, parameters)

plt.plot(q, p)
plt.xlabel('mass ratio')
plt.ylabel('PDF')

print(jnp.trapezoid(p, q))

In [None]:
z = jnp.linspace(0, 2, 1_000)
p = pdf_redshift(z, parameters)

plt.plot(z, p)
plt.xlabel('redshift')
plt.ylabel('evolution over redshift')

print(jnp.trapezoid(p, z))

Next, we'll set priors on the parameters of the population model - these are the parameters we want to measure from data. The priors will be the distributions that we draw from to train the model and also correspond to the Bayesian posterior.

In [None]:
import numpyro

In [None]:
priors = dict(
    alpha = numpyro.distributions.Uniform(-5, 0),
    mu = numpyro.distributions.Uniform(20, 50),
    sigma = numpyro.distributions.Uniform(1, 10),
    f = numpyro.distributions.Uniform(0, 0.2),
    beta = numpyro.distributions.Uniform(0, 5),
    gamma = numpyro.distributions.Uniform(-5, 5),
)

In [None]:
# the number of population parameters
dim_pop = len(priors)
dim_pop

In [None]:
# This function makes a single draw from these priors.
def sample_parameters(key):
    keys = jax.random.split(key, len(priors))
    return {k: priors[k].sample(key) for k, key in zip(priors, keys)}

In [None]:
sample_parameters(jax.random.key(0))

#### Training data

Now we need some fake observations to train the model with. The way we generate data should follow our model for how the observed catalogue is actually produced. This proceeds as follows:

1. Fix a realization of the universe with population parameters $\Lambda$ (the parameters above).
2. Draw a source with parameters $\theta$ (primary mass, mass ratio, redshift) from the population model.
3. Generate a gravitational-wave signal $s=h(\theta)$ using a waveform model $h$ and add it to detector noise $n$ to produce data $d=n+s$.
4. Decide whether the signal in data $d$ is detected ("det") or not.
5. Repeat 2-4 for many for many observations over an observing run.

To infer population parameters $\Lambda$ from data $d$, we need to invert this forward model. The inverse is given by the joint distribution
$$
p(\mathrm{det},d,\theta,\Lambda) = p(\mathrm{det}|d) p(d|\theta) p(\theta|\Lambda) p(\Lambda) .
$$
This is the probabilistic model that we'll try to get a normalizing flow to learn.

Here's where we'll cheat a bit by reusing the public binary black-hole software injections provided by the LVK from O3. Our "observations" will be the values of the source parameters (primary mass, mass ratio, and redshift) of injections that were classed as detected. Of course, this is not a realistic setting as we do not observe the source parameters directly, but will serve the purpose for this notebook.

The injections are pre-prepared in the file below (the scripts in the `data/` folder were used to download and prepare the data), which contains the values of the source parameters as well as the corresponding probability densities of the distribution from which they were drawn ("prior"), which we'll call $q(\theta)$ and is defined [here](https://zenodo.org/records/7890437).

In [None]:
import h5ify

In [None]:
injections = h5ify.load('data/vt.h5')
injections = {
    k: jnp.array(injections[k], dtype = jnp.float64).squeeze()
    for k in ('mass_1_source', 'mass_ratio', 'redshift', 'prior')
}

injections

However, in the generative model above, we want to draw sources from $p(\theta|\Lambda)$, not $q(\theta)$. So we'll play another trick:
$$
p(\mathrm{det},d,\theta,\Lambda) = p(\mathrm{det}|d) p(d|\theta) \frac{p(\theta|\Lambda)}{q(\theta)} q(\theta) p(\Lambda) .
$$
In words, since all of the sources in the above file are detected ("det") by definition, we can draw detections from $p(\theta|\Lambda)$ by instead drawing detections from $q(\theta)$ according to draw probabilities $\propto p(\theta|\Lambda) / q(\theta)$ for given population parameters $\Lambda$.

This requires $\mathrm{supp}\,p(\theta|\Lambda) \subset \mathrm{supp}\,q(\theta)$, which I conveniently chose the priors on $\Lambda$ above to make sure is true. However, using these injections means we'll have a finite amount of training data. And when $p(\theta|\Lambda)$ is very different from $q(\theta)$, we might reuse the same training data many times.

Below we define a function to draw a catalogue of detections.

In [None]:
def sample_detections(key, num_obs, parameters):
    weights = density(injections, parameters) / injections['prior']
    idxs = jax.random.choice(
        key, weights.size, shape = (num_obs,), p = weights,
    )
    return {k: injections[k][idxs] for k in injections}

We can see how selection effects alter the underlying distribution.

In [None]:
parameters = dict(
    alpha = -3,
    mu = 35,
    sigma = 3,
    f = 0.05,
    beta = 1,
    gamma = 0,
)

detections = sample_detections(jax.random.key(1), 1_000, parameters)
plt.hist(
    detections['mass_1_source'], density = True, bins = 30,
    label = 'detections',
)

m = jnp.linspace(2, 100, 1_000)
p = pdf_mass_1_source(m, parameters)
plt.plot(m, p, label = 'underlying population')

plt.legend()
plt.semilogy()
plt.xlabel('primary mass')
plt.ylabel('PDF');

#### Normalizing flow

Now let's set up the model that we'll train. We'll use a [block neural autoregressive flow](https://arxiv.org/abs/1904.04676) to approximate the population posterior. There's a nice library called [flowjax](https://github.com/danielward27/flowjax) to do normalizing flows in JAX that we'll use.

In [None]:
from flowjax.distributions import StandardNormal
from flowjax.flows import block_neural_autoregressive_flow

If you aren't familiar with normalizing flows, the (very brief) idea is that you can construct a probability distribution by transforming a simple known distribution (such as a standard normal distribution) with an invertible and differentiable function using the change-of-variables formula. For normalizing flows, that function is parametrized by a neural network, which is what makes the transformation flexible.

The transformation is trained so that the output distribution best matches some target distribution - in our case the posterior distribution of the parameters of the population model. The variables being transformed are the population parameters $\Lambda$.

Importantly, normalizing flows can model conditional probability distributions by making the transformation depend on some additional inputs - for posterior distributions, that is the observational data. In our case, we would like to condition on a catalogue of observations. We'll fix the number of observations in the catalogue.

In [None]:
# the fixed number of observations in the catalogue
num_obs = 50

For complicated data, it is typical to produce an embedded repesentation of it to more efficiently extract the information it contains, e.g., with another neural network; here, we'll keep things simple and just stack "observations" together.

In [None]:
flow_init = block_neural_autoregressive_flow(
    key = jax.random.key(2),
    base_dist = StandardNormal(shape = (dim_pop,)),
    cond_dim = num_obs * dim_event,
)

We should take care that our normalizing flow is defined on the parameter domain we want it to be. In particular, our priors on $\Lambda$ impose bounds on the range of values that can be taken. Therefore, we'll add some additional transformations to ensure those bounds are respected.

In [None]:
# transformations for population parameters

lo_pop = jnp.array([priors[k].low for k in priors])
hi_pop = jnp.array([priors[k].high for k in priors])

# map from unconstrained space to constrained space
def forward_parameters(x):
    y = jax.scipy.stats.norm.cdf(x) * (hi_pop - lo_pop) + lo_pop
    parameters = dict(zip(priors, y))
    return parameters

def inverse_parameters(parameters):
    x = jnp.array([parameters[k] for k in priors])
    x = jax.scipy.stats.norm.ppf((x - lo_pop) / (hi_pop - lo_pop))
    return x

Though not strictly necessary, we'll also do this for the observations so that all inputs to the flow have roughly the same range of numerical values.

In [None]:
# transformtions for source parameters

# primary mass, mass ratio, redshift
lo_event = jnp.array([2, 0.01, 0])
hi_event = jnp.array([100, 1, 2])

x = jnp.array([
    injections['mass_1_source'],
    injections['mass_ratio'],
    injections['redshift'],
]).T
x = (x - lo_event) / (hi_event - lo_event)
x = jax.scipy.stats.norm.ppf(x)
loc = x.mean(axis = 0)
scale = x.std(axis = 0)
x = (x - loc) / scale

# we only need the inverse to embed the data into the flow
def inverse_detections(detections):
    x = jnp.array(
        [detections[k] for k in ('mass_1_source', 'mass_ratio', 'redshift')]
    ).T # shape = (num_obs, dim_event)
    x = (x - lo_event) / (hi_event - lo_event)
    x = jax.scipy.stats.norm.ppf(x)
    x = (x - loc) / scale # shape = (num_obs, dim_event)
    x = x.ravel() # shape = (num_obs * dim_event,)
    return x

#### Training

To train the flow, we need to define a loss function to minimize with respect to the neural-network parameters. The training objective for flows can be thought of in several equivalent ways, including the Kullback-Leibler divergence, the cross entropy, and the likelihood of the training data.

The upshot is that the probability density predicted by the flow can be matched to the joint distribution we defined in [Training data](#Training-data) - we never need to compare it against the actual posterior distribution, which is the key insight that first enabled neural posterior estimation.

Below, we define the loss function and training loop to do this. Note that selection effects are automatically included in the generative model and thus the flow posterior (one of the reasons that simulation-based inference is nice in this context).

First, let's set some training parameters. We'll use [optax](https://github.com/google-deepmind/optax) to update the neural network.

In [None]:
import optax

In [None]:
batch_size = 200
steps = 10_000
learning_rate = 1e-2
optimizer = optax.adam(learning_rate)

flowjax is built on top of [equinox](https://github.com/patrick-kidger/equinox), which (among many other cool things it does) manages the neural network parameters.

In [None]:
import equinox

It works by splitting pytrees - nested python containers of JAX arrays - into trainable and non-trainable subsets.

In [None]:
# here, trainable parameters are any floating-point arrays:
params_init, static = equinox.partition(flow_init, equinox.is_inexact_array)

In [None]:
# function to sample a single training example:
# - a single set of population parameters
# - a catalogue of num_obs detections
def sample(key):
    key, subkey = jax.random.split(key)
    parameters = sample_parameters(subkey)
    x = inverse_parameters(parameters)
    key, subkey = jax.random.split(key)
    detections = sample_detections(subkey, num_obs, parameters)
    c = inverse_detections(detections)
    return x, c

In [None]:
# the loss function
# takes the mean over a batch of batch_size training examples
def loss_fn(params, key):
    keys = jax.random.split(key, batch_size)
    xs, cs = jax.vmap(sample)(keys)
    flow = equinox.combine(params, static) # rebuild the flow from partitions
    return -flow.log_prob(xs, cs).mean()

In [None]:
# function to update the neural-network parameters
def update(carry, step):
    key, params, state = carry
    key, subkey = jax.random.split(key)
    loss, grad = equinox.filter_value_and_grad(loss_fn)(params, subkey)
    updates, state = optimizer.update(grad, state, params)
    params = equinox.apply_updates(params, updates)
    carry = key, params, state
    return carry, loss

Another handy package is [jax_tqdm](https://github.com/jeremiecoullon/jax-tqdm) to add progress bar to JAX loops.

In [None]:
import jax_tqdm

In [None]:
update = jax_tqdm.scan_tqdm(steps, print_rate = 100, tqdm_type = 'std')(update)

In [None]:
# finally, the training loop (written without Python loops using JAX)

state = optimizer.init(params_init)
carry = jax.random.key(3), params_init, state

carry, losses = jax.lax.scan(update, carry, jnp.arange(steps))
key, params, state = carry

flow = equinox.combine(params, static)

In [None]:
# plot the loss function values over training steps
plt.plot(losses);

#### Inference

Now that the model is trained, let's infer the popoulation posterior on a mock catalogue.

In [None]:
# construct a mock catalogue

parameters = dict(
    alpha = -3,
    mu = 35,
    sigma = 3,
    f = 0.05,
    beta = 1,
    gamma = 0,
)

key = jax.random.key(np.random.randint(1e9))
detections = sample_detections(key, num_obs, parameters)

In [None]:
# draw posterior samples from the flow

c = inverse_detections(detections)
x = flow.sample(jax.random.key(4), (10_000,), condition = c)
posterior = jax.vmap(forward_parameters)(x)

posterior.keys()

In [None]:
from corner import corner

In [None]:
corner(
    np.transpose([posterior[k] for k in priors]),
    labels = list(priors),
    truths = [parameters[k] for k in priors],
    level = (0.5, 0.9, 0.99),
);

Let's also plot the inferred population-level distributions of source parameters and their posterior uncertainties.

In [None]:
grid = dict(
    mass_1_source = jnp.linspace(2, 100, 1_000),
    mass_ratio = jnp.linspace(0, 1, 1_000),
    redshift = jnp.linspace(0, 2, 1_000),
)

pdf = dict(
    mass_1_source = pdf_mass_1_source,
    mass_ratio = pdf_mass_ratio,
    redshift = pdf_redshift,
)

def make_plot(k):
    plt.hist(
        detections[k], bins = 50, density = True, label = 'observed',
        color = 'C0',
    )

    ps = jax.vmap(lambda parameters: pdf[k](grid[k], parameters))(posterior)
    for qs, alpha in (
        ((0.005, 0.995), 0.2),
        ((0.05, 0.95), 0.4),
        ((0.25, 0.75), 0.6),
    ):
        label = f'{(qs[1]-qs[0]) * 100:.0f}% posterior'
        plt.fill_between(
            grid[k], *np.quantile(ps, qs, axis = 0), label = label,
            color = 'C1', alpha = alpha, lw = 0,
        )

    plt.plot(
        grid[k], pdf[k](grid[k], parameters), label = 'true astrophysical',
        c = 'C2',
    )

    plt.legend()
    plt.xlabel(k)
    plt.ylabel(f'p({k})')

In [None]:
make_plot('mass_1_source')
plt.semilogy()
plt.ylim(1e-6, 1e1);

In [None]:
make_plot('mass_ratio')

In [None]:
make_plot('redshift')

#### Homework

There is some immediate tinkering you can do with the code above:
- The training settings, e.g., batch size, number of training steps, learning rate, optimizer etc. Try a [learning-rate scheduler](https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html), for example.
- The [flow settings](https://danielward27.github.io/flowjax/api/flows.html#flowjax.flows.block_neural_autoregressive_flow), e.g., try making the network smaller or larger.
- The flow itself, i.e., try [a different type](https://danielward27.github.io/flowjax/api/flows.html) of normalizing flow.

Try targeting a different gravitational-wave population:
- Change the catalogue size.
- Add more parameters to the population model.
- Change the population model to something more sophisticated.
- Include additional source parameters, e.g., black-hole spins.

We should check our results:
- Test that the model is properly converged.
- Check that the model is not overfitting.
- Compare flow predictions at different points along training, e.g., return the flow with the best loss (at the moment we just take the flow at the last training step).
- How could we check that the flow-predicted posterior is correct?
- How could we use separate distributions for the training prior and inference prior?

Make everything more realistic:
- We don't directly observe the source parameters, so what observational data should we input instead?
- Can the model be applied to varying catalogue sizes?
- How can we apply this for large catalogues from future gravitational-wave detectors?