In [None]:
from bayesbeat.result import get_fit
from bayesbeat.model import GenericAnalyticGaussianBeam
from bayesbeat.data import get_data, get_n_entries
import h5py
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pathlib

from utils import (
    get_duration,
    add_log10_bayes_factor_colorbar,
    get_bayes_factor_markers,
    get_frequency,
)

plt.style.use("paper.mplstyle")

In [None]:
outdir = pathlib.Path("figures")
outdir.mkdir(exist_ok=True)

file_format = "pdf"

In [None]:
data_file = "../data/PyTotalAnalysis_2024_02_23.mat"

In [None]:
index = 0
x_data, y_data, frequency, _ = get_data(data_file, index=index)

In [None]:
models = []
n_terms = [1, 3, 5, 7]
t = np.linspace(0, 1e4, int(1e5))
for n in n_terms:
    model = GenericAnalyticGaussianBeam(
        t,
        None,
        photodiode_gap=0.5e-3,
        photodiode_size=10.2e-3,
        beam_radius=3.3e-3,
        include_gap=True,
        equation_name=f"General_Equation_{n}_Terms",
        n_terms=n,
    )
    models.append(model)

In [None]:
fits = {}
theta = dict(
    a_1=1e-3,
    a_2=1e-5,
    a_scale=1.0,
    tau_1=100.0,
    tau_2=100.0,
    domega=10,
    dphi=6,
    x_offset=0.0,
    sigma_amp_noise=0.0,
    sigma_constant_noise=0.0,
)
for n, model in zip(n_terms, models):
    fits[n] = model.signal_model(theta)

In [None]:
plt.figure()
cutoff = 1000
plt.plot(t[:cutoff], fits[3][:cutoff])
plt.show()

In [None]:
fits_1 = {}
theta_1 = dict(
    a_1=1e-4,
    a_2=1e-4,
    a_scale=1.0,
    tau_1=100.0,
    tau_2=110.0,
    domega=0.5,
    dphi=np.pi/4,
    x_offset=0.0,
    sigma_amp_noise=0.0,
    sigma_constant_noise=0.0,
)
for n, model in zip(n_terms, models):
    fits_1[n] = model.signal_model(theta_1)

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

In [None]:
import matplotlib as mpl

In [None]:
figsize = plt.rcParams["figure.figsize"].copy()
figsize[0] *= 1.8
figsize[1] *= 1.5

ls = ["-", "--", "-.", ":",]

fig, axs = plt.subplots(2, 2, figsize=figsize, sharex=True)

for i, (key, y_fit) in enumerate(fits.items()):
    axs[0, 0].plot(t, y_fit, label=key, ls=ls[i])
    # axs[1, 0].plot(t, y_fit / fits[1])
    if key == 1:
        continue
    axs[1, 0].plot(t, y_fit / fits[1], color=f"C{i}", ls=ls[i])

for i, (key, y_fit) in enumerate(fits_1.items()):
    axs[0, 1].plot(t, y_fit, label=key, color=f"C{i}", ls=ls[i])
    if key == 1:
        continue
    axs[1, 1].plot(t, y_fit / fits_1[1], color=f"C{i}", ls=ls[i])


# axs[0,0].text(0.1, 0.1, r"$a_1=" + f"{theta['a_1']}" + r"$", transform=ax.transAxes, ha="center", va="center")

# for key, y_fit in fits_2.items():
#     axs[1, 0].plot(t, y_fit, label=key)

# for key, y_fit in fits_3.items():
#     axs[1, 1].plot(t, y_fit, label=key)


for ax in axs.flatten():
    # ax.set_xscale("log")
    # ax.set_yscale("log")
    ax.set_xlim(1e-1, 1e3)
    
axs[1, 0].set_xlabel("$t$ [s]")
axs[1, 1].set_xlabel("$t$ [s]")

axs[0, 0].set_ylabel("$A$")
axs[0, 1].set_ylabel("$A$")
axs[1, 0].set_ylabel("$A / A_{T=1}$")
axs[1, 1].set_ylabel("$A / A_{T=1}$")

handles, labels = axs[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, ncol=4, title=r"Number of terms ($T$)", loc="center", bbox_to_anchor=(0.5, 0.05))

plt.tight_layout()
fig.savefig(outdir / f"example_model_3_test.{file_format}")