In [None]:
import wandb
import numpy as np
import pandas as pd

from tqdm.autonotebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import matplotlib.gridspec as mpl_gs
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")

In [None]:
api = wandb.Api()

In [None]:
runs = api.runs(
    "tobifinn/test_diffusion_nextsim_regional",
    filters={
        "$or": [
            {"display_name": "resdiff_l_best_loss_ensemble"},
            {"display_name": "diffusion_best_loss_ensemble"},
            {"display_name": "stochastic_ensemble"},
        ]
    }
)

In [None]:
results = {}
rank_hist = {}
for run in tqdm(runs):
    try:
        table = [artifact for artifact in run.logged_artifacts() if (artifact.type == 'run_table') & ("testscores" in artifact.name)][0]
        table = table.get("test/scores.table.json")
        results[run.name] = pd.DataFrame(table.data, columns=table.columns)
    except IndexError:
        pass
    try:
        table = [artifact for artifact in run.logged_artifacts() if (artifact.type == 'run_table') & ("testrank_hist" in artifact.name)][0]
        table = table.get("test/rank_hist.table.json")
        rank_hist[run.name] = pd.DataFrame(table.data, columns=table.columns)
    except IndexError:
        pass

In [None]:
fig = plt.figure(figsize=(4, 3.5))
gs = mpl_gs.GridSpec(nrows=2, ncols=2, height_ratios=(1, 2), hspace=0.35, wspace=0.1)

ax = fig.add_subplot(gs[0, 0])
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.axvline(0.5, c="0.75", lw=1)
ax.axvline(10, c="0.75", lw=1)
ax.plot(np.arange(1, 31)/2, (results["stochastic_ensemble"]["spread_sit"] / results["stochastic_ensemble"]["rmse_sit"])[1:], c="#D6D683", label="Stochastic")
ax.plot(np.arange(1, 31)/2, (results["diffusion_best_loss_ensemble"]["spread_sit"] / results["diffusion_best_loss_ensemble"]["rmse_sit"])[1:], c="#A56262", label="Diffusion")
ax.plot(np.arange(1, 31)/2, (results["resdiff_l_best_loss_ensemble"]["spread_sit"] / results["resdiff_l_best_loss_ensemble"]["rmse_sit"])[1:], c="#9E62A6", ls="--", label="ResDiff")
ax.text(0.05, 0.97, "(a)", ha="left", va="top", transform=ax.transAxes, zorder=99)
ax.set_ylim(0, 1.1)
ax.set_xlim(0, 15.5)
ax.set_ylabel("Spread/RMSE", x=0.02)
ax.legend(framealpha=1., ncol=3, loc=8, bbox_to_anchor=(1., 1.1))
ax.set_title("Thickness", size=10, y=0.92)

ax = fig.add_subplot(gs[0, 1])
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.axvline(0.5, c="0.75", lw=1)
ax.axvline(10, c="0.75", lw=1)
ax.plot(np.arange(1, 31)/2, (results["stochastic_ensemble"]["spread_siu"] / results["stochastic_ensemble"]["rmse_siu"])[1:], c="#D6D683", label="Stochastic")
ax.plot(np.arange(1, 31)/2, (results["diffusion_best_loss_ensemble"]["spread_siu"] / results["diffusion_best_loss_ensemble"]["rmse_siu"])[1:], c="#A56262", label="Diffusion")
ax.plot(np.arange(1, 31)/2, (results["resdiff_l_best_loss_ensemble"]["spread_siu"] / results["resdiff_l_best_loss_ensemble"]["rmse_siu"])[1:], c="#9E62A6", ls="--", label="ResDiff")
ax.text(0.05, 0.97, "(b)", ha="left", va="top", transform=ax.transAxes)
ax.set_ylim(0, 1.1)
ax.set_xlim(0, 15.5)
ax.set_yticklabels([])
ax.set_title("Velocity", size=10, y=0.92)

fig.supxlabel("Lead time (days)", x=0.51, y=0.6, va="top", size=10)

gs_rank = mpl_gs.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[1, :], wspace=0.1, hspace=0.15)

ax = fig.add_subplot(gs_rank[0, 0])
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["1it_sit"]/rank_hist["diffusion_best_loss_ensemble"]["1it_sit"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    fc="#A5626299", fill=True
)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["1it_sit"]/rank_hist["diffusion_best_loss_ensemble"]["1it_sit"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    ec="#A56262FF", fill=False
)
ax.axhline(0, ls="--", c="black")
ax.text(0.05, 0.97, "(c)", ha="left", va="top", transform=ax.transAxes)
ax.set_xticklabels([])
ax.set_xticks(np.arange(0, 17, 4)+1)
ax.set_xlim(0.55, 17.45)
ax.set_ylim(-1.5, 1.5)
ax.set_yticks(np.linspace(-1, 1, num=3))
ax.set_yticklabels(np.logspace(-1, 1, num=3, base=2))

ax = fig.add_subplot(gs_rank[0, 1])
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.axhline(0, ls="--", c="black", zorder=2)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["1it_siu"]/rank_hist["diffusion_best_loss_ensemble"]["1it_siu"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    fc="#A5626299", fill=True, zorder=3
)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["1it_siu"]/rank_hist["diffusion_best_loss_ensemble"]["1it_siu"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    ec="#A56262FF", fill=False, zorder=4
)
ax.text(0.05, 0.97, "(d)", ha="left", va="top", transform=ax.transAxes, zorder=5)
ax.set_xticks(np.arange(0, 17, 4)+1)
ax.set_xlim(0.55, 17.45)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_ylim(-1.5, 1.5)
ax.set_yticks(np.linspace(-1, 1, num=3))

ax = fig.add_subplot(gs_rank[1, 0])
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.axhline(0, ls="--", c="black", zorder=2)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["20it_sit"]/rank_hist["diffusion_best_loss_ensemble"]["20it_sit"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    fc="#A5626299", fill=True, zorder=3
)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["20it_sit"]/rank_hist["diffusion_best_loss_ensemble"]["20it_sit"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    ec="#A56262FF", fill=False, zorder=4
)
ax.text(0.05, 0.97, "(e)", ha="left", va="top", transform=ax.transAxes)
ax.set_xticks(np.arange(0, 17, 4)+1)
ax.set_xlim(0.55, 17.45)
ax.set_ylim(-1.5, 1.5)
ax.set_yticks(np.linspace(-1, 1, num=3))
ax.set_yticklabels(np.logspace(-1, 1, num=3, base=2))
ax.set_ylabel("Normalized p(rank)", y=1.05, x=0.02)

ax = fig.add_subplot(gs_rank[1, 1])
ax.grid(ls="dotted", lw=0.5, c="0.5", alpha=0.5, zorder=1)
ax.axhline(0, ls="--", c="black", zorder=2)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["20it_siu"]/rank_hist["diffusion_best_loss_ensemble"]["20it_siu"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    fc="#A5626299", fill=True, zorder=3
)
ax.stairs(
    np.log2(rank_hist["diffusion_best_loss_ensemble"]["20it_siu"]/rank_hist["diffusion_best_loss_ensemble"]["20it_siu"].mean()),
    np.arange(0, 18)+0.5,
    baseline=-1.5,
    ec="#A56262FF", fill=False, zorder=4
)
ax.text(0.05, 0.97, "(f)", ha="left", va="top", transform=ax.transAxes, zorder=5)
ax.set_xticks(np.arange(0, 17, 4)+1)
ax.set_xlim(0.55, 17.45)
ax.set_yticklabels([])
ax.set_ylim(-1.5, 1.5)
ax.set_yticks(np.linspace(-1, 1, num=3))
fig.text(x=0.51, y=0.01, s="Rank", va="bottom", ha="center", size=10)

fig.savefig("figures/fig_app_b5_spread.png", dpi=300) 