In [1]:
import os, sys
sys.path.append("../models/")
sys.path.append("../utils/")

import tensorflow as tf
import tensorflow.keras as keras

# all the layers used for U-net
from tensorflow.keras.layers import (Activation, BatchNormalization, Concatenate, Conv2D,
                                     Conv2DTranspose, Input, MaxPool2D, Dense, Flatten, GlobalAveragePooling2D
)
from tensorflow.keras.models import Model

from unet_model import build_unet, conv_block
from wgan_model import *
from plotting import *
from other_utils import provide_default

from typing import List, Tuple, Union

import climetlab as cml 
import datetime as dt
import numpy as np
import xarray as xr
import json as js
import gc

from collections import OrderedDict

In [2]:
print(tf.__version__)

2.6.0


In [None]:
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_ifs/netcdf_data/all_files/"

ds_train = xr.open_dataset(os.path.join(datadir, "era5_to_ifs_train_corrected.nc"))
ds_test = xr.open_dataset(os.path.join(datadir, "era5_to_ifs_test_corrected.nc"))
                      
print(ds_test)

In [3]:
modelname = "wgan_lrdecay_test"

savedir = os.path.join("../trained_models/", modelname)
if not os.path.isdir(os.path.join(savedir, f"{modelname}_generator")) or not os.path.isdir(os.path.join(savedir, f"{modelname}_critic")):
    raise ValueError("Cannot find generator and critic model '{0}' under '{1}' to postprocess.".format(modelname, savedir))

In [4]:
wgan_generator = keras.models.load_model(os.path.join(savedir, f"{modelname}_generator"))

2022-06-30 21:54:12.960833: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2022-06-30 21:54:12.960890: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (jwlogin02.juwels): /proc/driver/nvidia/version does not exist
2022-06-30 21:54:12.962424: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX512F
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [None]:
wgan_model = WGAN(build_unet, critic_model, {"lr_decay": True, "lr": 5.e-06, "train_epochs": 10,
                                             "recon_weight": 1000., "d_steps": 6, "optimizer": "adam",
                                             "z_branch": True, "gp_weight": 10.})

# Load the previously saved weights
latest = tf.train.latest_checkpoint(savedir)
wgan_model.load_weights(latest)

In [5]:
js_norm = os.path.join(savedir, "norm_dict.json")
norm_dims = ["time", "lat", "lon"]

with open(js_norm, "r") as f:
    norm_dict = js.load(f)
    
train_vars = list(ds_train.keys())
mu_train, std_train = np.asarray(norm_dict["mu_train"]), np.asarray(norm_dict["std_train"])
mu_train = xr.DataArray(mu_train, coords={"variables": train_vars}, dims=["variables"])
std_train = xr.DataArray(std_train, coords={"variables": train_vars}, dims=["variables"])

da_test = reshape_ds(ds_test)
da_test = z_norm_data(da_test, mu=mu_train, std=std_train)

da_test_in, da_test_tar = WGAN.split_in_tar(da_test)
test_iter = tf.data.Dataset.from_tensor_slices((da_test_in, da_test_tar))
test_iter = test_iter.batch(32)

FileNotFoundError: [Errno 2] No such file or directory: '../trained_models/wgan_lrdecay_test/norm_dict.json'

In [None]:
_, _ = wgan_model.compile(da_test.astype(np.float32), da_test.astype(np.float32))
y_pred = wgan_model.predict(test_iter, batch_size=32)

In [None]:
# convert predictions to xarray and denormalize
coords = da_test.isel(variables=0).squeeze().coords
dims = da_test.isel(variables=0).squeeze().dims

y_pred_trans = xr.DataArray(y_pred[0].squeeze(), coords=coords, dims=dims)

y_pred_trans = y_pred_trans.squeeze()*std_train[0].squeeze() + mu_train[0].squeeze()
y_pred_trans = xr.DataArray(y_pred_trans, coords=coords, dims=dims)

In [None]:
mse_mean, mse_std = np.zeros(24), np.zeros(24)

for i, hh in enumerate(np.arange(0, 24)):
    mse_all = ((y_pred_trans.sel(time=dt.time(hh)) - ds_test["t2m_tar"].sel(time=dt.time(hh)))**2).mean(dim=["lat", "lon"])
    mse_mean[i], mse_std[i] = mse_all.mean().values, mse_all.std().values

In [None]:
mse = ((y_pred_trans - ds_test["t2m_tar"])**2).mean(dim=["lat", "lon"])

print(mse.argmin())

In [None]:
from matplotlib import pyplot as plt

mse_mean = xr.DataArray(mse_mean, coords = {"daytime": np.arange(0,24)}, dims=["daytime"]) 
mse_std = xr.DataArray(mse_std, coords = {"daytime": np.arange(0,24)}, dims=["daytime"])

mse_mean_v, mse_std_v = mse_mean.values, mse_std.values
fig, (ax) = plt.subplots(1,1)
ax.plot(mse_mean["daytime"].values, mse_mean_v, 'k-', label="ERA5 DeepHRES")
ax.fill_between(mse_mean["daytime"].values, mse_mean_v-mse_std_v, mse_mean_v+mse_std_v, facecolor="blue", alpha=0.2)
ax.set_ylim(0.,4.)
# label axis
ax.set_xlabel("daytime [UTC]", fontsize=16)
ax.set_ylabel("MSE T2m [K$^2$]", fontsize=16)
ax.tick_params(axis="both", which="both", direction="out", labelsize=14)

## add MSE from previous non-augmented dataset
#ax.errorbar(12, 0.394, yerr=0.094, fmt='x', capsize=5., ecolor="black", mfc="red",
#            mec="red", ms=10, mew=2., label = "Unet small")
# Configure legend
# get handles
handles, labels = ax.get_legend_handles_labels()
# remove the errorbars
#handles[1] = handles[1][0]

ax.legend(handles, labels, loc='upper right', numpoints=1)
# save plot to file
fig.savefig("downscaling_wgan_t2m_mse.png")

In [None]:
# choose a time index
tind = 380

y_pred_eval = y_pred_trans#.sel(time=dt.time(12))

# plot the full 2m temperature
plt_fname_exp = "./plot_temp_pred_real"
create_plots(y_pred_eval.isel(time=tind), ds_test["t2m_tar"].isel(time=tind), plt_fname_exp,
             opt_plot={"title1": "downscaled T2m", "title2": "target T2m", "levels": np.arange(-3, 27., 1.)})

plt_fname_diff = "./plot_temp_diff"
diff_in_tar = ds_test["2t_in"].isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
diff_down_tar = y_pred_eval.isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
create_plots(diff_in_tar, diff_down_tar, plt_fname_diff,
             opt_plot={"title1": "diff. input-target", "title2": "diff. downscaled-target",
                       "levels": np.arange(-3., 3.1, .2)})