In [None]:
import os

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

import glob

import bilby
import corner
import h5py
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np
from pesummary.io import read as pe_read
import seaborn as sns

from thesis_utils.gw import get_cbc_parameter_labels
from thesis_utils.plotting import set_plotting, save_figure, get_default_figsize

from utils import EVENTS

sns.set_palette("colorblind")
set_plotting()

plt.rcParams["axes.grid"] = False

In [None]:
event = "GW150914"
label = "cal_reweight"
plot_samples = False

In [None]:
gwtc1_prior_samples = {}
for event in EVENTS:
    gwtc1_file = f"../gwtc-1_sample_release/{event}_GWTC-1.hdf5"
    with h5py.File(gwtc1_file, "r") as f:
        gwtc1_prior_samples[event] = f["prior"][()]

In [None]:
bilby_ln_prior_values = {}
li_ln_prior_values = {}

bilby_priors = {}
li_priors = {}

for event in EVENTS:
    print(event)
    try:
        lalinf_prior = bilby.gw.prior.CBCPriorDict(
            filename=f"../analysis/original_priors/prior_files/{event}.prior"
        )
    except OSError:
        continue
    Pv2_file = glob.glob(
        f"../analysis/IMRPhenomXPHM/outdir_nessai_gwtc_1_{event}_{label}_*/final_result/*.hdf5"
    )[0]
    Pv2_result = bilby.core.result.read_in_result(Pv2_file)
    bilby_prior = Pv2_result.priors.copy()

    mass_keys = {"chirp_mass", "mass_ratio", "mass_1", "mass_2", "total_mass"}
    remove_bilby = set(bilby_prior.keys()) - mass_keys
    remove_lalinf = set(lalinf_prior.keys()) - mass_keys

    for k in remove_bilby:
        bilby_prior.pop(k)

    for k in remove_lalinf:
        lalinf_prior.pop(k)

    for prior in [bilby_prior, lalinf_prior]:
        if "total_mass" not in prior:
            prior["total_mass"] = bilby.core.prior.Constraint(
                1, 1000, name="total_mass"
            )

    bilby_priors[event] = bilby_prior
    li_priors[event] = lalinf_prior

    # bilby_ln_prior_values[event] = bilby_prior.ln_prob(samples, axis=0).reshape(*m1.shape)
    # li_ln_prior_values[event] = lalinf_prior.ln_prob(samples, axis=0).reshape(*m1.shape)

gwtc1_chirp_mass = bilby.gw.conversion.component_masses_to_chirp_mass(
    gwtc1_samples["m1_detector_frame_Msun"],
    gwtc1_samples["m2_detector_frame_Msun"]
)
gwtc1_mass_ratio = bilby.gw.conversion.component_masses_to_mass_ratio(
    gwtc1_samples["m1_detector_frame_Msun"],
    gwtc1_samples["m2_detector_frame_Msun"]
)

In [None]:
mplot_min = 1
mplot_max = 500
n = 1000
m1_vec = np.linspace(mplot_min, mplot_max, n, endpoint=True)
m2_vec = np.linspace(mplot_min, 210, n, endpoint=True)
m1, m2 = np.meshgrid(m1_vec, m1_vec)

In [None]:
samples = dict(mass_1=m1.flatten(), mass_2=m2.flatten())
samples["mass_ratio"] = samples["mass_2"] / samples["mass_1"]
samples["chirp_mass"] = bilby.gw.conversion.component_masses_to_chirp_mass(
    samples["mass_1"], samples["mass_2"]
)
samples["total_mass"] = bilby.gw.conversion.component_masses_to_total_mass(
    samples["mass_1"], samples["mass_2"]
)

In [None]:
events_to_include = {"GW170729", "GW170809"}
events_to_include = EVENTS

In [None]:
m1_samples = gwtc1_prior_samples[event]["m1_detector_frame_Msun"]

In [None]:
mass_ranges = {
    "GW150914": {"mass_1": [1, 270], "mass_2": [1, 85]},
    "GW151012": {"mass_1": [1, 200], "mass_2": [1, 45]},
    "GW151226": {"mass_1": [1, 75], "mass_2": [1, 15]},
    "GW170104": {"mass_1": [1, 250], "mass_2": [1, 65]},
    "GW170608": {"mass_1": [1, 65], "mass_2": [1, 15]},
    "GW170729": {"mass_1": [1, 550], "mass_2": [1, 220]},
    "GW170809": {"mass_1": [1, 270], "mass_2": [1, 55]},
    "GW170814": {"mass_1": [1, 250], "mass_2": [1, 65]},
    "GW170818": {"mass_1": [1, 300], "mass_2": [1, 65]},
    "GW170823": {"mass_1": [1, 400], "mass_2": [1, 75]},
}

In [None]:
legend_elements = [
    Line2D([0], [0], color="C0", ls="-", label="GWTC-1"),
    Line2D([0], [0], color="C1", ls="--", label="GWTC-2.1"),
]

In [None]:
colours = sns.color_palette("colorblind", n_colors=3)
figsize = get_default_figsize()
figsize[1] *= 2
print(colours)
fig, axs = plt.subplots(5, 2, figsize=figsize)
axs = axs.ravel()
# axs = [axs]
n = 200
for i, event in enumerate(events_to_include):
    m1_vec = np.linspace(*mass_ranges[event]["mass_1"], n, endpoint=True)
    m2_vec = np.linspace(*mass_ranges[event]["mass_2"], n, endpoint=True)
    m1, m2 = np.meshgrid(m1_vec, m2_vec)
    samples = dict(mass_1=m1.flatten(), mass_2=m2.flatten())
    samples["mass_ratio"] = samples["mass_2"] / samples["mass_1"]
    samples["chirp_mass"] = bilby.gw.conversion.component_masses_to_chirp_mass(
        samples["mass_1"], samples["mass_2"]
    )
    samples["total_mass"] = bilby.gw.conversion.component_masses_to_total_mass(
        samples["mass_1"], samples["mass_2"]
    )
    li_ln_prior_values = li_priors[event].ln_prob(samples, axis=0).reshape(n, n)
    bilby_ln_prior_values = bilby_priors[event].ln_prob(samples, axis=0).reshape(n, n)

    axs[i].contour(m1, m2, np.isfinite(li_ln_prior_values), [1.0], colors=["C0"])
    axs[i].contour(
        m1,
        m2,
        np.isfinite(bilby_ln_prior_values),
        [1.0],
        colors=["C1"],
        linestyles="--",
    )

    if plot_samples:
        m1_samples = gwtc1_prior_samples[event]["m1_detector_frame_Msun"]
        m2_samples = gwtc1_prior_samples[event]["m2_detector_frame_Msun"]
        corner.hist2d(
            m1_samples,
            m2_samples,
            plot_density=False,
            plot_datapoints=False,
            no_fill_contours=True,
            smooth=0.9,
            ax=axs[i],
            color="C2",
            bins=100,
            levels=[0.99, 0.5],
        )
        axs[i].set_xlim(0.9 * m1_samples.min(), 1.1 * m1_samples.max())
        axs[i].set_ylim(0.9 * m2_samples.min(), 1.1 * m2_samples.max())

    else:
        axs[i].set_xlim(*mass_ranges[event]["mass_1"])
        axs[i].set_ylim(*mass_ranges[event]["mass_2"])

    if (i % 2) == 0:
        axs[i].set_ylabel(get_cbc_parameter_labels("mass_2", units=True))
    # axs[0].legend()
    axs[i].text(0.9, 0.8, event, transform=axs[i].transAxes, ha="right")

axs[-2].set_xlabel(get_cbc_parameter_labels("mass_1", units=True))
axs[-1].set_xlabel(get_cbc_parameter_labels("mass_1", units=True))

fig.legend(
    handles=legend_elements,
    handlelength=2.0,
    ncol=2,
    bbox_transform=fig.transFigure,
    bbox_to_anchor=(0.5, -0.02),
    loc="lower center",
    # ha="center",
)

plt.tight_layout()
plt.show()
fig.savefig(f"figures/prior_comparison/all.png")
save_figure(fig, "all_priors", path="figures/prior_comparison/")
# plt.show()

In [None]:
lalinf_prior = bilby.gw.prior.CBCPriorDict(
    filename="prior_files/mock_lalinference_priors/GW150914.prior"
)
bilby_prior = Pv2_result.priors.copy()

mass_keys = {"chirp_mass", "mass_ratio", "mass_1", "mass_2", "total_mass"}
remove_bilby = set(bilby_prior.keys()) - mass_keys
remove_lalinf = set(lalinf_prior.keys()) - mass_keys

for k in remove_bilby:
    bilby_prior.pop(k)

for k in remove_lalinf:
    lalinf_prior.pop(k)
# lalinf_prior["chirp_mass"] = bilby.core.prior.Constraint(minimum=0.0, maximum=1000, name="chirp_mass")

In [None]:
mplot_min = 5
mplot_max = 160
n = 500
m1_vec = np.linspace(mplot_min, mplot_max, n, endpoint=True)
m2_vec = np.linspace(5, 81, n, endpoint=True)
m1, m2 = np.meshgrid(m1_vec, m2_vec)

In [None]:
samples = dict(mass_1=m1.flatten(), mass_2=m2.flatten())
samples["mass_ratio"] = samples["mass_2"] / samples["mass_1"]
samples["chirp_mass"] = bilby.gw.conversion.component_masses_to_chirp_mass(
    samples["mass_1"], samples["mass_2"]
)

In [None]:
lalinf_prior

In [None]:
bilby_log_prior = bilby_prior.ln_prob(samples, axis=0).reshape(*m1.shape)
li_log_prior = lalinf_prior.ln_prob(samples, axis=0).reshape(*m1.shape)

In [None]:
fig, axs = plt.subplots(1, 1)
axs = [axs]
axs[0].contour(m1, m2, np.isfinite(li_log_prior), [1.0], colors="C0")
axs[0].contour(
    m1, m2, np.isfinite(bilby_log_prior), [1.0], colors="C0", linestyles="--"
)
axs[0].set_xlabel(get_cbc_parameter_labels("mass_1", units=True))
axs[0].set_ylabel(get_cbc_parameter_labels("mass_2", units=True))
plt.show()

In [None]:
fig = plt.figure()
plt.contourf(m1, m2, bilby_log_prior)
plt.show()

In [None]:
fig, axs = plt.subplots(1, 1)
corner.hist2d(
    gwtc1_samples["m1_detector_frame_Msun"],
    gwtc1_samples["m2_detector_frame_Msun"],
    ax=axs,
    bins=50,
    color="C0",
    plot_density=False,
    label="lalinference",
    smooth=0.9,
    plot_datapoints=False,
    # range=[5, 100],
)
corner.hist2d(
    bilby_samples["mass_1"],
    bilby_samples["mass_2"],
    color="C1",
    bins=50,
    ax=axs,
    plot_density=False,
    plot_datapoints=False,
    label="bilby",
    smooth=0.9,
)
# axs.axvline(39.4, color="k", lw=1.5)
# axs.axhline(30.9, color="k", lw=1.5)
axs.set_xlabel(r"$m_1$")
axs.set_ylabel(r"$m_2$")
axs.set_xlim(5, 150)
axs.set_ylim(5, 85)
plt.show()

In [None]:
fig, axs = plt.subplots(1, 1)
corner.hist2d(
    gwtc1_chirp_mass,
    gwtc1_mass_ratio,
    ax=axs,
    bins=50,
    # levels=[1.0],
    color="C0",
    plot_density=False,
    label="lalinference",
    smooth=0.9,
    # range=[5, 100],
)
corner.hist2d(
    bilby_samples["chirp_mass"],
    bilby_samples["mass_ratio"],
    color="C1",
    bins=50,
    # levels=[1.0],
    ax=axs,
    plot_density=False,
    label="bilby",
    smooth=0.9,
)
axs.set_xlabel(r"$\mathcal{M}$")
axs.set_ylabel(r"$q$")
axs.set_xlim(10, 50)
axs.set_ylim(0.0, 1.0)
plt.show()