In [None]:
%matplotlib inline

import os
import numpy as np
import xarray as xr
import pandas as pd

import matplotlib.pyplot as plt

In [None]:
# auxiliary methods

def create_lines_plot(data: xr.DataArray, data_std: xr.DataArray, model_names: str, metric: dict,
                      plt_fname: str, x_coord: str = "hour", **kwargs):

    # get some plot parameters
    linestyle = kwargs.get("linestyle", ["k-", "b-"])
    err_col = kwargs.get("error_color", ["grey", "blue"])
    val_range = kwargs.get("value_range", (0.7, 1.1))
    fs = kwargs.get("fs", 16)
    ref_line = kwargs.get("ref_line", None)
    ref_linestyle = kwargs.get("ref_linestyle", "k--")
    
    fig, (ax) = plt.subplots(1, 1, figsize=(8, 6))
    for i, exp in enumerate(data["exp"]):
        ax.plot(data[x_coord].values, data.sel({"exp": exp}).values, linestyle[i],
                label=model_names[i])
        ax.fill_between(data[x_coord].values, data.sel({"exp": exp}).values-data_std.sel({"exp": exp}).values,
                        data.sel({"exp": exp}).values+data_std.sel({"exp": exp}).values, facecolor=err_col[i],
                        alpha=0.2)
    if ref_line is not None:
        nval = np.shape(data[x_coord].values)[0]
        ax.plot(data[x_coord].values, np.full(nval, ref_line), ref_linestyle)
    ax.set_ylim(*val_range)
    ax.set_yticks(np.arange(*val_range, 0.05))
    # label axis
    ax.set_xlabel("daytime [UTC]", fontsize=fs)
    metric_name, metric_unit = list(metric.keys())[0], list(metric.values())[0]
    ax.set_ylabel(f"{metric_name} T2m [{metric_unit}]", fontsize=fs)
    ax.tick_params(axis="both", which="both", direction="out", labelsize=fs-2)
    ax.legend(fontsize=fs-2, loc="upper right")

    # save plot and close figure
    plt_fname = plt_fname + ".png" if not plt_fname.endswith(".png") else plt_fname
    print(f"Save plot in file '{plt_fname}'")
    #plt.tight_layout()
    #fig.savefig(plt_fname)
    fig.savefig(plt_fname, bbox_inches="tight")
    plt.close(fig)

def get_id_from_fname(fname):
    try:
        start_index = fname.find("id") + 2            # Adding 2 to move past "id"
        end_index = fname.find("_", start_index)
        
        exp_id = fname[start_index:end_index]
    except:
        raise ValueError(f"Failed to deduce experiment ID from '{fname}'")
        
    return exp_id

In [None]:
# parameters
results_basedir = "/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/results"
plt_dir = os.path.join(results_basedir, "meta")

exp1 = "wgan_t2m_atmorep_test"
exp2 = "atmorep_id26n32cey"

varname = "T2m"
year = 2018

In [None]:
# main
os.makedirs(plt_dir, exist_ok=True)

fexp1 = os.path.join(results_basedir, exp1, "metric_files", "eval_grad_amplitude_year.csv")
fexp2 = os.path.join(results_basedir, exp2, "metric_files", "eval_grad_amplitude__small_dom_year.csv")

In [None]:
dims = ["hour", "type"]
coord_dict = {"hour": np.arange(24), "type": ["mean", "std"]}

da_gr_exp1 = xr.DataArray(pd.read_csv(fexp1, header=0, index_col=0), dims=dims, coords=coord_dict)
da_gr_exp2 = xr.DataArray(pd.read_csv(fexp2, header=0, index_col=0), dims=dims, coords=coord_dict)

In [None]:
da_gr_all = xr.concat([da_gr_exp1, da_gr_exp2], dim= "exp")
da_gr_all = da_gr_all.assign_coords({"exp": [exp1, exp2]})

# create plot
plt_fname = os.path.join(plt_dir, f"eval_grad_amplitude_{exp1}_{exp2}.png")
create_lines_plot(da_gr_all.sel({"type": "mean"}), da_gr_all.sel({"type": "std"}),
                  ["WGAN", "AtmoRep"], {"GRAD_AMPLITUDE": "1"}, plt_fname, re)