In [None]:
import wandb
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
api = wandb.Api()

In [None]:
runs = api.runs(
    "tobifinn/test_diffusion_nextsim_regional",
    filters={
        "tags": "errors"
    }
)

In [None]:
std = np.array([0.7506, 0.1848, 0.1968, 0.0836, 0.0878])

In [None]:
results = pd.DataFrame()

for run in runs:
    try:
        table = [artifact for artifact in run.logged_artifacts() if artifact.type == 'run_table'][0]
        table = table.get("test/scores.table.json")
        table = pd.DataFrame(table.data, columns=table.columns)
        results[run.name] = np.sqrt(((table[["rmse_sit", "rmse_sic", "rmse_damage", "rmse_siu", "rmse_siv"]]**2)/std**2).mean(axis=1))
    except IndexError:
        pass

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 2), dpi=150)
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.plot(
    results.index/2,
    results["deterministic"],
    ls="-", c="#81B3D5", alpha=1.,
    label="Deterministic"
)
ax.plot(
    results.index/2,
    results["stochastic_ensemble"],
    ls="--", c="#D6D683", alpha=1.,
    label="Stochastic"
)
ax.plot(
    results.index/2,
    results["diffusion_best_loss_ensemble"],
    ls="-", c="#A56262", alpha=1.,
    label="Diffusion"
)
ax.plot(
    results.index/2,
    results["resdiff_l_best_loss_ensemble"],
    c="#9E62A6", alpha=1.,
    label="ResDiffusion", ls="--"
)

ax.set_xlim(0, 15)
ax.set_xlabel("Lead time (days)")

ax.set_ylim(0, 0.55)
ax.set_ylabel("nRMSE")

ax.legend()
fig.savefig("figures/fig_03_nrmse.png", dpi=300)