In [None]:
import wandb
import numpy as np
import pandas as pd
import torch

from tqdm.autonotebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
api = wandb.Api()

In [None]:
diff_runs = api.runs(
    "tobifinn/train_diffusion_nextsim_regional",
    filters={
        "$or": [
            {"display_name": "deterministic"},
            {"display_name": "det_no_aug"},
            {"display_name": "det_no_labels"},
        ]
    }
)

In [None]:
results = pd.DataFrame()
for r in tqdm(diff_runs):
    curr_val = r.scan_history(
        page_size=100000,
        keys=["val/loss", "trainer/global_step"]
    )
    curr_val = pd.Series([r["val/loss"]for r in curr_val], index=[r["trainer/global_step"] for r in curr_val])
    results = pd.concat((results, curr_val.to_frame(r.name)), axis=1)
results = results.sort_index()

In [None]:
fig, ax = plt.subplots(figsize=(4, 2.5))
ax.grid(which="both")
ax.plot(results["deterministic"].dropna().index/100000, results["deterministic"].dropna(), label="With labels", c="#81B3D5")
ax.plot(results["det_no_labels"].dropna().index/100000, results["det_no_labels"].dropna(), label="W/o labels", c="#83D6C1")
ax.plot(results["det_no_aug"].dropna().index/100000, results["det_no_aug"].dropna(), label="W/o augmentation", c="#8583D6")
ax.set_xlim(0, 1)
ax.set_ylim(0.145, 0.199)
ax.legend()
ax.set_ylabel("Validation loss")
ax.set_xlabel(r"Iterations $\times 10^5$")
fig.savefig("figures/fig_app_b3_augment_loss.png")