In [141]:
import os, sys
sys.path.append("../models/")
sys.path.append("../utils/")
sys.path.append("../handle_data/")
import tensorflow as tf
import tensorflow.keras as keras
# all the layers used for U-net
from tensorflow.keras.layers import (Activation, BatchNormalization, Concatenate, Conv2D,
                                     Conv2DTranspose, Input, MaxPool2D, Dense, Flatten, GlobalAveragePooling2D
)
from tensorflow.keras.models import Model
from unet_model import build_unet, conv_block
from wgan_model import *
from handle_data_unet import *
from handle_data_class import  *
#from plotting import *
from other_utils import provide_default
from typing import List, Tuple, Union
# import climetlab as cml 
import datetime as dt
import numpy as np
import xarray as xr
import json as js
import gc

from collections import OrderedDict

In [None]:
!pip install climetlab

In [137]:
print(tf.__version__)

2.6.0


# Load the test dataset

In [207]:
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/"
ds_test = xr.open_dataset(os.path.join(datadir, "preproc_era5_crea6_test.nc"))

# Load the trained models

In [225]:
savedir = "/p/project/deepacf/maelstrom/gong1/downscaling_ap5/trained_models/2009_2016_train"
model_name = "2009_2016_train_generator/"
f_dir = os.path.join( savedir, model_name)
new_model = tf.keras.models.load_model(f_dir,compile=False)

In [227]:
# Take small samples for testing
ds_test = ds_test.sel(time=slice("2018-01-01", "2018-01-30"))

In [228]:
groud_truth = ds_test["t_2m_tar"]

In [229]:
#Get the statsitic information from saved json file
js_norm = os.path.join(savedir, "norm_dict.json")
norm_dims = ["time", "lat", "lon"]

with open(js_norm, "r") as f:
    norm_dict = js.load(f)

    
# preprocess data (i.e. normalizing)
def reshape_ds(ds):
    da = ds.to_array(dim="variables")
    da = da.transpose(..., "variables")
    return da

train_vars = list(ds_test.keys())
mu_train, std_train = np.asarray(norm_dict["mu"]), np.asarray(norm_dict["std"])
#mu_train = xr.DataArray(mu_train, coords={"variables": train_vars}, dims=["variables"])
#std_train = xr.DataArray(std_train, coords={"variables": train_vars}, dims=["variables"])
print("mu_train",mu_train)
print("std_train", std_train)
ds_test = reshape_ds(ds_test)
ds_test = HandleUnetData.z_norm_data(ds_test, norm_method="norm", save_path = savedir)


mu_train {'rotated_pole': 1.0, '2t_in': 281.45906591701794, 'sshf_in': -50942.02725137752, 'slhf_in': -163213.65913866778, 'blh_in': 502.4760395386998, '10u_in': 0.6583249490407379, '10v_in': 0.29964431657095814, 'z_in': 5686.443347884734, 't850_in': 277.5147786749324, 't925_in': 281.41634941763175, 'hsurf_tar': 571.5723795572917, 't_2m_tar': 281.918891182312}
std_train {'rotated_pole': 0.0, '2t_in': 8.417833978084378, 'sshf_in': 192018.66261724447, 'slhf_in': 248998.20731436837, 'blh_in': 469.94672652856195, '10u_in': 2.2966472941410414, '10v_in': 1.9018157740395092, 'z_in': 4484.579126295764, 't850_in': 6.938319414004714, 't925_in': 7.52898602570127, 'hsurf_tar': 497.46767440871554, 't_2m_tar': 8.50941746128903}
Loading file: /p/project/deepacf/maelstrom/gong1/downscaling_ap5/trained_models/2009_2016_train/norm_dict.json
norm_dict mu {'rotated_pole': 1.0, '2t_in': 281.45906591701794, 'sshf_in': -50942.02725137752, 'slhf_in': -163213.65913866778, 'blh_in': 502.4760395386998, '10u_in':

In [230]:
ds_test.values.shape
#Split the inputs and output
da_test_in, da_test_tar = HandleDataClass.split_in_tar(ds_test)

# Use trained model and make prediction 

In [231]:
y_pred =  new_model.predict(da_test_in.squeeze().values,batch_size=32)

# Denormalise the prediction values 

In [232]:
# convert predictions to xarray and denormalize
coords = da_test_tar.isel(variables=0).squeeze().coords
dims = da_test_tar.isel(variables=0).squeeze().dims
y_pred_trans = xr.DataArray(y_pred[0].squeeze(), coords=coords, dims=dims)

In [233]:
y_pred_trans

In [235]:
y_pred_trans = y_pred_trans.squeeze()*norm_dict["std"]['t_2m_tar'] +norm_dict["mu"]['t_2m_tar']
y_pred_trans = xr.DataArray(y_pred_trans, coords=coords, dims=dims)

In [236]:
groud_truth.values

In [237]:
y_pred_trans

In [238]:
mse_mean, mse_std = np.zeros(24), np.zeros(24)

for i, hh in enumerate(np.arange(0, 24)):
    #mse_all = ((y_pred_trans.sel(time=dt.time(hh)) - groud_truth.sel(time=dt.time(hh)))**2).mean(dim=["lat", "lon"])
    mse_all = ((y_pred_trans.sel(time=dt.time(hh)) - groud_truth.sel(time=dt.time(hh)))**2)
    mse_mean[i], mse_std[i] = mse_all.mean().values, mse_all.std().values

In [242]:
 mse_mean

In [243]:
mse_std

# Visualize the results

In [244]:
from matplotlib import pyplot as plt

mse_mean = xr.DataArray(mse_mean, coords = {"daytime": np.arange(0,24)}, dims=["daytime"]) 
mse_std = xr.DataArray(mse_std, coords = {"daytime": np.arange(0,24)}, dims=["daytime"])

mse_mean_v, mse_std_v = mse_mean.values, mse_std.values
fig, (ax) = plt.subplots(1,1)
ax.plot(mse_mean["daytime"].values, mse_mean_v, 'k-', label="ERA5 DeepHRES")
ax.fill_between(mse_mean["daytime"].values, mse_mean_v-mse_std_v, mse_mean_v+mse_std_v, facecolor="blue", alpha=0.2)
ax.set_ylim(0.,4.)
# label axis
ax.set_xlabel("daytime [UTC]", fontsize=16)
ax.set_ylabel("MSE T2m [K$^2$]", fontsize=16)
ax.tick_params(axis="both", which="both", direction="out", labelsize=14)

## add MSE from previous non-augmented dataset
#ax.errorbar(12, 0.394, yerr=0.094, fmt='x', capsize=5., ecolor="black", mfc="red",
#            mec="red", ms=10, mew=2., label = "Unet small")
# Configure legend
# get handles
handles, labels = ax.get_legend_handles_labels()
# remove the errorbars
#handles[1] = handles[1][0]

ax.legend(handles, labels, loc='upper right', numpoints=1)
# save plot to file
fig.savefig("downscaling_wgan_t2m_mse.png")

In [None]:
# choose a time index
tind = 380

y_pred_eval = y_pred_trans#.sel(time=dt.time(12))

# plot the full 2m temperature
plt_fname_exp = "./plot_temp_pred_real"
create_plots(y_pred_eval.isel(time=tind), ds_test["t2m_tar"].isel(time=tind), plt_fname_exp,
             opt_plot={"title1": "downscaled T2m", "title2": "target T2m", "levels": np.arange(-3, 27., 1.)})

plt_fname_diff = "./plot_temp_diff"
diff_in_tar = ds_test["2t_in"].isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
diff_down_tar = y_pred_eval.isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
create_plots(diff_in_tar, diff_down_tar, plt_fname_diff,
             opt_plot={"title1": "diff. input-target", "title2": "diff. downscaled-target",
                       "levels": np.arange(-3., 3.1, .2)})