If you're running in a separate notebook (e.g., Google Colab), go through and un-comment the cells below as required. Also make sure to set the runtime before running the notebook.

In [None]:
# !pip install numpy matplotlib corner h5ify

In [None]:
# # If you're running on CPU:
# !pip install jax numypro

# # If you're running on GPU
# !pip install -U 'jax[cuda12]'
# !pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# !pip install wcosmo jax_tqdm equinox equinox optax flowjax

In [None]:
# # Download software injections and parameter estimation from LVK O3:
# !mkdir -p data
# !wget https://github.com/mdmould/ml-gw-pop/blob/main/data/vt.h5 -P data
# !wget https://github.com/mdmould/ml-gw-pop/blob/main/data/pe.h5 -P data

In [None]:
# If you're running on a shared cluster and want to limit the resources you take up:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ["MKL_NUM_THREADS"] = '1'
os.environ["VECLIB_MAXIMUM_THREADS"] = '1'
os.environ["NUMEXPR_NUM_THREADS"] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['NPROC'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

## Simulation-based population model

In this notebook, we'll train a normalizing flow to learn an astrophysical population model using simulations and use it for population inference on a catalogue of gravitational-wave observations. We'll focus on just binary black-hole mergers. We'll use [JAX](https://github.com/jax-ml/jax) as the main workhorse behind this notebook.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)

In [None]:
# check for GPU devices
jax.devices()

#### Data

We will perform population inference on the catalogue of black-hole mergers with false-alarm rates > 1/year from O3. Below, we load in pre-prepared parameter estimation results for those events and a set of software injections that we can use to estimate selection effects.

In [None]:
import h5ify

In [None]:
injections = h5ify.load('data/vt.h5')
injections = {
    k: jnp.array(injections[k], dtype = jnp.float64).squeeze()
    for k in injections
}

In [None]:
injections.keys()

In [None]:
posteriors = h5ify.load('data/pe.h5')
posteriors = {
    k: jnp.array([posteriors[event][k] for event in sorted(posteriors)])
    for k in posteriors[list(posteriors)[0]]
}

In [None]:
posteriors.keys()

#### Population model

First, let's define the population model that we'll use for the astrophysical distribution of sources.

As an example, we'll construct a simple simulated model of hierarchical/repeated/second-generation mergers, where some of the population of black holes are the products of previous mergers.

Some of the ideas are following our paper ["Deep learning and Bayesian inference of gravitational-wave populations: Hierarchical black-hole mergers" (arXiv:2203.03651)](https://arxiv.org/abs/2203.03651) and code ["QLUSTER: quick clusters of merging binary black holes" (arXiv:2305.04987)](https://arxiv.org/abs/2305.04987).

Below are some utilities we'll use to build the population model.

In [None]:
# tapering functions

def cubic_filter(x):
    return (3 - 2 * x) * x**2 * (0 <= x) * (x <= 1) + (1 < x)

def highpass(x, xmin, dmin):
    return cubic_filter((x - xmin) / dmin)

def lowpass(x, xmax, dmax):
    return highpass(x, xmax, -dmax)

def bandpass(x, xmin, xmax, dmin, dmax):
    return highpass(x, xmin, dmin) * lowpass(x, xmax, dmax)

In [None]:
# power law functions

def truncated_powerlaw(x, alpha, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha
    norm = (xmax**(alpha + 1) - xmin**(alpha + 1)) / (alpha + 1)
    return cut * shape / norm

def powerlaw_integral(x, alpha, loc, delta):
    a, c, d = alpha, loc, delta
    return (
        3 * (2 * c + (4 + a) * d)
        * (c**2 / (1 + a) - 2 * c * x / (2 + a) + x**2 / (3 + a))
        - 2 * (x - c)**3
    ) * x**(1 + a) / (4 + a) / d**3

def bandpass_powerlaw(x, alpha, xmin, xmax, dmin, dmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha * bandpass(x, xmin, xmax, dmin, dmax)
    norm = (
        - powerlaw_integral(xmin, alpha, xmin, dmin)
        + powerlaw_integral(xmin + dmin, alpha, xmin, dmin)
        - (xmin + dmin)**(alpha + 1) / (alpha + 1)
        + (xmax - dmax)**(alpha + 1) / (alpha + 1)
        - powerlaw_integral(xmax - dmax, alpha, xmax, -dmax)
        + powerlaw_integral(xmax, alpha, xmax, -dmax)
    )
    return cut * shape / norm

In [None]:
# Gaussian functions

def truncated_normal(x, mu, sigma, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = jax.scipy.stats.norm.pdf(x, mu, sigma)
    norm = (
        - jax.scipy.stats.norm.cdf(xmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax, mu, sigma)
    )
    return cut * shape / norm

def normal_integral(x, mu, sigma, loc, delta):
    m, s, c, d = mu, sigma, loc, delta
    return (
        jnp.exp(-(x - m)**2 / 2 / s ** 2) * (2 / jnp.pi)**0.5 * s * (
            6 * c * (c + d - m - x)
            - 3 * d * (m + x)
            + 2 * (m**2 + 2 * s**2 + m * x + x**2)
        )
        - jax.lax.erf((m - x) / s / 2**0.5) * (
            (2 * c + 3 * d - 2 * m) * (c - m)**2
            + 3 * s**2 * (2 * c + d - 2 * m)
        )
    ) / 2 / d**3

def bandpass_normal(x, mu, sigma, xmin, xmax, dmin, dmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = (
        jax.scipy.stats.norm.pdf(x, mu, sigma)
        * bandpass(x, xmin, xmax, dmin, dmax)
    )
    norm = (
        - normal_integral(xmin, mu, sigma, xmin, dmin)
        + normal_integral(xmin + dmin, mu, sigma, xmin, dmin)
        - jax.scipy.stats.norm.cdf(xmin + dmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax - dmax, mu, sigma)
        - normal_integral(xmax - dmax, mu, sigma, xmax, -dmax)
        + normal_integral(xmax, mu, sigma, xmax, -dmax)
    )
    return cut * shape / norm

We'll make the simplifying assumption that the merger rate over comoving volume and source-frame time has a shared redshift evolution for all sources. We'll use [wcosmo](https://github.com/ColmTalbot/wcosmo), which is a nice package for cosmological calculations in JAX.

In [None]:
import wcosmo
wcosmo.disable_units()

In [None]:
def shape_z(z, parameters):
    return (1 + z)**parameters['gamma']

def pdf_z(z, parameters):
    zmax = 2
    fn = lambda z: (
        shape_z(z, parameters)
        * wcosmo.Planck15.differential_comoving_volume(z) * 4 * jnp.pi / 1e9
    )
    cut = (0 < z) * (z <= zmax)
    shape = fn(z)
    zz = jnp.linspace(0, zmax, 10_000)
    norm = jnp.trapezoid(fn(zz), zz)
    return cut * shape / norm

To build up a population of second-generation mergers, we'll begin by defining the distribution of first-generation mergers. We'll assume that:
- black-hole masses are distributed from a tapered power law with an additional Gaussian peak (similar to the [Power Law + Peak](https://arxiv.org/abs/1801.02699) model),
- binary mass ratios are follow a tapered power law,
- black-holes spins are uniform in mangnitude and isotropic in direction.

In [None]:
# primary mass
def pdf_m(m, parameters):
    pl = bandpass_powerlaw(
        m,
        parameters['alpha'],
        parameters['m_min'],
        parameters['m_max'],
        parameters['d_min'],
        parameters['d_max'],
    )
    tn = bandpass_normal(
        m,
        parameters['mu_m'],
        parameters['sigma_m'],
        parameters['m_min'],
        parameters['m_max'],
        parameters['d_min'],
        parameters['d_max'],
    )
    return (1 - parameters['f_m']) * pl + parameters['f_m'] * tn

# mass ratio - this is a bit of a handful, but otherwise, autodiff doesn't work
# let me know if you spot a better way to do it :')
def pdf_q_given_m(q, m, parameters):
    # pdf defined in terms if secondary mass, then converted to mass ratio
    pdf = lambda q, m: highpass_powerlaw(
        q * m, parameters['beta'], parameters['m_min'], m, parameters['d_min'],
    ) * m
    single = lambda q, m: jax.lax.cond(
        parameters['m_min'] <= q * m, lambda: pdf(q, m), lambda: 0.0,
    )
    return jax.vmap(single)(q.ravel(), m.ravel()).reshape(q.shape)

Next, we need a function to output the properties of black-hole merger remnants from input pre-merger binary properties. Importantly, the remnant recieves a recoil - or "kick" - that imparts a velocity, which can eject it from its host environment and thus prevent a second-generation merger.

We follow the fitting formulae to numerical-relativity simulations included in the [precession](https://arxiv.org/abs/1605.01067) Python [package](https://github.com/dgerosa/precession), but with a minimal re-implementation in JAX (I've removed the internal vectorization, which can be done automatically with `jax.vmap`, and fixed to the default spin and kick options).

In [None]:
def eval_eta(q):
    return q / (1 + q)**2

def eval_theta12(theta1, theta2, deltaphi):
    return jnp.arccos(
        jnp.sin(theta1) * jnp.sin(theta2) * jnp.cos(deltaphi)
        + jnp.cos(theta1) * jnp.cos(theta2)
    )

def angles_to_Lframe(theta1, theta2, deltaphi, r, q, chi1, chi2):
    L = r**0.5 * q / (1 + q)**2
    S1 = chi1 / (1 + q)**2
    S2 = chi2 * q**2 / (1 + q)**2

    Lx = 0
    Ly = 0
    Lz = L
    Lvec = jnp.array([Lx, Ly, Lz])

    S1x = S1 * jnp.sin(theta1)
    S1y = 0
    S1z = S1 * jnp.cos(theta1)
    S1vec = jnp.array([S1x, S1y, S1z])

    S2x = S2 * jnp.sin(theta2) * jnp.cos(deltaphi)
    S2y = S2 * jnp.sin(theta2) * jnp.sin(deltaphi)
    S2z = S2 * jnp.cos(theta2)
    S2vec = jnp.array([S2x, S2y, S2z])

    return Lvec, S1vec, S2vec

In [None]:
def remnantmass(theta1, theta2, q, chi1, chi2):
    eta = eval_eta(q)

    chit_par =  ( chi2*q**2 * jnp.cos(theta2) + chi1*jnp.cos(theta1) ) / (1+q)**2

    #Final mass. Barausse Morozova Rezzolla 2012
    p0 = 0.04827
    p1 = 0.01707
    Z1 = 1 + (1-chit_par**2)**(1/3)* ((1+chit_par)**(1/3)+(1-chit_par)**(1/3))
    Z2 = (3* chit_par**2 + Z1**2)**(1/2)
    risco = 3 + Z2 - jnp.sign(chit_par) * ((3-Z1)*(3+Z1+2*Z2))**(1/2)
    Eisco = (1-2/(3*risco))**(1/2)
    #Radiated energy, in units of the initial total mass of the binary
    Erad = eta*(1-Eisco) + 4* eta**2 * (4*p0+16*p1*chit_par*(chit_par+1)+Eisco-1)
    Mfin = 1- Erad # Final mass

    return Mfin

In [None]:
def remnantspin(theta1, theta2, deltaphi, q, chi1, chi2):
    eta = eval_eta(q)

    kfit = jnp.array( [[jnp.nan, 3.39221, 4.48865, -5.77101, -13.0459] ,
                      [35.1278, -72.9336, -86.0036, 93.7371, 200.975],
                      [-146.822, 387.184, 447.009, -467.383, -884.339],
                      [223.911, -648.502, -697.177, 753.738, 1166.89]])
    xifit = 0.474046

    # Calculate K00 from Eq 11
    kfit = kfit.at[0,0].set(4**2 * ( 0.68646 - jnp.sum( kfit[1:,0] /(4**(3+jnp.arange(kfit.shape[0]-1)))) - (3**0.5)/2))

    theta12 = eval_theta12(theta1, theta2, deltaphi)

    eps1 = 0.024
    eps2 = 0.024
    eps12 = 0
    theta1 = theta1 + eps1 * jnp.sin(theta1)
    theta2 = theta2 + eps2 * jnp.sin(theta2)
    theta12 = theta12 + eps12 * jnp.sin(theta12)

    # Eq. 14 - 15
    atot = ( chi1*jnp.cos(theta1) + chi2*jnp.cos(theta2)*q**2 ) / (1+q)**2
    aeff = atot + xifit*eta* ( chi1*jnp.cos(theta1) + chi2*jnp.cos(theta2) )

    # Eq. 2 - 6 evaluated at aeff, as specified in Eq. 11
    Z1= 1 + (1-(aeff**2))**(1/3) * ( (1+aeff)**(1/3) + (1-aeff)**(1/3) )
    Z2= ( (3*aeff**2) + (Z1**2) )**(1/2)
    risco= 3 + Z2 - jnp.sign(aeff) * ( (3-Z1)*(3+Z1+2*Z2) )**(1/2)
    Eisco=(1-2/(3*risco))**(1/2)
    Lisco = (2/(3*(3**(1/2)))) * ( 1 + 2*(3*risco - 2 )**(1/2) )

    # Eq. 13
    etatoi = eta**(1+jnp.arange(kfit.shape[0]))
    innersum = jnp.sum(kfit.T * etatoi,axis=1)
    aefftoj = aeff**(jnp.arange(kfit.shape[1]))
    sumell = jnp.sum(innersum  * aefftoj,axis=0)
    ell = jnp.abs( Lisco  - 2*atot*(Eisco-1)  + sumell )

    # Eq. 16
    chifin = (1/(1+q)**2) * ( chi1**2 + (chi2**2)*(q**4)  + 2*chi1*chi2*(q**2)*jnp.cos(theta12)
            + 2*(chi1*jnp.cos(theta1) + chi2*(q**2)*jnp.cos(theta2))*ell*q + ((ell*q)**2)  )**(1/2)

    return jnp.minimum(chifin,1)

In [None]:
def remnantkick(bigTheta, theta1, theta2, deltaphi, q, chi1, chi2):
# kms=False, maxphase=False, superkick=True, hangupkick=True, crosskick=True, full_output=False):

    eta = eval_eta(q)

    Lvec, S1vec, S2vec = angles_to_Lframe(theta1, theta2, deltaphi, 1, q, chi1, chi2)
    hatL = Lvec / jnp.linalg.norm(Lvec)
    hatS1 = S1vec / jnp.linalg.norm(S1vec)
    hatS2 = S2vec / jnp.linalg.norm(S2vec)

    #More spin parameters.
    Delta = - 1/(1+q) * (q*chi2*hatS2 - chi1*hatS1)
    Delta_par = jnp.dot(Delta, hatL)
    Delta_perp = jnp.linalg.norm(jnp.cross(Delta, hatL))
    chit = 1/(1+q)**2 * (chi2*q**2*hatS2 + chi1*hatS1)
    chit_par = jnp.dot(chit, hatL)
    chit_perp = jnp.linalg.norm(jnp.cross(chit, hatL))

    #Coefficients are quoted in km/s
    #vm and vperp from Kesden at 2010a. vpar from Lousto Zlochower 2013
    zeta=jnp.radians(145)
    A=1.2e4
    B=-0.93
    H=6.9e3

    #Multiply by 0/1 boolean flags to select terms
    V11 = 3677.76
    VA = 2481.21
    VB = 1792.45
    VC = 1506.52
    C2 = 1140
    C3 = 2481

    # #maxkick
    # bigTheta=np.random.uniform(0, 2*np.pi,q.shape) * (not maxphase)

    vm = A * eta**2 * (1+B*eta) * (1-q)/(1+q)
    vperp = H * eta**2 * Delta_par
    vpar = 16*eta**2 * (Delta_perp * (V11 + 2*VA*chit_par + 4*VB*chit_par**2 + 8*VC*chit_par**3) + chit_perp * Delta_par * (2*C2 + 4*C3*chit_par)) * jnp.cos(bigTheta)
    kick = jnp.array([vm+vperp*jnp.cos(zeta),vperp*jnp.sin(zeta),vpar]).T

    # if not kms:
    #     kick = kick/299792.458 # speed of light in km/s

    vk = jnp.linalg.norm(kick)

    return vk

Let's check this matches the original code. Note that [precession](https://github.com/dgerosa/precession) resamples the additional angle $\Theta$ internally, but above we made it an explicit input. We seed the random sampling below to ensure the same values are used.

In [None]:
import precession
from corner import corner

In [None]:
n = 10_000
np.random.seed(0)
bigTheta = np.random.uniform(0, 2 * np.pi, n)
theta1 = np.random.uniform(0, np.pi, n)
theta2 = np.random.uniform(0, np.pi, n)
deltaphi = np.random.uniform(0, 2 * np.pi, n)
q = np.random.uniform(0.1, 1, n)
chi1 = np.random.uniform(0, 1, n)
chi2 = np.random.uniform(0, 1, n)

In [None]:
np.random.seed(0)
og = np.array([
    precession.remnantmass(theta1, theta1, q, chi1, chi2),
    precession.remnantspin(theta1, theta1, deltaphi, q, chi1, chi2),
    precession.remnantkick(theta1, theta1, deltaphi, q, chi1, chi2, kms = True),
])

In [None]:
re = np.array([
    jax.vmap(remnantmass)(theta1, theta1, q, chi1, chi2),
    jax.vmap(remnantspin)(theta1, theta1, deltaphi, q, chi1, chi2),
    jax.vmap(remnantkick)(bigTheta, theta1, theta1, deltaphi, q, chi1, chi2),
])

In [None]:
np.allclose(og, re)

In [None]:
fig = None
for i, (samples, ls) in enumerate(zip((og, re), ('-', '--'))):
    fig = corner(
        samples.T, labels = ('mf', 'af', 'vf'), fig = fig,
        plot_datapoints = False, plot_density = False,
        plot_contours = True, fill_contours = False, no_fill_contours = True,
        hist_kwargs = dict(density = True, color = f'C{i}', ls = ls),
        contour_kwargs = dict(colors = [f'C{i}'], linestyles = ls),
    )

#### Priors

#### Training

#### Likelihood

#### Inference

#### Homework

We should add a bit more realism to our model:
- Include more flexible spin distributions.
- Account for differing time delays and redshift evolution for first- and second-generaiton mergers.
- Allow for a subpopulation of binaries in which both black holes are second generation.
- Allow for a subpopulation of sources in which second-generation mergers can't happen at all with distinct parameters for it's mass distribution.

In [None]:
# primary mass
def pdf_m(m, parameters):
    pl = bandpass_powerlaw(
        m,
        parameters['alpha'],
        parameters['m_min'],
        parameters['m_max'],
        parameters['d_min'],
        parameters['d_max'],
    )
    tn = bandpass_normal(
        m,
        parameters['mu_m'],
        parameters['sigma_m'],
        parameters['m_min'],
        parameters['m_max'],
        parameters['d_min'],
        parameters['d_max'],
    )
    return (1 - parameters['f_m']) * pl + parameters['f_m'] * tn

# mass ratio - this is a bit of a handful, but otherwise, autodiff doesn't work
# let me know if you spot a better way to do it :')
def pdf_q_given_m(q, m, parameters):
    # pdf defined in terms if secondary mass, then converted to mass ratio
    pdf = lambda q, m: highpass_powerlaw(
        q * m, parameters['beta'], parameters['m_min'], m, parameters['d_min'],
    ) * m
    single = lambda q, m: jax.lax.cond(
        parameters['m_min'] <= q * m, lambda: pdf(q, m), lambda: 0.0,
    )
    return jax.vmap(single)(q.ravel(), m.ravel()).reshape(q.shape)

# spin magnitude
def pdf_a(a, parameters):
    return truncnorm(a, parameters['mu_a'], parameters['sigma_a'], 0, 1)

# spin tilt
def pdf_c(c, parameters):
    return truncnorm(c, parameters['mu_c'], parameters['sigma_c'], -1, 1)

# redshift
def shape_z(z, parameters):
    return (1 + z)**parameters['gamma']

def pdf_z(z, parameters):
    zmax = 2
    fn = lambda z: (
        shape_z(z, parameters)
        * wcosmo.Planck15.differential_comoving_volume(z) * 4 * jnp.pi / 1e9
    )
    cut = (0 < z) * (z <= zmax)
    shape = fn(z)
    zz = jnp.linspace(0, zmax, 10_000)
    norm = jnp.trapezoid(fn(zz), zz)
    return cut * shape / norm

In [None]:
# the combined probability density for the population model
def density(data, parameters):
    return (
        pdf_m(data['mass_1_source'], parameters)
        * pdf_q_given_m(data['mass_ratio'], data['mass_1_source'], parameters)
        # * pdf_q(data['mass_ratio'], parameters)
        * pdf_a(data['a_1'], parameters)
        * pdf_a(data['a_2'], parameters)
        * pdf_c(data['cos_tilt_1'], parameters)
        * pdf_c(data['cos_tilt_2'], parameters)
        * pdf_z(data['redshift'], parameters)
    )

Let's plot what the population models look like for some parameter values.

In [None]:
import matplotlib.pyplot as plt

In [None]:
parameters = dict(
    alpha = -3.5,
    m_min = 5,
    m_max = 80,
    d_min = 5,
    d_max = 10,
    mu_m = 35,
    sigma_m = 3,
    f_m = 0.1,
)

In [None]:
m = jnp.linspace(2, 100, 1_000)
p = pdf_m(m, parameters)

plt.plot(m, p)
plt.semilogy()
plt.xlabel('primary mass')
plt.ylabel('PDF')
plt.ylim(1e-5, 1e0)

print(jnp.trapezoid(p, m))

#### Priors

Next, we'll set priors on the parameters of the population model - these are the parameters we want to measure from data.

In [None]:
import numpyro

In [None]:
priors = dict(
    alpha = numpyro.distributions.Uniform(-10, 10),
    m_min = numpyro.distributions.Uniform(2, 6),
    m_max = numpyro.distributions.Uniform(70, 100),
    d_min = numpyro.distributions.Uniform(0, 10),
    d_max = numpyro.distributions.Uniform(0, 10),
    mu_m = numpyro.distributions.Uniform(20, 50),
    sigma_m = numpyro.distributions.Uniform(1, 10),
    f_m = numpyro.distributions.Uniform(0, 1),
    beta = numpyro.distributions.Uniform(-10, 10),
    mu_a = numpyro.distributions.Uniform(0, 1),
    sigma_a = numpyro.distributions.Uniform(0.1, 1),
    mu_c = numpyro.distributions.Uniform(-1, 1),
    sigma_c = numpyro.distributions.Uniform(0.1, 4),
    gamma = numpyro.distributions.Uniform(-10, 10),
)

In [None]:
dim = len(priors)
dim

#### Likelihood

How likely is it that our population model is responsible for the observed data?

Unlike neural posterior estimation, which is a simulation-based inference method, variational inference is a likelihood-based method. This also means that it is not amortized, i.e., it is fit to one specific data set. Below we code up the gravitational-wave population likelihood; see, e.g.,

- https://arxiv.org/abs/1809.02063,
- https://arxiv.org/abs/2007.05579,
- https://arxiv.org/abs/2410.19145.

In particular, the likelihood function is approximated with several Monte Carlo integrals, which introduces additional statistical variance (https://arxiv.org/abs/1904.10879, https://arxiv.org/abs/2204.00461, https://arxiv.org/abs/2304.06138). We make sure to keep track of this variance below.

In [None]:
# mean and variance of the mean
def mean_and_variance(weights, n):
    mean = jnp.sum(weights, axis = -1) / n
    variance = jnp.sum(weights**2, axis = -1) / n**2 - mean**2 / n
    return mean, variance

# lazy ln(mean) and variance of ln(mean)
def ln_mean_and_variance(weights, n):
    mean, variance = mean_and_variance(weights, n)
    return jnp.log(mean), variance / mean**2

In [None]:
def ln_likelihood_and_variance(posteriors, injections, density, parameters):
    pe_weights = density(posteriors, parameters) / posteriors['prior']
    vt_weights = density(injections, parameters) / injections['prior']
    num_obs, num_pe = pe_weights.shape
    ln_lkls, pe_variances = ln_mean_and_variance(pe_weights, num_pe)
    ln_pdet, vt_variance = ln_mean_and_variance(vt_weights, injections['total'])
    ln_lkl = ln_lkls.sum() - ln_pdet * num_obs
    variance = pe_variances.sum() + vt_variance * num_obs**2
    # ln_lkl = jnp.nan_to_num(ln_lkl, nan = -jnp.inf)
    # variance = jnp.nan_to_num(variance, nan = jnp.inf)
    return ln_lkl, variance

#### Normalizing flow

Now let's set up the model that we'll train. We'll use a [block neural autoregressive flow](https://arxiv.org/abs/1904.04676) to approximate the population posterior. There's a nice library called [flowjax](https://github.com/danielward27/flowjax) to do normalizing flows in JAX that we'll use. This is built on top of [equinox](https://github.com/patrick-kidger/equinox), which will handle our neural networks.

If you aren't familiar with normalizing flows, the (very brief) idea is that you can construct a probability distribution by transforming a simple known distribution (such as a standard normal distribution) with an invertible and differentiable function using the change-of-variables formula. For normalizing flows, that function is parametrized by a neural network, which is what makes the transformation flexible.

The transformation is trained so that the output distribution best matches some target distribution - in our case the population posterior distribution. The variables being transformed are the population parameters above.

In [None]:
import equinox
from flowjax.distributions import StandardNormal
from flowjax.flows import block_neural_autoregressive_flow

In [None]:
flow_init = block_neural_autoregressive_flow(
    key = jax.random.key(0),
    base_dist = StandardNormal(shape = (dim,)),
    invert = False,
)

We should take care that our normalizing flow is defined on the parameter domain we want it to be. In particular, our priors impose bounds on the range of values that can be taken. Therefore, we'll add some additional transformations to ensure those bounds are respected. These transformation are fixed and not trainable, unlike the flow transformations.

In [None]:
from flowjax.bijections import Affine, Sigmoid, Chain, Stack
from flowjax.distributions import Transformed
import paramax

In [None]:
bijections = []
for k in priors:
    lo, hi = priors[k].low, priors[k].high
    bijection = Chain([Sigmoid(), Affine(loc = lo, scale = hi - lo)])
    bijections.append(bijection)
flow_init = Transformed(flow_init, paramax.non_trainable(Stack(bijections)))

#### Training

To train the flow, we need to define a loss function to minimize with respect to the neural-network parameters. For variational inference, the most common choice is the Kullback-Leibler divergence from the target posterior $\mathcal{P}(\Lambda)$ to the normalizing flow approximation $\mathcal{Q}(\Lambda)$, with $\Lambda$ being the parameters of our population model that we want to infer:

$$
\mathrm{KL}[\mathcal{Q},\mathcal{P}] = \int \mathrm{d}\,\Lambda\, \mathcal{Q}(\Lambda) \ln \frac{\mathcal{Q}(\Lambda)}{\mathcal{P}(\Lambda)} .
$$

We know that our target posterior can be written using Bayes' theorem:

$$
\mathcal{P}(\Lambda) = \frac { \mathcal{L}(\Lambda) \pi(\Lambda) } { \mathcal{Z} } ,
$$

where $\mathcal{L}(\Lambda)$ is the likelihood function, $\pi(\Lambda)$ is the prior, and $\mathcal{Z} = \int \mathrm{d}\,\Lambda\, \mathcal{L}(\Lambda) \pi(\Lambda)$ is the evidence.

We know the likelihood and prior - they're above - but we don't know the evidence. Therefore, the equivalent loss function is used:

$$
L = \mathrm{KL}[\mathcal{Q},\mathcal{P}] - \ln\mathcal{Z} = \int \mathrm{d}\,\Lambda\, \mathcal{Q}(\Lambda) \ln \frac{ \mathcal{Q}(\Lambda) }{ \mathcal{L}(\Lambda) \pi(\Lambda) } .
$$

This can be approximate with Monte Carlo integration using a batch of $M$ samples $\{\Lambda_i\}_{i=1}^M$ drawn from the normalizing flow $\mathcal{Q}$:

$$
L \approx \frac{1}{M} \sum_{i=1}^M \ln \frac{ \mathcal{Q}(\Lambda_i) }{ \mathcal{L}(\Lambda_i) \pi(\Lambda_i) } .
$$

First, let's choose some training settings. We'll use [optax](https://github.com/google-deepmind/optax) to optimize the neural-network parameters and [jax_tqdm](https://github.com/jeremiecoullon/jax-tqdm) to add a progress bar to our loop.

In [None]:
import optax
import jax_tqdm

In [None]:
batch_size = 1 # perhaps surprisingly, this is sufficient
steps = 10_000
learning_rate = 1e-2
optimizer = optax.adam(learning_rate)

Now the loss function and training loop.

In [None]:
# split the flow intro trainable and non-trainable partitions
params_init, static = equinox.partition(
    pytree = flow_init,
    filter_spec = equinox.is_inexact_array,
    is_leaf = lambda leaf: isinstance(leaf, paramax.NonTrainable),
)

In [None]:
def loss_fn(params, key):
    flow = equinox.combine(params, static)
    samples, ln_flows = flow.sample_and_log_prob(key, (batch_size,))
    parameters = dict(zip(priors, samples.T))
    ln_lkls, variances = jax.vmap(
        lambda parameters: ln_likelihood_and_variance(
            posteriors, injections, density, parameters,
        ),
    )(parameters)
    ln_priors = jnp.sum(
        jnp.array([priors[k].log_prob(parameters[k]) for k in priors]),
        axis = 0,
    )
    return jnp.mean(ln_flows - ln_priors - ln_lkls)

In [None]:
@jax_tqdm.scan_tqdm(steps, print_rate = 100, tqdm_type = 'std')
def update(carry, step):
    key, params, state = carry
    key, _key = jax.random.split(key)
    loss, grad = equinox.filter_value_and_grad(loss_fn)(params, _key)
    updates, state = optimizer.update(grad, state, params)
    params = equinox.apply_updates(params, updates)
    carry = key, params, state
    return carry, loss

In [None]:
import time

In [None]:
# Finally, the training loop.
# Sometimes the initial JIT compilation takes a few second to get going...

state = optimizer.init(params_init)
carry = jax.random.key(1), params_init, state

t0 = time.time()
carry, losses = jax.lax.scan(update, carry, jnp.arange(steps))
dt = time.time() - t0
print('total time including JIT compilation:', dt)

key, params, state = carry
flow = equinox.combine(params, static)

In [None]:
# plot the loss function values over training steps
plt.plot(losses);

#### Inference

Now that the flow is trained, we can draw as many posterior samples as we want.

In [None]:
from corner import corner

In [None]:
samples = flow.sample(jax.random.key(2), (10_000,))
posterior = dict(zip(priors, samples.T))

In [None]:
corner(np.array(samples), labels = list(priors));

Let's also plot the inferred population-level distributions of source parameters and their posterior uncertainties.

In [None]:
grid = dict(
    mass_1_source = jnp.linspace(2, 100, 1_000),
    mass_ratio = jnp.linspace(0, 1, 1_000),
    a = jnp.linspace(0, 1, 1_000),
    cos_tilt = jnp.linspace(-1, 1, 1_000),
    redshift = jnp.linspace(0, 2, 1_000),
)

# the mass ratio model is conditional on the primary mass, so we have to marginalize
def pdf_q_marginal(q, parameters):
    x, y = jnp.meshgrid(q, grid['mass_1_source'], indexing = 'ij')
    p = pdf_q_given_m(x, y, parameters) * pdf_m(y, parameters)
    return jnp.trapezoid(p, y, axis = 1)

pdf = dict(
    mass_1_source = pdf_m,
    mass_ratio = pdf_q_marginal,
    a = pdf_a,
    cos_tilt = pdf_c,
    redshift = pdf_z,
)

In [None]:
def make_plot(k, data):
    # we use sequential map for mass ratio because the integral uses more memory
    single = lambda parameters: pdf[k](grid[k], parameters)
    if k == 'mass_ratio':
        ps = jax.lax.map(single, data)
    else:
        ps = jax.vmap(single)(data)

    for qs, alpha in (
        ((0.005, 0.995), 0.2),
        ((0.05, 0.95), 0.3),
        ((0.25, 0.75), 0.4),
    ):
        label = f'{(qs[1]-qs[0]) * 100:.0f}% posterior'
        plt.fill_between(
            grid[k], *np.quantile(ps, qs, axis = 0), label = label,
            color = 'C0', alpha = alpha, lw = 0,
        )

    plt.plot(
        grid[k], np.median(ps, axis = 0), label = 'median posterior',
        c = 'C1', lw = 2,
    )
    plt.plot(
        grid[k], np.mean(ps, axis = 0), label = 'mean posterior (PPD)',
        c = 'C2', lw = 2, ls = '--',
    )

    plt.legend()
    plt.xlabel(k)
    plt.ylabel(f'p({k})')

In [None]:
for k in 'mass_1_source', 'mass_ratio', 'a', 'cos_tilt', 'redshift':
    make_plot(k, posterior)

    if k == 'mass_1_source':
        plt.semilogy()
        plt.ylim(1e-5, 1e0)
    elif k == 'mass_ratio':
        plt.semilogy()
        plt.ylim(1e-2, 1e1)

    plt.show()

#### Homework

There is some immediate tinkering you can do with the code above:
- The training settings, e.g., batch size, number of training steps, learning rate, optimizer etc. Try a [learning-rate scheduler](https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html), for example.
- The [flow settings](https://danielward27.github.io/flowjax/api/flows.html#flowjax.flows.block_neural_autoregressive_flow), e.g., try making the network smaller or larger.
- The flow itself, i.e., try [a different type](https://danielward27.github.io/flowjax/api/flows.html) of normalizing flow.

Try targeting a different gravitational-wave population:
- Try altering some of the population models.
- Try a completely different population models that you're interested in testing.

You could compare to different code backends or packages:
- Try implementing this in your favourite ML package, e.g., PyTorch, TensorFlow, Julia, etc.
- Compare the results here to an existing variational inference package, e.g., [pyro](https://docs.pyro.ai/en/stable/inference_algos.html) or [numpyro](https://num.pyro.ai/en/latest/svi.html).

We should check our results:
- Do you trust the posterior predicted by the flow and how would you test it?
- How could you use the normalizing flow to compute the Bayesian evidence for model comparison?
- What about the Monte Carlo variance - is it under control?
- Try reusing the likelihood function we coded up with a stochastic sampling algorithm to compare posteriors.

In https://arxiv.org/abs/2504.07197, we have several tips for training and inference validation.