In [None]:
import wandb
import numpy as np
import pandas as pd

from tqdm.autonotebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")


In [None]:
api = wandb.Api()

In [None]:
runs = api.runs(
    "tobifinn/test_diffusion_nextsim_regional",
    filters={
        "tags": "errors"
    }
)

In [None]:
results = {}
for run in tqdm(runs):
    try:
        table = [artifact for artifact in run.logged_artifacts() if (artifact.type == 'run_table') & ("testspectrum" in artifact.name)][0]
        table = table.get("test/spectrum.table.json")
        results[run.name] = pd.DataFrame(table.data, columns=table.columns)
    except IndexError:
        pass

# Get delta x

In [None]:
delta_x = 1/np.linspace(1, 32, 32) * 64 * 12

In [None]:
fig, ax = plt.subplots(nrows=3, ncols=2, dpi=90)

for axi in ax:
    for axij in axi:
        axij.set_xlim(1000, 20)
        axij.grid(ls="dotted", lw=0.5, alpha=0.5)

ax[0, 0].loglog(
    delta_x, results["deterministic"]["0it_sit"], c="black"
)
ax[0, 0].loglog(
    delta_x, results["deterministic"]["4it_sit"], c="#81B3D5"
)
ax[0, 0].loglog(
    delta_x, results["resdiff_l_best_loss_ensemble"]["4it_sit"], c="#9E62A6"
)
ax[0, 0].text(0.05, 0.98, "(a)", ha="left", va="top", transform=ax[0, 0].transAxes)
ax[0, 1].loglog(
    delta_x, results["deterministic"]["0it_sit"], c="black", label="neXtSIM"
)
ax[0, 1].loglog(
    delta_x, results["deterministic"]["30it_sit"], c="#81B3D5", label="Deterministic"
)
ax[0, 1].loglog(
    delta_x, results["resdiff_l_best_loss_ensemble"]["30it_sit"], c="#9E62A6", label="ResDiffusion"
)
ax[0, 1].text(0.05, 0.98, "(b)", ha="left", va="top", transform=ax[0, 1].transAxes)
ax[0, 1].legend(loc=8, ncol=3, frameon=True, framealpha=1., bbox_to_anchor=[0., 1.05])

ax[1, 0].loglog(
    delta_x, results["deterministic"]["0it_damage"], c="black"
)
ax[1, 0].loglog(
    delta_x, results["deterministic"]["4it_damage"], c="#81B3D5"
)
ax[1, 0].loglog(
    delta_x, results["resdiff_l_best_loss_ensemble"]["4it_damage"], c="#9E62A6"
)
ax[1, 0].text(0.05, 0.98, "(c)", ha="left", va="top", transform=ax[1, 0].transAxes)

ax[1, 1].loglog(
    delta_x, results["deterministic"]["0it_damage"], c="black"
)
ax[1, 1].loglog(
    delta_x, results["deterministic"]["30it_damage"], c="#81B3D5"
)
ax[1, 1].loglog(
    delta_x, results["resdiff_l_best_loss_ensemble"]["30it_damage"], c="#9E62A6"
)
ax[1, 1].text(0.05, 0.98, "(d)", ha="left", va="top", transform=ax[1, 1].transAxes)

ax[2, 0].loglog(
    delta_x, results["deterministic"]["0it_siu"], c="black"
)
ax[2, 0].loglog(
    delta_x, results["deterministic"]["4it_siu"], c="#81B3D5"
)
ax[2, 0].loglog(
    delta_x, results["resdiff_l_best_loss_ensemble"]["4it_siu"], c="#9E62A6"
)
ax[2, 0].text(0.05, 0.98, "(e)", ha="left", va="top", transform=ax[2, 0].transAxes)

ax[2, 1].loglog(
    delta_x, results["deterministic"]["0it_siu"], c="black"
)
ax[2, 1].loglog(
    delta_x, results["deterministic"]["30it_siu"], c="#81B3D5"
)
ax[2, 1].loglog(
    delta_x, results["resdiff_l_best_loss_ensemble"]["30it_siu"], c="#9E62A6"
)
ax[2, 1].text(0.05, 0.98, "(f)", ha="left", va="top", transform=ax[2, 1].transAxes)

ax[0, 0].set_ylim(4E3, 3E6)
ax[0, 0].set_xticklabels([])
ax[0, 0].set_ylabel("Thickness ($m^2$)")
ax[0, 0].set_title("After 2 days", fontsize=10, y=0.8)

ax[0, 1].set_ylim(4E3, 3E6)
ax[0, 1].set_yticklabels([])
ax[0, 1].set_xticklabels([])
ax[0, 1].set_title("After 15 days", fontsize=10, y=0.8)


ax[1, 0].set_ylim(5E2, 3E5)
ax[1, 0].set_ylabel("Damage ($1^2$)")
ax[1, 0].set_xticklabels([])

ax[1, 1].set_ylim(5E2, 3E5)
ax[1, 1].set_yticklabels([])
ax[1, 1].set_xticklabels([])

ax[2, 0].set_ylim(2.5E1, 4E4)
ax[2, 0].set_ylabel(r"Velocity ($\frac{m^2}{s^2}$)")

ax[2, 1].set_ylim(2.5E1, 4E4)
ax[2, 1].set_yticklabels([])

fig.align_ylabels([ax[0, 0], ax[1, 0], ax[2, 0]])
fig.supylabel("Power spectral density", y=0.5, x=0)
fig.supxlabel(r"Wavelength $\Delta x$ (km)", x=0.5, y=0)
fig.savefig("figures/fig_04_spectrum.png", dpi=300,)