In [None]:
import xarray as xr
import numpy as np
import pandas as pd
from pandas.plotting import table

from dask.distributed import LocalCluster, Client

In [None]:
cluster = LocalCluster(n_workers=12, threads_per_worker=1, local_directory="/tmp", dashboard_address=":8789")
client = Client(cluster)
client

# Baseline

In [None]:
lead_time = "10 min 8s"

## Train data

In [None]:
train_nature = xr.open_zarr("../../data/raw/train/lr_nature_forecast/", chunks={"time": -1, "ensemble": 1}).sel(lead_time=lead_time)
train_forecast = xr.open_zarr("../../data/raw/train/lr_forecast/", chunks={"time": -1, "ensemble": 1}).sel(lead_time=lead_time)

## Test data

In [None]:
nature_data = xr.open_zarr("../../data/raw/test/lr_nature_forecast/", chunks={"time":  -1, "ensemble": 1}).sel(lead_time=lead_time)

## Prediction

In [None]:
predictions = {
    "initial_only": xr.concat([xr.open_zarr(f"../../data/processed/input_initial/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "forecast_only": xr.concat([xr.open_zarr(f"../../data/processed/input_forecast/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "initial_forecast": xr.concat([xr.open_zarr(f"../../data/processed/unext_small/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "without_forcing": xr.concat([xr.open_zarr(f"../../data/processed/input_woforcing/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "difference_only": xr.concat([xr.open_zarr(f"../../data/processed/input_only_difference/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "initial_difference": xr.concat([xr.open_zarr(f"../../data/processed/input_difference/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
    "forecast_difference": xr.concat([xr.open_zarr(f"../../data/processed/input_fcst_difference/{s:d}/prediction_offline", chunks={"time": -1}) for s in range(10)], dim="seed"),
}

## Estimate train errors for normalization

In [None]:
train_err = train_forecast-train_nature

In [None]:
train_median = train_err.chunk({"ensemble": -1}).quantile(0.5)
train_scale = np.abs(train_err-train_median).mean(["ensemble", "time", "nMesh2_face", "nMesh2_node"])

In [None]:
norm_median = train_median.compute()
norm_scale = train_scale.compute()

In [None]:
train_mean = train_nature.mean().compute()

# Estimate general errors

In [None]:
pred_error = {
    name: (pred-nature_data)/norm_scale for name, pred in predictions.items()
}

# MAE

In [None]:
def get_mae(error):
    return np.abs(error).mean()

In [None]:
mae_results = pd.DataFrame({
    name: get_mae(error).to_array("var_names").to_pandas()
    for name, error in pred_error.items()
})

In [None]:
mae_mean = mae_results.mean()
mae_mean.name = "mean"

# Combine

In [None]:
combined_results = pd.concat([mae_results.T, mae_mean], axis=1)

In [None]:
combined_results = combined_results.round(2)
combined_results[["area", "damage", "stress_yy", "v", "mean"]].to_latex()