In [None]:
from pathlib import Path
import json
from functools import partial 

import xarray as xr
import numpy as np
import pandas as pd
import seaborn as sns 
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import datashader as dsh
import datashader.transfer_functions as dtf
from datashader.mpl_ext import dsshow


SNAKEMAKE = snakemake
inputs = SNAKEMAKE.input
outputs = SNAKEMAKE.output
config = SNAKEMAKE.config
plt_cfg = config["plotting"]
partition = SNAKEMAKE.wildcards.partition
experiment = SNAKEMAKE.wildcards.experiment

plt.rcParams.update(plt_cfg["rcparams"])

TASKS = [
    "air_temperature",
    "dew_point_temperature",
    "surface_air_pressure",
    "relative_humidity",
    "water_vapor_mixing_ratio",
]



In [None]:
out_dir = Path(outputs[0])
out_dir.mkdir(parents=True, exist_ok=True)
print(f"Partition: {partition} \n")
print("Experiment configuration: \n")
exp_config = config["experiments"][experiment]
print(json.dumps(exp_config, indent=4))
approaches = exp_config["approaches"]
try:
    approaches.remove("offline_constrained")
except:
    pass
if "architecture_constrained" in approaches:
    approaches.remove("architecture_constrained")
    approaches.append("architecture_constrained")

In [None]:
def preprocess_predictions(ds, reftimes=None):
    params = {}
    path = Path(ds.encoding["source"])
    params["approach"] = path.parents[2].name
    params["split"] = int(path.parents[4].name.split("~")[1])
    params["seed"] = int(path.parents[3].name.split("~")[1])
    dims = list(params.keys())
    ds = ds.assign_coords(params).expand_dims(dims)
    if reftimes is not None:
        ds = ds.reindex(forecast_reference_time=reftimes).load()
    return ds

def ds_to_df(ds, name):
    return (
        ds
        .reset_coords(drop=True)
        .to_array("variable")
        .to_dataframe(name)
        .reset_index()
    )

def unstack(ds):
    dims = ["forecast_reference_time", "t", "station"]
    samples = pd.MultiIndex.from_arrays([ds[dim].values for dim in dims], names=dims)
    ds = ds.reset_coords(drop=True).assign_coords(s=samples).unstack("s")
    return ds

def remove_source_prefix(ds):
    for var in ds.data_vars:
        _, name = var.split(":")
        ds = ds.rename({var: name})
    return ds

In [None]:
if partition in ["train", "val"]:
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))
    reftimes = obs.forecast_reference_time
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=partial(preprocess_predictions, reftimes=reftimes), parallel=True)
    pred = pred.dropna("forecast_reference_time", "all")
elif partition == "test":
    pred = xr.open_mfdataset(inputs["predictions"], preprocess=preprocess_predictions, parallel=True)
    obs = remove_source_prefix(unstack(xr.load_dataset(inputs["y"])))

obs = obs.reindex_like(pred).load().chunk("auto")
pred = pred.load().chunk("auto")
obs, pred = xr.broadcast(obs, pred)
err = pred - obs

In [None]:
def rh_from_t_td(t, t_d):
    
    rh = xr.where(
        t >= 0,
        100 * np.exp((17.368 * t_d)/(238.83+t_d) - (17.368 * t)/(238.83+t)),
        100 * np.exp((17.856 * t_d)/(245.52+t_d) - (17.856 * t)/(245.52+t))
    )
    
    return rh

def e_from_t_rh(t, rh):
    e = xr.where(
        t >= 0,
        rh / 100 * 6.107 * np.exp((17.368 * t) / (238.83 + t)),
        rh / 100 * 6.108 * np.exp((17.856 * t) / (245.52 + t)),
    )
    return e

def e_from_td(t_d, t):
    e = xr.where(
        t >= 0,
        6.107 * np.exp((17.368 * t_d) / (238.83 + t_d)),
        6.108 * np.exp((17.856 * t_d) / (245.52 + t_d)),
    )
    return e   

def td_from_e_t(e, t):

    td = xr.where(
        t >= 0.0,
        -238.83 * (np.log(e / 6.107)) / (np.log(e / 6.107) - 17.368),
        -245.52 * (np.log(e / 6.108)) / (np.log(e / 6.108) - 17.856),
    )

    return td

def r_from_e_p(e, p):
    return 622.0 * (e / (p - e))

def r_from_p_td(p, t_d, t):

    e = e_from_td(t_d, t)

    return r_from_e_p(e, p)

In [None]:

def get_categories(ds, var, bins, labels):
    assert len(bins) + 1 == len(labels)
    da = ds[var]
    bins = (-np.inf, *bins, np.inf)
    
    cat = np.empty_like(da, dtype="U100")
    for i in range(len(bins)-1):
        print(bins[i], bins[i+1])
        sel = (bins[i] <= da) & (da < bins[i+1])
        cat[sel] = labels[i]
    
    cat[cat == ""] = labels[2]
    cat = xr.DataArray(cat, coords=da.coords)    
    return cat

def shade_plot(df, ax, **kwargs):

    agg = kwargs.pop("aggregator", dsh.count())
    norm = kwargs.pop("norm", "eq_hist")
    
    artist = dsshow(
        df,
        dsh.Point(x,"y"),
        aggregator=agg,
        norm=norm,
        ax=ax,
        aspect="auto",
        **kwargs
    )
    return artist

def RH_deviation(ds):
    ds = ds.reset_coords(drop=True)
    t = ds["air_temperature"]
    t_d = ds["dew_point_temperature"]
    rh = ds["relative_humidity"]
    rh_derived = rh_from_t_td(t, t_d) 
    y = rh - rh_derived
    return y

def r_deviation(ds):
    ds = ds.reset_coords(drop=True)
    t = ds["air_temperature"]
    t_d = ds["dew_point_temperature"]
    p = ds["surface_air_pressure"]
    r = ds["water_vapor_mixing_ratio"]
    e = e_from_td(t_d, t)
    r_derived = r_from_e_p(e, p)
    y =  r - r_derived
    return y

In [None]:
x = "water_vapor_mixing_ratio"

fig, axs = plt.subplots(1, len(approaches), figsize=(6.5, 5), sharey=True, sharex=True)
for i, approach in enumerate(approaches):
    ds = pred.sel(approach=approach).reset_coords(drop=True)
    y = r_deviation(ds)
    df = xr.Dataset({x:ds[x], "y":y}).to_dataframe()
    artist = shade_plot(df, axs[i], norm="eq_hist")
    if approach == "unconstrained":
        xmin, xmax, ymin, ymax = artist.get_extent()
    axs[i].set_title(plt_cfg["approach_names"][approach], fontsize=10)
    axs[i].set_xlabel(r"$r$ [g $kg^{-1}$]")#, fontsize=13)
axs[0].set_ylabel(r"$r - g(P, T_d)$")#, fontsize=13)
axs[0].set_ylim(ymin - 2, ymax + 2)
axs[0].set_xlim(xmin - 2, xmax + 5)
plt.tight_layout()
plt.savefig(out_dir / "mixing_ratio_deviations.png")

In [None]:
x = "relative_humidity"

fig, axs = plt.subplots(1, len(approaches), figsize=(6.5, 5), sharey=True, sharex=True)
for i, approach in enumerate(approaches):
    ds = pred.sel(approach=approach).reset_coords(drop=True)
    y = RH_deviation(ds)
    df = xr.Dataset({x:ds[x], "y":y}).to_dataframe()
    artist = shade_plot(df, axs[i])
    if approach == "unconstrained":
        xmin, xmax, ymin, ymax = artist.get_extent()
    axs[i].set_title(plt_cfg["approach_names"][approach], fontsize=10)    
    axs[i].set_xlabel(r"$RH$ [%]")#, fontsize=13)
axs[0].set_ylabel(r"$RH - f(T, T_d)$")#, fontsize=13)
axs[0].set_ylim(-32, 20)
axs[0].set_xlim(xmin - 2, xmax + 2)
plt.tight_layout()
plt.savefig(out_dir / "relative_humidity_deviations.png")