In [None]:
import os

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import glob

import bilby
import corner
import h5py
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import numpy.lib.recfunctions as rfn
from scipy import stats
import seaborn as sns

from thesis_utils.gw import get_cbc_parameter_labels
from thesis_utils.plotting import set_plotting, save_figure, get_default_figsize

set_plotting()

In [None]:
nessai_result_files = {
    "low_spin": {
        # "reweight": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_reweight_cal/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        # "sample": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_sample_calibration/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        # "sample-fixed": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_sample_calibration_fix_sky/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        # "reweight-fixed": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_reweight_cal_fix_sky/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "sample-l1l2": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_sample_calibration_fix_sky_l1_l2/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "reweight-l1l2": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_reweight_cal_fix_sky_l1_l2/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "sample-priors": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_sample_calibration_fix_sky_priors/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "reweight-priors": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_low_spin_reweight_cal_fix_sky_priors/final_result/nessai_gwtc_1_GW170817_low_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
    },
    "high_spin": {
        # "reweight": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_reweight_cal/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        # "sample": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_sample_calibration/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        # "sample-fixed": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_sample_calibration_fix_sky/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        # "reweight-fixed": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_reweight_cal_fix_sky/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "sample-l1l2": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_sample_calibration_fix_sky_l1_l2/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "reweight-l1l2": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_reweight_cal_fix_sky_l1_l2/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "sample-priors": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_sample_calibration_fix_sky_priors/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
        "reweight-priors": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW170817/outdir_nessai_gwtc_1_GW170817_high_spin_reweight_cal_fix_sky_priors/final_result/nessai_gwtc_1_GW170817_high_spin_sample_cal_data0_1187008882-4457_analysis_H1L1V1_merge_result.hdf5",
    },
}

In [None]:
nessai_results = {}
for key, files in nessai_result_files.items():
    nessai_results[key] = {}
    for label, rf in files.items():
        try:
            nessai_results[key][label] = bilby.core.result.read_in_result(rf)
        except OSError:
            print(f"Skipping: {rf}")
            pass

In [None]:
gwtc1_results = {}
with h5py.File("../gwtc-1_sample_release/GW170817_GWTC-1.hdf5", "r") as f:
    gwtc1_results["low_spin"] = f["IMRPhenomPv2NRT_lowSpin_posterior"][()]
    gwtc1_results["low_spin"] = f["IMRPhenomPv2NRT_lowSpin_posterior"][()]

for key in gwtc1_results:
    s = gwtc1_results[key]

    new_s = rfn.append_fields(
        s,
        names=[
            "theta_jn",
            "lambda_tilde",
            "delta_lambda_tilde",
            "mass_ratio",
            "chirp_mass",
        ],
        data=[
            np.arccos(gwtc1_results[key]["costheta_jn"]),
            bilby.gw.conversion.lambda_1_lambda_2_to_lambda_tilde(
                s["lambda1"],
                s["lambda2"],
                s["m1_detector_frame_Msun"],
                s["m2_detector_frame_Msun"],
            ),
            bilby.gw.conversion.lambda_1_lambda_2_to_delta_lambda_tilde(
                s["lambda1"],
                s["lambda2"],
                s["m1_detector_frame_Msun"],
                s["m2_detector_frame_Msun"],
            ),
            bilby.gw.conversion.component_masses_to_mass_ratio(
                s["m1_detector_frame_Msun"],
                s["m2_detector_frame_Msun"],
            ),
            bilby.gw.conversion.component_masses_to_chirp_mass(
                s["m1_detector_frame_Msun"],
                s["m2_detector_frame_Msun"],
            ),
        ],
    )
    gwtc1_results[key] = new_s

In [None]:
gwtc1_results["low_spin"].dtype.names

In [None]:
gwtc1_mapping = dict(
    mass_1="m1_detector_frame_Msun",
    mass_2="m2_detector_frame_Msun",
    lambda_1="lambda1",
    lambda_2="lambda2",
    ra="right_ascension",
    dec="declination",
    cos_theta_jn="costheta_jn",
    a_1="spin1",
    a_2="spin2",
)

In [None]:
keys = [
    "mass_1",
    "mass_2",
    "lambda_tilde",
    "delta_lambda_tilde",
    "a_1",
    "a_2",
    "ra",
    "dec",
]

In [None]:
spin = "high_spin"
event = "GW170817"

In [None]:
fig, ax = plt.subplots()

legend_handles = []
for i, (label, result) in enumerate(nessai_results[spin].items()):
    corner.hist2d(
        result.posterior["lambda_tilde"].to_numpy(),
        result.posterior["delta_lambda_tilde"].to_numpy(),
        color=f"C{i}",
        ax=ax,
        no_fill_contours=True,
        plot_datapoints=False,
        plot_density=False,
        smooth=1.0,
        levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0)),
    )
    legend_handles.append(Line2D([0], [0], c=f"C{i}", label=label))

ax.set_xlim(0, 2000)
ax.set_xlim(-250, 2000)
ax.set_ylim(-500, 500)
ax.legend(handles=legend_handles)

In [None]:
parameters = ["mass_1", "mass_2", "lambda_tilde", "delta_lambda_tilde", "a_1", "a_2"]

ndim = len(parameters)
figsize = 0.9 * np.array([10.2756, 6.10236])
fig, axs = plt.subplots(
    ndim,
    ndim + 2,
    figsize=figsize,
    # layout="constrained",
    gridspec_kw=dict(
        hspace=0.05,
        wspace=0.05,
        #     hpad=1,
        #     wpad=1,
    ),
)
# fig.subplots_adjust(left=0.05, right=0.05, top=0.93, bottom=0.00, wspace=0.05, hspace=0.05)

legend_labels = {
    "gwtc1": "GWTC-1",
    "sample-priors": r"\texttt{nessai} - Calibration sampling",
    "reweight-priors": r"\texttt{nessai} - Calibration reweighting",
}

legend_handles = []

labels = get_cbc_parameter_labels(parameters, units=True)
bins = 32

linestyles = dict(
    gwtc1="--",
)
colours = {
    "gwtc1": "k",
    "sample-priors": "C0",
    "reweight-priors": "C1",
    "samples-l1l2": "C2",
    "reweight-l1l2": "C3",
}

x1range = None
x2range = None

for key in ["gwtc1", "sample-priors", "reweight-priors"]:
    colour = colours.get(key)
    ls = linestyles.get(key, "-")

    legend_handles.append(
        Line2D([0], [0], ls=ls, color=colour, label=legend_labels.get(key))
    )

    hist_kwargs = dict(
        density=True,
        histtype="step",
        color=colour,
        bins=32,
        linestyle=ls,
    )

    hist2d_kwargs = dict(
        no_fill_contours=True,
        plot_datapoints=False,
        plot_density=False,
        smooth=2.0,
        color=colour,
        bins=32,
        contour_kwargs=dict(linestyles=[ls]),
        levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0)),
    )

    if key == "gwtc1":
        gwtc1_parameters = [gwtc1_mapping.get(p, p) for p in parameters]
        x1 = rfn.structured_to_unstructured(
            rfn.repack_fields(gwtc1_results["low_spin"][gwtc1_parameters])
        )
        x2 = rfn.structured_to_unstructured(
            rfn.repack_fields(gwtc1_results["high_spin"][gwtc1_parameters])
        )
    else:
        x1 = nessai_results["low_spin"][key].posterior[parameters].to_numpy()
        x2 = nessai_results["high_spin"][key].posterior[parameters].to_numpy()

    if x1range is None:
        x1range = np.array([x1.min(axis=0), x1.max(axis=0)]).T
    else:
        x1range = np.array(
            [
                np.min([x1range[:, 0], x1.min(axis=0)], axis=0),
                np.max([x1range[:, 1], x1.max(axis=0)], axis=0),
            ]
        ).T

    if x2range is None:
        x2range = np.array([x2.min(axis=0), x2.max(axis=0)]).T
    else:
        x2range = np.array(
            [
                np.min([x2range[:, 0], x2.min(axis=0)], axis=0),
                np.max([x2range[:, 1], x2.max(axis=0)], axis=0),
            ]
        ).T

    print(x1range[0])

    for vdim in range(ndim):
        for hdim in range(ndim + 2):
            axs[vdim, hdim].tick_params(labelbottom=False, labelleft=False)
            # axs[vdim, hdim].set_box_aspect(1.0)
            if vdim == (hdim - 1):
                axs[vdim, hdim].axis("off")
            elif vdim == hdim:
                axs[vdim, hdim].hist(x1[:, vdim], **hist_kwargs)
                axs[vdim, hdim].set_yticks([])
                axs[vdim, hdim].set_xlim(x1range[vdim])
            elif vdim > hdim:
                corner.hist2d(
                    x1[:, hdim],
                    x1[:, vdim],
                    ax=axs[vdim, hdim],
                    **hist2d_kwargs,
                )
                axs[vdim, hdim].set_xlim(x1range[hdim])
                axs[vdim, hdim].set_ylim(x1range[vdim])
            elif vdim == (hdim - 2):
                axs[vdim, hdim].hist(x2[:, vdim], **hist_kwargs)
                axs[vdim, hdim].set_yticks([])
                axs[vdim, hdim].set_xlim(x2range[vdim])
            elif vdim < hdim:
                corner.hist2d(
                    x2[:, hdim - 2],
                    x2[:, vdim],
                    ax=axs[vdim, hdim],
                    **hist2d_kwargs,
                )
                axs[vdim, hdim].set_xlim(x2range[hdim - 2])
                axs[vdim, hdim].set_ylim(x2range[vdim])

for i in range(ndim):
    if i > 0:
        axs[i, 0].tick_params(labelleft=True, rotation=45, pad=8)
        axs[i, 0].set_ylabel(labels[i])
        for label in axs[i, 0].get_yticklabels():
            label.set_rotation_mode("anchor")
            label.set_rotation(45)
            label.set_va("center")
    if i < (ndim - 1):
        axs[i, -1].tick_params(labelright=True, rotation=45, pad=8)
        axs[i, -1].yaxis.set_label_position("right")
        axs[i, -1].set_ylabel(labels[i])
        for label in axs[i, -1].get_yticklabels():
            label.set_rotation_mode("anchor")
            label.set_rotation(45)
            label.set_va("center")

    axs[-1, i].tick_params(labelbottom=True, rotation=45)
    for label in axs[-1, i].get_xticklabels():
        label.set_rotation_mode("anchor")
        label.set_rotation(45)
        label.set_ha("right")

    axs[0, i + 2].tick_params(labeltop=True, rotation=45)
    for label in axs[0, i + 2].get_xticklabels():
        label.set_rotation_mode("anchor")
        label.set_rotation(45)
        label.set_ha("left")

    axs[-1, i].set_xlabel(labels[i])

    axs[0, i + 2].xaxis.set_label_position("top")
    axs[0, i + 2].set_xlabel(labels[i])


# Based on: https://stackoverflow.com/questions/60807792/arrows-between-matplotlib-subplots
transFigure = fig.transFigure.inverted()
coord1 = transFigure.transform(axs[0, 1].transAxes.transform((0.0, 1.0)))
coord2 = transFigure.transform(axs[-1, -2].transAxes.transform((1.0, 0.0)))
line = Line2D(
    (coord1[0], coord2[0]),  # xdata
    (coord1[1], coord2[1]),  # ydata
    transform=fig.transFigure,
    color="black",
)
fig.lines.append(line)


base = (coord2[0] - coord1[0]) * figsize[0]
height = (coord1[1] - coord2[1]) * figsize[1]
rotation = (180 / np.pi) * np.arctan(height / base)


nth = 2
fractions = 0.5 * (1 / ndim) + np.arange(0, ndim, nth * 1) / (ndim)

for fraction in fractions:
    fig.text(
        fraction * (coord2[0] - coord1[0]) + coord1[0],
        coord1[1] - fraction * (coord1[1] - coord2[1]),
        "High spin",
        ha="center",
        va="bottom",
        rotation=-rotation,
        rotation_mode="anchor",
    )
    fig.text(
        fraction * (coord2[0] - coord1[0]) + coord1[0],
        coord1[1] - fraction * (coord1[1] - coord2[1]),
        "Low spin",
        ha="center",
        va="top",
        rotation=-rotation,
        rotation_mode="anchor",
    )

fig.legend(
    handles=legend_handles,
    ncol=len(legend_handles),
    loc="center",
    bbox_to_anchor=(0.5, -0.025)
    # borderaxespad=1,
)
plt.show()
save_figure(fig, "GW170817_posteriors", "figures/GW170817/")

In [None]:
nessai_indices = {}
for key, results in nessai_results.items():
    nessai_indices[key] = {}
    for label, result in results.items():
        nessai_indices[key][label] = []
        files = glob.glob(result.outdir + "/*_nessai/result.hdf5")
        print(files)
        for rf in files:
            with h5py.File(rf) as f:
                nessai_indices[key][label].append(f["insertion_indices"][:])

In [None]:
colours = {
    "low_spin": sns.color_palette("Oranges", n_colors=3)[1:],
    "high_spin": sns.color_palette("Blues", n_colors=3)[1:],
}
linestyles = {
    "low_spin": "--",
    "high_spin": "-",
}

for analysis in ["sample-l1l2", "reweight-l1l2"]:
    nlive = 1000
    indices = {
        "low_spin": nessai_indices["low_spin"][analysis],
        "high_spin": nessai_indices["high_spin"][analysis],
    }

    figsize = 0.8 * get_default_figsize()
    fig = plt.figure(figsize=figsize)
    # plt.axhline(0, color="grey")
    x = np.arange(0, nlive, 1)
    analytic = x / x[-1]
    n = np.mean([len(idx) for idx in indices["low_spin"] + indices["high_spin"]])
    for key, data in indices.items():
        if not indices:
            continue
        c = colours[key]
        for i, idx in enumerate(data):
            _, counts = np.unique(idx, return_counts=True)
            estimated = np.cumsum(counts) / len(idx)
            plt.plot(analytic - estimated, ls=linestyles.get(key), c=c[i])

    for ci in [0.5, 0.95, 0.997]:
        bound = (1 - ci) / 2
        bound_values = stats.binom.ppf(1 - bound, n, analytic) / n
        lower = bound_values - analytic
        upper = analytic - bound_values
        upper[0] = 0
        upper[-1] = 0
        lower[0] = 0
        lower[-1] = 0

        plt.fill_between(x, lower, upper, color="grey", alpha=0.2)

    legend_elements = [
        Line2D([0], [0], ls="--", c=colours["low_spin"][-1], label="Low spin"),
        Line2D([0], [0], ls="-", c=colours["high_spin"][-1], label="High spin"),
    ]
    plt.legend(handles=legend_elements)

    plt.xlim(0, 1000)
    plt.xlabel("Insertion index")
    plt.ylabel(r"Analytic CMF - Empirical CMF")
    plt.show()
    save_figure(fig, f"GW170817_insertion_indices_{analysis}", "figures/GW170817/")