In [None]:
import os
os.environ["OMP_NUM_THREADS"] = "1"

import xarray as xr
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm

from itertools import product
from distributed import LocalCluster, Client

import matplotlib.pyplot as plt

In [None]:
plt.style.use("paper")
plt.style.use("egu_journals")

In [None]:
local_cluster = LocalCluster(n_workers=32, threads_per_worker=1, local_directory="/tmp")
client = Client(local_cluster)
client

# Estimate normaliser

In [None]:
nature_train_data = xr.open_zarr(f"../../data/raw/train/nature_data/").sel(lead_time="10 min 8s").chunk({"samples": 50})
forecast_train_data = xr.open_zarr("../../data/raw/train/forecast_data/").sel(lead_time="10 min 8s").chunk({"samples": 50})

In [None]:
train_err = nature_train_data-forecast_train_data

In [None]:
norm_rmse = np.sqrt((train_err**2).mean()).compute()

# Forecast nature

In [None]:
ds_nature = xr.open_zarr(f"../../data/raw/test/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1, "nMesh2_face": -1, "nMesh2_node": -1})
ds_forecast = xr.open_zarr(f"../../data/raw/test/lr_forecast/", chunks={"time": -1, "ensemble": 1, "nMesh2_face": -1, "nMesh2_node": -1})

In [None]:
ds_nn_sota = xr.open_mfdataset("../../data/processed/unext_small/*/traj_short/", engine="zarr", concat_dim="seed", combine="nested", parallel=True,
                               chunks={"time": -1, "ensemble": 1, "nMesh2_face": -1, "nMesh2_node": -1})
ds_nn_sota = ds_nn_sota.sel(lead_time=~ds_nn_sota.indexes["lead_time"].duplicated(keep="last"))

In [None]:
ds_nn_other = xr.open_mfdataset("../../data/processed/input_difference/*/traj_short/", engine="zarr", concat_dim="seed", combine="nested", parallel=True,
                                chunks={"time": -1, "ensemble": 1, "nMesh2_face": -1, "nMesh2_node": -1})
ds_nn_other = ds_nn_other.sel(lead_time=~ds_nn_other.indexes["lead_time"].duplicated(keep="last"))

# Estimate scores

In [None]:
rmse_fcst = np.sqrt(((ds_forecast-ds_nature)**2).mean(["time", "ensemble", "nMesh2_node", "nMesh2_face"])).compute()
rmse_persist = np.sqrt(((ds_nature.isel(lead_time=0)-ds_nature)**2).mean(["time", "ensemble", "nMesh2_node", "nMesh2_face"])).compute()

In [None]:
rmse_nn_sota = np.sqrt(((ds_nn_sota-ds_nature)**2).mean(["time", "seed", "ensemble", "nMesh2_node", "nMesh2_face"])).compute()
rmse_nn_other = np.sqrt(((ds_nn_other-ds_nature)**2).mean(["time", "seed", "ensemble", "nMesh2_node", "nMesh2_face"])).compute()

In [None]:
def error_subplot(ax, var_name, factor=1.):
    baseline_plt, = ax.plot(
        rmse_fcst.indexes["lead_time"]/pd.to_timedelta("1min"),
        rmse_fcst[var_name]/norm_rmse[var_name],
        c="black", label="Forecast", ls="-"
    )
    persist_plt, = ax.plot(
        rmse_persist.indexes["lead_time"]/pd.to_timedelta("1min"),
        rmse_persist[var_name]/norm_rmse[var_name],
        c="0.5", label="Persistance", ls="-."
    )
    sota_plt, = ax.plot(
        rmse_nn_sota.indexes["lead_time"]/pd.to_timedelta("1min"),
        rmse_nn_sota[var_name]/norm_rmse[var_name],
        c="firebrick", label="Input+Forecast", ls="--"
    )
    other_plt, = ax.plot(
        rmse_nn_other.indexes["lead_time"]/pd.to_timedelta("1min"),
        rmse_nn_other[var_name]/norm_rmse[var_name],
        c="C0", label="Input+Difference", ls="--"
    )
    return [persist_plt, baseline_plt, sota_plt, other_plt]

In [None]:
x_minor_ticks = pd.to_timedelta(["10 min 8s", "20 min 16s", "30 min 24s", "40 min 32s", "50 min 40s"]) /pd.to_timedelta("1min")

In [None]:
fig, ax = plt.subplots(nrows=4)

error_subplot(ax[0], "v")
error_subplot(ax[1], "stress_yy")
handles = error_subplot(ax[2], "damage")
handles = error_subplot(ax[3], "area")

# ax[0].set_ylabel("Velocity")
# ax[1].set_ylabel("Stress")
# ax[2].set_ylabel("Damage")
# ax[3].set_ylabel("Area")

ax[0].grid(which='major', alpha=0.5, linestyle="dotted", lw=0.5)
ax[1].grid(which='major', alpha=0.5, linestyle="dotted", lw=0.5)
ax[2].grid(which='major', alpha=0.5, linestyle="dotted", lw=0.5)
ax[3].grid(which='major', alpha=0.5, linestyle="dotted", lw=0.5)

ax[0].text(0.01, 0.99, s="(a) Velocity", ha="left", va="top", transform=ax[0].transAxes, fontsize=9)
ax[1].text(0.01, 0.99, s=r"(b) $\sigma_{yy}$", ha="left", va="top", transform=ax[1].transAxes, fontsize=9)
ax[2].text(0.01, 0.99, s="(c) Damage", ha="left", va="top", transform=ax[2].transAxes, fontsize=9)
ax[3].text(0.01, 0.99, s="(d) Area", ha="left", va="top", transform=ax[3].transAxes, fontsize=9)


ax[0].set_ylim(0, 1.9)
ax[1].set_ylim(0, 1.3)
ax[2].set_ylim(0, 2.5)
ax[3].set_ylim(0, 14)

ax[0].set_xlim(0, 61)
ax[0].set_xticks(np.arange(0, 61, 10))
ax[0].set_xticklabels([""]*7)
ax[1].set_xlim(0, 61)
ax[1].set_xticks(np.arange(0, 61, 10))
ax[1].set_xticklabels([""]*7)
ax[2].set_xlim(0, 61)
ax[2].set_xticks(np.arange(0, 61, 10))
ax[2].set_xticklabels([""]*7)
ax[3].set_xlim(0, 61)
ax[3].set_xticks(np.arange(0, 61, 10))
ax[3].set_xlabel("Lead time in min")

fig.subplots_adjust(hspace=0.1)

ax[0].legend(
    handles=handles,
    labels=["Persistence", "Forecast model", "Hybrid \"Initial+Forecast\"", "Hybrid \"Initial+Difference\"",],
    ncol=2,
    bbox_to_anchor=(0.5, 0.9),
    loc="lower center"
)
fig.supylabel("Normalised RMSE", x=0.085, fontsize=9)


fig.savefig("figures/fig08_short_term_error.pdf", dpi=300, bbox_inches='tight', pad_inches = 0)