In [None]:
from pathlib import Path
import json
from functools import partial 

import xarray as xr
import numpy as np
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter


SNAKEMAKE = snakemake
inputs = SNAKEMAKE.input
outputs = SNAKEMAKE.output
config = SNAKEMAKE.config
plt_cfg = config["plotting"]
partition = SNAKEMAKE.wildcards.partition
experiment = SNAKEMAKE.wildcards.experiment

plt.rcParams.update(plt_cfg["rcparams"])


TASKS = [
    "air_temperature",
    "dew_point_temperature",
    "surface_air_pressure",
    "relative_humidity",
    "water_vapor_mixing_ratio",
]


In [None]:
out_dir = Path(outputs[0])
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Partition: {partition} \n")
print("Experiment configuration: \n")
exp_config = config["experiments"][experiment]
print(json.dumps(exp_config, indent=4))



In [None]:
def preprocess_predictions(ds, reftimes=None):
    path = Path(ds.encoding["source"])
    params = {}
    for kv in path.parent.name.split("-"):
        k, v = kv.split("~")
        params[k] = v
    params["approach"] = path.parents[2].name
    params["split"] = int(path.parents[3].name.split("~")[1])
    params["seed"] = int(path.parents[4].name.split("~")[1])
    dims = list(params.keys())
    ds = ds.assign_coords(params).expand_dims(dims)
    if reftimes is not None:
        ds = ds.reindex(forecast_reference_time=reftimes).load()
    return ds

def ds_to_df(ds, name):
    return (
        ds
        .reset_coords(drop=True)
        .to_array("variable")
        .to_dataframe(name)
        .reset_index()
    )

def unstack(ds):
    dims = ["forecast_reference_time", "t", "station"]
    samples = pd.MultiIndex.from_arrays([ds[dim].values for dim in dims], names=dims)
    ds = ds.reset_coords(drop=True).assign_coords(s=samples).unstack("s")
    return ds

def remove_source_prefix(ds):
    for var in ds.data_vars:
        _, name = var.split(":")
        ds = ds.rename({var: name})
    return ds

In [None]:
if partition in ["train", "val"]:
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))
    reftimes = obs.forecast_reference_time
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=partial(preprocess_predictions, reftimes=reftimes), parallel=True)
    pred = pred.dropna("forecast_reference_time", "all")
elif partition == "test":
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=preprocess_predictions, parallel=True)
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))

obs = obs.reindex_like(pred).load().chunk({"forecast_reference_time": 200})
pred = pred.squeeze().load().chunk({"forecast_reference_time": 200})
obs = obs.squeeze()
obs, pred = xr.broadcast(obs, pred)
err = pred - obs

In [None]:
mae = abs(err).mean(["forecast_reference_time","t"])
mae = (
    ds_to_df(mae, "MAE")
    .replace(plt_cfg["approach_names"])
    .replace(plt_cfg["var_short_names"])
)


fig, axs = plt.subplots(1, len(TASKS), figsize=(5,5), sharey=False)
for i, var in enumerate(TASKS):
    var = plt_cfg["var_short_names"][var]
    sns.boxplot(
        data=mae.query(f"variable=='{var}'"),
        x="variable",
        y="MAE",
        palette=plt_cfg["approach_colors"],
        ax=axs[i]
    )
    
    axs[i].set_ylim(0., None)
    
for ax in axs:
    ax.set_ylabel("")
    ax.set_xlabel("")

plt.subplots_adjust(wspace=1.5, hspace=0)
axs[0].set_ylabel("Mean absolute error")


plt.savefig(out_dir / "mae_boxplots.png")

In [None]:
mae = np.abs(err).mean(["forecast_reference_time","t"])
mae["aggreated"] = mae.to_array("v").mean("v")
mae = ds_to_df(mae, "mae")
mae = mae.replace(plt_cfg["approach_names"])
mae = mae.replace(plt_cfg["var_short_names"])
cp = sns.catplot(data=mae, x="loss.alpha", y="mae", col="variable", palette="Blues", kind="box", sharey=False, col_wrap=3, aspect=0.5)
for ax in cp.axes:
    ax.set_title(ax.title.get_text().split("=")[1][1:])
    ax.set_ylim(0., None)

plt.savefig(out_dir / "alpha_vs_mae.png")

In [None]:
def rh_from_t_td(t, t_d):
    
    rh = xr.where(
        t >= 0,
        100 * np.exp((17.368 * t_d)/(238.83+t_d) - (17.368 * t)/(238.83+t)),
        100 * np.exp((17.856 * t_d)/(245.52+t_d) - (17.856 * t)/(245.52+t))
    )
    
    return rh

def e_from_t_rh(t, rh):
    e = xr.where(
        t >= 0,
        rh / 100 * 6.107 * np.exp((17.368 * t) / (238.83 + t)),
        rh / 100 * 6.108 * np.exp((17.856 * t) / (245.52 + t)),
    )
    return e


def td_from_e_t(e, t):

    td = xr.where(
        t >= 0.0,
        -238.83 * (np.log(e / 6.107)) / (np.log(e / 6.107) - 17.368),
        -245.52 * (np.log(e / 6.108)) / (np.log(e / 6.108) - 17.856),
    )

    return td

def r_from_e_p(e, p):
    return 622.0 * (e / (p - e))

In [None]:
rh = pred["relative_humidity"]
t = pred["air_temperature"]
t_d = pred["dew_point_temperature"]
r = pred["water_vapor_mixing_ratio"]
p = pred["surface_air_pressure"]
e = e_from_t_rh(t, rh)


rh_derived = rh_from_t_td(t, t_d)
rh_residual = (rh_derived - rh) ** 2

r_derived = r_from_e_p(e, p)
r_residual = (r_derived - r) ** 2

physical_penalty = abs(rh_residual) / obs.relative_humidity.var() + abs(r_residual) / obs.water_vapor_mixing_ratio.var()

In [None]:
dims = ["forecast_reference_time","t","station"]
mae = np.abs(err).mean(dim=dims)
mae = (mae / obs.std()).to_array().mean("variable")
fig, ax1 = plt.subplots()

ax1.plot(pred["loss.alpha"].values, physical_penalty.mean(dim=dims), color="red")
ax1.tick_params(axis='x', rotation=90)
ax1.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax1.tick_params(axis='y', labelcolor="red")
ax1.set_ylabel(r"$\mathcal{P}$")
ax2 = ax1.twinx()
ax2.plot(pred["loss.alpha"].values, mae.values, color="blue")
ax2.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax2.tick_params(axis='y', labelcolor="blue")
ax2.set_ylabel(r"NMAE")
ax1.set_xlabel(r"Physical penalty weight $\alpha$")
ax1.set_ylim(0., 0.04)
ax2.set_ylim(0.18, 0.22)
plt.tight_layout()
plt.savefig(out_dir / "p_mae_vs_alpha.png")