If you're running in a separate notebook (e.g., Google Colab), go through and un-comment the cells below as required. Also make sure to set the runtime before running the notebook.

In [None]:
# !pip install numpy matplotlib corner h5ify

In [None]:
# # If you're running on CPU:
# !pip install jax numypro

# # If you're running on GPU
# !pip install -U 'jax[cuda12]'
# !pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# !pip install wcosmo jax_tqdm equinox equinox optax flowjax

In [None]:
# # Download software injections and parameter estimation from LVK O3:
# !mkdir -p data
# !wget https://github.com/mdmould/ml-gw-pop/blob/main/data/vt.h5 -P data
# !wget https://github.com/mdmould/ml-gw-pop/blob/main/data/pe.h5 -P data

In [None]:
# # If you're running on a shared cluster and want to limit the resources you take up:
# import os
# os.environ["OPENBLAS_NUM_THREADS"] = '1'
# os.environ["MKL_NUM_THREADS"] = '1'
# os.environ["VECLIB_MAXIMUM_THREADS"] = '1'
# os.environ["NUMEXPR_NUM_THREADS"] = '1'
# os.environ['OMP_NUM_THREADS'] = '1'
# os.environ['NPROC'] = '1'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## Bonus: Hamiltonian Monte Carlo for gravitational-wave population inference

Hamiltonian Monte Carlo is a gradient-based stochastic sampling algorithm. It's particularly useful for sampling from high-dimensional posterior distributions. Here, we'll [numpyro](https://num.pyro.ai/en/latest/mcmc.html) for gravitational-wave population inferece.

Most of the code is copied over from the [variational-inference.ipynb](variational-inference.ipynb) notebook.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)

In [None]:
# check for GPU devices
jax.devices()

We will perform population inference on the catalogue of black-hole mergers with false-alarm rates > 1/year from O3. Below, we load in pre-prepared parameter estimation results for those events and a set of software injections that we can use to estimate selection effects (the scripts in the `data/` folder were used to download and prepare the data).

In [None]:
import h5ify

In [None]:
injections = h5ify.load('data/vt.h5')
injections = {
    k: jnp.array(injections[k], dtype = jnp.float64).squeeze()
    for k in injections
}

In [None]:
posteriors = h5ify.load('data/pe.h5')
posteriors = {
    k: jnp.array([posteriors[event][k] for event in sorted(posteriors)])
    for k in posteriors[list(posteriors)[0]]
}

#### Population model

First, let's define the population model that we'll use to model the astrophysical distribution of sources. We'll include source-frame primary masses, binary mass ratio, dimensionless spin magnitudes, spin-orbit misalignments, and redshift.

- The primary masses and mass ratios will follow my version [Power Law + Peak](https://arxiv.org/abs/1801.02699) model, which has tapering functions at low and high black-hole masses with analytical normalization.
- Spin magnitudes will be fit with a truncated normal distribution, independent and identical between primary and secondary black holes.
- Ditto for spin tilts.
- We'll assume that the merger rate evolves over comoving volume and source-frame time as a [power law in redshift](https://arxiv.org/abs/1805.10270).

The key thing to remember is that the likelihood function - and thus the population model - *must* be automatically differentiable. To be compatible with the framework here, it must be coded in JAX.

We'll also use [wcosmo](https://github.com/ColmTalbot/wcosmo), which is a nice package for cosmological calculations in JAX.

In [None]:
import wcosmo
wcosmo.disable_units()

In [None]:
# tapering functions

def cubic_filter(x):
    return (3 - 2 * x) * x**2 * (0 <= x) * (x <= 1) + (1 < x)

def highpass(x, xmin, dmin):
    return cubic_filter((x - xmin) / dmin)

def lowpass(x, xmax, dmax):
    return highpass(x, xmax, -dmax)

def bandpass(x, xmin, xmax, dmin, dmax):
    return highpass(x, xmin, dmin) * lowpass(x, xmax, dmax)

In [None]:
# power law functions

def powerlaw(x, alpha, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha
    norm = (xmax**(alpha + 1) - xmin**(alpha + 1)) / (alpha + 1)
    return cut * shape / norm

def powerlaw_integral(x, alpha, loc, delta):
    a, c, d = alpha, loc, delta
    return (
        3 * (2 * c + (4 + a) * d)
        * (c**2 / (1 + a) - 2 * c * x / (2 + a) + x**2 / (3 + a))
        - 2 * (x - c)**3
    ) * x**(1 + a) / (4 + a) / d**3

def highpass_powerlaw_integral(x, alpha, xmin, xmax, dmin):
    return (
        (
            - powerlaw_integral(xmin, alpha, xmin, dmin)
            + powerlaw_integral(jnp.minimum(xmin + dmin, x), alpha, xmin, dmin)
        ) * (xmin <= x)
        + (
            - (xmin + dmin)**(alpha + 1) / (alpha + 1)
            + xmax**(alpha + 1) / (alpha + 1)
        ) * (xmin + dmin <= x)
    )

def highpass_powerlaw(x, alpha, xmin, xmax, dmin):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha * highpass(x, xmin, dmin)
    norm = highpass_powerlaw_integral(xmax, alpha, xmin, xmax, dmin)
    return cut * shape / norm

def bandpass_powerlaw(x, alpha, xmin, xmax, dmin, dmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha * bandpass(x, xmin, xmax, dmin, dmax)
    norm = (
        - powerlaw_integral(xmin, alpha, xmin, dmin)
        + powerlaw_integral(xmin + dmin, alpha, xmin, dmin)
        - (xmin + dmin)**(alpha + 1) / (alpha + 1)
        + (xmax - dmax)**(alpha + 1) / (alpha + 1)
        - powerlaw_integral(xmax - dmax, alpha, xmax, -dmax)
        + powerlaw_integral(xmax, alpha, xmax, -dmax)
    )
    return cut * shape / norm

In [None]:
# Gaussian functions

def truncnorm(x, mu, sigma, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = jax.scipy.stats.norm.pdf(x, mu, sigma)
    norm = (
        - jax.scipy.stats.norm.cdf(xmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax, mu, sigma)
    )
    return cut * shape / norm

def normal_integral(x, mu, sigma, loc, delta):
    m, s, c, d = mu, sigma, loc, delta
    return (
        jnp.exp(-(x - m)**2 / 2 / s ** 2) * (2 / jnp.pi)**0.5 * s * (
            6 * c * (c + d - m - x)
            - 3 * d * (m + x)
            + 2 * (m**2 + 2 * s**2 + m * x + x**2)
        )
        - jax.lax.erf((m - x) / s / 2**0.5) * (
            (2 * c + 3 * d - 2 * m) * (c - m)**2
            + 3 * s**2 * (2 * c + d - 2 * m)
        )
    ) / 2 / d**3

def bandpass_normal(x, mu, sigma, xmin, xmax, dmin, dmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = (
        jax.scipy.stats.norm.pdf(x, mu, sigma)
        * bandpass(x, xmin, xmax, dmin, dmax)
    )
    norm = (
        - normal_integral(xmin, mu, sigma, xmin, dmin)
        + normal_integral(xmin + dmin, mu, sigma, xmin, dmin)
        - jax.scipy.stats.norm.cdf(xmin + dmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax - dmax, mu, sigma)
        - normal_integral(xmax - dmax, mu, sigma, xmax, -dmax)
        + normal_integral(xmax, mu, sigma, xmax, -dmax)
    )
    return cut * shape / norm

In [None]:
# primary mass
def pdf_m(m, parameters):
    pl = bandpass_powerlaw(
        m,
        parameters['alpha'],
        parameters['m_min'],
        parameters['m_max'],
        parameters['d_min'],
        parameters['d_max'],
    )
    tn = bandpass_normal(
        m,
        parameters['mu_m'],
        parameters['sigma_m'],
        parameters['m_min'],
        parameters['m_max'],
        parameters['d_min'],
        parameters['d_max'],
    )
    return (1 - parameters['f_m']) * pl + parameters['f_m'] * tn

# mass ratio - this is a bit of a handful, but otherwise, autodiff doesn't work
# let me know if you spot a better way to do it :')
def pdf_q_given_m(q, m, parameters):
    # pdf defined in terms if secondary mass, then converted to mass ratio
    pdf = lambda q, m: highpass_powerlaw(
        q * m, parameters['beta'], parameters['m_min'], m, parameters['d_min'],
    ) * m
    single = lambda q, m: jax.lax.cond(
        parameters['m_min'] <= q * m, lambda: pdf(q, m), lambda: 0.0,
    )
    return jax.vmap(single)(q.ravel(), m.ravel()).reshape(q.shape)

# spin magnitude
def pdf_a(a, parameters):
    return truncnorm(a, parameters['mu_a'], parameters['sigma_a'], 0, 1)

# spin tilt
def pdf_c(c, parameters):
    return truncnorm(c, parameters['mu_c'], parameters['sigma_c'], -1, 1)

# redshift
def shape_z(z, parameters):
    return (1 + z)**parameters['gamma']

def pdf_z(z, parameters):
    zmax = 2
    fn = lambda z: (
        shape_z(z, parameters)
        * wcosmo.Planck15.differential_comoving_volume(z) * 4 * jnp.pi / 1e9
    )
    cut = (0 < z) * (z <= zmax)
    shape = fn(z)
    zz = jnp.linspace(0, zmax, 10_000)
    norm = jnp.trapezoid(fn(zz), zz)
    return cut * shape / norm

In [None]:
# the combined probability density
def density(data, parameters):
    return (
        pdf_m(data['mass_1_source'], parameters)
        * pdf_q_given_m(data['mass_ratio'], data['mass_1_source'], parameters)
        # * pdf_q(data['mass_ratio'], parameters)
        * pdf_a(data['a_1'], parameters)
        * pdf_a(data['a_2'], parameters)
        * pdf_c(data['cos_tilt_1'], parameters)
        * pdf_c(data['cos_tilt_2'], parameters)
        * pdf_z(data['redshift'], parameters)
    )

#### Priors

Next, we'll set priors on the parameters of the population model - these are the parameters we want to measure from data.

In [None]:
import numpyro

In [None]:
priors = dict(
    alpha = numpyro.distributions.Uniform(-10, 10),
    m_min = numpyro.distributions.Uniform(2, 6),
    m_max = numpyro.distributions.Uniform(70, 100),
    d_min = numpyro.distributions.Uniform(0, 10),
    d_max = numpyro.distributions.Uniform(0, 10),
    mu_m = numpyro.distributions.Uniform(20, 50),
    sigma_m = numpyro.distributions.Uniform(1, 10),
    f_m = numpyro.distributions.Uniform(0, 1),
    beta = numpyro.distributions.Uniform(-10, 10),
    mu_a = numpyro.distributions.Uniform(0, 1),
    sigma_a = numpyro.distributions.Uniform(0.1, 1),
    mu_c = numpyro.distributions.Uniform(-1, 1),
    sigma_c = numpyro.distributions.Uniform(0.1, 4),
    gamma = numpyro.distributions.Uniform(-10, 10),
)

#### Likelihood

How likely is it that our population model is responsible for the observed data?

Below we code up the gravitational-wave population likelihood; see, e.g.,

- https://arxiv.org/abs/1809.02063,
- https://arxiv.org/abs/2007.05579,
- https://arxiv.org/abs/2410.19145.

In particular, the likelihood function is approximated with several Monte Carlo integrals, which introduces additional statistical variance (https://arxiv.org/abs/1904.10879, https://arxiv.org/abs/2204.00461, https://arxiv.org/abs/2304.06138). We make sure to keep track of this variance below.

In [None]:
# mean and variance of the mean
def mean_and_variance(weights, n):
    mean = jnp.sum(weights, axis = -1) / n
    variance = jnp.sum(weights**2, axis = -1) / n**2 - mean**2 / n
    return mean, variance

# lazy ln(mean) and variance of ln(mean)
def ln_mean_and_variance(weights, n):
    mean, variance = mean_and_variance(weights, n)
    return jnp.log(mean), variance / mean**2

In [None]:
def ln_likelihood_and_variance(posteriors, injections, density, parameters):
    pe_weights = density(posteriors, parameters) / posteriors['prior']
    vt_weights = density(injections, parameters) / injections['prior']
    num_obs, num_pe = pe_weights.shape
    ln_lkls, pe_variances = ln_mean_and_variance(pe_weights, num_pe)
    ln_pdet, vt_variance = ln_mean_and_variance(vt_weights, injections['total'])
    ln_lkl = ln_lkls.sum() - ln_pdet * num_obs
    variance = pe_variances.sum() + vt_variance * num_obs**2
    # ln_lkl = jnp.nan_to_num(ln_lkl, nan = -jnp.inf)
    # variance = jnp.nan_to_num(variance, nan = jnp.inf)
    return ln_lkl, variance

#### Inference

Now we'll draw samples from the posterior distributions using Hamiltonian Monte Carlo in [numpyro](https://num.pyro.ai/en/latest/mcmc.html).

In [None]:
def sample_reparam(name, dist, **args):
    base = numpyro.distributions.Normal()
    z = numpyro.sample(f'_{name}', base, **args)
    return numpyro.deterministic(name, dist.icdf(base.cdf(z)))

def sample_priors(priors, reparam = False):
    if reparam:
        return {k: sample_reparam(k, priors[k]) for k in priors}
    return {k: numpyro.sample(k, priors[k]) for k in priors}

In [None]:
def numpyro_model(posteriors, injections, density, priors, reparam = False):
    parameters = sample_priors(priors, reparam)
    ln_likelihood, variance = ln_likelihood_and_variance(
        posteriors, injections, density, parameters,
    )
    numpyro.deterministic('ln_likelihood', ln_likelihood)
    numpyro.deterministic('variance', variance)
    numpyro.factor('factor', ln_likelihood)

In [None]:
nuts = numpyro.infer.NUTS(numpyro_model)
mcmc = numpyro.infer.MCMC(nuts, num_warmup = 1_000, num_samples = 1_000)
mcmc.run(jax.random.key(0), posteriors, injections, density, priors)

Below shows some summary statistics to check that the MCMC chain converged or not.

In [None]:
numpyro.diagnostics.print_summary(mcmc.get_samples(), group_by_chain = False)

Let's also check the Monte Carlo variance to see how trustworthy our estimate of the population likelihood is over the posterior samples we drew.

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.hist(mcmc.get_samples()['variance'], bins = 100);
plt.axvline(1, c = 'r');

Then let's look at the posterior distribution itself.

In [None]:
from corner import corner

In [None]:
# cut = mcmc.get_samples()['variance'] < 1
cut = slice(None)
posterior = {k: mcmc.get_samples()[k][cut] for k in priors}

In [None]:
corner(np.transpose(list(posterior.values())), labels = list(priors));

And finally, the inferred population-level distributiona of source parameters and their posterior uncertainties.

In [None]:
grid = dict(
    mass_1_source = jnp.linspace(2, 100, 1_000),
    mass_ratio = jnp.linspace(0, 1, 1_000),
    a = jnp.linspace(0, 1, 1_000),
    cos_tilt = jnp.linspace(-1, 1, 1_000),
    redshift = jnp.linspace(0, 2, 1_000),
)

# the mass ratio model is conditional on the primary mass, so we have to marginalize
def pdf_q_marginal(q, parameters):
    x, y = jnp.meshgrid(q, grid['mass_1_source'], indexing = 'ij')
    p = pdf_q_given_m(x, y, parameters) * pdf_m(y, parameters)
    return jnp.trapezoid(p, y, axis = 1)

pdf = dict(
    mass_1_source = pdf_m,
    mass_ratio = pdf_q_marginal,
    a = pdf_a,
    cos_tilt = pdf_c,
    redshift = pdf_z,
)

In [None]:
def make_plot(k, data):
    # we use sequential map for mass ratio because the integral uses more memory
    single = lambda parameters: pdf[k](grid[k], parameters)
    if k == 'mass_ratio':
        ps = jax.lax.map(single, data)
    else:
        ps = jax.vmap(single)(data)

    for qs, alpha in (
        ((0.005, 0.995), 0.2),
        ((0.05, 0.95), 0.3),
        ((0.25, 0.75), 0.4),
    ):
        label = f'{(qs[1]-qs[0]) * 100:.0f}% posterior'
        plt.fill_between(
            grid[k], *np.quantile(ps, qs, axis = 0), label = label,
            color = 'C0', alpha = alpha, lw = 0,
        )

    plt.plot(
        grid[k], np.median(ps, axis = 0), label = 'median posterior',
        c = 'C1', lw = 2,
    )
    plt.plot(
        grid[k], np.mean(ps, axis = 0), label = 'mean posterior (PPD)',
        c = 'C2', lw = 2, ls = '--',
    )

    plt.legend()
    plt.xlabel(k)
    plt.ylabel(f'p({k})')

In [None]:
for k in 'mass_1_source', 'mass_ratio', 'a', 'cos_tilt', 'redshift':
    make_plot(k, posterior)

    if k == 'mass_1_source':
        plt.semilogy()
        plt.ylim(1e-5, 1e0)
    elif k == 'mass_ratio':
        plt.semilogy()
        plt.ylim(1e-2, 1e1)

    plt.show()