In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import numpy as np

from compute_jsd import SEARCH_PARAMETERS

from thesis_utils.plotting import (
    set_plotting,
    save_figure,
    get_default_figsize,
)
from thesis_utils.colours import lighten_colour, pillarbox
from thesis_utils.io import load_json
from thesis_utils.gw import get_cbc_parameter_labels

import seaborn as sns

set_plotting()

plt.rcParams["text.usetex"] = True

In [None]:
marg_file = "results/jsd_results_marg.json"
no_marg_file = "results/jsd_results_nomarg.json"

marg_results_dict = load_json(marg_file)
no_marg_results_dict = load_json(no_marg_file)

Convert to a more useful of `parameter: JSD-values`

In [None]:
factor = 1000

In [None]:
marg_results = dict()
for parameter in SEARCH_PARAMETERS.get("marg"):
    marg_results[parameter] = np.array(
        [v[parameter] for v in marg_results_dict.values()]
    )

no_marg_results = dict()
for parameter in SEARCH_PARAMETERS.get("nomarg"):
    no_marg_results[parameter] = np.array(
        [v[parameter] for v in no_marg_results_dict.values()]
    )

In [None]:
max_jsd_nomarg = (
    np.array(
        [max(v[0] for v in d.values()) for d in no_marg_results_dict.values()]
    )
    * factor
)
max_jsd_marg = (
    np.array(
        [max(v[0] for v in d.values()) for d in marg_results_dict.values()]
    )
    * factor
)

In [None]:
single_idx = 25

In [None]:
fig = plt.figure()

sep = 75

single_width = sep / 2

xticks = np.arange(0, 15 * sep, sep)

left = 0

bins = np.logspace(-0.2, 2, 32, base=10)


n_samples = 1000
threshold = (10 / n_samples) * 1000

colours = np.tile(sns.color_palette("crest", n_colors=7), (2, 1))

for i, parameter in enumerate(SEARCH_PARAMETERS.get("nomarg")):
    vals = no_marg_results[parameter][..., 0] * factor
    if parameter in SEARCH_PARAMETERS.get("marg"):
        vals_marg = marg_results[parameter][..., 0] * factor
    else:
        vals_marg = np.array([np.nan])

    freqs, bin_edges = np.histogram(vals, bins=bins)
    bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2
    widths = np.diff(bin_edges)
    plt.barh(bin_centres, -freqs, left=left, height=widths, color=colours[i])

    plt.plot(
        [left - 1.25 * freqs.max(), left],
        vals[single_idx] * np.ones(2),
        color="C1",
    )

    if parameter in SEARCH_PARAMETERS.get("marg"):
        freqs, bin_edges = np.histogram(vals_marg, bins=bins)
        bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2
        widths = np.diff(bin_edges)
        plt.barh(
            bin_centres,
            freqs,
            left=left,
            height=widths,
            color=lighten_colour(colours[i], 0.5),
        )

        # plt.plot([left, left + 1.2 * freqs.max()], vals_marg[single_idx] * np.ones(2), color="C1")

    left += sep


freqs, bin_edges = np.histogram(max_jsd_nomarg, bins=bins)
bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2
widths = np.diff(bin_edges)
plt.barh(bin_centres, -freqs, left=left, height=widths, color=pillarbox)

freqs, bin_edges = np.histogram(max_jsd_marg, bins=bins)
bin_centres = (bin_edges[:-1] + bin_edges[1:]) / 2
widths = np.diff(bin_edges)
plt.barh(
    bin_centres,
    freqs,
    left=left,
    height=widths,
    color=lighten_colour(pillarbox, 0.5),
)

plt.yscale("log")
plt.xlim(left=-sep)

plt.axhline(threshold, ls="--", color="k")

plt.xticks(
    xticks,
    labels=get_cbc_parameter_labels(SEARCH_PARAMETERS.get("nomarg"))
    + ["Max."],
)
plt.tick_params(axis="x", which="minor", bottom=False, top=False)

plt.ylabel("JSD [mbits]")

handles = [
    Patch(facecolor="grey", label="Without distance marginalization"),
    Patch(
        facecolor=lighten_colour("grey", 0.5),
        label="With distance marginalization",
    ),
    Line2D([0], [1], ls="--", color="k", label="Threshold"),
]

plt.legend(handles=handles)


plt.show()

save_figure(fig, "jsd", "figures")