In [None]:
import glob
import os

os.environ["PATH"] = os.pathsep.join(("/usr/local/texlive/2023/bin/x86_64-linux", os.environ["PATH"]))
os.environ["BILBY_STYLE"] = "none"

import bilby
import seaborn as sns
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import tqdm
import h5py

import thesis_utils
from thesis_utils.gw import get_cbc_parameter_labels
from thesis_utils.plotting import (
    set_plotting,
    get_default_figsize,
    save_figure,
    get_default_figsize,
    make_pp_plot_bilby_results,
)
from thesis_utils.io import load_json
from thesis_utils import colours as thesis_colours

# import os

set_plotting()
# plt.rcParams["text.usetex"] = False

from pp_plot import make_pp_plot

from importlib import reload
reload(thesis_utils)

In [None]:
path = "outdir_nessai_mass_ratio_rerun/"
path_cvm = "outdir_nessai_cmv/"
path_marg = "outdir_nessai_phase_marg/"


In [None]:
result_files = natsorted(glob.glob(path + "/result/*_result.hdf5"))
cvm_result_files = natsorted(glob.glob(path_cvm + "/result/*_result.hdf5"))
marg_result_files = natsorted(glob.glob(path_marg + "/result/*_result.hdf5"))

In [None]:
results = []
for rf in tqdm.tqdm(result_files):
    results.append(bilby.core.result.read_in_result(rf))

In [None]:
marg_results = []
for rf in tqdm.tqdm(marg_result_files):
    marg_results.append(bilby.core.result.read_in_result(rf))

In [None]:
cvm_results = []
for rf in tqdm.tqdm(cvm_result_files):
    cvm_results.append(bilby.core.result.read_in_result(rf))

In [None]:
parameters = results[0].search_parameter_keys
labels = {p : get_cbc_parameter_labels(p, units=False) for p in parameters}

In [None]:
figsize = get_default_figsize()
figsize[1] = figsize[0]
figsize /= 1.5

In [None]:
fig, pvals = make_pp_plot_bilby_results(
    results,
    labels=labels,
    width=figsize[0],
    height=figsize[0],
    colours=['#d73027','#fc8d59','#fee090','#91bfdb','#4575b4'],
    
)
fig.savefig("figures/pp_plot_spins.pdf")

In [None]:
fig, pvals = make_pp_plot_bilby_results(
    marg_results,
    labels=labels,
    width=figsize[0],
    height=figsize[0],
    colours=['#d73027','#fc8d59','#fee090','#91bfdb','#4575b4'],
    
)
fig.savefig("figures/pp_plot_marg.pdf")

In [None]:
fig, pvals = make_pp_plot_bilby_results(
    cvm_results,
    labels=labels,
    width=figsize[0],
    height=figsize[0],
    colours=['#d73027','#fc8d59','#fee090','#91bfdb','#4575b4'],
    
)
fig.savefig("figures/pp_plot_cvm.pdf")

In [None]:
snrs = np.array([[ifo["optimal_SNR"] for ifo in r.meta_data["likelihood"]["interferometers"].values()] for r in results])

In [None]:
network_snrs = np.sqrt((snrs ** 2).sum(axis=1))

In [None]:
nessai_results_files = natsorted(glob.glob(path + "/result/*_nessai/result.hdf5"))
marg_nessai_results_files = natsorted(glob.glob(path_marg + "/result/*_nessai/result.hdf5"))

In [None]:
nessai_results = dict(evaluations=[], log_evidence=[], sampling_time=[], population_time=[], likelihood_time=[], training_time=[])
for rf in nessai_results_files:
    with h5py.File(rf, "r") as f:
        nessai_results["evaluations"].append(f["total_likelihood_evaluations"][()])
        nessai_results["log_evidence"].append(f["log_evidence"][()])
        nessai_results["sampling_time"].append(f["sampling_time"][()])
        nessai_results["population_time"].append(f["population_time"][()])
        nessai_results["likelihood_time"].append(f["likelihood_evaluation_time"][()])
        nessai_results["training_time"].append(f["training_time"][()])
nessai_results = {k: np.array(v) for k, v in nessai_results.items()}

In [None]:
marg_nessai_results = dict(evaluations=[], log_evidence=[], sampling_time=[], population_time=[], likelihood_time=[], training_time=[])
for rf in marg_nessai_results_files:
    with h5py.File(rf, "r") as f:
        marg_nessai_results["evaluations"].append(f["total_likelihood_evaluations"][()])
        marg_nessai_results["log_evidence"].append(f["log_evidence"][()])
        marg_nessai_results["sampling_time"].append(f["sampling_time"][()])
        marg_nessai_results["population_time"].append(f["population_time"][()])
        marg_nessai_results["likelihood_time"].append(f["likelihood_evaluation_time"][()])
        marg_nessai_results["training_time"].append(f["training_time"][()])
marg_nessai_results = {k: np.array(v) for k, v in marg_nessai_results.items()}

In [None]:
half_figsize = get_default_figsize()
half_figsize *= 0.5

In [None]:
fig, ax = plt.subplots(figsize=half_figsize)
plt.scatter(
    nessai_results["sampling_time"] / 3600,
    nessai_results["evaluations"],
    c=network_snrs,
    cmap="cividis",
)
plt.xlabel("Wall time [hrs]")
plt.ylabel("Likelihood evaluations")
# plt.yscale("log")
plt.colorbar(label="SNR")
# ax.ticklabel_format(useOffset=False)
save_figure(fig, "phase_stats", bbox_inches=None)
plt.show()

ax_pos = ax.get_position().bounds

fig, ax = plt.subplots(figsize=half_figsize)
plt.scatter(
    nessai_results["sampling_time"] / marg_nessai_results["sampling_time"],
    nessai_results["evaluations"] / marg_nessai_results["evaluations"],
    c=network_snrs,
    cmap="cividis",
)
plt.xlabel("Wall time ratio")
plt.ylabel("Likelihood evaluations\nratio")
plt.colorbar(label="SNR")
ax.set_position(ax_pos)
save_figure(fig, "phase_comparison", bbox_inches=None)
plt.show()

In [None]:
np.median(nessai_results["evaluations"] / marg_nessai_results["evaluations"])

In [None]:
np.median(nessai_results["evaluations"])

In [None]:
pop_fraction = (nessai_results["population_time"] - nessai_results["likelihood_time"]) / nessai_results["sampling_time"]
likelihood_fraction = nessai_results["likelihood_time"] / nessai_results["sampling_time"]
train_fraction = nessai_results["training_time"] / nessai_results["sampling_time"]

In [None]:
np.mean(train_fraction)

In [None]:
kwargs = dict(
    histtype="step",
    lw=2.0,
)
fig = plt.figure(figsize=half_figsize)
plt.hist(pop_fraction, label="Population", **kwargs)
plt.hist(likelihood_fraction, label="Likelihood", **kwargs)
plt.xlabel("Fraction of wall time")
plt.legend()
save_figure(fig, "phase_time_fraction")
plt.show()