## Download GWTC-3 parameter-estimation samples and sensitivty injections

In [None]:
# !wget https://zenodo.org/records/11254021/files/analyses_PowerLawPeak.tar.gz
# !tar -xzvf analyses_PowerLawPeak.tar.gz
# !mv ./analyses/PowerLawPeak/o1o2o3_mass_c_iid_mag_iid_tilt_powerlaw_redshift_samples.hdf5 .
# !rm -r analyses
# !wget https://zenodo.org/records/7890398/files/o1+o2+o3_bbhpop_real+semianalytic-LIGO-T2100377-v2.hdf5

In [1]:
ls

[0m[01;31manalyses_PowerLawPeak.tar.gz[0m
o1+o2+o3_bbhpop_real+semianalytic-LIGO-T2100377-v2.hdf5
o1o2o3_mass_c_iid_mag_iid_tilt_powerlaw_redshift_samples.hdf5
variational.ipynb


## Imports

In [2]:
import numpy as np
import pandas
import h5py
import json
import matplotlib.pyplot as plt
from corner import corner

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import jax
jax.config.update('jax_enable_x64', True)

import gwpopulation
gwpopulation.set_backend('jax')
from gwpopulation.utils import xp

## Load the data

In [4]:
def load_posteriors():
    posteriors = []
    
    with h5py.File(
        './o1o2o3_mass_c_iid_mag_iid_tilt_powerlaw_redshift_samples.hdf5', 'r',
    ) as f:

        samples = f['original']
        num_obs, num_samples = samples['a_1'].shape
        
        for n in range(num_obs):
            posterior = pandas.DataFrame()
            for par in samples:
                posterior[par] = samples[par][n, :]

            posteriors.append(posterior)
    
    return posteriors

In [5]:
def load_injections():
    cut_far = 1.0
    cut_snr = 10.0
    
    injections = {}
    
    with h5py.File(
        './o1+o2+o3_bbhpop_real+semianalytic-LIGO-T2100377-v2.hdf5', 'r',
    ) as f:
        samples = f['injections']
    
        ifars = [
            samples[par][:] for par in samples.keys()
            if ('ifar' in par) and ('cwb' not in par)
            ]
        ifar = np.max(ifars, axis=0)
        snr = samples['optimal_snr_net'][:]
        runs = samples['name'][:].astype(str)
        found = np.where(runs == 'o3', ifar > 1/cut_far, snr > cut_snr)
    
        mass_1 = samples['mass1_source'][:][found]
        mass_2 = samples['mass2_source'][:][found]
        spin1x = samples['spin1x'][:][found]
        spin1y = samples['spin1y'][:][found]
        spin1z = samples['spin1z'][:][found]
        spin2x = samples['spin2x'][:][found]
        spin2y = samples['spin2y'][:][found]
        spin2z = samples['spin2z'][:][found]
        redshift = samples['redshift'][:][found]
    
        mass_ratio = mass_2 / mass_1
        a_1 = (spin1x**2 + spin1y**2 + spin1z**2)**.5
        a_2 = (spin2x**2 + spin2y**2 + spin2z**2)**.5
        cos_tilt_1 = spin1z / a_1
        cos_tilt_2 = spin2z / a_2
    
        injections['mass_1'] = mass_1
        injections['mass_ratio'] = mass_ratio
        injections['a_1'] = a_1
        injections['a_2'] = a_2
        injections['cos_tilt_1'] = cos_tilt_1
        injections['cos_tilt_2'] = cos_tilt_2
        injections['redshift'] = redshift
    
        prior = samples['sampling_pdf'][:][found]
        prior *= mass_1 * 4 * np.pi**2 * a_1**2 * a_2**2
        injections['prior'] = prior
    
        dodgy = injections['redshift'] > 1.9
        for par in injections:
            injections[par] = injections[par][~dodgy]
    
        injections['total_generated'] = samples.attrs['total_generated']
    
    for par in injections:
        injections[par] = xp.array(injections[par])
    
    return injections

In [6]:
posteriors = load_posteriors()
injections = load_injections()

2024-07-15 19:59:33.527823: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


## Define the population model

In [7]:
from gwpopulation.models.mass import SinglePeakSmoothedMassDistribution
from gwpopulation.models.spin import iid_spin
from gwpopulation.models.redshift import PowerLawRedshift
from gwpopulation.experimental.jax import NonCachingModel

In [8]:
def spin_model(dataset, xi_spin, sigma_spin, alpha_chi, beta_chi):
    amax = 1.0
    return iid_spin(dataset, xi_spin, sigma_spin, amax, alpha_chi, beta_chi)

In [9]:
def make_model():
    models = [
        SinglePeakSmoothedMassDistribution(),
        spin_model,
        PowerLawRedshift(cosmo_model = 'Planck15'),
    ]
    return NonCachingModel(models)

## Defined the population likelihood

In [10]:
from gwpopulation.vt import ResamplingVT
from gwpopulation.hyperpe import HyperparameterLikelihood

In [11]:
selection_function = ResamplingVT(
    model = make_model(),
    data = injections,
    n_events = len(posteriors),
    marginalize_uncertainty = False,
    enforce_convergence = False,
)

In [12]:
likelihood = HyperparameterLikelihood(
    posteriors = posteriors,
    hyper_prior = make_model(),
    selection_function = selection_function,
    maximum_uncertainty = xp.inf,
)

## Variational inference

In [16]:
from gwax.flows import default_flow

In [18]:
flow = default_flow(jax.random.PRNGKey(0), prior_bounds.values())

In [25]:
import equinox
from flowjax.wrappers import NonTrainable

In [26]:
params, static = equinox.partition(
    pytree = flow,
    filter_spec = equinox.is_inexact_array,
    is_leaf = lambda leaf: isinstance(leaf, NonTrainable),
)

In [28]:
equinox.combine(params, static).sample_and_log_prob(jax.random.PRNGKey(0))

(Array([ 3.52449962, -1.39407437,  6.18543927, 68.38022515,  0.70118079,
        28.98300705,  1.59406214,  6.25878098,  8.59514782,  1.95192068,
         0.34938759,  3.20558454, -3.26842975], dtype=float64),
 Array(-19.21724187, dtype=float64))

In [13]:
from gwax.variational import trainer

In [14]:
prior_bounds = dict(
    alpha = [-4, 12],
    beta = [-2, 7],
    mmin = [0, 6.5],
    mmax = [65, 100],
    lam = [0, 1],
    mpp = [20, 50],
    sigpp = [1, 10],
    delta_m = [0, 10],
    alpha_chi = [1, 10],
    beta_chi = [1, 10],
    xi_spin = [0, 1],
    sigma_spin = [0.1, 4],
    lamb = [-6, 6],
)

In [15]:
flow, losses = trainer(
    key = jax.random.PRNGKey(42),
    prior_bounds = prior_bounds,
    likelihood = likelihood,
    flow = None,
    batch_size = 1,
    steps = 1_000,
    learning_rate = 1e-2,
    optimizer = None,
    taper = None,
    temper_schedule = None,
    print_rate = 1,
)

AttributeError: DynamicJaxprTracer has no attribute sample_and_log_prob