# Postprocessing trained downscaling models

In [None]:
import os, sys
sys.path.append("../models/")
sys.path.append("../utils/")
sys.path.append("../handle_data/")
import tensorflow as tf
import tensorflow.keras as keras
from handle_data_unet import *
from handle_data_class import  *
from statistical_evaluation import Scores
from plotting import *
import datetime as dt
import numpy as np
import xarray as xr
import json as js

from collections import OrderedDict

## Base directories for test dataset and model

Adapt `datadir`, `model_base_dir` and `model_name`.
 - `datadir`: directory where the test dataset is stored
 - `model_base_dir`: top-level directory where trained downscaling models are saved
 - `model_name`: name of trained model
 - `lztar`: flag if high-resolved (target) topography is part of the input data

In [None]:
data_dir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/"
model_base_dir = "/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/trained_models"
# name of the model to be postprocessed
model_name = "wgan_era5_to_crea6_epochs40_supervision_ztar2in_noes2"
lztar = True

# constrct model directory paths
model_base = os.path.join(model_base_dir, model_name)
model_dir = os.path.join(model_base, f"{model_name}_generator")

Next, we load the model and also retrieve the testing dataset by reading the corresponding netCDF-file.

In [None]:
print(f"Load model '{model_name}'")
trained_model = keras.models.load_model(model_dir, compile=False)
print(f"Read training dataset from {data_dir}") 
ds_test = xr.open_dataset(os.path.join(data_dir, "preproc_era5_crea6_test.nc"))

## Data preprocessing

After retrieving the reference data (i.e. the ground truth data)...

In [None]:
ground_truth = ds_test["t_2m_tar"]

... we preprocess the input from the test dataset. For this, the data is reshaped into a xarray DataArray whose last dimension corresponds to the variables (the feature channels).

In [None]:
# Get the normalization parameters from saved json file
js_norm = os.path.join(model_dir, "..", "z_norm_dict.json")

try:
    with open(js_norm, "r") as f:
        norm_dict = js.load(f)
except FileNotFoundError as e:
    raise FileNotFoundError(f"Could not find '{js_norm}'. Please check model-directory '{model_dir}'.")

train_vars = list(ds_test.keys())
mu_train, std_train = np.asarray(norm_dict["mu"]), np.asarray(norm_dict["std"])
da_test = HandleDataClass.reshape_ds(ds_test)
da_test = HandleUnetData.z_norm_data(da_test, norm_method="norm", save_path=model_base)

In [None]:
# Split the inputs and the target data
da_test_in, da_test_tar = HandleDataClass.split_in_tar(da_test)
if lztar:
    print("Add high-resolved target topography to input features.")
    da_test_in = xr.concat([da_test_in, da_test_tar.sel({"variables": "hsurf_tar"})], dim="variables")

## Create predictions from trained model

The preprocessed data is fed into the trained model to obtain the downscalted 2m tmepertaure which is subject to evaluation later.

In [None]:
print("Start inference from trained model...")
y_pred_trans =  trained_model.predict(da_test_in.squeeze().values, batch_size=32, verbose=1)

For evaluation, we have to denormalize the data. 

In [None]:
# get coordinates and dimensions from target data
coords = da_test_tar.isel(variables=0).squeeze().coords
dims = da_test_tar.isel(variables=0).squeeze().dims
y_pred = xr.DataArray(y_pred_trans[0].squeeze(), coords=coords, dims=dims)
# perform denormalization
y_pred = HandleUnetData.denormalize(y_pred.squeeze(), 
                                    norm_dict["mu"]["t_2m_tar"], 
                                    norm_dict["std"]["t_2m_tar"])
y_pred = xr.DataArray(y_pred, coords=coords, dims=dims)

## Evaluation

Subsequently, the produced downscaling product is evaluated using the following scores
- RMSE
- Bias
- Horizontal gradient ratio

In [None]:
def create_line_plot(data: xr.DataArray, data_std: xr.DataArray, model_name: str, metric: dict,
                     filename: str):
    
    fig, (ax) = plt.subplots(1,1)
    ax.plot(data["hour"].values, data.values, 'k-', label=model_name)
    ax.fill_between(data["hour"].values, data.values-data_std.values, data.values+data_std.values, facecolor="blue", alpha=0.2)
    ax.set_ylim(0.,4.)
    # label axis
    ax.set_xlabel("daytime [UTC]", fontsize=16)
    metric_name, metric_unit = list(metric.keys())[0], list(metric.values())[0]
    ax.set_ylabel(f"{metric_name} T2m [{metric_unit}]", fontsize=16)
    ax.tick_params(axis="both", which="both", direction="out", labelsize=14)

    # save plot to file
    fig.savefig(filename)

In [None]:
# get plot directory
plot_dir = os.path.join(".", model_name)
os.makedirs(plot_dir, exist_ok=True)

avg_dims = ["rlat", "rlon"]
# instantiate score engine
score_engine = Scores(y_pred, ground_truth, avg_dims)

Now, we start to create the RMSE-plots:

In [None]:
# calculate rmse
rmse_all = score_engine("rmse")

rmse_hourly_all = rmse_all.groupby("time.hour")
rmse_hourly_mean, rmse_hourly_std = rmse_hourly_all.mean(), rmse_hourly_all.std()
for hh in range(24):
    if hh == 0:
        tmp = rmse_all.isel({"time": rmse_all.time.dt.hour == hh}).groupby("time.season")
        rmse_hourly_mean_sea, rmse_hourly_std_sea = tmp.mean(), tmp.std()
    else:
        tmp = rmse_all.isel({"time": rmse_all.time.dt.hour == hh}).groupby("time.season")
        rmse_hourly_mean_sea, rmse_hourly_std_sea = xr.concat([rmse_hourly_mean_sea, tmp.mean()], dim="hour"), \
                                                    xr.concat([rmse_hourly_std_sea, tmp.std()], dim="hour")
   
# create RMSE plots                                  
create_line_plot(rmse_hourly_mean, rmse_hourly_std, "WGAN",
                 {"RMSE": "K"}, os.path.join(plot_dir, "downscaling_wgan_rmse.png"))

for sea in rmse_hourly_mean_sea["season"]:
    create_line_plot(rmse_hourly_mean_sea.sel({"season": sea}), 
                     rmse_hourly_mean_sea.sel({"season": sea}),
                     "WGAN", {"RMSE": "K"},
                     os.path.join(plot_dir, f"downscaling_wgan_rmse_{sea.values}.png"))

Next, the Bias gets visualized:

In [None]:
# calculate bias
bias_all = score_engine("bias")

bias_hourly_all = bias_all.groupby("time.hour")
bias_hourly_mean, bias_hourly_std = bias_hourly_all.mean(), bias_hourly_all.std()
for hh in range(24):
    if hh == 0:
        tmp = bias_all.isel({"time": bias_all.time.dt.hour == hh}).groupby("time.season")
        bias_hourly_mean_sea, bias_hourly_std_sea = tmp.mean(), tmp.std()
    else:
        tmp = bias_all.isel({"time": bias_all.time.dt.hour == hh}).groupby("time.season")
        bias_hourly_mean_sea, bias_hourly_std_sea = xr.concat([bias_hourly_mean_sea, tmp.mean()], dim="hour"), \
                                                    xr.concat([bias_hourly_std_sea, tmp.std()], dim="hour")
   
# create RMSE plots                                  
create_line_plot(bias_hourly_mean, bias_hourly_std, "WGAN",
                 {"Bias": "K"}, os.path.join(plot_dir, "downscaling_wgan_bias.png"))

for sea in bias_hourly_mean_sea["season"]:
    create_line_plot(bias_hourly_mean_sea.sel({"season": sea}), 
                     bias_hourly_mean_sea.sel({"season": sea}),
                     "WGAN", {"Bias": "K"},
                     os.path.join(plot_dir, f"downscaling_wgan_bias_{sea.values}.png"))

Finally, the spatial variability gets evaluated:

In [None]:
# calculate bias
grad_all = score_engine("grad_amplitude")

grad_hourly_all = grad_all.groupby("time.hour")
grad_hourly_mean, grad_hourly_std = grad_hourly_all.mean(), grad_hourly_all.std()
for hh in range(24):
    if hh == 0:
        tmp = grad_all.isel({"time": grad_all.time.dt.hour == hh}).groupby("time.season")
        grad_hourly_mean_sea, grad_hourly_std_sea = tmp.mean(), tmp.std()
    else:
        tmp = grad_all.isel({"time": grad_all.time.dt.hour == hh}).groupby("time.season")
        grad_hourly_mean_sea, grad_hourly_std_sea = xr.concat([grad_hourly_mean_sea, tmp.mean()], dim="hour"), \
                                                    xr.concat([grad_hourly_std_sea, tmp.std()], dim="hour")
   
# create RMSE plots                                  
create_line_plot(grad_hourly_mean, grad_hourly_std, "WGAN",
                 {"Gradient ratio": "1"}, os.path.join(plot_dir, "downscaling_wgan_grad_rat.png"))

for sea in grad_hourly_mean_sea["season"]:
    create_line_plot(grad_hourly_mean_sea.sel({"season": sea}), 
                     grad_hourly_std_sea.sel({"season": sea}),
                     "WGAN", {"Gradient ratio": "1"},
                     os.path.join(plot_dir, f"downscaling_wgan_grad_rat_{sea.values}.png"))

## Evaluation with map plots

In [None]:
########## TO-DO ########## 

y_pred_eval = y_pred_trans#.sel(time=dt.time(12))

# plot the full 2m temperature
plt_fname_exp = "./plot_temp_pred_real"
create_plots(y_pred_eval.isel(time=tind), ds_test["t2m_tar"].isel(time=tind), plt_fname_exp,
             opt_plot={"title1": "downscaled T2m", "title2": "target T2m", "levels": np.arange(-3, 27., 1.)})

plt_fname_diff = "./plot_temp_diff"
diff_in_tar = ds_test["2t_in"].isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
diff_down_tar = y_pred_eval.isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
create_plots(diff_in_tar, diff_down_tar, plt_fname_diff,
             opt_plot={"title1": "diff. input-target", "title2": "diff. downscaled-target",
                       "levels": np.arange(-3., 3.1, .2)})