In [None]:
import glob
import itertools
import os

import seaborn as sns
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np

from thesis_utils.gw import get_cbc_parameter_labels
from thesis_utils.plotting import set_plotting, get_default_figsize, save_figure, get_default_corner_kwargs, get_default_figsize
from thesis_utils.io import load_json
from thesis_utils import colours as thesis_colours

# import os
os.environ["PATH"] = os.pathsep.join(("/usr/local/texlive/2022/bin/x86_64-linux", os.environ["PATH"]))
os.environ["BILBY_STYLE"] = "none"

set_plotting()
# plt.rcParams["text.usetex"] = False
# plt.rcParams["font.family"] = "DejaVu Sans"

Restults files are saved on HAWK

In [None]:
paper_results_path = "/home/michael.williams/git_repos/nessai-validation/gw/paper_analysis/original_results/"
cvm_results_path = "/home/michael.williams/git_repos/nessai-validation/gw/paper_analysis/updated_results/"

In [None]:
rerun_path = "/home/michael.williams/git_repos/nessai-validation/gw/paper_analysis/outdir_v0.8.0b1/result/"

In [None]:
orig_path_marg = "/scratch/michael.williams/projects/nessai-validation/gw/paper_analysis/outdir_nessai_constant_volume_mode_marg_dist/"
orig_path_no_marg = "/scratch/michael.williams/projects/nessai-validation/gw/paper_analysis/outdir_nessai_constant_volume_mode_A/"


## Load paper results

In [None]:
dirs = [
    a + b for a, b in itertools.product(["nessai", "dynesty"], ["", "_marg"])
]
result_keys = ["evaluations", "evidence", "runtimes"]

In [None]:
paper_results = {}
for d in dirs:
    paper_results[d] = {}
    for k in result_keys:
        paper_results[d][k] = np.array(
            list(
                load_json(
                    os.path.join(paper_results_path, d, f"{k}.json")
                ).values()
            )
        )

In [None]:
snrs = np.sqrt(
    np.sum(
        np.array(
            list(
                load_json(
                    "/home/michael.williams/git_repos/nessai-paper/results/snrs.json"
                ).values()
            )
        )
        ** 2,
        axis=1,
    )
)

## Load CVM results

In [None]:
cvm_result_keys = ["evaluations", "log_evidence", "runtimes"]
cvm_path = "/home/michael.williams/git_repos/nessai-validation/gw/paper_analysis/updated_results/"

In [None]:
cvm_results = {"marg": {}, "no_marg": {}}
cvm_results["no_marg"]["evaluations"] = np.array(list(load_json(cvm_path + "new_evaluations_wo_marg.json").values()))
cvm_results["no_marg"]["runtimes"] = np.array(list(load_json(cvm_path + "new_runtimes_wo_marg.json").values()))
cvm_results["marg"]["evaluations"] = np.array(list(load_json(cvm_path + "new_evaluations_w_marg.json").values()))
cvm_results["marg"]["runtimes"] = np.array(list(load_json(cvm_path + "new_runtimes_w_marg.json").values()))

## Load reruns

In [None]:
import h5py

In [None]:
nessai_results_files = natsorted(glob.glob(rerun_path + "*_nessai/result.hdf5"))

In [None]:
rerun_results = dict(evaluations=[], log_evidence=[], runtimes=[])
for rf in nessai_results_files:
    with h5py.File(rf, "r") as f:
        rerun_results["evaluations"].append(f["total_likelihood_evaluations"][()])
        rerun_results["log_evidence"].append(f["log_evidence"][()])
        rerun_results["runtimes"].append(f["sampling_time"][()])
rerun_results = {k: np.array(v) for k, v in rerun_results.items()}

## Compare results

In [None]:
print("Summary of results (no marg)")
print("Improvement:")
print(
    "Evaluations:",
    np.median(
        paper_results["nessai"]["evaluations"] / cvm_results["no_marg"]["evaluations"]
    ),
)
print(
    "Times:",
    np.median(paper_results["nessai"]["runtimes"] / cvm_results["no_marg"]["runtimes"]),
)
print("Summary of results (marg)")
print("Improvement:")
print(
    "Evaluations:",
    np.median(
        paper_results["nessai_marg"]["evaluations"] / cvm_results["marg"]["evaluations"]
    ),
)
print(
    "Times:",
    np.median(paper_results["nessai_marg"]["runtimes"] / cvm_results["marg"]["runtimes"]),
)

print("Summary of results (rerun)")
print("Improvement:")
print(
    "Evaluations:",
    np.median(
        paper_results["nessai"]["evaluations"] / rerun_results["evaluations"]
    ),
)
print(
    "Times:",
    np.median(paper_results["nessai"]["runtimes"] / rerun_results["runtimes"]),
)
print("Summary of results (dynesty)")
print(
    "Evaluations:",
    np.mean(
        paper_results["dynesty_marg"]["evaluations"] / cvm_results["marg"]["evaluations"]
    ),
)
print(
    "Times:",
    np.median(paper_results["dynesty_marg"]["runtimes"] / cvm_results["marg"]["runtimes"]),
)

In [None]:
np.median(cvm_results["marg"]["evaluations"])

In [None]:
colours = {
    "dynesty": "C0",
    "nessai": "C1",
    "cvm": "C2",
    # "rerun": "C3",
}
ls = {
    "dynesty": "-.",
    "nessai": "--",
    "cvm": "-",
    # "rerun": ":",
}

In [None]:
labels = {
    "dynesty": "dynesty",
    "nessai": "nessai - Williams et al. 2021",
    "cvm": "nessai - CVM",
    # "rerun": "nessai - CVM -  rerun",
}

In [None]:
figsize = get_default_figsize()
figsize[1] *= 0.78
fig, axs = plt.subplots(2, 2, sharey=True, figsize=figsize, sharex="col")

bins = [np.logspace(5.8, 8, 16), np.logspace(0.1, 3.0, 20)]

hist_kwargs = dict(histtype="step")
factors = [1, 3600]

for i, (k, factor, b) in enumerate(zip(["evaluations", "runtimes"], factors, bins)):

    axs[0, i].hist(
        paper_results["dynesty"][k] / factor,
        bins=b,
        color=colours["dynesty"],
        ls=ls["dynesty"],
        **hist_kwargs,
    )
    axs[0, i].hist(
        paper_results["nessai"][k] / factor,
        bins=b,
        color=colours["nessai"],
        ls=ls["nessai"],
        **hist_kwargs,
    )
    axs[0, i].hist(
        cvm_results["no_marg"][k] / factor,
        bins=b,
        color=colours["cvm"],
        ls=ls["cvm"],
        **hist_kwargs
    )
    
    axs[1, i].hist(
        paper_results["dynesty_marg"][k] / factor,
        bins=b,
        color=colours["dynesty"],
        ls=ls["dynesty"],
        **hist_kwargs,
    )
    axs[1, i].hist(
        paper_results["nessai_marg"][k] / factor,
        bins=b,
        color=colours["nessai"],
        ls=ls["nessai"],
        **hist_kwargs,
    )
    axs[1, i].hist(
        cvm_results["marg"][k] / factor,
        bins=b,
        color=colours["cvm"],
        ls=ls["cvm"],
        **hist_kwargs
    )
    
    # axs[i].hist(
    #     rerun_results[k] / factor,
    #     bins=b,
    #     color=colours["rerun"],
    #     ls=ls["rerun"],
    #     **hist_kwargs
    # )

for ax in axs.reshape(-1):
    ax.set_xscale("log")

axs[0, 0].text(0.05, 0.9, "No distance marg.", transform=axs[0, 0].transAxes)
axs[0, 1].text(0.05, 0.9, "No distance marg.", transform=axs[0, 1].transAxes)
axs[1, 0].text(0.05, 0.9, "Distance marg.", transform=axs[1, 0].transAxes)
axs[1, 1].text(0.05, 0.9, "Distance marg.", transform=axs[1, 1].transAxes)


    
# axs[0, 0].set_xscale("log")
axs[1, 0].set_xlabel("Likelihood evaluations")

# axs[1].set_xscale("log")
# axs[0, 1].set_xlabel("Wall time [hr]")
axs[1, 1].set_xlabel("Wall time [hrs]")

handles = []
legend_labels = []
for sampler in labels.keys():
    legend_labels.append(labels[sampler])
    handles.append(mpl.lines.Line2D([0, 1], [0, 1], color=colours[sampler], ls=ls[sampler]))
plt.tight_layout()
fig.legend(
    handles,
    legend_labels,
    loc="center",
    ncol=3,
    bbox_to_anchor=(0.5, -0.0)
)
save_figure(fig, "nessai_cvm_comparison_gw", "figures")
plt.show()

## Per event comparison

In [None]:
fig, axs = plt.subplots(
    2, 2,
    sharex="col",
    sharey="row",
    gridspec_kw={'width_ratios': [3, 1]},
    #figsize=figsize,
)
fig.subplots_adjust(wspace=0)


bins = np.logspace(-1.2, 2.5, 12)

for i, k in enumerate(["evaluations", "runtimes"]):

    axs[i, 0].scatter(
        snrs,
        paper_results["dynesty_marg"][k] / paper_results["nessai_marg"][k],
        marker=".",
        color="C1",
    )
    
    axs[i, 0].scatter(
        snrs,
        paper_results["dynesty_marg"][k] / cvm_results["marg"][k],
        marker="+",
        color="C2"
    )  # , bins=bins, **hist_kwargs)
    
    hist_kwargs = dict(
        histtype="step",
        bins=bins,
    )
    
    axs[i, 1].hist(
        paper_results["dynesty_marg"][k] / paper_results["nessai_marg"][k],
        orientation="horizontal",
        color="C1",
        **hist_kwargs
    )
    
    axs[i, 1].hist(
        paper_results["dynesty_marg"][k] / cvm_results["marg"][k],
        orientation="horizontal",
        color="C2",
        **hist_kwargs

    )

    axs[i, 0].axhline(1.0, zorder=-1, color="k")
    axs[i, 1].axhline(1.0, zorder=-1, color="k")
    axs[i, 0].set_xscale("log")
    axs[i, 0].set_yscale("log")
    # axs[i].set_xlabel(r"$\rho$")
    
axs[0, 0].set_ylabel("Ratio - likelihood evaluations")
axs[1, 0].set_ylabel("Ratio - wall times")
axs[1, 0].set_xlabel(r"$\rho$")
axs[1, 1].set_xlabel(r"Counts")

save_figure(fig, "snr_breakdown")

plt.show()

In [None]:
fig, axs = plt.subplots(
    2, 2,
    sharex="col",
    sharey="row",
    gridspec_kw={'width_ratios': [3, 1]},
    #figsize=figsize,
)
fig.subplots_adjust(wspace=0)


bins = np.logspace(-0.2, 1.1, 12)

for i, k in enumerate(["evaluations", "runtimes"]):

    axs[i, 0].scatter(
        snrs,
        paper_results["nessai_marg"][k] / cvm_results["marg"][k],
        marker=".",
        color="C0"
    )  # , bins=bins, **hist_kwargs)
    
    hist_kwargs = dict(
        histtype="step",
        bins=bins,
    )

    axs[i, 1].hist(
        paper_results["nessai_marg"][k] / cvm_results["marg"][k],
        orientation="horizontal",
        color="C0",
        **hist_kwargs

    )

    axs[i, 0].axhline(1.0, zorder=-1, color="k")
    axs[i, 1].axhline(1.0, zorder=-1, color="k")
    axs[i, 0].set_xscale("log")
    axs[i, 0].set_yscale("log")
    # axs[i].set_xlabel(r"$\rho$")
    
# axs[0, 1].yaxis.set_tick_params(labelright='off', labelleft='off')
# axs[1, 1].yaxis.set_tick_params(labelright='off', labelleft='off')

axs[0, 0].set_ylabel("Ratio - likelihood evaluations")
axs[1, 0].set_ylabel("Ratio - wall times")
axs[1, 0].set_xlabel(r"$\rho$")
axs[1, 1].set_xlabel(r"Counts")

save_figure(fig, "snr_breakdown_nessai")

plt.show()

## Examine "bad" runs

In [None]:
cvm_results_path = "/scratch/michael.williams/projects/nessai-validation/gw/paper_analysis/outdir_nessai_constant_volume_mode_marg_dist/result/"

In [None]:
to_check = np.argsort(paper_results["dynesty_marg"]["evaluations"] / cvm_results["marg"]["evaluations"])[1:17]

In [None]:
ratios = paper_results["dynesty_marg"]["evaluations"] / cvm_results["marg"]["evaluations"]

to_check = np.where(
    (paper_results["dynesty"]["evaluations"] / cvm_results["evaluations"] < 1.0)
    & (paper_results["dynesty"]["evaluations"]  != 0.0)
)[0]

In [None]:
snrs[to_check]

In [None]:
all_cvm_results = natsorted(glob.glob(cvm_results_path + "*_nessai/result.json"))
cvm_rf_to_check = {i: r for i, r in enumerate(all_cvm_results) if i in to_check}

In [None]:
snrs_to_check = snrs[to_check]
ratios_to_check = ratios[to_check]

In [None]:
ratios_to_check

In [None]:
idx = np.argsort(snrs_to_check)

In [None]:
snr_loc = np.searchsorted(snrs_to_check[idx], snrs_to_check)

In [None]:
results_to_check = {}
for i, rf in cvm_rf_to_check.items():
    results_to_check[i] = load_json(rf)

In [None]:
list(results_to_check.keys())

In [None]:
corner_kwargs = get_default_corner_kwargs()

In [None]:
grid_figsize = 1 * get_default_figsize()

In [None]:
fig_sky, axs_sky = plt.subplots(
    4, 4,
    sharex=True,
    sharey=True,
    figsize=grid_figsize,
)
plt.subplots_adjust(hspace=0.1, wspace=0.1)

fig_time, axs_time = plt.subplots(4, 4, sharex=True, sharey=True, figsize=grid_figsize)
plt.subplots_adjust(hspace=0.1, wspace=0.1)


ra_label, dec_label, t_label = get_cbc_parameter_labels(["ra", "dec", "geocent_time"], units=True)
# t_label = "t_c"

time_xticks = np.array([-0.1, -0.05, 0.0, 0.05, 0.1])

# Set axis labels before unravelling the arrays
for i in range(4):
    axs_sky[-1, i].set_xlabel(dec_label)
    axs_time[-1, i].set_xlabel(t_label)

for i in range(4):
    axs_sky[i, 0].set_ylabel(ra_label)
    axs_time[i, 0].set_ylabel("Density") 

axs_sky = axs_sky.ravel()
axs_time = axs_time.ravel()

bins = np.linspace(-0.1, 0.1, 64, endpoint=True)

for i, snr, (run_id, result) in zip(snr_loc, snrs_to_check, results_to_check.items()):
    
    
    # print(ratio)
    corner.hist2d(
        np.array(result["nested_samples"]["ra"]),
        np.array(result["nested_samples"]["dec"]),
        bins=64,
        color=thesis_colours.teal,
        smooth=0.8,
        ax=axs_sky[i],
        plot_datapoints=False,
        fill_contours=True,
    )
    
    axs_time[i].hist(
        result["nested_samples"]["geocent_time"],
        density=True,
        bins=bins,
        color=thesis_colours.teal,
    )
    
    
    ratio = ratios[run_id]
    if ratio > 1:
        axs_sky[i].text(0.9, 0.8, r"$\blacktriangle$", transform=axs_sky[i].transAxes)
        axs_time[i].text(0.9, 0.8, r"$\blacktriangle$",transform=axs_time[i].transAxes)
    axs_sky[i].text(0.05, 0.8, rf"$\rho={snr:.1f}$", transform=axs_sky[i].transAxes)
    axs_time[i].text(0.05, 0.8, rf"$\rho={snr:.1f}$", transform=axs_time[i].transAxes)
    axs_time[i].set_xticks(time_xticks)
    axs_time[i].tick_params(axis='x', rotation=45)

# axs_sky[-1].axis("off")
# axs_time[-1].axis("off")

# axs_time.[-3].set_xlabel()

# fig_sky.tight_layout()
fig_sky.show()
# fig_time.tight_layout()
fig_time.show()

save_figure(fig_sky, "bad_injections_sky")
save_figure(fig_time, "bad_injections_time")