In [None]:
import bilby
import corner
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from thesis_utils.gw import get_cbc_parameter_labels
from thesis_utils.gw import injections
from thesis_utils.plotting import (
    get_default_corner_kwargs,
    set_plotting,
    get_default_figsize,
    save_figure,
)

set_plotting()

In [None]:
results = dict(
    et_only=dict(
        aligned=bilby.core.result.read_in_result(
            "/home/michael/git_repos/nessai-et/outdir/1ET_2023_04_19_XAS/1ET_result.json"
        ),
        precessing=bilby.core.result.read_in_result(
            "/home/michael/git_repos/nessai-et/outdir/1ET_2023_05_05_XP/1ET_result.json"
        ),
    ),
    et_plus_ce=dict(
        aligned=bilby.core.result.read_in_result(
            "/home/michael/git_repos/nessai-et/outdir/1ET_2023_04_21_XAS/1ET_1CE_result.json"
        ),
        precessing=bilby.core.result.read_in_result(
            "/home/michael/git_repos/nessai-et/outdir/1ET_1CE_2023_05_05_XP/1ET_1CE_result.json"
        ),
    ),
)

In [None]:
aligned = False
parameters = [
    "chirp_mass",
    "mass_ratio",
    "ra",
    "dec",
    "theta_jn",
    "psi",
    "geocent_time",
]
# parameters = ["chirp_mass", "mass_ratio", "theta_jn", "geocent_time", "chi_eff", "chi_p"]
injection_parameters = (
    injections.BBH_GW150914.convert_to_aligned().bilby_format()
)
injection_parameters["luminosity_distance"] = 4000.0
injection_parameters = bilby.gw.conversion.generate_all_bbh_parameters(
    injection_parameters
)
injection_parameters["dt"] = 0.0

In [None]:
for k1 in results.keys():
    for k2 in results[k1].keys():
        post = results[k1][k2].posterior
        results[k1][k2].posterior["dt"] = 1000 * (
            post["geocent_time"] - injection_parameters["geocent_time"]
        )

if "geocent_time" in parameters:
    parameters.pop(parameters.index("geocent_time"))
    parameters.append("dt")
truth = [injection_parameters[p] for p in parameters]

In [None]:
ndim = len(parameters)
figsize = 0.9 * np.array([10.2756, 6.10236])
fig, axs = plt.subplots(
    ndim,
    ndim + 2,
    figsize=figsize,
    # layout="constrained",
    gridspec_kw=dict(
        hspace=0.05,
        wspace=0.05,
        #     hpad=1,
        #     wpad=1,
    ),
)
# fig.subplots_adjust(left=0.05, right=0.05, top=0.93, bottom=0.00, wspace=0.05, hspace=0.05)

legend_labels = {
    "et_only": "ET",
    "et_plus_ce": "ET+CE",
    "aligned": "Aligned",
    "precessing": "Precessing",
}

legend_handles = []

labels = get_cbc_parameter_labels(parameters, units=True)
if "dt" in labels:
    loc = labels.index("dt")
    labels[loc] = r"$\Delta t_c\;[\textrm{ms}]$"
bins = 32

linestyles = {
    "et_only": "-",
    "et_plus_ce": "-",
    "aligned": "-",
    "precessing": "-",
}
colours = {
    "et_only": "C0",
    "et_plus_ce": "C1",
    "aligned": "C0",
    "precessing": "C0",
}

truth_kwargs = dict(
    color="k",
    ls=":",
)

diag_keys = ["precessing"]

for key in diag_keys:
    colour = colours.get(key)
    ls = linestyles.get(key, "-")

    legend_handles.append(
        Line2D([0], [0], ls=ls, color=colour, label=legend_labels.get(key))
    )

    hist_kwargs = dict(
        density=True,
        histtype="step",
        color=colour,
        bins=32,
        linestyle=ls,
    )

    hist2d_kwargs = dict(
        no_fill_contours=True,
        plot_datapoints=False,
        plot_density=True,
        smooth=0.9,
        color=colour,
        bins=32,
        contour_kwargs=dict(linestyles=[ls]),
        levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0)),
        new_fig=False,
    )

    x1 = results["et_only"][key].posterior[parameters].to_numpy()
    x2 = results["et_plus_ce"][key].posterior[parameters].to_numpy()

    for vdim in range(ndim):
        for hdim in range(ndim + 2):
            axs[vdim, hdim].tick_params(labelbottom=False, labelleft=False)
            # axs[vdim, hdim].set_box_aspect(1.0)
            if vdim == (hdim - 1):
                axs[vdim, hdim].axis("off")
            elif vdim == hdim:
                axs[vdim, hdim].hist(x1[:, vdim], **hist_kwargs)
                axs[vdim, hdim].axvline(truth[vdim], **truth_kwargs)
                axs[vdim, hdim].set_yticks([])
            elif vdim > hdim:
                corner.hist2d(
                    x1[:, hdim],
                    x1[:, vdim],
                    ax=axs[vdim, hdim],
                    **hist2d_kwargs,
                )
                axs[vdim, hdim].axvline(truth[hdim], **truth_kwargs)
                axs[vdim, hdim].axhline(truth[vdim], **truth_kwargs)
            elif vdim == (hdim - 2):
                axs[vdim, hdim].hist(x2[:, vdim], **hist_kwargs)
                axs[vdim, hdim].set_yticks([])
                axs[vdim, hdim].axvline(truth[vdim], **truth_kwargs)
            elif vdim < hdim:
                corner.hist2d(
                    x2[:, hdim - 2],
                    x2[:, vdim],
                    ax=axs[vdim, hdim],
                    **hist2d_kwargs,
                )
                axs[vdim, hdim].axvline(truth[hdim - 2], **truth_kwargs)
                axs[vdim, hdim].axhline(truth[vdim], **truth_kwargs)

for i in range(ndim):
    if i > 0:
        axs[i, 0].tick_params(labelleft=True, rotation=45, pad=8)
        axs[i, 0].set_ylabel(labels[i])
        for label in axs[i, 0].get_yticklabels():
            label.set_rotation_mode("anchor")
            label.set_rotation(45)
            label.set_va("center")
    if i < (ndim - 1):
        axs[i, -1].tick_params(labelright=True, rotation=45, pad=8)
        axs[i, -1].yaxis.set_label_position("right")
        axs[i, -1].set_ylabel(labels[i])
        for label in axs[i, -1].get_yticklabels():
            label.set_rotation_mode("anchor")
            label.set_rotation(45)
            label.set_va("center")

    axs[-1, i].tick_params(labelbottom=True, rotation=45)
    for label in axs[-1, i].get_xticklabels():
        label.set_rotation_mode("anchor")
        label.set_rotation(45)
        label.set_ha("right")

    axs[0, i + 2].tick_params(labeltop=True, rotation=45)
    for label in axs[0, i + 2].get_xticklabels():
        label.set_rotation_mode("anchor")
        label.set_rotation(45)
        label.set_ha("left")

    axs[-1, i].set_xlabel(labels[i])

    axs[0, i + 2].xaxis.set_label_position("top")
    axs[0, i + 2].set_xlabel(labels[i])


# Based on: https://stackoverflow.com/questions/60807792/arrows-between-matplotlib-subplots
transFigure = fig.transFigure.inverted()
coord1 = transFigure.transform(axs[0, 1].transAxes.transform((0.0, 1.0)))
coord2 = transFigure.transform(axs[-1, -2].transAxes.transform((1.0, 0.0)))
line = Line2D(
    (coord1[0], coord2[0]),  # xdata
    (coord1[1], coord2[1]),  # ydata
    transform=fig.transFigure,
    color="black",
)
fig.lines.append(line)


base = (coord2[0] - coord1[0]) * figsize[0]
height = (coord1[1] - coord2[1]) * figsize[1]
rotation = (180 / np.pi) * np.arctan(height / base)


nth = 2
fractions = 0.5 * (1 / ndim) + np.arange(0, ndim, nth * 1) / (ndim)

for fraction in fractions:

    fig.text(
        fraction * (coord2[0] - coord1[0]) + coord1[0],
        coord1[1] - fraction * (coord1[1] - coord2[1]),
        "ET + CE",
        ha="center",
        va="bottom",
        rotation=-rotation,
        rotation_mode="anchor",
    )
    fig.text(
        fraction * (coord2[0] - coord1[0]) + coord1[0],
        coord1[1] - fraction * (coord1[1] - coord2[1]),
        "ET-only",
        ha="center",
        va="top",
        rotation=-rotation,
        rotation_mode="anchor",
    )

# fig.legend(
#     handles=legend_handles,
#     ncol=len(legend_handles),
#     loc="center",
#     bbox_to_anchor=(0.5, -0.025)
#     # borderaxespad=1,
# )
save_figure(fig, "ET_posterior")
plt.show()