In [None]:
from pathlib import Path
import json
from functools import partial 

import xarray as xr
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import norm 
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from itertools import combinations

SNAKEMAKE = snakemake
inputs = SNAKEMAKE.input
outputs = SNAKEMAKE.output
config = SNAKEMAKE.config
plt_cfg = config["plotting"]
partition = SNAKEMAKE.wildcards.partition

plt.rcParams.update(plt_cfg["rcparams"])

TASKS = [
    "air_temperature",
    "dew_point_temperature",
    "surface_air_pressure",
    "relative_humidity",
    "water_vapor_mixing_ratio",
]


In [None]:
out_dir = Path(outputs[0])
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Partition: {partition} \n")
print("Experiment configuration: \n")
exp_config = config["experiments"]["default"]
print(json.dumps(exp_config, indent=4))

# Data import

In [None]:
def preprocess_predictions(ds, reftimes=None):
    params = {}
    path = Path(ds.encoding["source"])
    params["approach"] = path.parents[2].name
    params["split"] = int(path.parents[3].name.split("~")[1])
    params["seed"] = int(path.parents[4].name.split("~")[1])
    dims = list(params.keys())
    ds = ds.assign_coords(params).expand_dims(dims)
    if reftimes is not None:
        ds = ds.reindex(forecast_reference_time=reftimes).load()
    return ds


def process_nwp(ds, parameters):
    for var in ds.data_vars:
        prefix, name = var.split(":")
        name = name.removesuffix("_ensavg")
        ds = ds.rename({var:name})
    ds = ds[parameters]
    return ds        

def ds_to_df(ds, name):
    return (
        ds
        .reset_coords(drop=True)
        .to_array("variable")
        .to_dataframe(name)
        .reset_index()
    )

def unstack(ds):
    dims = ["forecast_reference_time", "t", "station"]
    samples = pd.MultiIndex.from_arrays([ds[dim].values for dim in dims], names=dims)
    ds = ds.reset_coords(drop=True).assign_coords(s=samples).unstack("s")
    return ds

def remove_source_prefix(ds):
    for var in ds.data_vars:
        _, name = var.split(":")
        ds = ds.rename({var: name})
    return ds


if partition in ["train", "val"]:
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))
    reftimes = obs.forecast_reference_time
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=partial(preprocess_predictions, reftimes=reftimes), parallel=True)
    pred = pred.dropna("forecast_reference_time", "all")
elif partition == "test":
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=preprocess_predictions, parallel=True)
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))

obs = obs.reindex_like(pred).load().chunk({"forecast_reference_time": 200})
pred = pred.squeeze().load().chunk({"forecast_reference_time": 200})
obs = obs.squeeze()
obs, pred = xr.broadcast(obs, pred)


reftimes = obs.forecast_reference_time.values
stations = obs.station.values
features = xr.open_zarr("/scratch/fzanetta/pcpp-workflow/data/01_raw/features.zarr").sel(station=stations, forecast_reference_time=reftimes)
nwp = process_nwp(features, TASKS)
nwp = nwp.compute().chunk("auto").persist()



# Plotting and tables functions

In [None]:
def boxplots(ds, cfg, figsize=(6,6), ylim_mul = (1., 1.), yticks_inside=False, approaches=None, return_legend=False):
    
    df = (
        ds_to_df(ds, "metric")
        .replace(cfg["approach_names"])
        .replace(cfg["var_short_names"])
        .rename(columns={"approach":"Approach"})
    )
    
    if approaches is None:
        names = cfg["approach_names"].values()
    else:
        names = [v for k, v in cfg["approach_names"].items() if k in approaches]
        
    fig, axs = plt.subplots(1, len(TASKS), figsize=figsize, sharey=False)
    for i, var in enumerate(TASKS):
        var = cfg["var_short_names"][var]
        sns.boxplot(
            data=df.query(f"variable=='{var}'"),
            x="variable",
            y="metric",
            hue="Approach",
            hue_order=names,
            palette=cfg["approach_colors"],
            ax=axs[i],
            showfliers=False
        )

    for i, ax in enumerate(axs):
        ax.legend().remove()
        ax.set_ylabel("")
        ax.set_xlabel("")
        ymin, ymax = ax.get_ylim()
        new_ylims = (ymin * ylim_mul[0], ymax * ylim_mul[1])
        if TASKS[i] == "relative_humidity":
            new_ylims = (new_ylims[0] * 0.9, new_ylims[1] * 1.1)
        ax.set_ylim(*new_ylims)
        if yticks_inside:
            ymin_ = round(ymin * 0.4 + new_ylims[0] * 0.6, 2)
            ymax_ = round(ymax * 0.4 + new_ylims[1] * 0.6, 2)
            ax.tick_params(axis="y",direction="in", pad=-40)
            ax.set_yticks([ymin_, ymax_], labels=[ymin_, ymax_])
            

#             ax.text(0., ymin_, str(round(ymin_,2)), ha="center")
#             ax.text(0., ymax_, str(round(ymax_,2)), ha="center")
        

    lgd = axs[0].legend(
        bbox_to_anchor=(0.075, 0.87, 0.88, 0.1),
        loc="lower left",
        ncol=2,
        mode="expand",
        borderaxespad=0.,
        frameon=False,
        # fontsize=9,
        bbox_transform=plt.gcf().transFigure
    )
    
    plt.subplots_adjust(wspace=1.5, hspace=0)
    if return_legend:
        return fig, axs, lgd
    else:
        return fig, axs

def latex_table(results, cfg, scores=["mae", "msss"], labels=["MAE", "MSSS"]):

    out = ""

    # preamble
    out += "\\begin{table*}\n"
    out += "    \caption{This is a table. Add something here.}\n"
    out += "    \\renewcommand{\\arraystretch}{1.1} \n"
    out += "    \\begin{tabular*}{\hsize}{@{\extracolsep\\fill}l lllll lllll@{}}\n"

    # body
    var_symbols = list(cfg["var_short_names"].values())
    out += "        \\topline\n"
    out += "        &" + " & ".join(["\multicolumn{2}{c}" + f"{'{'}{v}{'}'}" for v in var_symbols]) + "\\\ \n"
    out += "        \cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7} \cmidrule(lr){8-9} \cmidrule(lr){10-11} \n"
    out += f"        & {labels[0]} & {labels[1]} & {labels[0]} & {labels[1]} & {labels[0]} & {labels[1]} & {labels[0]} & {labels[1]} & {labels[0]} & {labels[1]} \\\ \n"
    for approach, name in cfg["approach_names"].items():
        row = "        "
        row += f"{name} & " + " & ".join(
            [f"{results['mae'][var][approach]:.3f} & {results['msss'][var][approach]:.3f}"
             for var in results["mae"].keys()]
        )
        row += " \\\ \n"
        out += row
    out += "        \\\ \n"
    
    # postamble
    out += "   \end{tabular*}\n"
    out += "\end{table*}"
    return out

def latex_table_counts(results, cfg):

    out = ""

    # preamble
    out += "\\begin{table*}\n"
    out += "    \caption{This is a table. Add something here.}\n"
    out += "    \\renewcommand{\\arraystretch}{1.1} \n"
    out += "    \\begin{tabular*}{\hsize}{@{\extracolsep\\fill}l lllll@{}}\n"

    # body
    var_symbols = list(cfg["var_short_names"].values())
    out += "        \\topline\n"
    out += "        & " + " & ".join([f"{v}" for v in var_symbols]) + "\\\ \n"
    out += "        \\midline\n"
    for approach, name in cfg["approach_names"].items():
        row = "        "
        row += f"{name} & " + " & ".join(
            [f"{results[var][approach]}"
             for var in results.keys()]
        )
        row += " \\\ \n"
        out += row
    out += "        \\\ \n"
    
    # postamble
    out += "   \end{tabular*}\n"
    out += "\end{table*}"
    return out

# Metrics 

In [None]:
def mean_absolute_error(pred, obs, reduce_dims):
    return abs(pred - obs).mean(reduce_dims)

def mean_squared_error(pred, obs, reduce_dims):
    return ((pred - obs) ** 2).mean(reduce_dims)

def mean_squared_skill_score(pred, ref, obs, reduce_dims):
    mse_pred = ((pred - obs) ** 2).mean(reduce_dims)
    mse_ref  = ((ref - obs) ** 2).mean(reduce_dims)
    return 1 - mse_pred / mse_ref

def r_squared(pred, obs, reduce_dims):
    res = pred - obs
    ss_tot = ((obs.mean(reduce_dims) - obs) ** 2).sum(reduce_dims)
    ss_res = (res ** 2).sum(reduce_dims)
    return 1 - ss_res / ss_tot




# Boxplots

These plots show the performance metrics averaged over the forecast reference time, station and leadtime. The distributions represent the variability due to cross-validation split and random seed.

In the case of the MSSS and the R$^2$, we first compute the metrics for each station individually and take the average afterwards. This way we remove the variability related to a station's climatological mean. If we would not do that, we would have large values (close to 1), but it would not be a fair evaluation.

### Mean Absolute Error

In [None]:
mae = mean_absolute_error(pred, obs, ["forecast_reference_time", "station","t"])
fig, axs = boxplots(mae, plt_cfg, figsize=(6,5), ylim_mul=(0.97, 1.03), yticks_inside=True)
axs[0].set_ylabel("Mean absolute error")
plt.tight_layout()
plt.subplots_adjust(wspace=0.1)
plt.savefig(out_dir / "mae_boxplots.png")

### Mean Squared Skill Score

In [None]:
msss = mean_squared_skill_score(pred, nwp, obs, ["forecast_reference_time", "t"]).mean("station")
fig, axs = boxplots(msss, plt_cfg, figsize=(6,5), ylim_mul=(0.9, 1.1), yticks_inside=True)
axs[0].set_ylabel("Mean Squared Skill Score")
plt.tight_layout()
plt.subplots_adjust(wspace=0.1)
plt.savefig(out_dir / "msss_boxplots.png")

### Coefficient of determination

In [None]:
r2 = r_squared(pred, obs, ["forecast_reference_time", "t"]).mean("station")
fig, axs = boxplots(r2, plt_cfg, figsize=(6,5), ylim_mul=(0.99, 1.015), yticks_inside=True)
axs[0].set_ylabel("Coefficient of determination")
plt.tight_layout()
plt.subplots_adjust(wspace=0.1)
plt.savefig(out_dir / "r_squared_boxplots.png")

# Scatter plots

Here we plot the NWP baseline performance metric on the x-axis and the postprocessed predictions on the y-axis.

In [None]:
def station_metrics_scatterplot(err_pp, err_nwp):
    
    fig, axs = plt.subplots(2,3, figsize=(15,10), layout="constrained")
    for var, ax in zip(TASKS, axs.ravel()[:-1]):
    #     var = "relative_humidity"
    #     fig, ax = plt.subplots()
        for i, approach in enumerate(exp_config["approaches"]):
            ax.scatter(
                err_nwp[var].values,
                err_pp[var].sel(approach=approach).values,
                c=plt_cfg["approach_colors"][i],
                s=12,
                alpha=0.7,
                label=plt_cfg["approach_names"][approach]
            )
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        lim = (min(xlim[0], ylim[0]), max(xlim[1], ylim[1]))
        ax.axline((np.mean(lim),np.mean(lim)), slope=1, linestyle="--", c="k", linewidth=1)
        ax.set(
            ylim=lim, xlim=lim, 
            title=plt_cfg["var_long_names"][var],
            aspect=1
        )
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, bbox_to_anchor=(0.95, 0.3))
    # leg = fig.legend(loc="lower right")
    axs[1,2].remove()
    
    return fig, axs

### Mean Absolute Error

In [None]:
mae_nwp = mean_absolute_error(nwp, obs, ["forecast_reference_time","seed","split","t"]).mean(["approach"])
mae_pp = mean_absolute_error(pred, obs, ["forecast_reference_time","seed","split","t"])

fig, axs = station_metrics_scatterplot(mae_pp, mae_nwp)
for var, ax in zip(TASKS, axs.ravel()):
    ax.set(
        ylabel=f"MAE [{plt_cfg['var_units'][var]}] of postprocessing", 
        xlabel=f"MAE [{plt_cfg['var_units'][var]}] of NWP",
    )
    
plt.savefig(out_dir / "mae_scatterplot.png")

### Mean squared error

In [None]:
mae_nwp = mean_squared_error(nwp, obs, ["forecast_reference_time","seed","split","t"]).mean(["approach"])
mae_pp = mean_squared_error(pred, obs, ["forecast_reference_time","seed","split","t"])

fig, axs = station_metrics_scatterplot(mae_pp, mae_nwp)
for var, ax in zip(TASKS, axs.ravel()):
    ax.set(
        ylabel=f"MSE [{plt_cfg['var_units'][var]}]$^2$ of postprocessing", 
        xlabel=f"MSE [{plt_cfg['var_units'][var]}]$^2$ of NWP",
    )
    
plt.savefig(out_dir / "mse_scatterplot.png")

# Overall results & tables

In [None]:
# mae
mae = mean_absolute_error(pred, obs, ["forecast_reference_time","seed", "split","t"]).mean("station")
mae.to_dataframe().to_json(out_dir / "mae_results.json", indent=4)

    
# msss
msss = mean_squared_skill_score(pred, nwp, obs, ["forecast_reference_time","seed", "split", "t"]).mean("station")
msss.to_dataframe().to_json(out_dir / "msss_results.json", indent=4)


# r squared
r2 = r_squared(pred, obs, ["forecast_reference_time","seed", "split", "t"]).mean("station")
r2.to_dataframe().to_json(out_dir / "r_squared_results.json", indent=4)



### Main results table

In [None]:
out = latex_table({"mae":mae.to_dataframe(), "msss":msss.to_dataframe()}, plt_cfg)
with open(out_dir / "latex_table_results.txt", "w") as f:
    f.write(out)

## Significance testing
We perform tests of equal predictive performance using the Diebold-Mariano test, and apply the Benjamin-Hochberg procedure to account for the fals discovery rate in multiple testing. We apply tests between each pair of models, for every combination of leadtime and station. The loss vectors are first averaged over the cross-validation split and the random seeds. 
We assume independence between forecast errors.

In [None]:
def diebold_mariano_test(l1, l2):
    """
    Diebold-Mariano (DM) test, assuming independence between
    loss values (as in Schulz and Lerch (2022)). 
    
    Parameters:
    -----------
    l1: np.ndarray
        Loss of the first forecast.
    l2: np.ndarray
        Loss of the second forecast.

    Returns:
    -------
    dm_stat: the DM test statistic 
    p_value: the p-value
    
    """
    D = l1 - l2
    n = D.shape[0]
    D_mean = np.mean(D)
    D_var = np.var(D)
    se = np.sqrt(D_var / n)
    dm_stat = D_mean / se
    p_value = 2 * (1 - norm.cdf(np.abs(dm_stat)))
    return dm_stat, p_value


def benjamin_hochberg_correction(p_values, q=0.05):
    """
    Returns significant p-values after applying Benjamin-Hochberg correction.
    
    Parameters:
    -----------
    p_values: np.ndarray
        The array of p-values.
    q: float, optional
        The significance level for the correction.
    """
    sorted_p_values = np.sort(p_values)
    m = len(p_values)
    adjusted_p_values = sorted_p_values * m / np.arange(1, m+1)
    k = np.max(np.where(adjusted_p_values <= np.arange(1, m+1) / m * q))
    significant_p_values = p_values <= sorted_p_values[k]
    return significant_p_values


In [None]:
def predictive_performance_test(errors: xr.DataArray):
    """
    For a given target variable, perform Diebold-Mariano tests between each pair
    of settings, for every combination of leadime and stations.

    Parameters:
    -----------
    errors: xr.DataArray
        The errors for the given target variable.
    
    Returns:
    --------
    results: pd.DataFrame

    """

    results = np.zeros((4,4)).astype(int) * np.nan 
    # loop for pairwise model comparisons
    for a1, a2 in list(combinations(exp_config["approaches"], 2)):
        res = []
        n = 0
        # loop over stations and leadtimes
        for station in errors.station.values:
            for leadtime in errors.t.values:
                l1 = errors.sel(approach=a1, station=station, t=leadtime).dropna("forecast_reference_time")
                l2 = errors.sel(approach=a2, station=station, t=leadtime).dropna("forecast_reference_time")

                # diebold-mariano test
                dm, p = diebold_mariano_test(l1.values, l2.values)
                res.append((station, leadtime, dm, p))
                n += 1

        # aggregate results
        res = pd.DataFrame(res, columns=["station","leadtime", "DM","p"]).set_index("station")

        # table index
        a1_idx = exp_config["approaches"].index(a1)
        a2_idx = exp_config["approaches"].index(a2)

        # forecasts of model 1 significantly better than 2
        results[a1_idx, a2_idx] = int(((res["DM"] < 0.) & benjamin_hochberg_correction(res["p"])).sum()) / n * 100

        # forecasts of model 2 significantly better than 1
        results[a2_idx, a1_idx] = int(((res["DM"] > 0.) & benjamin_hochberg_correction(res["p"])).sum()) / n * 100

    return results

In [None]:
errors = abs(pred - obs).mean(["split","seed"]).compute()

dfs = []
for var in TASKS:
    results = predictive_performance_test(errors[var])
    df = pd.DataFrame(results, columns=exp_config["approaches"], index=exp_config["approaches"])
    df["Winning average"] = df.mean(axis=1)
    df.loc["Losing average"] = (*df.mean(axis=0).values[:-1], np.nan)
    dfs.append(df)

In [None]:
def latex_table_tests(results, cfg):

    out = ""

    # preamble
    out += "\\begin{table*}\n"
    out += "    \caption{This is a table. Add something here.}\n"
    out += "    \\renewcommand{\\arraystretch}{1.} \n"
    out += "    \\begin{tabular*}{\hsize}{@{\extracolsep\\fill}l llll l@{}}\n"

    # body
    colnames = list(cfg["approach_names"].values()) + ["Winning average"]
    rownames = list(cfg["approach_names"].values()) + ["Losing average"]
    out += "        \\topline\n"
    out += "        & " + " & ".join([f"{v}" for v in colnames]) + "\\\ \n"
    out += "        \\midline\n"
    for name in results.index.values:
        f = lambda v: "" if np.isnan(v) else f"{v:.2f}"
        row = "        "
        if name == "Losing average":
            row += "\\midline\n        "
        row += f"{cfg['approach_names'].get(name, name)} & " + " & ".join([f(results.loc[name, col]) for col in results.columns])
        row += " \\\ \n"
        out += row
    out += "        \\\ \n"
    
    # postamble
    out += "   \end{tabular*}\n"
    out += "\end{table*}"
    return out

for var, df in zip(TASKS, dfs):
    out = latex_table_tests(df, plt_cfg)
    with open(out_dir / f"signficance_table_{var}.txt", "w") as f:
        f.write(out)

## Extra: suspect scores
One might reasonably believe that the skill scores (MSSS) presented above are suspiciously high, expecially for pressure. The reason for these large values is that the errors of the NWP model are indeed very large: this is typical in situations of complex topography (such as Switzerland) where quantities related to elevation often have very large biases due to the difference between the model elevation and the true elevation.

In the code below we show that the mean bias of the NWP forecast at certain stations varies between -80 and +87 hPa.

In [None]:
nwp_err = nwp - obs
nwp_err["surface_air_pressure"].mean(["seed","split","forecast_reference_time","t","approach"]).compute().astype(int)

This bias is easily corrected by the postprocessing models, and gets close to zero.

In [None]:
err = pred - obs
err["surface_air_pressure"].mean(["seed","split","forecast_reference_time","t", "approach"]).compute().astype(int)