In [None]:
import os

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

import bilby
import corner
import glob
import h5py
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import numpy.lib.recfunctions as rfn
import tqdm
import seaborn as sns
from scipy import stats

from importlib import reload

import thesis_utils
from thesis_utils.gw import get_cbc_parameter_labels

reload(thesis_utils.gw.utils)
from thesis_utils.plotting import (
    set_plotting,
    save_figure,
    get_default_corner_kwargs,
    get_default_figsize,
)

set_plotting()

In [None]:
discovery = {}
with h5py.File("GW190425/discovery_posterior_samples.h5", "r") as f:
    print(f["PhenomPNRT-LS"].keys())
    discovery["low_spin"] = f["PhenomPNRT-LS/posterior_samples"][()]
    discovery["high_spin"] = f["PhenomPNRT-HS/posterior_samples"][()]

In [None]:
gwtc2 = {}
with h5py.File(
    "/home/michael.williams/git_repos/nessai-gwtc-1/comparison/GW190425/GW190425/GW190425.h5",
    "r",
) as f:
    print(f.keys())
    gwtc2["low_spin"] = f["C01:IMRPhenomPv2_NRTidal-LS/posterior_samples"][()]
    gwtc2["high_spin"] = f["C01:IMRPhenomPv2_NRTidal-HS/posterior_samples"][()]

In [None]:
gwtc2p1 = {}
with h5py.File(
    "/home/michael.williams/git_repos/nessai-gwtc-1/gwtc-2.1_sample_release/IGWN-GWTC2p1-v2-GW190425_081805_PEDataRelease_mixed_nocosmo.h5",
    "r",
) as f:
    print(f["C01:IMRPhenomPv2_NRTidal:HighSpin"].keys())
    gwtc2p1["low_spin"] = f["C01:IMRPhenomPv2_NRTidal:LowSpin/posterior_samples"][()]
    gwtc2p1["high_spin"] = f["C01:IMRPhenomPv2_NRTidal:HighSpin/posterior_samples"][()]

In [None]:
low_spin_result_files = {
    # "gwtc2p1": "/home/michael.williams/git_repos/nessai-gwtc-1/comparison/GW190425/ProdF4_data0_1240215503-017_analysis_L1V1_dynesty_merge_result.json",
    "sample": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_low_spin/final_result/GW190425_low_spin_data0_1240215503-017_analysis_L1V1_merge_result.hdf5",
    "reweight": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_low_spin_reweight/final_result/GW190425_low_spin_data0_1240215503-017_analysis_L1V1_merge_result.hdf5",
}

In [None]:
high_spin_result_files = {
    # "gwtc2p1": "/home/michael.williams/git_repos/nessai-gwtc-1/comparison/GW190425/ProdF5_data0_1240215503-017_analysis_L1V1_dynesty_merge_result.json",
    "sample": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_high_spin/final_result/GW190425_high_spin_data0_1240215503-017_analysis_L1V1_merge_result.hdf5",
    "reweight": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_high_spin_reweight/final_result/GW190425_high_spin_data0_1240215503-017_analysis_L1V1_merge_result.hdf5",
}

In [None]:
high_spin_results = {}
for key, rf in tqdm.tqdm(high_spin_result_files.items()):
    result = bilby.core.result.read_in_result(rf)
    # result.posterior = bilby.gw.conversion.generate_spin_parameters(result.posterior)
    # result.posterior = bilby.gw.conversion.generate_source_frame_parameters(result.posterior)
    high_spin_results[key] = result

In [None]:
low_spin_results = {}
for key, rf in tqdm.tqdm(low_spin_result_files.items()):
    result = bilby.core.result.read_in_result(rf)
    # result.posterior = bilby.gw.conversion.generate_all_bns_parameters(result.posterior)
    low_spin_results[key] = result

In [None]:
results = {
    "low_spin": low_spin_results,
    "high_spin": high_spin_results,
}

In [None]:
spin = "high_spin"

parameters = ["chirp_mass", "mass_ratio", "lambda_1", "lambda_2", "a_1", "a_2", "theta_jn"]
colours = {
    "gwtc2p1": "k",
    "sample": "C1",
    "reweight": "C0",
}
labels = {
    "gwtc2p1": "GWTC-2.1",
    "sample": r"\texttt{nessai} - Sample calibration",
    "reweight": r"\texttt{nessai} - Reweight calibration",
}

corner_kwargs = get_default_corner_kwargs()
corner_kwargs["plot_density"] = False
corner_kwargs["plot_datapoints"] = False
corner_kwargs["no_fill_contours"] = True
corner_kwargs["show_titles"] = False
corner_kwargs.pop("fill_contours")
corner_kwargs["levels"] = (1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0))
corner_kwargs["quantiles"] = []
corner_kwargs["smooth"] = 2.0
corner_kwargs["bins"] = 32
fig = None

legend_elements = []
corner_kwargs["color"] = "k"
corner_kwargs["hist_kwargs"]["color"] = "k"
corner_kwargs["hist_kwargs"]["ls"] = "--"
fig = corner.corner(
    rfn.structured_to_unstructured(discovery[spin][parameters]),
    fig=fig,
    labels=get_cbc_parameter_labels(parameters, units=True),
    labelpad=-0.1,
    contour_kwargs=dict(linestyles=["--"]),
    **corner_kwargs,
)
legend_elements.append(
    Line2D([0], [0], color="k", ls="--", label="Discovery")
)

corner_kwargs["hist_kwargs"]["ls"] = "-"
fig = corner.corner(
    rfn.structured_to_unstructured(gwtc2p1[spin][parameters]),
    fig=fig,
    labels=get_cbc_parameter_labels(parameters, units=True),
    labelpad=-0.1,
    contour_kwargs=dict(linestyles=["-"]),
    **corner_kwargs,
)
legend_elements.append(
    Line2D([0], [0], color="k", ls="-", label="GWTC-2.1")
)

corner_kwargs["hist_kwargs"]["ls"] = "-"
for i, (key, result) in enumerate(results[spin].items()):
    corner_kwargs["color"] = colours.get(key)
    corner_kwargs["hist_kwargs"]["color"] = colours.get(key)
    fig = corner.corner(
        result.posterior[parameters].to_numpy(),
        fig=fig,
        labels=get_cbc_parameter_labels(parameters, units=True),
        labelpad=-0.1,
        **corner_kwargs,
    )
    legend_elements.append(
        Line2D([0], [0], color=colours.get(key), ls="-", label=labels.get(key))
    )


fig.legend(
    handles=legend_elements,
    bbox_to_anchor=(1.0, 0.9),
    loc="upper right",
    fontsize=18,
)
plt.show()
save_figure(fig, f"GW190425_{spin}_posterior", "figures/GW190425/")

## Insertion indices

In [None]:
all_paths = {
    "reweight": {
        "low_spin": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_low_spin_reweight/result/GW190425_low_spin_data0_1240215503-017_analysis_L1V1_par*_nessai/result.hdf5",
        "high_spin": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_high_spin_reweight/result/GW190425_high_spin_data0_1240215503-017_analysis_L1V1_par*_nessai/result.hdf5",
    },
    "sample": {
        "low_spin": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_low_spin/result/GW190425_low_spin_data0_1240215503-017_analysis_L1V1_par*_nessai/result.hdf5",
        "high_spin": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/GW190425/outdir_GW190425_high_spin/result/GW190425_high_spin_data0_1240215503-017_analysis_L1V1_par*_nessai/result.hdf5",
    },
}

In [None]:
for overall_label, paths in all_paths.items():
    indices = {}
    for key, path in paths.items():
        nessai_result_files = glob.glob(path)
        indices[key] = []
        for rf in nessai_result_files:
            with h5py.File(rf, "r") as f:
                indices[key].append(f["insertion_indices"][()])

    colours = {
        "low_spin": sns.color_palette("Oranges", n_colors=5)[1:],
        "high_spin": sns.color_palette("Blues", n_colors=5)[1:],
    }
    linestyles = {
        "low_spin": "--",
        "high_spin": "-",
    }
    figsize = 0.8 * get_default_figsize()
    fig = plt.figure(figsize=figsize)
    # plt.axhline(0, color="grey")
    x = np.arange(0, 1000, 1)
    analytic = x / x[-1]
    n = np.mean([len(idx) for idx in indices["low_spin"] + indices["high_spin"]])
    for key, data in indices.items():
        c = colours[key]
        for i, idx in enumerate(data):
            _, counts = np.unique(idx, return_counts=True)
            estimated = np.cumsum(counts) / len(idx)
            plt.plot(analytic - estimated, ls=linestyles.get(key), c=c[i])

    for ci in [0.5, 0.95, 0.997]:
        bound = (1 - ci) / 2
        bound_values = stats.binom.ppf(1 - bound, n, analytic) / n
        lower = bound_values - analytic
        upper = analytic - bound_values
        upper[0] = 0
        upper[-1] = 0
        lower[0] = 0
        lower[-1] = 0

        plt.fill_between(x, lower, upper, color="grey", alpha=0.2)

    legend_elements = [
        Line2D([0], [0], ls="--", c=colours["low_spin"][-1], label="Low spin"),
        Line2D([0], [0], ls="-", c=colours["high_spin"][-1], label="High spin"),
    ]
    plt.legend(handles=legend_elements)

    plt.xlim(0, 1000)
    plt.xlabel("Insertion index")
    plt.ylabel(r"Analytic CMF - Empirical CMF")
    plt.show()
    save_figure(fig, f"GW190425_insertion_indices_{overall_label}", "figures/GW190425/")
    # fig.savefig("figures/GW190425/insertion_indices.png", dpi=200, bbox_inches="tight")

In [None]:
parameters = [
    "chirp_mass",
    "mass_ratio",
    "a_1",
    "a_2",
    "lambda_1",
    "lambda_2",
    "theta_jn",
]

ndim = len(parameters)
figsize = 0.9 * np.array([10.2756, 6.10236])
fig, axs = plt.subplots(
    ndim,
    ndim + 2,
    figsize=figsize,
    # layout="constrained",
    gridspec_kw=dict(
        hspace=0.05,
        wspace=0.05,
        #     hpad=1,
        #     wpad=1,
    ),
)
# fig.subplots_adjust(left=0.05, right=0.05, top=0.93, bottom=0.00, wspace=0.05, hspace=0.05)

legend_labels = {
    "discovery": "Discovery paper",
    "gwtc2p1": "GWTC-2.1",
    "sample": r"\texttt{nessai} - Calibration sampling",
    "reweight": r"\texttt{nessai} - Calibration reweighting",
}

legend_handles = []

labels = get_cbc_parameter_labels(parameters, units=True)
bins = 32

linestyles = dict(
    discovery="--",
)
colours = {
    "discovery": "k",
    "gwtc2p1": "k",
    "sample": "C0",
    "reweight": "C1",
}

for key in ["discovery", "gwtc2p1", "sample", "reweight"]:
    colour = colours.get(key)
    ls = linestyles.get(key, "-")

    legend_handles.append(
        Line2D([0], [0], ls=ls, color=colour, label=legend_labels.get(key))
    )

    hist_kwargs = dict(
        density=True,
        histtype="step",
        color=colour,
        bins=32,
        linestyle=ls,
    )

    hist2d_kwargs = dict(
        no_fill_contours=True,
        plot_datapoints=False,
        plot_density=False,
        smooth=2.0,
        color=colour,
        bins=32,
        contour_kwargs=dict(linestyles=[ls]),
        levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0)),
    )

    if key == "discovery":
        x1 = rfn.structured_to_unstructured(discovery["low_spin"][parameters])
        x2 = rfn.structured_to_unstructured(discovery["high_spin"][parameters])
    elif key == "gwtc2p1":
        x1 = rfn.structured_to_unstructured(gwtc2p1["low_spin"][parameters])
        x2 = rfn.structured_to_unstructured(gwtc2p1["high_spin"][parameters])
    else:
        x1 = results["low_spin"][key].posterior[parameters].to_numpy()
        x2 = results["high_spin"][key].posterior[parameters].to_numpy()

    for vdim in range(ndim):
        for hdim in range(ndim + 2):
            axs[vdim, hdim].tick_params(labelbottom=False, labelleft=False)
            # axs[vdim, hdim].set_box_aspect(1.0)
            if vdim == (hdim - 1):
                axs[vdim, hdim].axis("off")
            elif vdim == hdim:
                axs[vdim, hdim].hist(x1[:, vdim], **hist_kwargs)
                axs[vdim, hdim].set_yticks([])
            elif vdim > hdim:
                corner.hist2d(
                    x1[:, hdim],
                    x1[:, vdim],
                    ax=axs[vdim, hdim],
                    **hist2d_kwargs,
                )
            elif vdim == (hdim - 2):
                axs[vdim, hdim].hist(x2[:, vdim], **hist_kwargs)
                axs[vdim, hdim].set_yticks([])
            elif vdim < hdim:
                corner.hist2d(
                    x2[:, hdim - 2],
                    x2[:, vdim],
                    ax=axs[vdim, hdim],
                    **hist2d_kwargs,
                )

for i in range(ndim):
    if i > 0:
        axs[i, 0].tick_params(labelleft=True, rotation=45, pad=8)
        axs[i, 0].set_ylabel(labels[i])
        for label in axs[i, 0].get_yticklabels():
            label.set_rotation_mode("anchor")
            label.set_rotation(45)
            label.set_va("center")
    if i < (ndim - 1):
        axs[i, -1].tick_params(labelright=True, rotation=45, pad=8)
        axs[i, -1].yaxis.set_label_position("right")
        axs[i, -1].set_ylabel(labels[i])
        for label in axs[i, -1].get_yticklabels():
            label.set_rotation_mode("anchor")
            label.set_rotation(45)
            label.set_va("center")

    axs[-1, i].tick_params(labelbottom=True, rotation=45)
    for label in axs[-1, i].get_xticklabels():
        label.set_rotation_mode("anchor")
        label.set_rotation(45)
        label.set_ha("right")

    axs[0, i + 2].tick_params(labeltop=True, rotation=45)
    for label in axs[0, i + 2].get_xticklabels():
        label.set_rotation_mode("anchor")
        label.set_rotation(45)
        label.set_ha("left")

    axs[-1, i].set_xlabel(labels[i])

    axs[0, i + 2].xaxis.set_label_position("top")
    axs[0, i + 2].set_xlabel(labels[i])


# Based on: https://stackoverflow.com/questions/60807792/arrows-between-matplotlib-subplots
transFigure = fig.transFigure.inverted()
coord1 = transFigure.transform(axs[0, 1].transAxes.transform((0.0, 1.0)))
coord2 = transFigure.transform(axs[-1, -2].transAxes.transform((1.0, 0.0)))
line = Line2D(
    (coord1[0], coord2[0]),  # xdata
    (coord1[1], coord2[1]),  # ydata
    transform=fig.transFigure,
    color="black",
)
fig.lines.append(line)


base = (coord2[0] - coord1[0]) * figsize[0]
height = (coord1[1] - coord2[1]) * figsize[1]
rotation = (180 / np.pi) * np.arctan(height / base)


nth = 2
fractions = 0.5 * (1 / ndim) + np.arange(0, ndim, nth * 1) / (ndim)

for fraction in fractions:
    fig.text(
        fraction * (coord2[0] - coord1[0]) + coord1[0],
        coord1[1] - fraction * (coord1[1] - coord2[1]),
        "High spin",
        ha="center",
        va="bottom",
        rotation=-rotation,
        rotation_mode="anchor",
    )
    fig.text(
        fraction * (coord2[0] - coord1[0]) + coord1[0],
        coord1[1] - fraction * (coord1[1] - coord2[1]),
        "Low spin",
        ha="center",
        va="top",
        rotation=-rotation,
        rotation_mode="anchor",
    )

fig.legend(
    handles=legend_handles,
    ncol=len(legend_handles),
    loc="center",
    bbox_to_anchor=(0.5, -0.025)
    # borderaxespad=1,
)
plt.show()
save_figure(fig, "GW190425_posteriors", "figures/GW190425/")

fig.savefig("figures/GW190425/GW190425_posteriors.png", bbox_inches="tight", dpi=300)