In [1]:
import bilby as bb
import gwpopulation as gwpop
import jax
import matplotlib.pyplot as plt
import pandas as pd
from bilby.core.prior import PriorDict, Uniform
from gwpopulation.experimental.jax import JittedLikelihood, NonCachingModel
import os

gwpop.set_backend("jax")

xp = gwpop.utils.xp

In [2]:
project_dir = '/home/divyajyoti/ACADEMIC/Projects/Cardiff_University/Next_gen_detectability/A-sharp-study/'

## Load posteriors

In [3]:
posteriors = pd.read_pickle(os.path.join(project_dir, 'gwpopulation', 'BBH', 'redshift_posteriors_499_events.pkl'))

## Load injections

In [4]:
import dill

with open(os.path.join(project_dir, 'gwpopulation', 'BBH', 'detected_injections.pkl'), "rb") as ff:
    injections = dill.load(ff)

## Define models and likelihood

In [5]:
model = NonCachingModel(
    model_functions=[gwpop.models.redshift.MadauDickinsonRedshift(cosmo_model="Planck18", z_max=8)],
    #model_functions=[gwpop.models.redshift.PowerLawRedshift(z_max=8)],
)

vt = gwpop.vt.ResamplingVT(model=model, data=injections, n_events=len(posteriors))

likelihood = gwpop.hyperpe.HyperparameterLikelihood(
    posteriors=posteriors,
    hyper_prior=model,
    selection_function=vt,
)

## Define prior

In [6]:
priors = PriorDict()
priors['gamma'] = Uniform(minimum=1, maximum=5, latex_label="$\\gamma$")
priors['kappa'] = Uniform(minimum=2, maximum=8, latex_label="$\\kappa$")
priors['z_peak'] = Uniform(minimum=0.5, maximum=4, latex_label="$z_{peak}$")
#priors['lamb'] = Uniform(minimum=0.5, maximum=4, latex_label="$\\lambda$")

## Just-in-time compile

In [7]:
parameters = priors.sample()
likelihood.parameters.update(parameters)
likelihood.log_likelihood_ratio()
%time print(likelihood.log_likelihood_ratio())
jit_likelihood = JittedLikelihood(likelihood)
jit_likelihood.parameters.update(parameters)
%time print(jit_likelihood.log_likelihood_ratio())
%time print(jit_likelihood.log_likelihood_ratio())

4497.104404986363
CPU times: user 669 ms, sys: 181 ms, total: 850 ms
Wall time: 377 ms
4497.104404986363
CPU times: user 8.81 s, sys: 681 ms, total: 9.49 s
Wall time: 2.29 s
4497.104404986363
CPU times: user 209 ms, sys: 46.2 ms, total: 256 ms
Wall time: 81.2 ms


In [8]:
result = bb.run_sampler(
    likelihood=jit_likelihood,
    priors=priors,
    sampler="dynesty",
    nlive=100,
    label="cosmo",
    sample="acceptance-walk",
    naccept=5,
    save="hdf5",
)

13:07 bilby INFO    : Running for label 'cosmo', output will be saved to 'outdir'
13:07 bilby INFO    : Analysis priors:
13:07 bilby INFO    : gamma=Uniform(minimum=1, maximum=5, name=None, latex_label='$\\gamma$', unit=None, boundary=None)
13:07 bilby INFO    : kappa=Uniform(minimum=2, maximum=8, name=None, latex_label='$\\kappa$', unit=None, boundary=None)
13:07 bilby INFO    : z_peak=Uniform(minimum=0.5, maximum=4, name=None, latex_label='$z_{peak}$', unit=None, boundary=None)
13:07 bilby INFO    : Analysis likelihood class: <class 'gwpopulation.experimental.jax.JittedLikelihood'>
13:07 bilby INFO    : Analysis likelihood noise evidence: nan
13:07 bilby INFO    : Single likelihood evaluation took 4.427e-05 s
13:07 bilby INFO    : Using sampler Dynesty with kwargs {'nlive': 100, 'bound': 'live', 'sample': 'acceptance-walk', 'periodic': None, 'reflective': None, 'update_interval': 600, 'first_update': None, 'npdim': None, 'rstate': None, 'queue_size': 1, 'pool': None, 'use_pool': None

131it [00:21,  5.91it/s, bound:0 nc:  3 ncall:3.6e+02 eff:36.9% logz-ratio=4502.36+/-0.20 dlogz:23.3>0.1]

13:08 bilby INFO    : Run interrupted by signal 2: checkpoint and exit on 130
13:08 bilby INFO    : Written checkpoint file outdir/cosmo_resume.pickle



Exception while calling loglikelihood function:
  params: [2.83899239 2.92941997 3.31222462]
  args: []
  kwargs: {}
  exception:


Traceback (most recent call last):
  File "/home/divyajyoti/miniconda3/envs/gwpopulation/lib/python3.11/site-packages/dynesty/dynesty.py", line 913, in __call__
    return self.func(np.asarray(x).copy(), *self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/divyajyoti/miniconda3/envs/gwpopulation/lib/python3.11/site-packages/bilby/core/sampler/dynesty.py", line 54, in _log_likelihood_wrapper
    return _sampling_convenience_dump.likelihood.log_likelihood_ratio()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/divyajyoti/miniconda3/envs/gwpopulation/lib/python3.11/site-packages/gwpopulation/experimental/jax.py", line 95, in log_likelihood_ratio
    np.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs))
  File "/home/divyajyoti/miniconda3/envs/gwpopulation/lib/python3.11/site-packages/numpy/lib/_type_check_impl.py", line 458, in nan_to_num
    x = _nx.array(x, subok=True, copy=copy

SystemExit: 130

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
