# Postprocessing trained downscaling models

In [None]:
import os, sys
sys.path.append("../models/")
sys.path.append("../utils/")
sys.path.append("../handle_data/")
import tensorflow as tf
import tensorflow.keras as keras
from handle_data_unet import *
from handle_data_class import  *
from statistical_evaluation import Scores
from plotting import *
import datetime as dt
import numpy as np
import xarray as xr
import json as js

## Base directories for test dataset and model

Adapt `datadir`, `model_base_dir` and `model_name`.
 - `datadir`: directory where the test dataset is stored
 - `model_base_dir`: top-level directory where trained downscaling models are saved
 - `model_name`: name of trained model
 - `lztar`: flag if high-resolved (target) topography is part of the input data
 - `last`: flag if last (instead of best) model should be evaluated (requires supervised training optimization, i.e. for WGAN only!)

In [None]:
data_dir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/"
model_base_dir = "/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/trained_models"
# model_base_dir = "/p/scratch/deepacf/deeprain/ji4/Downsacling/results_ap5/unet_exp0909_booster_epoch30/"
# name of the model to be postprocessed
model_name = "wgan_era5_to_crea6_epochs40_supervision_ztar2in_noes2"
# model_name = "unet_era5_to_crea6_test"
lztar = True
# lztar = False
last = False

# constrct model directory paths
model_base = os.path.join(model_base_dir, model_name)
if "wgan" in model_name:
    add_str = "_last" if last else ""
    add_path = ".."
    model_dir = os.path.join(model_base, f"{model_name}_generator{add_str}")
    model_type = "wgan"
else:
    add_str = ""
    add_path = ""
    model_dir = model_base
    model_type = "unet"

Next, we load the model and also retrieve the testing dataset by reading the corresponding netCDF-file.

In [None]:
print(f"Load model '{model_name}' from {model_dir}")
trained_model = keras.models.load_model(model_dir, compile=False)
print(f"Read training dataset from {data_dir}") 
ds_test = xr.open_dataset(os.path.join(data_dir, "preproc_era5_crea6_test.nc"))

## Data preprocessing

After retrieving the reference data (i.e. the ground truth data)...

In [None]:
ground_truth = ds_test["t_2m_tar"]

... we preprocess the input from the test dataset. For this, the data is reshaped into a xarray DataArray whose last dimension corresponds to the variables (the feature channels).

In [None]:
# Get the normalization parameters from saved json file
js_norm = os.path.join(model_dir, add_path, "z_norm_dict.json")

try:
    with open(js_norm, "r") as f:
        norm_dict = js.load(f)
except FileNotFoundError as e:
    raise FileNotFoundError(f"Could not find '{js_norm}'. Please check model-directory '{model_dir}'.")

train_vars = list(ds_test.keys())
mu_train, std_train = np.asarray(norm_dict["mu"]), np.asarray(norm_dict["std"])
da_test = HandleDataClass.reshape_ds(ds_test)
da_test = HandleUnetData.z_norm_data(da_test, norm_method="norm", save_path=model_base)

In [None]:
# Split the inputs and the target data
da_test_in, da_test_tar = HandleDataClass.split_in_tar(da_test)
if lztar:
    print("Add high-resolved target topography to input features.")
    da_test_in = xr.concat([da_test_in, da_test_tar.sel({"variables": "hsurf_tar"})], dim="variables")

## Create predictions from trained model

The preprocessed data is fed into the trained model to obtain the downscalted 2m tmepertaure which is subject to evaluation later.

In [None]:
print("Start inference from trained model...")
y_pred_trans =  trained_model.predict(da_test_in.squeeze().values, batch_size=32, verbose=1)

For evaluation, we have to denormalize the data. 

In [None]:
# get coordinates and dimensions from target data
coords = da_test_tar.isel(variables=0).squeeze().coords
dims = da_test_tar.isel(variables=0).squeeze().dims
y_pred = xr.DataArray(y_pred_trans[0].squeeze(), coords=coords, dims=dims)
# perform denormalization
y_pred = HandleUnetData.denormalize(y_pred.squeeze(), 
                                    norm_dict["mu"]["t_2m_tar"], 
                                    norm_dict["std"]["t_2m_tar"])
y_pred = xr.DataArray(y_pred, coords=coords, dims=dims)

## Evaluation

Subsequently, the produced downscaling product is evaluated using the following scores
- RMSE
- Bias
- Horizontal gradient ratio

For this, we instantiate a score-engine which allows us to efficiently calculate some scores. Furthermore, we set and create the directory for saving the plots.

In [None]:
# get plot directory
plot_dir = os.path.join(".", model_name+add_str)
os.makedirs(plot_dir, exist_ok=True)

avg_dims = ["rlat", "rlon"]
# instantiate score engine
score_engine = Scores(y_pred, ground_truth, avg_dims)

To run the evaluation and to create the desired plots, we define a small auxiliary function.

In [None]:
def run_evaluation(score_engine, score_name: str, score_unit: str, **plt_kwargs):
    os.makedirs(plot_dir, exist_ok=True)
    model_type = plt_kwargs.get("model_type", "wgan")
    
    print(f"Start evaluation in terms of {score_name}")
    score_all = score_engine(score_name)
    
    print(f"Globally averaged {score_name}: {score_all.mean().values:.4f} {score_unit}, standard deviation: {score_all.std().values:.4f}")
    
    score_hourly_all = score_all.groupby("time.hour")
    score_hourly_mean, score_hourly_std = score_hourly_all.mean(), score_hourly_all.std()
    for hh in range(24):
        if hh == 0:
            tmp = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season")
            score_hourly_mean_sea, score_hourly_std_sea = tmp.mean().copy(), tmp.std().copy()
        else:
            tmp = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season")
            score_hourly_mean_sea, score_hourly_std_sea = xr.concat([score_hourly_mean_sea, tmp.mean()], dim="hour"), \
                                                          xr.concat([score_hourly_std_sea, tmp.std()], dim="hour")
   
    # create plots                                  
    create_line_plot(score_hourly_mean, score_hourly_std, model_type.upper(),
                     {score_name.upper(): score_unit}, os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}.png"), **plt_kwargs)

    for sea in score_hourly_mean_sea["season"]:
        create_line_plot(score_hourly_mean_sea.sel({"season": sea}), 
                         score_hourly_std_sea.sel({"season": sea}),
                         model_type.upper(), {score_name.upper(): score_unit},
                         os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea.values}.png"), 
                         **plt_kwargs)
    return True
    

Next, we perform the evaluation in terms of the desired metrics sequentially:

In [None]:
_ = run_evaluation(score_engine, "rmse", "K", value_range=(0., 3.), model_type=model_type)
_ = run_evaluation(score_engine, "bias", "K", value_range=(-1., 1.), ref_line=0.,
                   model_type=model_type)
_ = run_evaluation(score_engine, "grad_amplitude", "1", value_range=(0.7, 1.1),
                   ref_line=1., model_type=model_type)

The evaluation aggregated over the whole target domain is complemented by a spatial ebaluation of the verification metrics.
This is useful to identify regions where the downscaling model is most prone to errors and to underpin potential reasons for this behaviour. 
Thus, we initialize a new Scores-engine which does not perform any averaging beforehand (empty list passed as dims) and ...

In [None]:
score_engine = Scores(y_pred, ground_truth, [])

... again create a small auxiliary function:

In [None]:
def run_evaluation_spatial(score_engine, score_name: str, score_unit: str, plot_dir, **plt_kwargs):
    os.makedirs(plot_dir, exist_ok=True)
    
    model_type = plt_kwargs.get("model_type", "wgan")
    score_all = score_engine(score_name)
    cosmo_prj = ccrs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25)
    
    score_mean = score_all.mean(dim="time")
    fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_avg_map.png") 
    create_map_score(score_mean, fname, score_dims = ["rlat", "rlon"],
                     title=f"{score_name.upper()} (avg.)", projection=cosmo_prj, **plt_kwargs)    
    
    score_hourly_mean = score_all.groupby("time.hour").mean(dim=["time"])
    for hh in range(24):   
        fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{hh:02d}_map.png")                                  
        create_map_score(score_hourly_mean.sel({"hour": hh}), fname, 
                         score_dims=["rlat", "rlon"], title=f"{score_name.upper()} {hh:02d} UTC",
                         projection=cosmo_prj, **plt_kwargs)

    for hh in range(24):
        score_now = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season").mean(dim="time")
        for sea in score_now["season"]:
            fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea.values}_{hh:02d}_map.png") 
            create_map_score(score_now.sel({"season": sea}), fname, score_dims = ["rlat", "rlon"],
                             title=f"{score_name} {sea.values} {hh:02d} UTC", projection=cosmo_prj, **plt_kwargs)
            
    return True

We run the spatial evaluation procedure for the different metrics. Note that the plots are saved in separated sub-directories for better organization.

In [None]:
lvl_rmse = np.arange(0., 3.1, 0.2)
cmap_rmse = mpl.cm.afmhot_r(np.linspace(0., 1., len(lvl_rmse)))
_ = run_evaluation_spatial(score_engine, "rmse", "K", os.path.join(plot_dir, "rmse_spatial"), cmap=cmap_rmse, levels=lvl_rmse)

lvl_bias = np.arange(-2., 2.1, 0.1)
cmap_bias = mpl.cm.seismic(np.linspace(0., 1., len(lvl_bias)))
_ = run_evaluation_spatial(score_engine, "bias", "K", os.path.join(plot_dir, "bias_spatial"), cmap=cmap_bias, levels=lvl_bias)

# does not work as the gradient gets already spatially averaged 
#lvl_grad = np.arange(0.5, 1.51, 0.025)
#cmap_grad = mpl.cm.seismic(np.linspace(0., 1., len(lvl_grad)))
#_ = run_evaluation_spatial(score_engine, "grad_amplitude", "1", os.path.join(plot_dir, "grad_amplitude_spatial"), cmap=cmap_grad, levels=lvl_grad)