In [None]:
import wandb
import numpy as np
import pandas as pd
import torch

from tqdm.autonotebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
api = wandb.Api()

In [None]:
diff_runs = api.runs(
    "tobifinn/train_diffusion_nextsim_regional",
    filters={
        "display_name": "diff_l_exp"
    }
)

In [None]:
std = np.array([0.7506, 0.1848, 0.1968, 0.0836, 0.0878])
var = std**2

In [None]:
results = pd.DataFrame()
for r in tqdm(diff_runs):
    curr_train = r.scan_history(
        page_size=100000,
        keys=["train/loss", "trainer/global_step"]
    )
    curr_train = pd.Series([r["train/loss"]for r in curr_train], index=[r["trainer/global_step"] for r in curr_train])
    curr_val = r.scan_history(
        page_size=100000,
        keys=["val/loss", "trainer/global_step"]
    )
    curr_val = pd.Series([r["val/loss"]for r in curr_val], index=[r["trainer/global_step"] for r in curr_val])
    curr_scores = r.scan_history(
        page_size=100000,
        keys=["scores/mse_sit", "scores/mse_sic", "scores/mse_damage", "scores/mse_siu", "scores/mse_siv", "trainer/global_step"]
    )
    curr_mse = pd.DataFrame(
        {
            "mse_sit": [r["scores/mse_sit"]/var[0] for r in curr_scores],
            "mse_sic": [r["scores/mse_sic"]/var[1] for r in curr_scores],
            "mse_sid": [r["scores/mse_damage"]/var[2] for r in curr_scores],
            "mse_siu": [r["scores/mse_siu"]/var[3] for r in curr_scores],
            "mse_siv": [r["scores/mse_siv"]/var[4] for r in curr_scores]
        }, index=[r["trainer/global_step"] for r in curr_scores]
    )
    curr_results = pd.concat([curr_train.to_frame("train_loss"), curr_val.to_frame("val_loss"), curr_mse], axis=1)
    results = pd.concat((results, curr_results), axis=0)
results = results.sort_index()

In [None]:
results = results.drop([198999, 335499])

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(4, 3))
ax[0].grid()
ax[0].plot(results["train_loss"].dropna().index/100000, np.log10(results["train_loss"].dropna()), c="salmon", label="Train", alpha=0.5)
val_loss = results["val_loss"].dropna()
ax[0].plot(val_loss.index/100000, np.log10(val_loss), c="black", label="Validation")
best_val = val_loss.argmin()
ax[0].scatter(val_loss.index[best_val]/100000, np.log10(val_loss.iloc[best_val]), fc="yellow", ec="black", s=10, marker="o", lw=0.5, zorder=99)

ax[0].text(0.02, 0.98, "(a)", ha="left", va="top", transform=ax[0].transAxes)
ax[0].legend()
ax[0].set_ylabel(r"$\log_{10}$(Loss)")
ax[0].set_xlim(0, 3.85)
ax[0].set_xticklabels([])

ax[1].grid()
mse_results = results.iloc[:, 2:].dropna()
plt_var = ax[1].plot(mse_results.index/100000, np.log10(np.sqrt(mse_results)), c="0.5", ls="--", label="Variables", lw=0.7, alpha=0.7)
plt_avg, = ax[1].plot(mse_results.index/100000, np.log10(np.sqrt(mse_results.mean(axis=1))), c="black", label="Averaged")
best_val = np.sqrt(mse_results.mean(axis=1)).argmin()
ax[1].scatter(mse_results.index[best_val]/100000, np.log10(np.sqrt(mse_results.mean(axis=1))).iloc[best_val], fc="yellow", ec="black", s=20, marker="*", lw=0.5, zorder=99)
ax[1].text(0.02, 0.98, "(b)", ha="left", va="top", transform=ax[1].transAxes)
ax[1].legend(handles=[plt_var[0], plt_avg])
ax[1].set_yticks(np.log10([0.1, 0.2, 0.5]))
ax[1].set_yticklabels([0.1, 0.2, 0.5])
ax[1].set_ylabel(r"nRMSE")
ax[1].set_xlim(0, 3.85)
ax[1].set_xlabel(r"Iterations $\times 10^5$")

fig.align_ylabels(ax)
fig.subplots_adjust(hspace=0.1)

fig.savefig("figures/fig_app_b6_overfitting_diffusion.png")