If you're running in a separate notebook (e.g., Google Colab), go through and un-comment the cells below as required. Also make sure to set the runtime before running the notebook.

In [None]:
# If you're running on a shared cluster and want to limit the resources you take up:
import os
os.environ["OPENBLAS_NUM_THREADS"] = '1'
os.environ["MKL_NUM_THREADS"] = '1'
os.environ["VECLIB_MAXIMUM_THREADS"] = '1'
os.environ["NUMEXPR_NUM_THREADS"] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['NPROC'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # you can change to a GPU ID not in use

In [None]:
# !pip install numpy matplotlib corner h5ify precession
# !pip install wcosmo jax_tqdm equinox equinox optax flowjax

In [None]:
# # If you're running on CPU:
# !pip install jax numypro

# # If you're running on GPU
# !pip install -U 'jax[cuda12]'
# !pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

## Simulation-based population model of gravitational-wave sources

In this notebook, we'll train a normalizing flow to learn an astrophysical population model using simulations.

We'll use [JAX](https://github.com/jax-ml/jax) as the main workhorse behind this notebook.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)

In [None]:
# check for GPU devices
jax.devices()

#### Population model

First, let's define the population model that we'll use for the astrophysical distribution of sources.

As an example, we'll construct a simple simulated model of "hierarchical" mergers, where some of the population of black holes are the products of previous black-hole mergers. We call "first-generation" (1G) black holes those that are born from stars and "second-generation" (2G) black holes those that are born from black-hole mergers.

Some of the ideas are based on our paper ["Deep learning and Bayesian inference of gravitational-wave populations: Hierarchical black-hole mergers" (arXiv:2203.03651)](https://arxiv.org/abs/2203.03651) and [code](http://github.com/mdmould/qluster) called ["QLUSTER: quick clusters of merging binary black holes" (arXiv:2305.04987)](https://arxiv.org/abs/2305.04987).

Below are some utilities we'll use to build the population model.

In [None]:
# tapering functions

def cubic_filter(x):
    return (3 - 2 * x) * x**2 * (0 <= x) * (x <= 1) + (1 < x)

def highpass(x, xmin, dmin):
    return cubic_filter((x - xmin) / dmin)

def lowpass(x, xmax, dmax):
    return highpass(x, xmax, -dmax)

def bandpass(x, xmin, xmax, dmin, dmax):
    return highpass(x, xmin, dmin) * lowpass(x, xmax, dmax)

In [None]:
# power law functions

def powerlaw_integral(x, alpha, loc, delta):
    a, c, d = alpha, loc, delta
    return (
        3 * (2 * c + (4 + a) * d)
        * (c**2 / (1 + a) - 2 * c * x / (2 + a) + x**2 / (3 + a))
        - 2 * (x - c)**3
    ) * x**(1 + a) / (4 + a) / d**3

def bandpass_powerlaw(x, alpha, xmin, xmax, dmin, dmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = x**alpha * bandpass(x, xmin, xmax, dmin, dmax)
    norm = (
        - powerlaw_integral(xmin, alpha, xmin, dmin)
        + powerlaw_integral(xmin + dmin, alpha, xmin, dmin)
        - (xmin + dmin)**(alpha + 1) / (alpha + 1)
        + (xmax - dmax)**(alpha + 1) / (alpha + 1)
        - powerlaw_integral(xmax - dmax, alpha, xmax, -dmax)
        + powerlaw_integral(xmax, alpha, xmax, -dmax)
    )
    return cut * shape / norm

In [None]:
# Gaussian functions

def truncated_normal(x, mu, sigma, xmin, xmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = jax.scipy.stats.norm.pdf(x, mu, sigma)
    norm = (
        - jax.scipy.stats.norm.cdf(xmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax, mu, sigma)
    )
    return cut * shape / norm

def normal_integral(x, mu, sigma, loc, delta):
    m, s, c, d = mu, sigma, loc, delta
    return (
        jnp.exp(-(x - m)**2 / 2 / s ** 2) * (2 / jnp.pi)**0.5 * s * (
            6 * c * (c + d - m - x)
            - 3 * d * (m + x)
            + 2 * (m**2 + 2 * s**2 + m * x + x**2)
        )
        - jax.lax.erf((m - x) / s / 2**0.5) * (
            (2 * c + 3 * d - 2 * m) * (c - m)**2
            + 3 * s**2 * (2 * c + d - 2 * m)
        )
    ) / 2 / d**3

def bandpass_normal(x, mu, sigma, xmin, xmax, dmin, dmax):
    cut = (xmin <= x) * (x <= xmax)
    shape = (
        jax.scipy.stats.norm.pdf(x, mu, sigma)
        * bandpass(x, xmin, xmax, dmin, dmax)
    )
    norm = (
        - normal_integral(xmin, mu, sigma, xmin, dmin)
        + normal_integral(xmin + dmin, mu, sigma, xmin, dmin)
        - jax.scipy.stats.norm.cdf(xmin + dmin, mu, sigma)
        + jax.scipy.stats.norm.cdf(xmax - dmax, mu, sigma)
        - normal_integral(xmax - dmax, mu, sigma, xmax, -dmax)
        + normal_integral(xmax, mu, sigma, xmax, -dmax)
    )
    return cut * shape / norm

To build up a population of second-generation mergers, we'll begin by defining the distribution of first-generation mergers. We'll assume that:
- black-hole masses are distributed from a tapered power law with an additional Gaussian peak (similar to the [Power Law + Peak](https://arxiv.org/abs/1801.02699) model), where both components decay to zero after the Gaussian peak - this is a mock up of the effect of the (pulsational) pair instability;
- black-holes spin magnitudes are drawn from a truncated normal distribution;
- spins are isotropic in direction.

In [None]:
def maximum_mass(parameters):
    return parameters['mu_m'] + parameters['sigma_m'] * 3

def pdf_m(m, parameters):
    pl = bandpass_powerlaw(
        m,
        alpha = parameters['alpha'],
        xmin = parameters['m_min'],
        xmax = maximum_mass(parameters),
        dmin = parameters['d_min'],
        dmax = parameters['sigma_m'] * 3,
    )
    tn = bandpass_normal(
        m,
        mu = parameters['mu_m'],
        sigma = parameters['sigma_m'],
        xmin = parameters['m_min'],
        xmax = maximum_mass(parameters),
        dmin = parameters['d_min'],
        dmax = parameters['sigma_m'] * 3,
    )
    return (1 - parameters['f_m']) * pl + parameters['f_m'] * tn

In [None]:
import matplotlib.pyplot as plt

In [None]:
parameters = dict(
    alpha = -2.5,
    m_min = 5,
    d_min = 5,
    mu_m = 50,
    sigma_m = 3,
    f_m = 0.1,
)

m = jnp.linspace(2, 100, 1_000)
p = pdf_m(m, parameters)

plt.plot(m, p)
plt.axvline(parameters['mu_m'], c = 'r', ls = '--')
plt.axvline(maximum_mass(parameters), c = 'r', ls = '--')
plt.semilogy()
plt.ylim(1e-5, 1e0)

jnp.trapezoid(p, m)

To generate training data for our model, we need a function to generate samples for our 1G mergers.

We model black holes as being paired according to a power law with slope `beta` in the total binary mass.

In [None]:
def inversion_sampling(key, shape, pdf, x):
    cdf = jnp.cumsum(jnp.diff(x) * (pdf[1:] + pdf[:-1]) / 2)
    cdf = jnp.insert(cdf / cdf[-1], 0, 0)
    u = jax.random.uniform(key, shape)
    return jnp.interp(u, cdf, x)

def sample_m(key, shape, parameters):
    x = jnp.linspace(parameters['m_min'], maximum_mass(parameters), 1_000)
    p = pdf_m(x, parameters)
    return inversion_sampling(key, shape, p, x)

def sample_truncated_normal(key, shape, mu, sigma, lo, hi):
    u = jax.random.uniform(key, shape)
    loc = jax.scipy.stats.norm.cdf(lo, mu, sigma)
    scale = jax.scipy.stats.norm.cdf(hi, mu, sigma) - loc
    return jax.scipy.stats.norm.ppf(u * scale + loc, mu, sigma)

In [None]:
def sample_1g(key, n, parameters):
    # masses
    key, subkey = jax.random.split(key)
    m = sample_m(subkey, (2, n), parameters)
    m1, m2 = m.max(axis = 0), m.min(axis = 0)
    q = m2 / m1

    # pairing function
    pair = (m1 + m2)**parameters['beta']
    key, subkey = jax.random.split(key)
    idxs = jax.random.choice(subkey, n, shape = (n,), p = pair)
    m1, q = m1[idxs], q[idxs]

    # spin magnitudes
    key, subkey = jax.random.split(key)
    a1, a2 = sample_truncated_normal(
        subkey, (2, n), parameters['mu_a'], parameters['sigma_a'], 0, 1,
    )

    # spin tilts
    key, subkey = jax.random.split(key)
    c1, c2 = jax.random.uniform(subkey, (2, n), minval = -1, maxval = 1)

    # azimuthal angle
    key, subkey = jax.random.split(key)
    dp = jax.random.uniform(subkey, (n,), minval = 0, maxval = 2 * jnp.pi)

    return jnp.array([m1, q, a1, a2, c1, c2, dp])

Next, we want to sample a population of mergers involving a 2G black hole. We'll just consider binaries in which one black hole is 1G and the other is 2G (1G+2G binary).

To do so, we need a function to output the properties of black-hole merger remnants from input pre-merger binary properties. Importantly, the remnant recieves a recoil - or "kick" - that imparts a velocity, which can eject it from its host environment and thus prevent it from participating in 2G mergers.

We follow the fitting formulae to numerical-relativity simulations included in the [precession](https://arxiv.org/abs/1605.01067) Python [package](https://github.com/dgerosa/precession), but with a minimal re-implementation in JAX (I've removed the internal vectorization, which can be done automatically with `jax.vmap`, and fixed to the default spin and kick options).

In [None]:
def eval_eta(q):
    return q / (1 + q)**2

def eval_theta12(theta1, theta2, deltaphi):
    return jnp.arccos(
        jnp.sin(theta1) * jnp.sin(theta2) * jnp.cos(deltaphi)
        + jnp.cos(theta1) * jnp.cos(theta2)
    )

def angles_to_Lframe(theta1, theta2, deltaphi, r, q, chi1, chi2):
    L = r**0.5 * q / (1 + q)**2
    S1 = chi1 / (1 + q)**2
    S2 = chi2 * q**2 / (1 + q)**2

    Lx = 0
    Ly = 0
    Lz = L
    Lvec = jnp.array([Lx, Ly, Lz])

    S1x = S1 * jnp.sin(theta1)
    S1y = 0
    S1z = S1 * jnp.cos(theta1)
    S1vec = jnp.array([S1x, S1y, S1z])

    S2x = S2 * jnp.sin(theta2) * jnp.cos(deltaphi)
    S2y = S2 * jnp.sin(theta2) * jnp.sin(deltaphi)
    S2z = S2 * jnp.cos(theta2)
    S2vec = jnp.array([S2x, S2y, S2z])

    return Lvec, S1vec, S2vec

In [None]:
def remnant_mass(theta1, theta2, q, chi1, chi2):
    eta = eval_eta(q)

    chit_par =  ( chi2*q**2 * jnp.cos(theta2) + chi1*jnp.cos(theta1) ) / (1+q)**2

    #Final mass. Barausse Morozova Rezzolla 2012
    p0 = 0.04827
    p1 = 0.01707
    Z1 = 1 + (1-chit_par**2)**(1/3)* ((1+chit_par)**(1/3)+(1-chit_par)**(1/3))
    Z2 = (3* chit_par**2 + Z1**2)**(1/2)
    risco = 3 + Z2 - jnp.sign(chit_par) * ((3-Z1)*(3+Z1+2*Z2))**(1/2)
    Eisco = (1-2/(3*risco))**(1/2)
    #Radiated energy, in units of the initial total mass of the binary
    Erad = eta*(1-Eisco) + 4* eta**2 * (4*p0+16*p1*chit_par*(chit_par+1)+Eisco-1)
    Mfin = 1- Erad # Final mass

    return Mfin

In [None]:
def remnant_spin(theta1, theta2, deltaphi, q, chi1, chi2):
    eta = eval_eta(q)

    kfit = jnp.array( [[jnp.nan, 3.39221, 4.48865, -5.77101, -13.0459] ,
                      [35.1278, -72.9336, -86.0036, 93.7371, 200.975],
                      [-146.822, 387.184, 447.009, -467.383, -884.339],
                      [223.911, -648.502, -697.177, 753.738, 1166.89]])
    xifit = 0.474046

    # Calculate K00 from Eq 11
    kfit = kfit.at[0,0].set(4**2 * ( 0.68646 - jnp.sum( kfit[1:,0] /(4**(3+jnp.arange(kfit.shape[0]-1)))) - (3**0.5)/2))

    theta12 = eval_theta12(theta1, theta2, deltaphi)

    eps1 = 0.024
    eps2 = 0.024
    eps12 = 0
    theta1 = theta1 + eps1 * jnp.sin(theta1)
    theta2 = theta2 + eps2 * jnp.sin(theta2)
    theta12 = theta12 + eps12 * jnp.sin(theta12)

    # Eq. 14 - 15
    atot = ( chi1*jnp.cos(theta1) + chi2*jnp.cos(theta2)*q**2 ) / (1+q)**2
    aeff = atot + xifit*eta* ( chi1*jnp.cos(theta1) + chi2*jnp.cos(theta2) )

    # Eq. 2 - 6 evaluated at aeff, as specified in Eq. 11
    Z1= 1 + (1-(aeff**2))**(1/3) * ( (1+aeff)**(1/3) + (1-aeff)**(1/3) )
    Z2= ( (3*aeff**2) + (Z1**2) )**(1/2)
    risco= 3 + Z2 - jnp.sign(aeff) * ( (3-Z1)*(3+Z1+2*Z2) )**(1/2)
    Eisco=(1-2/(3*risco))**(1/2)
    Lisco = (2/(3*(3**(1/2)))) * ( 1 + 2*(3*risco - 2 )**(1/2) )

    # Eq. 13
    etatoi = eta**(1+jnp.arange(kfit.shape[0]))
    innersum = jnp.sum(kfit.T * etatoi,axis=1)
    aefftoj = aeff**(jnp.arange(kfit.shape[1]))
    sumell = jnp.sum(innersum  * aefftoj,axis=0)
    ell = jnp.abs( Lisco  - 2*atot*(Eisco-1)  + sumell )

    # Eq. 16
    chifin = (1/(1+q)**2) * ( chi1**2 + (chi2**2)*(q**4)  + 2*chi1*chi2*(q**2)*jnp.cos(theta12)
            + 2*(chi1*jnp.cos(theta1) + chi2*(q**2)*jnp.cos(theta2))*ell*q + ((ell*q)**2)  )**(1/2)

    return jnp.minimum(chifin,1)

In [None]:
def remnant_kick(bigTheta, theta1, theta2, deltaphi, q, chi1, chi2):
# kms=False, maxphase=False, superkick=True, hangupkick=True, crosskick=True, full_output=False):

    eta = eval_eta(q)

    Lvec, S1vec, S2vec = angles_to_Lframe(theta1, theta2, deltaphi, 1, q, chi1, chi2)
    hatL = Lvec / jnp.linalg.norm(Lvec)
    hatS1 = S1vec / jnp.linalg.norm(S1vec)
    hatS2 = S2vec / jnp.linalg.norm(S2vec)

    #More spin parameters.
    Delta = - 1/(1+q) * (q*chi2*hatS2 - chi1*hatS1)
    Delta_par = jnp.dot(Delta, hatL)
    Delta_perp = jnp.linalg.norm(jnp.cross(Delta, hatL))
    chit = 1/(1+q)**2 * (chi2*q**2*hatS2 + chi1*hatS1)
    chit_par = jnp.dot(chit, hatL)
    chit_perp = jnp.linalg.norm(jnp.cross(chit, hatL))

    #Coefficients are quoted in km/s
    #vm and vperp from Kesden at 2010a. vpar from Lousto Zlochower 2013
    zeta=jnp.radians(145)
    A=1.2e4
    B=-0.93
    H=6.9e3

    #Multiply by 0/1 boolean flags to select terms
    V11 = 3677.76
    VA = 2481.21
    VB = 1792.45
    VC = 1506.52
    C2 = 1140
    C3 = 2481

    # #maxkick
    # bigTheta=np.random.uniform(0, 2*np.pi,q.shape) * (not maxphase)

    vm = A * eta**2 * (1+B*eta) * (1-q)/(1+q)
    vperp = H * eta**2 * Delta_par
    vpar = 16*eta**2 * (Delta_perp * (V11 + 2*VA*chit_par + 4*VB*chit_par**2 + 8*VC*chit_par**3) + chit_perp * Delta_par * (2*C2 + 4*C3*chit_par)) * jnp.cos(bigTheta)
    kick = jnp.array([vm+vperp*jnp.cos(zeta),vperp*jnp.sin(zeta),vpar]).T

    # if not kms:
    #     kick = kick/299792.458 # speed of light in km/s

    vk = jnp.linalg.norm(kick)

    return vk

Let's check this matches the original code (note that [precession](https://github.com/dgerosa/precession) resamples the additional angle $\Theta$ internally, but above we made it an explicit input; we seed the random sampling below to ensure the same values are used).

In [None]:
import precession
from corner import corner

In [None]:
n = 10_000
np.random.seed(0)
bigTheta = np.random.uniform(0, 2 * np.pi, n)
theta1 = np.random.uniform(0, np.pi, n)
theta2 = np.random.uniform(0, np.pi, n)
deltaphi = np.random.uniform(0, 2 * np.pi, n)
q = np.random.uniform(0.1, 1, n)
chi1 = np.random.uniform(0, 1, n)
chi2 = np.random.uniform(0, 1, n)

In [None]:
np.random.seed(0)
og = np.array([
    precession.remnantmass(theta1, theta1, q, chi1, chi2),
    precession.remnantspin(theta1, theta1, deltaphi, q, chi1, chi2),
    precession.remnantkick(theta1, theta1, deltaphi, q, chi1, chi2, kms = True),
])

In [None]:
re = np.array([
    jax.vmap(remnant_mass)(theta1, theta1, q, chi1, chi2),
    jax.vmap(remnant_spin)(theta1, theta1, deltaphi, q, chi1, chi2),
    jax.vmap(remnant_kick)(bigTheta, theta1, theta1, deltaphi, q, chi1, chi2),
])

In [None]:
np.allclose(og, re)

In [None]:
fig = None
for i, (samples, ls) in enumerate(zip((og, re), ('-', '--'))):
    fig = corner(
        samples.T, labels = ('mf', 'af', 'vf'), fig = fig,
        plot_datapoints = False, plot_density = False,
        plot_contours = True, fill_contours = False, no_fill_contours = True,
        hist_kwargs = dict(density = True, color = f'C{i}', ls = ls),
        contour_kwargs = dict(colors = [f'C{i}'], linestyles = ls),
    )

Next, we'll combine the 1G population with the remnant fits to produce a population of 1G+2G black-hole mergers.

For the 2G black holes, we will assume all mergers occur in an environment with escape speed `v_esc`; remnants that recieve a kick larger than `v_esc` are "ejected" and do not lead to 1G+2G mergers.

The 2G black holes that are retained will pair with 1G black holes, acording again to a power law in the total mass, this time with slope `gamma`.

In [None]:
def remnant(m1, q, a1, a2, c1, c2, dp, th):
    t1, t2 = jnp.arccos(c1), jnp.arccos(c2)
    mf = remnant_mass(t1, t2, q, a1, a2) * m1 * (1 + q)
    af = remnant_spin(t1, t2, dp, q, a1, a2)
    vf = remnant_kick(th, t1, t2, dp, q, a1, a2)
    return mf, af, vf

In [None]:
def sample_2g(key, mf, af, vf, parameters):
    n = jnp.array([mf, af, vf]).shape[1]

    # 1g mass
    key, subkey = jax.random.split(key)
    m = sample_m(subkey, (n,), parameters)

    # ejection and pairing
    key, subkey = jax.random.split(key)
    pair = (vf < parameters['v_esc']) * (mf + m)**parameters['gamma']
    idxs = jax.random.choice(subkey, n, shape = (n,), p = pair)

    # masses
    m_2g = jnp.stack([mf, m])
    sort = jnp.argsort(m_2g, axis = 0)
    m2_2g, m1_2g = jnp.take_along_axis(m_2g, sort, axis = 0)
    q_2g = m2_2g / m1_2g

    # spin magnitudes
    key, subkey = jax.random.split(key)
    a = sample_truncated_normal(
        subkey, (n,), parameters['mu_a'], parameters['sigma_a'], 0, 1,
    )
    a2_2g, a1_2g = jnp.take_along_axis(jnp.stack([af, a]), sort, axis = 0)

    # spin tilts
    key, subkey = jax.random.split(key)
    c1_2g, c2_2g = jax.random.uniform(subkey, (2, n), minval = -1, maxval = 1)

    # azimuthal spin
    key, subkey = jax.random.split(key)
    dp_2g = jax.random.uniform(subkey, (n,), minval = 0, maxval = 2 * jnp.pi)

    return jnp.array([m1_2g, q_2g, a1_2g, a2_2g, c1_2g, c2_2g, dp_2g])

Now, to combine these ingredients to sample our overall population.

The last assumption we'll make is that the overall fraction of 1G+2G mergers is proportional to the fraction of 2G black holes that are retained (i.e., with kicks smaller than `v_esc`), with constant of proportionality `r_2g`.

As we assume that the spin angles are isotropic, we will just model directly the masses and spin magnitudes.

In [None]:
dim = 4

In [None]:
def sample_mergers(key, n, parameters):
    # 1g+1g mergers
    key, subkey = jax.random.split(key)
    mergers = sample_1g(subkey, n, parameters)

    # 1g+1g remnants (2g)
    key, subkey = jax.random.split(key)
    th = jax.random.uniform(subkey, (n,), minval = 0, maxval = 2 * jnp.pi)
    mf, af, vf = jax.vmap(remnant)(*mergers, th)

    # 1g+2g mergers
    key, subkey = jax.random.split(key)
    mergers_2g = sample_2g(subkey, mf, af, vf, parameters)

    # mixing
    key, subkey = jax.random.split(key)
    f = (vf < parameters['v_esc']).mean() * parameters['r_2g']
    b = jax.random.binomial(subkey, n = 1, p = f, shape = (n,))
    mergers = jnp.where(b, mergers_2g, mergers)

    # we'll just fit masses and spin magnitudes as the other angles are
    # mostly independent
    return mergers[:4].T

Let's check how our simulated model looks on an example.

Is this simulated model realistic? No. Will it do for this tutorial? Probably.

In [None]:
parameters = dict(
    alpha = -2.5,
    m_min = 5,
    d_min = 5,
    mu_m = 35,
    sigma_m = 3,
    f_m = 0.1,
    mu_a = 0.1,
    sigma_a = 0.2,
    beta = 0,
    v_esc = 200,
    gamma = 0,
    r_2g = 0.5,
)

In [None]:
mergers = sample_mergers(jax.random.key(0), 100_000, parameters)

In [None]:
fig = corner(np.array(mergers), bins = 50, labels = ('m1', 'q', 'a1', 'a2'));

#### Priors

Moving on, the next thing we need is a prior distribution of model parameters over which we will sample when training the normalizing flow.

In [None]:
import numpyro

In [None]:
priors = dict(
    alpha = numpyro.distributions.Uniform(-5, 0),
    m_min = numpyro.distributions.Uniform(2, 10),
    d_min = numpyro.distributions.Uniform(0, 10),
    mu_m = numpyro.distributions.Uniform(20, 50),
    sigma_m = numpyro.distributions.Uniform(1, 5),
    f_m = numpyro.distributions.Uniform(0, 0.2),
    mu_a = numpyro.distributions.Uniform(0, 0.5),
    sigma_a = numpyro.distributions.Uniform(0.1, 0.5),
    beta = numpyro.distributions.Uniform(-5, 5),
    v_esc = numpyro.distributions.Uniform(0, 1_000),
    gamma = numpyro.distributions.Uniform(-5, 5),
    r_2g = numpyro.distributions.Uniform(0, 0.5),
)

In [None]:
cond_dim = len(priors)

In [None]:
def sample_parameters(key):
    keys = jax.random.split(key, len(priors))
    return {k: priors[k].sample(key) for k, key in zip(priors, keys)}

#### Normalizing flow

Now let's set up the model that we'll train. We'll use a [block neural autoregressive flow](https://arxiv.org/abs/1904.04676) to approximate the simulations. There's a nice library called [flowjax](https://github.com/danielward27/flowjax) to do normalizing flows in JAX that we'll use. This is built on top of [equinox](https://github.com/patrick-kidger/equinox), which will handle our neural networks.

If you aren't familiar with normalizing flows, the (very brief) idea is that you can construct a probability distribution by transforming a simple known distribution (such as a standard normal distribution) with an invertible and differentiable function using the change-of-variables formula. For normalizing flows, that function is parametrized by a neural network, which is what makes the transformation flexible.

The transformation is trained so that the output distribution best matches some target distribution - in our case the simulation-based population model.

In [None]:
import equinox
from flowjax.distributions import StandardNormal
from flowjax.flows import block_neural_autoregressive_flow

In [None]:
flow_init = block_neural_autoregressive_flow(
    key = jax.random.key(1),
    base_dist = StandardNormal(shape = (dim,)),
    cond_dim = cond_dim,
    invert = True,
)

It works by splitting pytrees - nested python containers of JAX arrays - into trainable and non-trainable subsets.

In [None]:
params_init, static = equinox.partition(flow_init, equinox.is_inexact_array)

We can easily get the number of parameters (weights and biases of the neural networks) in our model.

In [None]:
array_init, unravel = jax.flatten_util.ravel_pytree(params_init)
array_init.size

We should take care that our normalizing flow is defined on the parameter domain we want it to be. In particular, our priors impose bounds on the range of values that can be taken. Therefore, we'll add some additional transformations to ensure those bounds are respected. These transformation are fixed and not trainable, unlike the flow transformations. (We clip the inputs as the bounding transformations are asymptotic, meaning the extremal values are ill defined, e.g., exactly equal masseas.)

In [None]:
lo = jnp.array([2, 0, 0, 0])
hi = jnp.array([200, 1, 1, 1])

def inverse_mergers(mergers):
    x = (mergers - lo) / (hi - lo)
    x = jnp.clip(x, 1e-5, 1 - 1e-5)
    return jax.scipy.special.logit(x)

def forward_mergers(x):
    return jax.nn.sigmoid(x) * (hi - lo) + lo

Though not strictly necessary, we'll also do something similar for the population parameters, so that all inputs to the flow have roughly the same range of numerical values.

In [None]:
def inverse_parameters(parameters):
    x = jnp.array([
        (parameters[k] - priors[k].low) / (priors[k].high - priors[k].low)
        for k in priors
    ])
    # x = jnp.clip(x, 1e-5, 1 - 1e-5)
    x = jax.scipy.stats.norm.ppf(x)
    return x

#### Training

On to training the model.

To train the flow, we need to define a loss function to minimize with respect to the neural-network parameters. The training objective for flows can be thought of in several equivalent ways, including the Kullback-Leibler divergence, the cross entropy, and the likelihood of the training data. The upshot is that the probability density predicted by the flow can be matched to the empirical distribution of the simulations.

Below, we define the loss function and training loop to do this.

First, let's set some training parameters. We'll use [optax](https://github.com/google-deepmind/optax) to update the neural network.

In [None]:
import optax

In [None]:
batch_size = 1_000
steps = 10_000
learning_rate = 1e-2
learning_rate = optax.cosine_decay_schedule(learning_rate, steps)
optimizer = optax.adam(learning_rate)

*Note that we will train the flow by showing it only one simulation at a time, i.e., one value of `parameters` that produces an amount `batch_size` of mergers.*

In [None]:
# wrap the sampling and bounding functions together
def sample(key):
    key, subkey = jax.random.split(key)
    parameters = sample_parameters(subkey)
    c = inverse_parameters(parameters)
    key, subkey = jax.random.split(key)
    mergers = sample_mergers(subkey, batch_size, parameters)
    x = inverse_mergers(mergers)
    return x, c

# the average likelihood of mergers from a single simulation
def loss_fn(params, key):
    flow = equinox.combine(params, static)
    x, c = sample(key)
    return -flow.log_prob(x, c).mean()

In [None]:
# if things go wrong and the loss becomes non-finite,
# then return to the best last update
def check_finite(carry, loss):
    key, params, state, best_params, best_state, best_loss = carry
    params, state = jax.lax.cond(
        jnp.isfinite(loss),
        lambda: (params, state),
        lambda: (best_params, best_state),
    )
    return key, params, state, best_params, best_state, best_loss

# keep track of the flow with the best loss value
def check_best(carry, loss):
    key, params, state, best_params, best_state, best_loss = carry
    best_params, best_state, best_loss = jax.lax.cond(
        loss < best_loss,
        lambda: (params, state, loss),
        lambda: (best_params, best_state, best_loss),
    )
    return key, params, state, best_params, best_state, best_loss

# function to update the neural-network parameters
def update(carry, step):
    key, params, state, best_params, best_state, best_loss = carry
    key, subkey = jax.random.split(key)
    loss, grad = equinox.filter_value_and_grad(loss_fn)(params, subkey)
    updates, state = optimizer.update(grad, state, params)
    params = equinox.apply_updates(params, updates)
    carry = key, params, state, best_params, best_state, best_loss
    carry = check_finite(carry, loss)
    carry = check_best(carry, loss)
    return carry, loss

Another handy package is [jax_tqdm](https://github.com/jeremiecoullon/jax-tqdm) to add progress bar to JAX loops.

In [None]:
import jax_tqdm

In [None]:
update = jax_tqdm.scan_tqdm(steps, print_rate = 100, tqdm_type = 'std')(update)

Now let's train it!

In [None]:
# finally, the training loop (written without Python loops using JAX)
state = optimizer.init(params_init)
carry = (
    jax.random.key(2),
    params_init,
    state,
    params_init,
    state,
    jnp.inf,
)
carry, losses = jax.lax.scan(update, carry, jnp.arange(steps))
key, params, state, best_params, best_state, best_loss = carry

In [None]:
# the loss values over training steps
plt.plot(losses);
plt.axvline(losses.argmin(), c = 'r');

We can reconstruct the flow at both the last training step and the step with the best loss (but note that, due to sampling noise, the lowest loss value is not necessarily the one producing the "best" trained model).

In [None]:
last_flow = equinox.combine(params, static)
best_flow = equinox.combine(best_params, static)

How did the training do? Let's compare the trained flow to the actual simulated model.

In [None]:
from corner import corner

In [None]:
# # to select particular population parameters, uncomment this
# parameters = dict(
#     alpha = -2.5,
#     m_min = 5,
#     d_min = 5,
#     mu_m = 35,
#     sigma_m = 3,
#     f_m = 0.1,
#     mu_a = 0.1,
#     sigma_a = 0.2,
#     beta = 0,
#     v_esc = 200,
#     gamma = 0,
#     r_2g = 0.2,
# )

# # to select random population parameters, uncomment this
parameters = sample_parameters(jax.random.key(np.random.randint(1e9)))

c = inverse_parameters(parameters)

mergers = sample_mergers(jax.random.key(4), 10_000, parameters)

x = last_flow.sample(jax.random.key(5), (10_000,), condition = c)
last_flow_mergers = forward_mergers(x)

x = best_flow.sample(jax.random.key(6), (10_000,), condition = c)
best_flow_mergers = forward_mergers(x)

In [None]:
parameters

In [None]:
lim = np.max([
    mergers[:, 0].max(),
    last_flow_mergers[:, 0].max(),
    best_flow_mergers[:, 0].max(),
])
bounds = np.transpose([lo, hi])
bounds[0, 1] = lim

fig = None
for i, samples in enumerate((mergers, last_flow_mergers, best_flow_mergers)):
    fig = corner(
        np.array(samples), labels = ('m1', 'q', 'a1', 'a2'),
        fig = fig, plot_datapoints = False, plot_density = False,
        plot_contours = True, fill_contours = False, no_fill_contours = True,
        hist_kwargs = dict(density = True, color = f'C{i}'),
        contour_kwargs = dict(colors = [f'C{i}']),
        range = bounds,
        levels = (0.5, 0.9, 0.99), bins = 20, smooth = 0.5,
    )

for i in range(dim):
    ax = fig.axes[i + dim * i]
    ylim = max(patch.xy[:, 1].max() for patch in ax.patches)
    ax.set_ylim(0, ylim * 1.1)

fig.axes[1].legend(
    handles = fig.axes[0].patches,
    labels = ('simulation', 'last flow', 'best flow'),
    fontsize = 20,
    loc = 'upper left',
);

#### Homework

Now you've trained the population model, what next?
- Combine the implementation of the likelihood and sampling algorithms from other notebooks to infer the population parameters of our model.
- Run the original simulated model and the trained flow using the population inference results to see how predictions compare.

We should add a bit more realism to our model:
- Add a model for the evolution of the merger rate over redshift.
- Account for differing time delays and redshift evolution for first- and second-generaiton mergers.
- Allow for a subpopulation of binaries in which both black holes are second generation, or high generation.
- Allow for subpopulation form other channels with distinct first-generation properties, e.g., isolated evolution where second-generation mergers don't occur.
- Improve the prescriptions for binary pairing.
- Include a more realistic model for the host environments and their escape speeds.

There is some more immediate tinkering you can do with the training code above:
- The training settings, e.g., batch size, number of training steps, learning rate, optimizer etc. Try a [learning-rate scheduler](https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html), for example.
- The [flow settings](https://danielward27.github.io/flowjax/api/flows.html#flowjax.flows.block_neural_autoregressive_flow), e.g., try making the network smaller or larger.
- The flow itself, i.e., try [a different type](https://danielward27.github.io/flowjax/api/flows.html) of normalizing flow.