In [None]:
import bilby
import corner
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
from scipy import stats
import seaborn as sns

from thesis_utils.plotting import (
    set_plotting,
    get_default_figsize,
    save_figure,
    get_default_corner_kwargs,
)
from thesis_utils.gw import get_cbc_parameter_labels
from thesis_utils.io import load_pickle

set_plotting()

In [None]:
no_cal_path = "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_no_calibration_Pv2/final_result/nessai_gwtc_1_GW150914_data0_1126259462-391_analysis_H1L1_merge_result.hdf5"
reweight_path = "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_cal_reweight_Pv2/final_result/nessai_gwtc_1_GW150914_data0_1126259462-391_analysis_H1L1_merge_result.hdf5"
cal_path = "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_cal_sample_Pv2/final_result/nessai_gwtc_1_GW150914_Pv2_data0_1126259462-391_analysis_H1L1_merge_result.hdf5"

In [None]:
cal_result = bilby.core.result.read_in_result(cal_path)
no_cal_result = bilby.core.result.read_in_result(no_cal_path)
reweight_result = bilby.core.result.read_in_result(reweight_path)

In [None]:
parameters = [
    "chirp_mass",
    "mass_ratio",
    "chi_p",
    "chi_eff",
    "theta_jn",
    "ra",
    "dec",
    "luminosity_distance",
]

In [None]:
from pesummary.utils.utils import jensen_shannon_divergence

In [None]:
for key in parameters:
    print(key)
    x1 = cal_result.posterior[key]
    x2 = reweight_result.posterior[key]
    x3 = no_cal_result.posterior[key]
    cal_reweight = jensen_shannon_divergence([x1, x2])
    cal_none = jensen_shannon_divergence([x1, x3])
    reweight_none = jensen_shannon_divergence([x2, x3])
    print(cal_reweight, cal_none, reweight_none)

In [None]:
for result in [cal_result, reweight_result, no_cal_result]:
    print(result.sampling_time)

In [None]:
results = {
    "no_cal": no_cal_result,
    "sample": cal_result,
    "reweight": reweight_result,
}
labels = {
    "sample": "Calibration sampling",
    "reweight": "Calibration reweighting",
    "no_cal": "No calibration",
}
colours = {
    "sample": "C0",
    "reweight": "C1",
    "no_cal": "k",
}
linestyles = {
    "sample": "-",
    "reweight": "-",
    "no_cal": "--",
}

In [None]:
corner_kwargs = get_default_corner_kwargs()
corner_kwargs["plot_density"] = False
corner_kwargs["plot_datapoints"] = False
corner_kwargs["no_fill_contours"] = True
corner_kwargs["show_titles"] = False
corner_kwargs.pop("fill_contours")
corner_kwargs["levels"] = (1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0))
corner_kwargs["quantiles"] = []
corner_kwargs["smooth"] = 0.9
corner_kwargs["bins"] = 32
corner_kwargs["labelpad"] = 0.1

figsize = 2 * get_default_figsize()
figsize[0] = figsize[1]
fig = plt.figure(figsize=figsize)

corner_labels = get_cbc_parameter_labels(parameters, units=True)
legend_elements = []
for key, result in results.items():
    ls = linestyles.get(key)
    c = colours.get(key)
    corner_kwargs["color"] = c
    corner_kwargs["hist_kwargs"]["color"] = c
    corner_kwargs["hist_kwargs"]["ls"] = ls
    corner_kwargs["contour_kwargs"] = dict(linestyles=[ls])
    data = result.posterior[parameters].to_numpy()
    fig = corner.corner(
        data,
        fig=fig,
        labels=corner_labels,
        **corner_kwargs,
    )
    legend_elements.append(Line2D([0], [0], color=c, ls=ls, label=labels.get(key)))
fig.legend(
    handles=legend_elements,
    bbox_to_anchor=(0.8, 0.8),
    loc="center",
    fontsize=16,
)
plt.show()
save_figure(fig, "calibration_posterior", f"figures/calibration/")

# Run statistics

In [None]:
runs = {
    "sample_cal_0": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_cal_sample_Pv2/result/nessai_gwtc_1_GW150914_Pv2_data0_1126259462-391_analysis_H1L1_par0_nessai/nested_sampler_resume.pkl",
    "sample_cal_1": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_cal_sample_Pv2/result/nessai_gwtc_1_GW150914_Pv2_data0_1126259462-391_analysis_H1L1_par1_nessai/nested_sampler_resume.pkl",
    "reweight_0": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_cal_reweight_Pv2/result/nessai_gwtc_1_GW150914_data0_1126259462-391_analysis_H1L1_par0_nessai/nested_sampler_resume.pkl",
    "reweight_1": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_cal_reweight_Pv2/result/nessai_gwtc_1_GW150914_data0_1126259462-391_analysis_H1L1_par1_nessai/nested_sampler_resume.pkl",
    # "none_0": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_no_calibration_Pv2/result/nessai_gwtc_1_GW150914_data0_1126259462-391_analysis_H1L1_par0_nessai/nested_sampler_resume.pkl",
    # "none_1": "/home/michael.williams/git_repos/nessai-gwtc-1/analysis/IMRPhenomPv2/outdir_nessai_gwtc_1_GW150914_no_calibration_Pv2/result/nessai_gwtc_1_GW150914_data0_1126259462-391_analysis_H1L1_par1_nessai/nested_sampler_resume.pkl",
}

In [None]:
samplers = {}
for name, path in runs.items():
    samplers[name] = load_pickle(path)

In [None]:
ls = {
    "sample_cal_0": "-",
    "sample_cal_1": "-",
    "reweight_0": "--",
    "reweight_1": "--",
}
colours = {
    "sample_cal_0": "#2c7bb6",
    "sample_cal_1": "#abd9e9",
    "reweight_0": "#d7191c",
    "reweight_1": "#fdae61",
}

In [None]:
labels = {
    "sample_cal_0": "Sample",
    "reweight_0": "Reweight",
}

In [None]:
figsize = get_default_figsize()
figsize[1] *= 1.3
fig, axs = plt.subplots(4, 1, sharex=True, figsize=figsize)

for i, (name, ns) in enumerate(samplers.items()):
    it = (np.arange(len(ns.min_likelihood))) * (ns.nlive // 10)

    axs[0].plot(
        it,
        ns.likelihood_evaluations,
        color=colours.get(name),
        ls=ls.get(name),
    )

    dtrain = np.array(ns.training_iterations[1:]) - np.array(
        ns.training_iterations[:-1]
    )
    axs[1].plot(
        it,
        ns.mean_acceptance_history,
        ls=ls.get(name),
        color=colours.get(name),
    )
    # axs[2].plot(
    #     ns.training_iterations,
    #     np.arange(len(ns.training_iterations)),
    #     ls=ls.get(name),
    #     color=colours.get(name),
    # )
    axs[2].plot(
        ns.training_iterations[1:],
        dtrain,
        ls=ls.get(name),
        color=colours.get(name),
    )

    axs[3].plot(
        ns.population_iterations,
        ns.population_acceptance,
        ls=ls.get(name),
        label=labels.get(name, None),
        color=colours.get(name),
    )

axs[0].set_ylabel("Likelihood\nevaluations")

# axs[1].set_ylim([1e-3, 1])
axs[1].set_yscale("log")
axs[1].set_ylabel("Acceptance")

# axs[1].set_ylim(0, 2000)

# axs[2].set_ylabel("Cumulative \ntraining count")

axs[2].set_ylabel("Iterations \nbetween training")
axs[2].set_yscale("log")
# axs[2].set_ylim([0, 1000])

axs[3].set_yscale("log")
# axs[3].set_ylim([3e-3, 0.10])
axs[3].set_ylabel("Rejection sampling \nacceptance")

# for ax in axs:
#     ax.fill_betweenx(
#         y=ax.get_ylim(),
#         x1=0,
#         x2=ns.training_iterations[0],
#         alpha=0.25,
#         zorder=-1,
#         color="gray",
#         lw=0.0
#     )
#     ax.set_xlim([0, 75_000])

axs[-1].set_xlabel("Iteration")

axs[-1].legend(ncol=len(samplers), loc="center", bbox_to_anchor=(0.5, -0.6))
axs[-1].set_xlim([0, 27_500])

plt.tight_layout()
save_figure(fig, "calibration_comparison_stats", "figures/calibration/")
plt.show()

In [None]:
indices = {
    "sample": [
        samplers["sample_cal_0"].insertion_indices,
        samplers["sample_cal_1"].insertion_indices,
    ],
    "reweight": [
        samplers["reweight_0"].insertion_indices,
        samplers["reweight_1"].insertion_indices,
    ],
}

In [None]:
colours = {
    "reweight": sns.color_palette("Oranges", n_colors=3)[1:],
    "sample": sns.color_palette("Blues", n_colors=3)[1:],
}
linestyles = {
    "reweight": "--",
    "sample": "-",
}

nlive = 1000

figsize = 0.8 * get_default_figsize()
fig = plt.figure(figsize=figsize)
# plt.axhline(0, color="grey")
x = np.arange(0, nlive, 1)
analytic = x / x[-1]
n = np.mean([len(idx) for idx in indices["sample"] + indices["reweight"]])
for key, data in indices.items():
    if not indices:
        continue
    c = colours[key]
    for i, idx in enumerate(data):
        _, counts = np.unique(idx, return_counts=True)
        estimated = np.cumsum(counts) / len(idx)
        plt.plot(analytic - estimated, ls=linestyles.get(key), c=c[i])

for ci in [0.5, 0.95, 0.997]:
    bound = (1 - ci) / 2
    bound_values = stats.binom.ppf(1 - bound, n, analytic) / n
    lower = bound_values - analytic
    upper = analytic - bound_values
    upper[0] = 0
    upper[-1] = 0
    lower[0] = 0
    lower[-1] = 0

    plt.fill_between(x, lower, upper, color="grey", alpha=0.2)

legend_elements = [
    Line2D([0], [0], ls=linestyles["sample"], c=colours["sample"][-1], label="Sample"),
    Line2D(
        [0], [0], ls=linestyles["reweight"], c=colours["reweight"][-1], label="Reweight"
    ),
]
plt.legend(handles=legend_elements)

plt.xlim(0, 1000)
plt.xlabel("Insertion index")
plt.ylabel(r"Analytic CMF - Empirical CMF")
plt.show()
save_figure(fig, f"calibration_indices", "figures/calibration/")