In [21]:
import bilby as bb
import gwpopulation as gwpop
import jax
import matplotlib.pyplot as plt
import pandas as pd
from bilby.core.prior import PriorDict, Uniform
from gwpopulation.experimental.jax import JittedLikelihood, NonCachingModel
import os
import bilby

In [22]:
gwpop.set_backend("jax")

xp = gwpop.utils.xp

project_dir = '/home/divyajyoti/ACADEMIC/Projects/Cardiff_University/Next_gen_detectability/A-sharp-study/'
#project_dir = '/home/divyajyoti.nln/Cardiff_University/Next_gen_detectability/A-sharp-study/'

In [31]:
#netw = 'Asharp'
netw = 'CE4020ET123'
outdir = os.path.join(project_dir, 'gwpopulation', 'BBH', 'gwpop_analysis_results', 
                      netw, 'run06_original_cov_mf_from_opt_SNR_injections')

In [32]:
#result = bilby.result.read_in_result(filename=os.path.join(outdir, 'Asharp-study-gwpop_result.hdf5'))
result = bilby.result.read_in_result(filename=os.path.join(outdir, 'CE4020ET123_result.hdf5'))

In [33]:
true_params = {'gamma':1.8032, 'kappa':5.3023, 'z_peak':1.8362}

In [34]:
#result.plot_corner(parameters=list(true_params.keys()), outdir=outdir)
#filename = 'Asharp-study-gwpop_corner_with_truths.png'
filename = 'CE4020ET123_corner_with_truths.png'
result.plot_corner(parameters=true_params, outdir=outdir, quantiles=(0.05, 0.95), filename=os.path.join(outdir, filename))

<Figure size 760x760 with 9 Axes>

In [None]:
# ## Load posteriors

# In[3]:


posteriors = pd.read_pickle(os.path.join(project_dir, 'gwpopulation', 'BBH', 'CE4020ET123_CoBA10_2_PLP_z_MD_zmax_10_lmrd_22_no_spins_z_posteriors_1500_events.pkl'))
#posteriors = pd.read_pickle(os.path.join(project_dir, 'gwpopulation', 'BBH', 
#                        'LHI_Asharp_1_PLP_z_MD_zmax_6_lmrd_22_no_spins_z_posteriors_499_events.pkl'))


# ## Load injections

# In[4]:


import dill

with open(os.path.join(project_dir, 'gwpopulation', 'BBH', 
    'CE4020ET123_CoBA10_SNR_2_pop_PLP_spin_prec_z_MD_zmax_10_lmrd_22_corrected_td_detected_injs_mf_SNR_1M_points.pkl'), "rb") as ff:
    injections = dill.load(ff)

#with open(os.path.join(project_dir, 'gwpopulation', 'BBH', 
#    'LHI_Asharp_SNR_1_pop_PLP_spin_prec_z_MD_zmax_6_lmrd_22_corrected_td_detected_injs_mf_SNR_1M_points.pkl'), "rb") as ff:
#    injections = dill.load(ff)


# ## Define models and likelihood

# In[5]:


model = NonCachingModel(
    model_functions=[gwpop.models.redshift.MadauDickinsonRedshift(cosmo_model="Planck18", z_max=8)],
    #model_functions=[gwpop.models.redshift.PowerLawRedshift(z_max=8)],
)

vt = gwpop.vt.ResamplingVT(model=model, data=injections, n_events=len(posteriors))

likelihood = gwpop.hyperpe.HyperparameterLikelihood(
    posteriors=posteriors,
    hyper_prior=model,
    selection_function=vt,
)

priors = PriorDict()
priors['gamma'] = Uniform(minimum=0, maximum=5, latex_label="$\\gamma$")
priors['kappa'] = Uniform(minimum=0, maximum=20, latex_label="$\\kappa$")
priors['z_peak'] = Uniform(minimum=0.5, maximum=4, latex_label="$z_{peak}$")

parameters = priors.sample()
likelihood.parameters.update(parameters)

In [6]:
likelihood.ln_likelihood_and_variance()

(Array(10168.56957079, dtype=float64), Array(2.16934084, dtype=float64))

In [7]:
func = jax.jit(likelihood.generate_extra_statistics)
#func = jax.jit(likelihood.ln_likelihood_and_variance)

In [11]:
test_full_posterior = pd.DataFrame(
    [func(parameters) for parameters in result.posterior.to_dict(orient="records")[:30]]
).astype(float)

In [12]:
test_full_posterior.describe()

Unnamed: 0,gamma,kappa,ln_bf_0,ln_bf_1,ln_bf_10,ln_bf_100,ln_bf_1000,ln_bf_1001,ln_bf_1002,ln_bf_1003,...,var_992,var_993,var_994,var_995,var_996,var_997,var_998,var_999,variance,z_peak
count,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0,...,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0,30.0
mean,2.03141,5.840315,5.85524,4.78745,7.573994,6.829806,7.621988,7.657583,7.483845,7.61989,...,6e-06,1.488449e-08,0.000123,4e-06,2.623074e-07,1.629704e-07,0.000108,1.062879e-06,0.156281,1.675543
std,0.856416,0.578233,0.106829,0.138769,0.078003,0.081321,0.055896,0.064892,0.062475,0.07666,...,2e-06,1.435541e-08,2.3e-05,1e-06,1.696958e-07,1.228433e-07,2.9e-05,5.567972e-07,0.038967,0.361045
min,0.689909,4.809981,5.670565,4.554714,7.444769,6.70161,7.542046,7.564413,7.380849,7.497371,...,4e-06,2.389196e-10,8.9e-05,2e-06,7.255203e-08,3.216687e-08,6.5e-05,3.940353e-07,0.100688,1.143314
25%,1.264477,5.331091,5.75909,4.706936,7.529056,6.760402,7.581869,7.605919,7.42828,7.577019,...,5e-06,3.884931e-09,0.000107,3e-06,1.233467e-07,6.25871e-08,8.4e-05,5.875182e-07,0.122329,1.38278
50%,2.026869,5.907591,5.889743,4.795774,7.568749,6.836537,7.601243,7.638857,7.481026,7.603232,...,6e-06,8.528385e-09,0.000121,4e-06,1.94975e-07,1.08514e-07,0.000107,1.001532e-06,0.150357,1.662818
75%,2.747392,6.207635,5.940037,4.872676,7.626938,6.902603,7.679225,7.711825,7.525335,7.665619,...,8e-06,2.046306e-08,0.000132,5e-06,4.400225e-07,2.707802e-07,0.00013,1.461471e-06,0.196289,1.904933
max,3.722217,6.798804,5.995183,5.037301,7.713021,6.953098,7.706873,7.764884,7.582507,7.756766,...,1e-05,4.925773e-08,0.000163,6e-06,6.071621e-07,4.296557e-07,0.000159,2.289514e-06,0.217839,2.35466


In [13]:
test_full_posterior[result.search_parameter_keys + ["log_likelihood", "variance"]].corr()

Unnamed: 0,gamma,kappa,z_peak,log_likelihood,variance
gamma,1.0,0.67871,-0.90219,0.396908,0.166507
kappa,0.67871,1.0,-0.3633,0.218216,0.339513
z_peak,-0.90219,-0.3633,1.0,-0.457952,-0.26548
log_likelihood,0.396908,0.218216,-0.457952,1.0,0.116619
variance,0.166507,0.339513,-0.26548,0.116619,1.0


In [18]:
pd.plotting.scatter_matrix(
    test_full_posterior[["gamma", "kappa", "z_peak", "log_likelihood", "variance"]],
    alpha=0.1,
)
plt.show()

  plt.show()


In [20]:
plt.scatter(test_full_posterior['z_peak'], test_full_posterior['variance'])
plt.show()

  plt.show()


In [14]:
#full_posterior = pd.DataFrame(
#    [func(parameters) for parameters in result.posterior.to_dict(orient="records")]
#).astype(float)
#full_posterior.describe()