In [2]:
from pathlib import Path
import json
from functools import partial 

import xarray as xr
import numpy as np
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter


SNAKEMAKE = snakemake
inputs = SNAKEMAKE.input
outputs = SNAKEMAKE.output
config = SNAKEMAKE.config
plt_cfg = config["plotting"]
partition = SNAKEMAKE.wildcards.partition
experiment = SNAKEMAKE.wildcards.experiment


plt.rcParams.update(plt_cfg["rcparams"])


TASKS = [
    "air_temperature",
    "dew_point_temperature",
    "surface_air_pressure",
    "relative_humidity",
    "water_vapor_mixing_ratio",
]

In [3]:
out_dir = Path(outputs[0])
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Partition: {partition} \n")
print("Experiment configuration: \n")
exp_config = config["experiments"][experiment]
print(json.dumps(exp_config, indent=4))
approaches = exp_config["approaches"]

In [4]:
def preprocess_predictions(ds, reftimes=None):
    params = {}
    path = Path(ds.encoding["source"])
    params["approach"] = path.parents[2].name
    params["split"] = int(path.parents[4].name.split("~")[1])
    params["seed"] = int(path.parents[3].name.split("~")[1])
    dims = list(params.keys())
    ds = ds.assign_coords(params).expand_dims(dims)
    if reftimes is not None:
        ds = ds.reindex(forecast_reference_time=reftimes).load()
    return ds

def ds_to_df(ds, name):
    return (
        ds
        .reset_coords(drop=True)
        .to_array("variable")
        .to_dataframe(name)
        .reset_index()
    )

def unstack(ds):
    dims = ["forecast_reference_time", "t", "station"]
    samples = pd.MultiIndex.from_arrays([ds[dim].values for dim in dims], names=dims)
    ds = ds.reset_coords(drop=True).assign_coords(s=samples).unstack("s")
    return ds

def remove_source_prefix(ds):
    for var in ds.data_vars:
        _, name = var.split(":")
        ds = ds.rename({var: name})
    return ds

In [5]:
if partition in ["train", "val"]:
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))
    reftimes = obs.forecast_reference_time
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=partial(preprocess_predictions, reftimes=reftimes), parallel=True)
    pred = pred.dropna("forecast_reference_time", "all")
elif partition == "test":
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=preprocess_predictions, parallel=True)
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))

obs = obs.reindex_like(pred).load().chunk("auto")
pred = pred.load().chunk("auto")
obs, pred = xr.broadcast(obs, pred)
err = pred - obs

In [6]:
mae = abs(err).mean(["forecast_reference_time","t","station"])
reduce_dims = ["var"]
nmae = (mae / obs.std()).to_array("var").mean(reduce_dims).compute()

In [7]:
q50, q90, q99, q999 = obs.air_temperature.quantile([0.5, 0.9, 0.99, 0.999]).compute().values

In [8]:
res_nmae = []
res_mae = []
res_bias = []

for q in (q50, q90, q99, q999):
    bias = err.where(obs.air_temperature > q).mean(["forecast_reference_time","t","station"])
    mae = abs(err).where(obs.air_temperature > q).mean(["forecast_reference_time","t","station"])
    reduce_dims = ["var"]
    nmae = (mae / obs.std()).to_array("var").mean(reduce_dims).compute()
    res_nmae.append(nmae)
    res_mae.append(mae)
    res_bias.append(bias)

In [9]:
nmae_ = xr.concat(res_nmae, dim="quantile").assign_coords(quantile=[0.5, 0.9, 0.99, 0.999])
mae_ = xr.concat(res_mae, dim="quantile").assign_coords(quantile=[0.5, 0.9, 0.99, 0.999])
bias_ = xr.concat(res_bias, dim="quantile").assign_coords(quantile=[0.5, 0.9, 0.99, 0.999])

In [17]:
fig, ax = plt.subplots(1, figsize=(6,5))


df = (nmae_.to_dataframe(name="Aggregated NMAE").reset_index()
      .replace(plt_cfg["approach_names"])
      .rename(columns={"approach":"Approach"})
     )


sns.boxplot(
    data=df,
    x="quantile",
    y="Aggregated NMAE",
    hue="Approach",
    hue_order=list(plt_cfg["approach_names"].values()),
    palette=plt_cfg["approach_colors"],
    showfliers=False,
    ax=ax
)

ax = plt.gca()
ax.set_xlabel("Quantile of $T$")
lgd = ax.legend(
        bbox_to_anchor=(0.13, 0.885, 0.78, 0.1),
        loc="lower left",
        ncol=2,
        mode="expand",
        borderaxespad=0.,
        frameon=False,
        fontsize=11,
        bbox_transform=plt.gcf().transFigure
    )

plt.savefig(out_dir / "NMAE_vs_quantile.png")