In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from pandas.plotting import table

from dask.distributed import LocalCluster, Client

import matplotlib.pyplot as plt
import cmocean
import src_screening.model.accessor

In [None]:
cluster = LocalCluster(n_workers=48, threads_per_worker=1, local_directory="/tmp")
client = Client(cluster)
client

# Baseline

In [None]:
lead_times = ["1 hour"]

## Train data

In [None]:
train_nature = xr.open_zarr("../../data/raw/train/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(lead_time=lead_times)
train_forecast = xr.open_zarr("../../data/raw/train/lr_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(lead_time=lead_times)

## Test data

In [None]:
nature_data = xr.open_zarr("../../data/raw/test/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(lead_time=lead_times)
persist_data = xr.open_zarr("../../data/raw/test/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(lead_time="0s")
forecast_data = xr.open_zarr("../../data/raw/test/lr_forecast/", chunks={"time": -1, "ensemble": 1, "lead_time": -1}).sel(lead_time=lead_times)

## Prediction

In [None]:
predictions = {
    "U-NeXt (×1)": xr.open_mfdataset(
        "../../data/processed/unext_small/*/traj_short", concat_dim="seed", combine="nested", engine="zarr", parallel=True,
        chunks={"time": -1, "ensemble": 1, "lead_time": -1, "seed": 5}
    ),
    "Initial+Difference": xr.open_mfdataset(
        "../../data/processed/input_difference/*/traj_short", concat_dim="seed", combine="nested", engine="zarr", parallel=True,
        chunks={"time": -1, "ensemble": 1, "lead_time": -1, "seed": 5}
    ),
}

In [None]:
predictions = {
    k: p.sel(lead_time=~p.indexes["lead_time"].duplicated(keep="last"))
    for k, p in predictions.items()
}

In [None]:
predictions = {
    k: p.sel(lead_time=lead_times)
    for k, p in predictions.items()
}

## Estimate train errors for normalization

In [None]:
train_errors = train_forecast-train_nature

In [None]:
norm_std = train_errors.std(["ensemble", "time", "nMesh2_face", "nMesh2_node"], ddof=1)
norm_std = norm_std.compute()
norm_std = norm_std.clip(min=1E-9)

# Estimate general errors

In [None]:
persist_error = persist_data-nature_data
fcst_error = forecast_data-nature_data

In [None]:
all_errors = {
    "persistence": persist_error/norm_std,
    "forecast": fcst_error/norm_std,
}

In [None]:
pred_error = {
    name: (pred-nature_data)/norm_std for name, pred in predictions.items()
}
all_errors = all_errors | pred_error

# Analyse output

In [None]:
def get_rmse(error):
    if "seed" in error.dims:
        return np.sqrt((error**2).mean(["seed", "ensemble", "time", "nMesh2_face", "nMesh2_node"]))
    else:
        return np.sqrt((error**2).mean(["ensemble", "time", "nMesh2_face", "nMesh2_node"]))

In [None]:
def get_mean_rmse(error):
    if "seed" in error.dims:
        return np.sqrt((error**2).mean(["seed", "ensemble", "time", "nMesh2_face", "nMesh2_node"]).to_array("var_names").mean("var_names"))
    else:
        return np.sqrt((error**2).mean(["ensemble", "time", "nMesh2_face", "nMesh2_node"]).to_array("var_names").mean("var_names"))

In [None]:
rmse_results = pd.DataFrame({
    name: get_rmse(error).to_array("var_names").stack(stacked=["lead_time", "var_names"]).to_pandas()
    for name, error in all_errors.items()
})

In [None]:
rmse_results.round(2)

In [None]:
total_rmse = pd.DataFrame({
    name: get_mean_rmse(error).to_pandas()
    for name, error in all_errors.items()
})

In [None]:
total_rmse.name = "mean"
total_rmse.index = [(pd.Timedelta('0 days 01:00:00'), 'mean')]

In [None]:
total_rmse

In [None]:
combined_result = pd.concat([rmse_results, total_rmse], axis=0)

In [None]:
combined_result.index = combined_result.index.get_level_values("var_names")

In [None]:
combined_result.round(2).T

In [None]:
total_rmse.T.round(2).to_latex()