In [None]:
import matplotlib.pyplot as plt
import numpy as np
from nessai.utils import rolling_mean
from scipy import stats

from thesis_utils.plotting import set_plotting, save_figure, get_default_figsize
from thesis_utils.io import load_pickle

set_plotting()

## Comparison plot

Previous version used runs labelled `v13`

In [None]:
runs = {
    "default": "outdir/paper_default_no_phase_marg_nessai/nested_sampler_resume.pkl",
    # "default": "outdir/gw_v13_nessai/nested_sampler_resume.pkl",
    # "alpha-beta": "outdir/alpha_beta_nessai/nested_sampler_resume.pkl",
    "delta-phase": "outdir/delta_phase_nessai/nested_sampler_resume.pkl",
    # "delta-phase-bilby": "outdir/delta_phase_bilby_nessai/nested_sampler_resume.pkl",
    # "order": "outdir/defaults_order_nessai/nested_sampler_resume.pkl",
    #"quaternions_spins": "outdir/quaternions_spins_nessai/nested_sampler_resume.pkl",
    "quaternions": "outdir/quaternions_default_nessai/nested_sampler_resume.pkl",
    # "gw_v6": "outdir/gw_v6_nessai/nested_sampler_resume.pkl",
    #"no_gw_v2": "outdir/no_gw_v1_nessai/nested_sampler_resume.pkl",
    # "quaternions": "outdir/gw_v13_quaternions_nessai/nested_sampler_resume.pkl",
    # "delta-phase": "outdir/gw_v13_delta_phase_nessai/nested_sampler_resume.pkl",
    "no-gw": "outdir/no_gw_v1_nessai/nested_sampler_resume.pkl",
    "spins": "outdir/gw_v18_nessai/nested_sampler_resume.pkl",
}

runs = {
    "default": "outdir/fix-spins/paper_default_phase_marg_nessai/nested_sampler_resume.pkl",
    "no_marg": "outdir/fix-spins/paper_default_no_phase_marg_nessai/nested_sampler_resume.pkl",
    "quaterions": "outdir/fix-spins/quaternions_default_nessai/nested_sampler_resume.pkl",
    "no_gw_v2": "outdir/no_gw_v2_nessai/nested_sampler_resume.pkl",
}

In [None]:
labels = {
    "default": "Default",
    "default-other": "Default",
    "alpha-beta": r"$(\alpha, \beta)$",
    "delta-phase": r"$\Delta\phi$",
    "no-gw": "No GW",
    "quaternions": "Quaternions",
    "spins": "Spins",
}

In [None]:
samplers = {}
for name, path in runs.items():
    samplers[name] = load_pickle(path)

In [None]:
ls = ["-", "--", "-.", ':', '-', "--"]

In [None]:
figsize = get_default_figsize()
figsize[1] *= 1.3
fig, axs = plt.subplots(4, 1, sharex=True, figsize=figsize)

for i, (name, ns) in enumerate(samplers.items()):
    it = (np.arange(len(ns.min_likelihood))) * (ns.nlive // 10)
    dtrain = np.array(ns.training_iterations[1:]) - np.array(ns.training_iterations[:-1])
    axs[0].plot(it, rolling_mean(ns.mean_acceptance_history, 16), ls=ls[i])
    axs[1].plot(
        ns.training_iterations,
        np.arange(len(ns.training_iterations)),
        ls=ls[i]
    )
    axs[2].plot(ns.training_iterations[1:], rolling_mean(dtrain, 16), ls=ls[i])

    axs[3].plot(ns.population_iterations, rolling_mean(ns.population_acceptance, 16), ls=ls[i], label=labels.get(name, name))

axs[0].set_ylim([5e-4, 1])
axs[0].set_yscale("log")
axs[0].set_ylabel("Acceptance")

axs[1].set_ylim(0, 8000)

axs[1].set_ylabel("Cumulative \ntraining count")

axs[2].set_ylabel("Iterations \nbetween training")
axs[2].set_yscale("log")
# axs[2].set_ylim(top=2000)

axs[3].set_yscale("log")
axs[3].set_ylim(top=1e-2)
axs[3].set_ylabel("Rejection sampling \nacceptance")

# for ax in axs:
#     ax.fill_betweenx(
#         y=ax.get_ylim(),
#         x1=0,
#         x2=ns.training_iterations[0],
#         alpha=0.25,
#         zorder=-1,
#         color="gray",
#         lw=0.0
#     )
#     ax.set_xlim([0, 75_000])

axs[-1].set_xlabel("Iteration")

axs[-1].legend(ncol=len(samplers), loc="center", bbox_to_anchor=(0.45, -0.55))
plt.tight_layout()
save_figure(fig, "phase_comparison_diagnostics", "figures")
plt.show()

In [None]:
indices = {}
include = ["default", "delta-phase", "quaternions"]
for key, ns in samplers.items():
    if key in include:
        indices[key] = np.array(ns.insertion_indices)

In [None]:
nlive = 2000
bins = np.arange(0, nlive+1, 20)
nbins = len(bins)

x = np.arange(0, nlive, 1)
analytic = x / x[-1]

figsize = get_default_figsize()
# figsize[1] *= 1.2
fig, axs = plt.subplots(1, 1, figsize=figsize)

ls = ["-", "--", "-."]

n = np.min([len(idx) for idx in indices.values()])

for j, (key, idx) in enumerate(indices.items()):
    _, counts = np.unique(idx, return_counts=True)
    estimated = np.cumsum(counts) / len(idx)
    axs.plot(analytic - estimated, ls=ls[j], c=f"C{j}", label=labels.get(key))

for ci in [0.5, 0.95, 0.997]:
    bound = (1 - ci) / 2
    bound_values = stats.binom.ppf(1 - bound, n, analytic) / n
    lower = (bound_values - analytic)
    upper = (analytic - bound_values)
    upper[0] = 0
    upper[-1] = 0
    lower[0] = 0
    lower[-1] = 0

    axs.fill_between(x, lower, upper, color="grey", alpha=0.2)

axs.set_xlim(0, nlive-1)
axs.set_xlabel("Insertion index")
axs.set_ylabel("Analytic CMF - Empirical CMF")
plt.tight_layout()
plt.legend()

save_figure(fig, "phase_reparams_insertion_indices")