In [None]:
import wandb
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
api = wandb.Api()

In [None]:
runs = api.runs(
    "tobifinn/test_diffusion_nextsim_regional",
    filters={
        "$or": [
            {"display_name": "deterministic"},
            {"display_name": "det_wo_damage"},
        ]
    }
)

In [None]:
std = np.array([0.7506, 0.1848, 0.1968, 0.0836, 0.0878])

In [None]:
results = {}

for run in runs:
    try:
        table = [artifact for artifact in run.logged_artifacts() if artifact.type == 'run_table'][0]
        table = table.get("test/scores.table.json")
        results[run.name] = pd.DataFrame(table.data, columns=table.columns)[["rmse_sit", "rmse_sic", "rmse_siu", "rmse_siv"]]
        
    except IndexError:
        pass

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(4, 3), dpi=150)
ax[0].grid(True)
ax[0].plot(
    results["deterministic"].index/2,
    results["deterministic"]["rmse_sit"],
    ls="-", c="#81B3D5", alpha=1.,
    label="Deterministic"
)
ax[0].plot(
    results["deterministic"].index/2,
    results["det_wo_damage"]["rmse_sit"],
    c="black", alpha=1., ls="--",
    label="W/o damage"
)
ax[0].text(0.02, 0.98, "(a)", ha="left", va="top", transform=ax[0].transAxes)

ax[1].grid(True)
ax[1].plot(
    results["deterministic"].index/2,
    results["deterministic"]["rmse_siu"],
    ls="-", c="#81B3D5", alpha=1.,
    label="Deterministic"
)
ax[1].plot(
    results["deterministic"].index/2,
    results["det_wo_damage"]["rmse_siu"],
    c="black", alpha=1., ls="--",
    label="W/o damage"
)
ax[1].text(0.02, 0.98, "(b)", ha="left", va="top", transform=ax[1].transAxes)


ax[0].set_xlim(0, 15)
ax[0].set_xticklabels([])
ax[1].set_xlim(0, 15)
ax[1].set_xlabel("Lead time (days)")

ax[0].set_ylim(0, 0.35)
ax[0].set_ylabel("RMSE SIT (m)")

ax[1].set_ylim(0, 0.035)
ax[1].set_ylabel("RMSE SIU (m/s)")

ax[0].legend()

fig.subplots_adjust(hspace=0.1)
fig.align_ylabels(ax)
fig.savefig("figures/fig_app_b2_damage_rmse.png", dpi=300)