In [None]:
import wandb
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
api = wandb.Api()

In [None]:
runs = api.runs(
    "tobifinn/test_diffusion_nextsim_regional",
    filters={
        "$or": [
            {"display_name": "deterministic"},
            {"display_name": "det_no_clipping"},
            {"display_name": "diffusion_best_loss_single"},
            {"display_name": "diff_l_exp_single_best_no_clipping"},
        ]
    }
)

In [None]:
std = np.array([0.7506, 0.1848, 0.1968, 0.0836, 0.0878])

In [None]:
results = {}

for run in runs:
    try:
        table = [artifact for artifact in run.logged_artifacts() if artifact.type == 'run_table'][0]
        table = table.get("test/scores.table.json")
        table = pd.DataFrame(table.data, columns=table.columns)
        table = table[["rmse_sit", "rmse_sic", "rmse_damage", "rmse_siu", "rmse_siv"]]
        table["nrmse"] = np.sqrt(((table[["rmse_sit", "rmse_sic", "rmse_damage", "rmse_siu", "rmse_siv"]]**2)/std**2).mean(axis=1))
        results[run.name] = table
    except IndexError:
        pass

In [None]:
fig, ax = plt.subplots(figsize=(4, 2.5), dpi=150)
ax.grid(which="both")
ax.semilogy(
    results["deterministic"].index/2,
    results["deterministic"]["nrmse"],
    ls="-", c="#81B3D5", alpha=1.,
    label="Deterministic"
)
ax.plot(
    results["det_no_clipping"].index/2,
    results["det_no_clipping"]["nrmse"],
    c="#81B3D5", alpha=1.,
    label="Deterministic w/o clip", ls="--", marker="x"
)
ax.plot(
    results["diffusion_best_loss_single"].index/2,
    results["diffusion_best_loss_single"]["nrmse"],
    ls="-", c="#A56262", alpha=1.,
    label="Diffusion"
)
ax.plot(
    results["diff_l_exp_single_best_no_clipping"].index/2,
    results["diff_l_exp_single_best_no_clipping"]["nrmse"],
    c="#A56262", alpha=1.,
    label="Diffusion w/o clip", ls="--", marker="+"
)

ax.set_xlim(0, 15)
ax.set_xlabel("Lead time (days)")

#ax.set_ylim(0, 1.5)
ax.set_ylabel("nRMSE")

ax.legend(framealpha=1)
fig.savefig("figures/fig_app_b1_clipping.png", dpi=300)