In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from pandas.plotting import table

from dask.distributed import LocalCluster, Client

In [None]:
cluster = LocalCluster(n_workers=12, threads_per_worker=1, local_directory="/tmp", dashboard_address=":8789")
client = Client(cluster)
client

# Baseline

In [None]:
lead_time = "10 min 8s"

## Train data

In [None]:
train_nature = xr.open_zarr("../../data/raw/train/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1}).sel(lead_time=lead_time)
train_forecast = xr.open_zarr("../../data/raw/train/lr_forecast/", chunks={"time": -1, "ensemble": 1}).sel(lead_time=lead_time)

## Test data

In [None]:
nature_data = xr.open_zarr("../../data/raw/test/lr_nature_forecast/", chunks={"time":  -1, "ensemble": 1}).sel(lead_time=lead_time)

## Prediction

In [None]:
predictions = {
    "gaussian_fixed": xr.concat([xr.open_zarr(f"../../data/processed/gaussian_fixed/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "gaussian_nll": xr.concat([xr.open_zarr(f"../../data/processed/gaussian_nll/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "laplace_fixed": xr.concat([xr.open_zarr(f"../../data/processed/laplace_fixed/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "laplace_nll": xr.concat([xr.open_zarr(f"../../data/processed/unext_small/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
}

## Estimate train errors for normalization

In [None]:
train_err = train_forecast-train_nature

In [None]:
train_std = train_err.std(["ensemble", "time", "nMesh2_face", "nMesh2_node"], ddof=1)
train_median = train_err.chunk({"ensemble": -1}).quantile(0.5)
train_scale = np.abs(train_err-train_median).mean(["ensemble", "time", "nMesh2_face", "nMesh2_node"])

In [None]:
norm_std = train_std.compute()
norm_scale = train_scale.compute()

# Estimate general errors

In [None]:
pred_error = {
    name: (pred-nature_data) for name, pred in predictions.items()
}

# MAE

In [None]:
def get_mae(error):
    return np.abs(error/norm_scale).mean().to_array("var_names").mean()

In [None]:
def get_rmse(error):
    return np.sqrt(((error/norm_std)**2).mean().to_array("var_names").mean())

In [None]:
def get_corr(prediction, truth):
    pred_perts = prediction-prediction.mean(["nMesh2_face", "nMesh2_node"])
    truth_perts = truth-truth.mean(["nMesh2_face", "nMesh2_node"])
    n_grid = xr.ones_like(pred_perts).sum(["nMesh2_face", "nMesh2_node"])
    cov = (pred_perts*truth_perts).sum(["nMesh2_face", "nMesh2_node"])/(n_grid-1)
    corr = cov / (pred_perts.std(["nMesh2_face", "nMesh2_node"], ddof=1)+1E-9) / (truth_perts.std(["nMesh2_face", "nMesh2_node"], ddof=1)+1E-9)
    average_z = np.arctanh(corr).mean().to_array("var_names").mean()
    average_corr = np.tanh(average_z)
    return average_corr

# Estimate correlation

In [None]:
mae_results = pd.Series({
    name: float(get_mae(error))
    for name, error in pred_error.items()
})
mae_results.name = "mae"

In [None]:
rmse_results = pd.Series({
    name: float(get_rmse(error))
    for name, error in pred_error.items()
})
rmse_results.name = "rmse"

In [None]:
corr_results = pd.Series({
    name: float(get_corr(pred, nature_data))
    for name, pred in predictions.items()
})
corr_results.name = "corr"

# Combine

In [None]:
combined_results = pd.concat([rmse_results, mae_results, corr_results], axis=1)
#combined_results = combined_results.round(2)

In [None]:
combined_results

In [None]:
combined_results.round(2).to_latex()