# Loss plots

This notebook contains code to reproduce figures 12 and 14 

In [None]:
from bayesbeat.data import get_n_entries
from bayesbeat.conversion import loss_from_decay_parameter
import h5py
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from cycler import cycler
import numpy as np
import pathlib
import pandas as pd

from utils import get_frequency, model_labels, model_colours

plt.style.use("paper.mplstyle")

In [None]:
outdir = pathlib.Path("figures") / "loss"
outdir.mkdir(exist_ok=True, parents=True)

file_format = "pdf"

Path to the data file

In [None]:
data_file = "../data/PyTotalAnalysis_2024_02_23.mat"
n_ringdowns = get_n_entries(data_file)

Path to the result files

In [None]:
results_path = pathlib.Path("../results/bayesbeat_inference_results/real_data/")
paths = {
    "model_1_constant_noise": results_path / "model_1_constant_noise/",
    "model_1": results_path / "model_1",
    "model_3": results_path / "model_3",
}

In [None]:
frequencies = np.empty(n_ringdowns)
for index in range(n_ringdowns):
    frequencies[index] = get_frequency(data_file, index)

Load the lookup table that maps from frequencies to mode families

In [None]:
freq_lookup = pd.read_csv("labelled_frequencies.csv")


Load the results produced using the original method

In [None]:
original_results = pd.read_csv("../results/reference_results.txt", sep="\t")

Load the posterior distributions for $\tau_1$ and $\tau_2$

In [None]:
parameters = ["tau_1", "tau_2"]
posteriors = {}
for key, path in paths.items():
    posteriors[key] = []
    for index in range(n_ringdowns):
        result_file = path / f"result_ringdown_{index}.hdf5"
        if not result_file.exists():
            print(result_file)
            continue
        with h5py.File(result_file, "r") as res_file:
            post = res_file["posterior_samples"][()]
            posteriors[key].append(post)

## Figure 12

In [None]:
indices = [0, 4]

for index in indices:

    figsize = plt.rcParams["figure.figsize"].copy()
    fig, axs = plt.subplots(1, 2, sharey=False)

    keys = [
        "model_1_constant_noise",
        "model_1",
        "model_3",
    ]

    for i, ax in enumerate(axs):
        for key in keys:
            tau = posteriors[key][index][f"tau_{i+1}"]
            loss = loss_from_decay_parameter(tau, frequencies[index])
            ax.hist(loss, bins=50, label=model_labels[key], histtype="step", color=model_colours[key])

        # Add original results
        f, phi, phi_min, phi_max = original_results[["Frequency", f"Phi_{i+1}", f"Phi_{i+1}_Upper", f"Phi_{i+1}_Lower"]].iloc[index]
        ylim = ax.get_ylim()
        if np.isfinite(phi_min) and np.isfinite(phi_max):
            ax.fill_betweenx(ylim, phi_min, phi_max, color="lightgrey")
        ax.axvline(phi, color="grey", ls="--", label="Original")

        # Set axis labels and ticks
        ax.set_yticklabels([])
        ax.set_ylim(ylim)
        ax.set_xlabel(f"$1 / Q_{i+1}$")


    legend_handles = [Line2D([0], [0], color="grey", ls="--", label="Original")] + [
        Line2D([0], [0], color=model_colours[key], label=model_labels[key]) for key in keys
    ]

    fig.legend(handles=legend_handles, loc="center", bbox_to_anchor=(0.5, -0.05), ncol=2)
    plt.tight_layout()
    fig.savefig(outdir / f"loss_{index}.{file_format}", bbox_inches="tight")
    plt.show()

## Figure 14

In [None]:
# Colours from the orgiinal paper
colours_orig = ["#fdae61", "#d7191c"]

def get_colour(freq):
    """Determine the colour for a given frequency based on the mode family."""
    idx = (freq_lookup["Measured Frequency"] - freq).abs().idxmin()
    colour = colours_orig[freq_lookup['m'][idx]]
    colour = f"C{freq_lookup['m'][idx]}"
    return colour

In [None]:
keys = [
    "model_1",
    "model_3"
]
phi_1 = {}
phi_2 = {}
for key in keys:
    phi_1[key] = [loss_from_decay_parameter(post["tau_1"], f) for post, f in zip(posteriors[key], frequencies)]
    phi_2[key] = [loss_from_decay_parameter(post["tau_2"], f) for post, f in zip(posteriors[key], frequencies)]


figsize = plt.rcParams["figure.figsize"].copy()
figsize[0] = 1.8 * figsize[0]
figsize[1] = 2 * figsize[1]

n_analyses = 1 + len(keys)

fig, axs = plt.subplots(3, n_analyses, sharey="row", sharex="col", figsize=figsize)

flagged = np.zeros(len(frequencies), dtype=bool)
missing = np.zeros(len(frequencies), dtype=bool)

freq_factor = 1e-3
ring_size = 50

with plt.rc_context({"axes.prop_cycle": (cycler('color', colours_orig))}):

    # Plot original results
    # 2 plots, one for each Q
    for j in range(2):
        for idx, (freq, phi, phi_min, phi_max) in enumerate(original_results[["Frequency", f"Phi_{j+1}", f"Phi_{j+1}_Upper", f"Phi_{j+1}_Lower"]].values):
            colour = get_colour(freq)
            freq = freq * freq_factor
            if not np.isfinite(phi_min):
                missing[idx] = True
            else:
                phi_err = np.array([(phi-phi_min, phi_max - phi)]).T
                axs[j, 0].errorbar(freq, phi, yerr=phi_err, ls="", c=colour, fmt=".", markeredgecolor=colour, markerfacecolor="none")
                if j == 1:
                    axs[j+1, 0].errorbar(freq, phi, yerr=phi_err, ls="", c=colour, fmt=".", markeredgecolor=colour, markerfacecolor="none")
                if phi > 1e-6:
                    flagged[idx] = True
                    axs[j, 0].scatter(freq, phi, color=colour, facecolor="none", edgecolor="k", s=ring_size)
                    if j == 1:
                        axs[j+1, 0].scatter(freq, phi, color=colour, facecolor="none", edgecolor="k", s=ring_size)

    print(f"Original method failed for {np.sum(missing)} ringdowns")

    for ax, key in zip(axs.T[1:], keys):
        for j, phi_vals in enumerate([phi_1[key], phi_2[key]]):
            for idx, (phi_array, freq) in enumerate(zip(phi_vals, frequencies)):
                colour = get_colour(freq)
                if missing[idx]:
                    marker = "^"
                    colour = "darkorange"
                else:
                    marker = "."
                freq = freq / 1e3
                phi_min, phi, phi_max = np.quantile(phi_array, [0.05, 0.5, 0.95])
                phi_err = np.array([(phi-phi_min, phi_max-phi)]).T
                ax[j].errorbar(freq, phi, yerr=phi_err, ls="", fmt=marker, c=colour, markeredgecolor=colour, markerfacecolor="none")
                if j == 1:
                    ax[j+1].errorbar(freq, phi, yerr=phi_err, ls="", fmt=marker, c=colour, markeredgecolor=colour, markerfacecolor="none")
                if j == 1 and flagged[idx]:
                    # colour = "red"
                    ax[j].scatter(freq, phi, color=colour, facecolor="none", edgecolor="k", s=ring_size)
                    if j == 1:
                        ax[j+1].scatter(freq, phi, color=colour, facecolor="none", edgecolor="k", s=ring_size)
        ax[0].set_title(model_labels[key])


    axs[0, 0].set_title("Original")

    zoom_ylims = [0, 6e-7]

    for i in range(n_analyses):
        xlim = axs[1, i].get_xlim()
        axs[1, i].fill_between(xlim, *zoom_ylims, color="lightgrey", zorder=-1)
        axs[1, i].set_xlim(xlim)


    axs[0, 0].set_ylim(*zoom_ylims)
    axs[1, 0].set_ylim(0, 5e-6)
    axs[2, 0].set_ylim(*zoom_ylims)

    axs[0, 0].set_ylabel("$1 / Q_1$")
    axs[1, 0].set_ylabel("$1 / Q_2$")
    axs[2, 0].set_ylabel("$1 / Q_2 $ (Zoom)")

    for i in range(n_analyses):
        axs[-1, i].set_xlabel(r"$f$ [kHz]")


    handles = [
        Line2D([0], [0], color="C0", ls="", marker="o", label="$p=0$ mode family", markerfacecolor="none", markeredgecolor="C0"),
        Line2D([0], [0], color="C1", ls="", marker="o", label="$p=1$ mode family", markerfacecolor="none", markeredgecolor="C1"),
        Line2D([0], [0], color="k", ls="", marker="o", label="Previous outliers", markerfacecolor="none", markeredgecolor="k", markersize=10),
        Line2D([0], [0], color="grey", ls="", marker="^", label="Newly characterised", markerfacecolor="none", markeredgecolor="grey"),
    ]
    fig.legend(
        handles=handles,
        bbox_to_anchor=(0.5, -0.05),
        ncol=4,
        loc="center",
    )

    plt.tight_layout()

    fig.savefig("figures/losses_real_data_all.pdf")

    plt.show()