In [None]:
import wandb
import numpy as np
import pandas as pd
import torch

from tqdm.autonotebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
api = wandb.Api()

In [None]:
det_run = api.run(
    "tobifinn/train_diffusion_nextsim_regional/abk97702"
)

In [None]:
diff_runs = api.runs(
    "tobifinn/train_diffusion_nextsim_regional",
    filters={
        "display_name": "diff_l_exp"
    }
)

In [None]:
resdiff_runs = api.runs(
    "tobifinn/train_diffusion_nextsim_regional",
    filters={
        "display_name": "resdiff_l_exp"
    }
)

In [None]:
std = np.array([0.7506, 0.1848, 0.1968, 0.0836, 0.0878])
var = std**2

In [None]:
det_scores = det_run.scan_history(
    page_size=100000,
    keys=["scores/mse_sit", "scores/mse_sic", "scores/mse_damage", "scores/mse_siu", "scores/mse_siv", "trainer/global_step"]
)
det_nrmse = pd.Series(
    [
        np.sqrt(((r["scores/mse_sit"]/var[0]+r["scores/mse_sic"]/var[1]+r["scores/mse_damage"]/var[2]+r["scores/mse_siu"]/var[3]+r["scores/mse_siv"]/var[4])/5))
        for r in det_scores
    ],
    index=[r["trainer/global_step"] for r in det_scores],
)

In [None]:
diff_nrmse = None
for r in tqdm(diff_runs):
    curr_scores = r.scan_history(
        page_size=100000,
        keys=["scores/mse_sit", "scores/mse_sic", "scores/mse_damage", "scores/mse_siu", "scores/mse_siv", "trainer/global_step"]
    )
    curr_nrmse = pd.Series(
        [
            np.sqrt(((r["scores/mse_sit"]/var[0]+r["scores/mse_sic"]/var[1]+r["scores/mse_damage"]/var[2]+r["scores/mse_siu"]/var[3]+r["scores/mse_siv"]/var[4])/5))
            for r in curr_scores
        ],
        index=[r["trainer/global_step"] for r in curr_scores],
    )
    if diff_nrmse is None:
        diff_nrmse = curr_nrmse
    else:
        diff_nrmse = pd.concat((diff_nrmse, curr_nrmse))
diff_nrmse = diff_nrmse.sort_index()

In [None]:
resdiff_nrmse = None
for r in tqdm(resdiff_runs):
    curr_scores = r.scan_history(
        page_size=100000,
        keys=["scores/mse_sit", "scores/mse_sic", "scores/mse_damage", "scores/mse_siu", "scores/mse_siv", "trainer/global_step"]
    )
    curr_nrmse = pd.Series(
        [
            np.sqrt(((r["scores/mse_sit"]/var[0]+r["scores/mse_sic"]/var[1]+r["scores/mse_damage"]/var[2]+r["scores/mse_siu"]/var[3]+r["scores/mse_siv"]/var[4])/5))
            for r in curr_scores
        ],
        index=[r["trainer/global_step"] for r in curr_scores],
    )
    if diff_nrmse is None:
        resdiff_nrmse = curr_nrmse
    else:
        resdiff_nrmse = pd.concat((resdiff_nrmse, curr_nrmse))
resdiff_nrmse = resdiff_nrmse.sort_index()

In [None]:
resdiff_nrmse = resdiff_nrmse.drop(resdiff_nrmse.index[43])

In [None]:
it_det = torch.load("../data/models/deterministic/deterministic/best.ckpt", map_location="cpu")["global_step"]
it_diff = torch.load("../data/models/diffusion/diff_l_exp/best.ckpt", map_location="cpu")["global_step"]
it_resdiff = torch.load("../data/models/diffusion/resdiff_l_exp/best.ckpt", map_location="cpu")["global_step"]

In [None]:
fig, ax = plt.subplots(figsize=(3, 1.5), dpi=300)
ax.grid(ls="dotted", lw=0.5)
ax.plot(
    det_nrmse.index/1E5, det_nrmse,
    c="#81B3D5", label="Deterministic"
)
ax.scatter(it_det/1E5, det_nrmse.reindex(index=[it_det], method="nearest"), fc="yellow", ec="#81B3D5", s=10, marker="o", lw=0.5, zorder=99)

ax.plot(
    diff_nrmse.index/1E5, diff_nrmse,
    ls="-", c="#A56262", alpha=1., label="Diffusion"
)
ax.scatter(it_diff/1E5, diff_nrmse.reindex(index=[it_diff], method="nearest"), fc="yellow", ec="#A56262", s=10, marker="o", lw=0.5, zorder=99)

ax.plot(
    resdiff_nrmse.index/1E5, resdiff_nrmse,
    c="#9E62A6", alpha=1.,
    label="ResDiffusion", ls="--"
)
ax.scatter(it_resdiff/1E5, resdiff_nrmse.reindex(index=[it_resdiff], method="nearest"), fc="yellow", ec="#9E62A6", s=10, marker="o", lw=0.5, zorder=99)

ax.set_ylabel("nRMSE")
ax.set_xlabel(r"Iterations $\times 10^5$")

ax.legend(framealpha=1, loc=1, bbox_to_anchor=(1., 1.))
ax.set_ylim(0.13, 0.18)
ax.set_yticks([0.13, 0.14, 0.15, 0.16, 0.17])
ax.set_xlim(0, 5.05)
fig.savefig("figures/fig_02_res_diff_loss.png", dpi=300)