In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import scienceplots  # noqa: F401
from tensorflow.keras.models import load_model

from plotting import watermark

In [None]:
model_file = "CNN_nadelhorn_energy_electron_n80000_e25.keras"
history_file = "history_" + model_file.split(".")[0] + ".csv"

In [None]:
plt.style.use(["science", "notebook"])

In [None]:
plt.rcParams["font.size"] = 14
plt.rcParams["axes.formatter.limits"] = -5, 4
plt.rcParams["figure.figsize"] = 6, 4
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
model = load_model(model_file)

In [None]:
model_name = model.name

In [None]:
model.summary()

In [None]:
history_df = pd.read_csv(history_file)

In [None]:
n_events = 80000

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
# plt.title("CNN $")
ax1.plot(history_df["loss"].values, color=colors[0])
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss Function", color=colors[0])
try:
    ax2.plot(history_df["mae"].values, color=colors[1])
except KeyError:
    ax2.plot(history_df["dense_2_mae"].values, color=colors[1])
    ax2.plot(history_df["dense_3_mae"].values, color=colors[1])
ax2.set_ylabel("Error", color=colors[1])
plt.text(
    0.3,
    0.7,
    f"Training dataset: {n_events} events\n"
    # f"Test dataset: {events_test.num_entries} events\n"
    f"Training duration: {len(history_df)} epochs\n{model_name}",
    transform=ax1.transAxes,
)
ax1.set_yscale("log")
ax2.set_yscale("log")
watermark()
plt.savefig(f"plots/convergence_{model_name}_n{n_events}_e{len(history_df)}.pdf")
plt.savefig(f"plots/convergence_{model_name}_n{n_events}_e{len(history_df)}.png")