<a href="https://colab.research.google.com/github/ming-256/GPU-Accelerated-Bayesian-Inference-of-Gravitational-Waves/blob/main/gpu_gw_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/handley-lab/blackjax
!pip install anesthetic tqdm astropy
!pip install git+https://github.com/ming-256/jim
!pip install jax[cuda12_pip]==0.4.31 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Note: Installation may take 2-3 minutes in Google Colab
# Advanced dependencies (optax, flax) are installed later when needed

Collecting git+https://github.com/handley-lab/blackjax
  Cloning https://github.com/handley-lab/blackjax to /tmp/pip-req-build-grw12m59
  Running command git clone --filter=blob:none --quiet https://github.com/handley-lab/blackjax /tmp/pip-req-build-grw12m59
  Resolved https://github.com/handley-lab/blackjax to commit 171ff14d6f319a7f277396e0d7b772e8ac15a664
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jaxopt>=0.8 (from blackjax==0.1.dev706+g171ff14)
  Downloading jaxopt-0.8.5-py3-none-any.whl.metadata (3.3 kB)
Downloading jaxopt-0.8.5-py3-none-any.whl (172 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.4/172.4 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: blackjax
  Building wheel for blackjax (pyproject.toml) ... [?25l[?25hdone
  Created wheel for blackjax: filename=blackjax-0.1.

In [2]:
import jax
jax.config.update('jax_enable_x64', True)
import blackjax
import blackjax.ns.adaptive
import matplotlib.pyplot as plt
import time
import jax.scipy.stats as stats
import jax.numpy as jnp
import numpy as np
import tqdm
import anesthetic
from anesthetic import NestedSamples
from functools import partial
from blackjax.ns.adaptive import build_kernel, init
from blackjax.ns.base import new_state_and_info, delete_fn
from blackjax.ns.utils import repeat_kernel, finalise
from blackjax.mcmc.random_walk import build_rmh, RWState
from blackjax import SamplingAlgorithm
from astropy.time import Time
from jimgw.single_event.detector import Detector, H1, L1, V1
from jimgw.single_event.likelihood import original_relative_binning_likelihood as relative_binning_likelihood_function
from jimgw.single_event.waveform import RippleIMRPhenomD_NRTidalv2
from flowMC.strategy.optimization import optimization_Adam
from jaxtyping import Array, Float
import numpy.typing as npt
from scipy.interpolate import interp1d
import optax

In [3]:
label = 'Test'

# | Define LIGO event data
gps = 1187008882.43
fmin = 23.0
fmax = 2048.0
duration = 128
post_trigger_duration = 2
end_time = gps + post_trigger_duration
start_time = end_time - duration
roll_off = 0.4
tukey_alpha = 2 * roll_off / duration
psd_pad = 16
psd_duration = 1024
detectors = [H1, L1, V1]

for det in detectors:
    det.load_data(
        gps,
        duration - post_trigger_duration,
        post_trigger_duration,
        fmin,
        fmax,
        psd_pad=psd_pad,
        psd_duration=psd_duration,
        tukey_alpha=tukey_alpha,
        gwpy_kwargs={"cache": True, "version": 2}
        )

waveform = RippleIMRPhenomD_NRTidalv2(
    f_ref=fmin,
    use_lambda_tildes=False,
    no_taper=False
    )

frequencies = H1.frequencies
epoch = duration - post_trigger_duration
gmst = Time(gps, format="gps").sidereal_time("apparent", "greenwich").rad

Fetching data from H1...
Fetching PSD data...
Finished loading data.
Fetching data from L1...
Fetching PSD data...
Finished loading data.
Fetching data from V1...
Fetching PSD data...
Finished loading data.


In [4]:
# Define the parameters class
class ParameterPrior:
    def __init__(self, name: str, label: str, prior_fn: callable, *args):
        self.name = name
        self.label = label
        self.prior_fn = prior_fn
        self.args = args

    def logprob(self, value: float) -> float:
        return self.prior_fn(value, *self.args)

# Define the prior functions
@jax.jit
def UniformPrior(x: float, min: float, max: float) -> float:
    return stats.uniform.logpdf(x, min, max-min)

@jax.jit
def SinPrior(x):
    return jnp.where((x < 0.0) | (x > jnp.pi), -jnp.inf, jnp.log(jnp.sin(x) / 2.0))

@jax.jit
def CosPrior(x):
    return jnp.where((x < -jnp.pi / 2.0) | (x > jnp.pi / 2.0), -jnp.inf, jnp.log(jnp.cos(x) / 2.0))

@jax.jit
def BetaPrior(x, min, max):
    return stats.beta.logpdf(x, 3.0, 1.0, min, max-min)

@jax.jit
def FlatInLogPrior(x: float, min: float, max: float) -> float:
    return jnp.where((x < min) | (x > max), -jnp.inf, (-jnp.log(jnp.log(max / min)) - jnp.log(x)))

In [5]:
parameters = [
    ParameterPrior("M_c", r"$M_c$", UniformPrior, 1.184, 2.168),
    ParameterPrior("q", r"$q$", UniformPrior, 0.125, 1.00),
    ParameterPrior("s1_z", r"$s_{1z}$", UniformPrior, -0.05, 0.05),
    ParameterPrior("s2_z", r"$s_{2z}$", UniformPrior, -0.05, 0.05),
    ParameterPrior("iota", r"$\iota$", SinPrior),
    ParameterPrior("d_L", r"$d_L$", BetaPrior, 10.0, 75.0),
    ParameterPrior("t_c", r"$t_c$", UniformPrior, -0.1, 0.1),
    ParameterPrior("phase_c", r"$\phi_c$", UniformPrior, 0.0, 2 * jnp.pi),
    ParameterPrior("psi", r"$\psi$", UniformPrior, 0.0, jnp.pi),
    ParameterPrior("ra", r"$\alpha$", UniformPrior, 3.44, 3.45),
    ParameterPrior("dec", r"$\delta$", UniformPrior, -0.41, -0.40),
    ParameterPrior("lambda_1", r"$\Lambda_1$", UniformPrior, 0.0, 5000.0),
    ParameterPrior("lambda_2", r"$\Lambda_2$", UniformPrior, 0.0, 5000.0),
    ParameterPrior("H_0", r"$H_0$", FlatInLogPrior, 50.0, 140.0),
    ParameterPrior("v_p", r"$v_p$", UniformPrior, -1000.0, 1000.0)
]

parameter_names = [param.name for param in parameters]
labels = [param.label for param in parameters]

In [6]:
# | Define the log prior function
@jax.jit
def logprior_fn(params_dict):
    return jnp.sum(jnp.array([param.logprob(params_dict[param.name]) for param in parameters]))

In [7]:
# SETUP HETERODYNING
def max_phase_diff(
    f: npt.NDArray[np.floating],
    f_low: float,
    f_high: float,
    chi: Float = 1.0,
    ):

    gamma = np.arange(-5, 6, 1) / 3.0
    f = np.repeat(f[:, None], len(gamma), axis=1)
    f_star = np.repeat(f_low, len(gamma))
    f_star[gamma >= 0] = f_high
    return 2 * np.pi * chi * np.sum((f / f_star) ** gamma * np.sign(gamma), axis=1)

def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center):
    A0_array = []
    A1_array = []
    B0_array = []
    B1_array = []

    df = freqs[1] - freqs[0]
    data_prod = np.array(data * h_ref.conj())
    self_prod = np.array(h_ref * h_ref.conj())
    for i in range(len(f_bins) - 1):
        f_index = np.where((freqs >= f_bins[i]) & (freqs < f_bins[i + 1]))[0]
        A0_array.append(4 * np.sum(data_prod[f_index] / psd[f_index]) * df)
        A1_array.append(
            4
            * np.sum(
                data_prod[f_index]
                / psd[f_index]
                * (freqs[f_index] - f_bins_center[i])
            )
            * df
        )
        B0_array.append(4 * np.sum(self_prod[f_index] / psd[f_index]) * df)
        B1_array.append(
            4
            * np.sum(
                self_prod[f_index]
                / psd[f_index]
                * (freqs[f_index] - f_bins_center[i])
            )
            * df
        )

    A0_array = jnp.array(A0_array)
    A1_array = jnp.array(A1_array)
    B0_array = jnp.array(B0_array)
    B1_array = jnp.array(B1_array)
    return A0_array, A1_array, B0_array, B1_array

def original_likelihood(
    params: dict[str, Float],
    h_sky: dict[str, Float[Array, " n_dim"]],
    detectors: list[Detector],
    freqs: Float[Array, " n_dim"],
    align_time: Float,
    **kwargs,
) -> Float:
    log_likelihood = 0.0
    df = freqs[1] - freqs[0]
    for detector in detectors:
        h_dec = detector.fd_response(freqs, h_sky, params) * align_time
        match_filter_SNR = (
            4 * jnp.sum((jnp.conj(h_dec) * detector.data) / detector.psd * df).real
        )
        optimal_SNR = 4 * jnp.sum(jnp.conj(h_dec) * h_dec / detector.psd * df).real
        log_likelihood += match_filter_SNR - optimal_SNR / 2

    return log_likelihood

class HeterodynedLikelihood():
    def __init__(self, detectors: list[Detector], waveform, frequencies, epoch, gmst):
        self.detectors = detectors
        self.waveform = waveform
        self.frequencies = frequencies
        self.epoch = epoch
        self.gmst = gmst
        self.n_bins = 100
        self.A0_array = {}
        self.A1_array = {}
        self.B0_array = {}
        self.B1_array = {}
        self.waveform_low_ref = {}
        self.waveform_center_ref = {}

    def make_binning_scheme(
        self, freqs: npt.NDArray[np.floating], n_bins: int, chi: float = 1
    ) -> tuple[Float[Array, " n_bins+1"], Float[Array, " n_bins"]]:

        phase_diff_array = max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi)
        bin_f = interp1d(phase_diff_array, freqs)
        f_bins = np.array([])
        for i in np.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1):
            f_bins = np.append(f_bins, bin_f(i))
        f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2
        return jnp.array(f_bins), jnp.array(f_bins_center)

    def maximize_likelihood(
        self,
        popsize: int = 500,
        n_steps: int = 3000,
        ):
        parameter_names = [param.name for param in parameters] + ["M_c", "eta", "gmst"]
        def y(x):
            named_params = dict(zip(parameter_names, x))
            M_1 = named_params["M_1"]
            M_2 = named_params["M_2"]
            named_params["M_c"] = (M_1 * M_2)**0.6 / (M_1 + M_2)**0.2
            named_params["eta"] = M_1 * M_2 / (M_1 + M_2)**2
            named_params["gmst"] = self.gmst
            return -self.evaluate_original(named_params)

        print("Starting the optimizer")

        #optimizer = optimization_Adam(
        #    n_steps=n_steps, learning_rate=0.001, noise_level=1
        #)

        optimizer = optax.adamw(
            learning_rate=0.001
        )

        initial_position = jnp.zeros((popsize, len(parameter_names))) + jnp.nan
        while not jax.tree.reduce(
            jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)
        ).all():
            non_finite_index = jnp.where(
                jnp.any(
                    ~jax.tree.reduce(
                        jnp.logical_and,
                        jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
                    ),
                    axis=1,
                )
            )[0]

            rng_key = jax.random.PRNGKey(0)
            rng_key, init_key = jax.random.split(rng_key, 2)
            init_keys = jax.random.split(init_key, len(parameters))
            guess = jnp.vstack([sample_prior(param, key, popsize) for param, key in zip(parameters, init_keys)]).T

            M_1, M_2 = guess[:, M_1_index], guess[:, M_2_index]
            new_M_1 = jax.lax.select(M_1 < M_2, M_2, M_1)
            new_M_2 = jax.lax.select(M_1 < M_2, M_1, M_2)
            guess = guess.at[:, M_1_index].set(new_M_1)
            guess = guess.at[:, M_2_index].set(new_M_2)

            M_c = (guess[:, parameter_names.index("M_1")] * guess[:, parameter_names.index("M_2")]) ** 0.6 / (guess[:, parameter_names.index("M_1")] + guess[:, parameter_names.index("M_2")]) ** 0.2
            eta = guess[:, parameter_names.index("M_1")] * guess[:, parameter_names.index("M_2")]  / (guess[:, parameter_names.index("M_1")] + guess[:, parameter_names.index("M_2")]) ** 2
            gmst2 = jnp.full((guess.shape[0],), self.gmst)
            print(gmst2.shape)
            guess = jnp.hstack([guess, M_c[:, None], eta[:, None], gmst2[:, None]])
            finite_guess = jnp.where(
                jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
            )[0]
            common_length = min(len(finite_guess), len(non_finite_index))
            initial_position = initial_position.at[
                non_finite_index[:common_length]
            ].set(guess[:common_length])

        #rng_key, optimized_positions, summary = optimizer.optimize(
        #    jax.random.PRNGKey(12094), y, initial_position
        #)

        state = optimizer.init(initial_position)
        rng = jax.random.PRNGKey(0)
        for i in range(n_steps):
            rng, step_rng = jax.random.split(rng)

            step_loss, grad = jax.value_and_grad(y)(initial_position)
            updates, state = optimizer.update(grad, state, initial_position)
            initial_position = optax.apply_updates(initial_position, updates)
        print(f"Step {i}, Loss: {step_loss}")
        print(f"final position: {initial_position}" )

        best_fit = optimized_positions[jnp.argmin(summary["final_log_prob"])]

        named_params = dict(zip(parameter_names, best_fit))
        M_1 = named_params["M_1"]
        M_2 = named_params["M_2"]
        if M_1 < M_2:
            named_params["M_1"], named_params["M_2"] = M_2, M_1
        named_params["M_c"] = (M_1 * M_2)**0.6 / (M_1 + M_2)**0.2
        named_params["eta"] = M_1 * M_2 / (M_1 + M_2)**2
        return named_params

    def evaluate_original(
        self, params: dict[str, Float]
    ) -> (
        Float
    ):
        log_likelihood = 0
        frequencies = self.frequencies
        params["gmst"] = self.gmst
        # evaluate the waveform as usual
        waveform_sky = self.waveform(frequencies, params)
        align_time = jnp.exp(
            -1j * 2 * jnp.pi * frequencies * (self.epoch + params["t_c"])
        )
        log_likelihood = original_likelihood(
            params,
            waveform_sky,
            self.detectors,
            frequencies,
            align_time
        )

        return log_likelihood

    def reference_state(self):
        popsize = 100
        n_steps = 2000
        #params = self.maximize_likelihood(
        #    popsize=popsize,
        #    n_steps=n_steps,
        #)
        #params = {key: float(value) for key, value in params.items()}

        samples = anesthetic.read_chains('Final_UniformMass_Final.csv')
        import h5py
        file_name = 'GW170817_GWTC-1.hdf5'
        data_dict = {}
        with h5py.File(file_name, "r") as hdf_file:
            dataset = hdf_file["IMRPhenomPv2NRT_lowSpin_posterior"]
            for name in dataset.dtype.names:
                data_dict[name] = np.array(dataset[name])
        param_conversion = {'M_1':'m1_detector_frame_Msun' , 'M_2':'m2_detector_frame_Msun', 'd_L':'luminosity_distance_Mpc', 'iota':'costheta_jn', 'ra':'right_ascension', 'dec':'declination', 'spin1':'spin1', 'spin2':'spin2', 'costilt1':'costilt1', 'costilt2':'costilt2', 'lambda_1':'lambda1', 'lambda_2':'lambda2'}
        columns = samples.columns
        LVK_samples = anesthetic.MCMCSamples(columns=columns)
        for param in param_conversion.keys():
            LVK_samples[param] = data_dict[param_conversion[param]]
        LVK_samples["M_c"] = (LVK_samples["M_1"]*LVK_samples["M_2"])**0.6 / (LVK_samples["M_1"]+LVK_samples["M_2"])**0.2
        LVK_samples["q"] = LVK_samples["M_2"] / LVK_samples["M_1"]
        LVK_samples["iota"] = np.arccos(LVK_samples["iota"])
        LVK_samples["s1_z"] = LVK_samples["spin1"] * LVK_samples["costilt1"]
        LVK_samples["s2_z"] = LVK_samples["spin2"] * LVK_samples["costilt2"]
        LVK_samples["eta"] =  LVK_samples["M_1"] * LVK_samples["M_2"] / (LVK_samples["M_1"] + LVK_samples["M_2"])**2
        means = LVK_samples.mean()
        params = means.to_dict()
        params = {k[0]: v for k, v in params.items()}
        params["t_c"] = 0.014463470602718088
        params["phase_c"] = np.pi
        params["psi"] = np.pi / 2
        params.pop('logL')
        params.pop('logL_birth')
        params.pop('nlive')

        '''
        samples["M_c"] = (samples["M_1"]*samples["M_2"])**0.6 / (samples["M_1"]+samples["M_2"])**0.2
        samples["q"] = samples["M_2"] / samples["M_1"]
        samples["eta"] =  samples["M_1"] * samples["M_2"] / (samples["M_1"] + samples["M_2"])**2
        samples.columns = pd.MultiIndex.from_tuples(
            [(col[0], col[1] if col[0] not in ['M_c', 'q'] else ('$M_c$' if col[0] == 'M_c' else '$q$')) for col in samples.columns],
            names=samples.columns.names
            )

        means = samples.mean()
        params = means.to_dict()
        params.pop(('logL', '$\\ln\\mathcal{L}$'))
        params.pop(('logL_birth', '$\\ln\\mathcal{L}_\\mathrm{birth}$'))
        params.pop(('nlive', '$n_\\mathrm{live}$'))
        params = {k[0]: v for k, v in params.items()}
        '''
        print(f'Optimized reference parameters: {params}')
        params["gmst"] = self.gmst
        h_sky = waveform(self.frequencies, params)
        # Get the grid of the relative binning scheme (contains the final endpoint)
        # and the center points
        freq_grid, self.freq_grid_center = self.make_binning_scheme(
            np.array(self.frequencies), self.n_bins
            )
        self.freq_grid_low = freq_grid[:-1]
        if jnp.isclose(params["eta"], 0.25):
            params["eta"] = 0.249995
        # Get frequency masks to be applied, for both original
        # and heterodyne frequency grid
        h_amp = jnp.sum(
            jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]), axis=0
        )
        f_valid = self.frequencies[jnp.where(h_amp > 0)[0]]
        f_max = jnp.max(f_valid)
        f_min = jnp.min(f_valid)

        mask_heterodyne_grid = jnp.where((freq_grid <= f_max) & (freq_grid >= f_min))[0]
        mask_heterodyne_low = jnp.where(
            (self.freq_grid_low <= f_max) & (self.freq_grid_low >= f_min)
        )[0]
        mask_heterodyne_center = jnp.where(
            (self.freq_grid_center <= f_max) & (self.freq_grid_center >= f_min)
        )[0]
        freq_grid = freq_grid[mask_heterodyne_grid]
        self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low]
        self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center]

        # Assure frequency grids have same length
        if len(self.freq_grid_low) > len(self.freq_grid_center):
            self.freq_grid_low = self.freq_grid_low[: len(self.freq_grid_center)]

        h_sky_low = self.waveform(self.freq_grid_low, params)
        h_sky_center = self.waveform(self.freq_grid_center, params)
        # Get phase shifts to align time of coalescence
        align_time = jnp.exp(
            -1j
            * 2
            * jnp.pi
            * self.frequencies
            * (self.epoch + params["t_c"])
        )
        align_time_low = jnp.exp(
            -1j
            * 2
            * jnp.pi
            * self.freq_grid_low
            * (self.epoch + params["t_c"])
        )
        align_time_center = jnp.exp(
            -1j
            * 2
            * jnp.pi
            * self.freq_grid_center
            * (self.epoch + params["t_c"])
        )

        for detector in self.detectors:
            waveform_ref = (
                detector.fd_response(self.frequencies, h_sky, params)
                * align_time
            )
            self.waveform_low_ref[detector.name] = (
                detector.fd_response(self.freq_grid_low, h_sky_low, params)
                * align_time_low
            )
            self.waveform_center_ref[detector.name] = (
                detector.fd_response(
                    self.freq_grid_center, h_sky_center, params
                )
                * align_time_center
            )
            A0, A1, B0, B1 = compute_coefficients(
                detector.data,
                waveform_ref,
                detector.psd,
                self.frequencies,
                freq_grid,
                self.freq_grid_center,
            )
            self.A0_array[detector.name] = A0[mask_heterodyne_center]
            self.A1_array[detector.name] = A1[mask_heterodyne_center]
            self.B0_array[detector.name] = B0[mask_heterodyne_center]
            self.B1_array[detector.name] = B1[mask_heterodyne_center]

    def evaluate(self, params: dict[str, Float]) -> Float:
        frequencies_low = self.freq_grid_low
        frequencies_center = self.freq_grid_center
        params["gmst"] = self.gmst
        # evaluate the waveforms as usual
        waveform_sky_low = self.waveform(frequencies_low, params)
        waveform_sky_center = self.waveform(frequencies_center, params)
        align_time_low = jnp.exp(
            -1j * 2 * jnp.pi * frequencies_low * (self.epoch + params["t_c"])
        )
        align_time_center = jnp.exp(
            -1j * 2 * jnp.pi * frequencies_center * (self.epoch + params["t_c"])
        )
        return relative_binning_likelihood_function(
            params,
            self.A0_array,
            self.A1_array,
            self.B0_array,
            self.B1_array,
            waveform_sky_low,
            waveform_sky_center,
            self.waveform_low_ref,
            self.waveform_center_ref,
            self.detectors,
            frequencies_low,
            frequencies_center,
            align_time_low,
            align_time_center
        )

In [8]:
likelihood_function = HeterodynedLikelihood(detectors, waveform, frequencies, epoch, gmst)

# | Define the likelihood function
@jax.jit
def loglikelihood_fn(params_dict):
    params = dict(params_dict)  # Create a copy of the parameter dictionary
    params["eta"] = params["q"] / (1 + params["q"]) ** 2

    ll_vr = stats.norm.logpdf(3327, params["v_p"] + params["H_0"] * params["d_L"], 72)
    ll_vp = stats.norm.logpdf(310, params["v_p"], 150)

    return likelihood_function.evaluate(params) + ll_vr + ll_vp

In [9]:
# | Define the Nested Sampling algorithm
num_dims = len(parameter_names)
num_live = 2500
num_delete = int(num_live * 0.5)
num_mcmc_steps = int(num_dims * 5)

# Initialize nested sampler
def custom_nsmcmc(
    logprior_fn,
    loglikelihood_fn,
    num_delete,
    num_inner_steps,
):
  """
  Build a custom nested sampling MCMC algorithm from low-level components.

  This demonstrates how to construct a nested sampler using BlackJAX's
  modular infrastructure - useful for research and customization.

  Parameters
  ----------
  logprior_fn : callable
      Function that computes the log prior probability of the parameters.
  loglikelihood_fn : callable
      Function that computes the log likelihood of the parameters.
  num_delete : int
      Number of particles to delete at each step.
  num_inner_steps : int
      Number of inner MCMC steps to perform.

  Returns
  -------
  SamplingAlgorithm
      Custom nested sampling algorithm with init and step functions.
  """

  # Build the MCMC kernel for exploring within likelihood constraints
  mcmc_kernel = build_rmh()

  @repeat_kernel(num_inner_steps)
  def inner_kernel(rng_key, state, logprior_fn, loglikelihood_fn, loglikelihood_0, params):
      """Inner MCMC kernel that explores within likelihood constraint."""
      def proposal_distribution(rng_key, position):
          # Handle dictionary position structure
          if isinstance(position, dict):
              step = {}
              for key in position.keys():
                  sigma_val = params['sigma'][key] if isinstance(params['sigma'], dict) else params['sigma']
                  step[key] = sigma_val * jax.random.normal(rng_key, shape=position[key].shape)
                  rng_key, _ = jax.random.split(rng_key)  # Split key for each parameter
              return {key: position[key] + step[key] for key in position.keys()}
          else:
              # Fallback for array position
              step = params['sigma'] * jax.random.normal(rng_key, shape=position.shape)
              return position + step

      # Convert to MCMC state format
      mcmc_state = RWState(position=state.position, logdensity=state.logprior)
      new_mcmc_state, mcmc_info = mcmc_kernel(rng_key, mcmc_state, logprior_fn, proposal_distribution)

      # Evaluate likelihood at new position
      loglikelihood = loglikelihood_fn(new_mcmc_state.position)

      # Create new nested sampling state
      new_state, info = new_state_and_info(
          position=new_mcmc_state.position,
          logprior=new_mcmc_state.logdensity,
          loglikelihood=loglikelihood,
          info=mcmc_info,
      )

      # Accept only if likelihood exceeds threshold (key constraint!)
      new_state = jax.lax.cond(
          loglikelihood > loglikelihood_0,
          lambda _: new_state,
          lambda _: state,
          operand=None,
      )

      return new_state, info

  def update_inner_kernel_params_fn(state, info, params):
      """Adapt step size based on current particle distribution."""
      # Calculate standard deviation for each parameter
      sigma_dict = {}
      for key in state.particles.keys():
          sigma_dict[key] = jnp.std(state.particles[key])
      return {'sigma': sigma_dict}

  # Build the full nested sampling kernel
  _delete_fn = partial(delete_fn, num_delete=num_delete)

  step_fn = build_kernel(
      logprior_fn,
      loglikelihood_fn,
      _delete_fn,
      inner_kernel,
      update_inner_kernel_params_fn,
      )

  init_fn = partial(
      init,
      logprior_fn=logprior_fn,
      loglikelihood_fn=loglikelihood_fn,
      update_inner_kernel_params_fn=update_inner_kernel_params_fn,
      )

  return SamplingAlgorithm(init_fn, step_fn)



In [10]:
# | Sample live points from the prior
def sample_prior(parameter, key, n_live):
    if parameter.prior_fn == UniformPrior:
        return jax.random.uniform(key, (n_live,), minval=parameter.args[0], maxval=parameter.args[1])
    elif parameter.prior_fn == SinPrior:
        return 2 * jnp.arcsin(jax.random.uniform(key, (n_live,)) ** 0.5)
    elif parameter.prior_fn == CosPrior:
        return 2 * jnp.arcsin(jax.random.uniform(key, (n_live,)) ** 0.5) - jnp.pi / 2.0
    elif parameter.prior_fn == BetaPrior:
        return jax.random.beta(key, 3.0, 1.0, (n_live,)) * (parameter.args[1] - parameter.args[0]) + parameter.args[0]
    elif parameter.prior_fn == FlatInLogPrior:
        return parameter.args[0] * (parameter.args[1] / parameter.args[0]) ** jax.random.uniform(key, (n_live,))

# Produce initial particles
rng_key = jax.random.PRNGKey(0)
rng_key, prior_key = jax.random.split(rng_key, 2)
prior_keys = jax.random.split(prior_key, len(parameters))

# Create particles as a PyTree (dictionary) where each parameter has shape (num_live,)
particles = {}
for param, key in zip(parameters, prior_keys):
    particles[param.name] = sample_prior(param, key, num_live)

# Build nested sampler
nested_sampler = custom_nsmcmc(
    logprior_fn=logprior_fn,
    loglikelihood_fn=loglikelihood_fn,
    num_delete=num_delete,
    num_inner_steps=num_mcmc_steps,
)

# JIT compile
init_fn = jax.jit(nested_sampler.init)
step_fn = jax.jit(nested_sampler.step)

likelihood_function.reference_state()  # Precompute reference state for likelihood

Optimized reference parameters: {'M_1': 1.4900663499588278, 'M_2': 1.2755048307759151, 's1_z': 0.005136138323169717, 's2_z': 0.003235146993487445, 'iota': 2.545065595974997, 'd_L': 38.43694858086378, 't_c': 0.014463470602718088, 'phase_c': 3.141592653589793, 'psi': 1.5707963267948966, 'ra': 3.4461599999999994, 'dec': -0.4080839999999999, 'spin1': 0.023265200435027916, 'spin2': 0.022787912334285128, 'costilt1': 0.17303020549308612, 'costilt2': 0.10384259231844456, 'lambda_1': 368.17802383555687, 'lambda_2': 586.5487031450857, 'M_c': 1.197555435188453, 'q': 0.8605117648682054, 'eta': 0.24786618323504223}


In [11]:
# Run nested sampling
print("Running nested sampling...")
ns_start = time.time()
live = init_fn(particles)
dead = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while not live.logZ_live - live.logZ < -3:
        rng_key, subkey = jax.random.split(rng_key, 2)
        live, dead_info = step_fn(subkey, live)
        dead.append(dead_info)
        pbar.update(num_delete)

dead = blackjax.ns.utils.finalise(live, dead)
ns_time = time.time() - ns_start

Running nested sampling...


Dead points: 25000 dead points [01:04, 390.10 dead points/s] 


In [12]:
# Post processing
data = jnp.vstack([dead.particles[name] for name in parameter_names]).T
samples = NestedSamples(
    data,
    logL=dead.loglikelihood,
    logL_birth=dead.loglikelihood_birth,
    columns=parameter_names,
    labels=labels,
    logzero=jnp.nan,
)

print(f"Sampler runtime: {ns_time:.2f} seconds")
print(f"Log Evidence: {samples.logZ():.2f} ± {samples.logZ(100).std():.2f}")

samples.

samples.to_csv(f'{label}.csv')
print(f"Samples saved to {label}.csv")

1250 of which have logL == logL_birth.
This may just indicate numerical rounding errors at the peak of the likelihood, but further investigation of the chains files is recommended.
Dropping the invalid samples.


Sampler runtime: 81.18 seconds
Log Evidence: 77.30 ± 0.08
Samples saved to Test.csv
