In [106]:
import os, sys
sys.path.append("../models/")
sys.path.append("../utils/")
sys.path.append("../handle_data/")

import tensorflow as tf
import tensorflow.keras as keras

# all the layers used for U-net
from tensorflow.keras.layers import (Activation, BatchNormalization, Concatenate, Conv2D,
                                     Conv2DTranspose, Input, MaxPool2D, Dense, Flatten, GlobalAveragePooling2D
)
from tensorflow.keras.models import Model
from unet_model import build_unet, conv_block
from wgan_model import *
from handle_data_unet import *
from handle_data_class import  *
#from plotting import *
from other_utils import provide_default
from typing import List, Tuple, Union
import climetlab as cml 
import datetime as dt
import numpy as np
import xarray as xr
import json as js
import gc

from collections import OrderedDict

In [78]:
!pip install climetlab

Defaulting to user installation because normal site-packages is not writeable


In [79]:
print(tf.__version__)

2.6.0


In [107]:
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/"

ds_train = xr.open_dataset(os.path.join(datadir, "preproc_era5_crea6_train.nc"))
ds_test = xr.open_dataset(os.path.join(datadir, "preproc_era5_crea6_test.nc"))
                      
print(ds_test)

<xarray.Dataset>
Dimensions:       (time: 8748, rlon: 120, rlat: 96)
Coordinates:
  * time          (time) datetime64[ns] 2018-01-01T01:00:00 ... 2018-12-31T23...
  * rlon          (rlon) float64 -8.273 -8.218 -8.163 ... -1.838 -1.783 -1.728
  * rlat          (rlat) float64 -3.933 -3.878 -3.823 ... 1.182 1.237 1.292
Data variables:
    rotated_pole  int32 ...
    2t_in         (time, rlat, rlon) float32 ...
    sshf_in       (time, rlat, rlon) float32 ...
    slhf_in       (time, rlat, rlon) float32 ...
    blh_in        (time, rlat, rlon) float32 ...
    10u_in        (time, rlat, rlon) float32 ...
    10v_in        (time, rlat, rlon) float32 ...
    z_in          (time, rlat, rlon) float32 ...
    t850_in       (time, rlat, rlon) float32 ...
    t925_in       (time, rlat, rlon) float32 ...
    hsurf_tar     (time, rlat, rlon) float32 ...
    t_2m_tar      (time, rlat, rlon) float32 ...
Attributes:
    CDI:          Climate Data Interface version 2.0.2 (https://mpimet.mpg.de...
    Co

In [108]:
np.max(ds_test["2t_in"])

In [109]:
modelname = "2014_2016_train"

savedir = os.path.join("/p/project/deepacf/maelstrom/gong1/downscaling_ap5/trained_models/", modelname)
if not os.path.isdir(os.path.join(savedir, f"{modelname}_generator")) or not os.path.isdir(os.path.join(savedir, f"{modelname}_critic")):
    raise ValueError("Cannot find generator and critic model '{0}' under '{1}' to postprocess.".format(modelname, savedir))

In [110]:
wgan_generator = keras.models.load_model(os.path.join(savedir, f"{modelname}_generator"))



In [111]:
wgan_model = WGAN(build_unet, critic_model, {"lr_decay": True, "lr": 5.e-06, "train_epochs": 10,
                                             "recon_weight": 1000., "d_steps": 6, "optimizer": "adam",
                                             "z_branch": True, "gp_weight": 10.})

# Load the previously saved weights
latest = tf.train.latest_checkpoint(savedir)
wgan_model.load_weights(latest)

The following parsed hyperparameters are unknown and thus are ignored: lr


AttributeError: 'NoneType' object has no attribute 'endswith'

In [None]:
latest 

In [None]:
#wgan_model

In [None]:
ds_test

In [112]:
stat_dir = "/p/project/deepacf/maelstrom/gong1/downscaling_ap5/trained_models/2014_2016_train/"

In [113]:
f_dir = "/p/project/deepacf/maelstrom/gong1/downscaling_ap5/trained_models/2014_2016_train/2014_2016_train_generator/"
new_model = tf.keras.models.load_model(f_dir,compile=False)

In [114]:
ds_test

In [115]:
ds_test = ds_test.sel(time=slice("2018-01-01", "2018-01-30"))

In [116]:
ds_test["2t_in"]

In [117]:
js_norm = os.path.join(savedir, "norm_dict.json")
norm_dims = ["time", "lat", "lon"]

with open(js_norm, "r") as f:
    norm_dict = js.load(f)
    
train_vars = list(ds_train.keys())
mu_train, std_train = np.asarray(norm_dict["mu"]), np.asarray(norm_dict["std"])
#mu_train = xr.DataArray(mu_train, coords={"variables": train_vars}, dims=["variables"])
#std_train = xr.DataArray(std_train, coords={"variables": train_vars}, dims=["variables"])
print("mu_train",mu_train)
print("std_train", std_train)
#da_test = reshape_ds(ds_test)
ds_test = HandleUnetData.z_norm_data(ds_test, norm_method="norm", save_path = stat_dir)
ds_test = ds_test.drop("rotated_pole")

mu_train {'rotated_pole': 1.0, '2t_in': 282.10606790804894, 'sshf_in': -51064.59748829226, 'slhf_in': -163613.4169778147, 'blh_in': 498.8113857893632, '10u_in': 0.6717324768282911, '10v_in': 0.3682871183282463, 'z_in': 5686.443347885874, 't850_in': 278.13217365043425, 't925_in': 282.0733933286382, 'hsurf_tar': 571.5723795572917, 't_2m_tar': 282.59556879231184}
std_train {'rotated_pole': 0.0, '2t_in': 7.8750459455031745, 'sshf_in': 193919.07785517865, 'slhf_in': 247231.89637676522, 'blh_in': 476.12699951069465, '10u_in': 2.279233517951644, '10v_in': 1.867399198139431, 'z_in': 4484.579126295651, 't850_in': 6.404352097562571, 't925_in': 6.934853246069489, 'hsurf_tar': 497.4676744087111, 't_2m_tar': 7.97646439233192}
Loading file: /p/project/deepacf/maelstrom/gong1/downscaling_ap5/trained_models/2014_2016_train/norm_dict.json
norm_dict mu {'rotated_pole': 1.0, '2t_in': 282.10606790804894, 'sshf_in': -51064.59748829226, 'slhf_in': -163613.4169778147, 'blh_in': 498.8113857893632, '10u_in': 0

In [120]:
ds_test

In [122]:
#da_test_in  = da_test_in.drop("rotated_pole")
da_test_in, da_test_tar = HandleDataClass.split_in_tar(ds_test)
# test_iter = tf.data.Dataset.from_tensor_slices((da_test_in.to_array(dim = "variables").squeeze(), da_test_tar.to_array(dim = "variables").squeeze()))
# test_iter = test_iter.batch(32)


In [139]:
da_test_tar["hsurf_tar"].shape

(719, 96, 120, 2)

In [149]:
da_test_tar.values


<bound method Mapping.values of <xarray.Dataset>
Dimensions:    (time: 719, rlon: 120, rlat: 96, variables: 2)
Coordinates:
  * time       (time) datetime64[ns] 2018-01-01T01:00:00 ... 2018-01-30T23:00:00
  * rlon       (rlon) float64 -8.273 -8.218 -8.163 ... -1.838 -1.783 -1.728
  * rlat       (rlat) float64 -3.933 -3.878 -3.823 -3.768 ... 1.182 1.237 1.292
  * variables  (variables) <U12 't_2m_tar' 'hsurf_tar'
Data variables:
    2t_in      (time, rlat, rlon, variables) float64 -0.6851 -0.5919 ... -0.5972
    sshf_in    (time, rlat, rlon, variables) float64 1.321e+04 211.3 ... 204.0
    slhf_in    (time, rlat, rlon, variables) float64 -4.037e+03 ... -21.14
    blh_in     (time, rlat, rlon, variables) float64 53.32 0.2741 ... -0.6506
    10u_in     (time, rlat, rlon, variables) float64 -35.11 -1.144 ... -1.144
    10v_in     (time, rlat, rlon, variables) float64 -35.19 -1.145 ... -1.146
    z_in       (time, rlat, rlon, variables) float64 975.4 15.06 ... 0.4475
    t850_in    (time, r

In [148]:
da_test_in.isel(variables=0).squeeze().to_array().shape

(11, 719, 96, 120)

In [75]:
_, _ = wgan_model.compile(ds_test.astype(np.float32), ds_test.astype(np.float32))
y_pred = wgan_model.predict(test_iter, batch_size=32)

AttributeError: 'Dataset' object has no attribute 'shape'

In [None]:
# convert predictions to xarray and denormalize
coords = da_test.isel(variables=0).squeeze().coords
dims = da_test.isel(variables=0).squeeze().dims

y_pred_trans = xr.DataArray(y_pred[0].squeeze(), coords=coords, dims=dims)

y_pred_trans = y_pred_trans.squeeze()*std_train[0].squeeze() + mu_train[0].squeeze()
y_pred_trans = xr.DataArray(y_pred_trans, coords=coords, dims=dims)

In [None]:
mse_mean, mse_std = np.zeros(24), np.zeros(24)

for i, hh in enumerate(np.arange(0, 24)):
    mse_all = ((y_pred_trans.sel(time=dt.time(hh)) - ds_test["t2m_tar"].sel(time=dt.time(hh)))**2).mean(dim=["lat", "lon"])
    mse_mean[i], mse_std[i] = mse_all.mean().values, mse_all.std().values

In [None]:
mse = ((y_pred_trans - ds_test["t2m_tar"])**2).mean(dim=["lat", "lon"])

print(mse.argmin())

In [None]:
from matplotlib import pyplot as plt

mse_mean = xr.DataArray(mse_mean, coords = {"daytime": np.arange(0,24)}, dims=["daytime"]) 
mse_std = xr.DataArray(mse_std, coords = {"daytime": np.arange(0,24)}, dims=["daytime"])

mse_mean_v, mse_std_v = mse_mean.values, mse_std.values
fig, (ax) = plt.subplots(1,1)
ax.plot(mse_mean["daytime"].values, mse_mean_v, 'k-', label="ERA5 DeepHRES")
ax.fill_between(mse_mean["daytime"].values, mse_mean_v-mse_std_v, mse_mean_v+mse_std_v, facecolor="blue", alpha=0.2)
ax.set_ylim(0.,4.)
# label axis
ax.set_xlabel("daytime [UTC]", fontsize=16)
ax.set_ylabel("MSE T2m [K$^2$]", fontsize=16)
ax.tick_params(axis="both", which="both", direction="out", labelsize=14)

## add MSE from previous non-augmented dataset
#ax.errorbar(12, 0.394, yerr=0.094, fmt='x', capsize=5., ecolor="black", mfc="red",
#            mec="red", ms=10, mew=2., label = "Unet small")
# Configure legend
# get handles
handles, labels = ax.get_legend_handles_labels()
# remove the errorbars
#handles[1] = handles[1][0]

ax.legend(handles, labels, loc='upper right', numpoints=1)
# save plot to file
fig.savefig("downscaling_wgan_t2m_mse.png")

In [None]:
# choose a time index
tind = 380

y_pred_eval = y_pred_trans#.sel(time=dt.time(12))

# plot the full 2m temperature
plt_fname_exp = "./plot_temp_pred_real"
create_plots(y_pred_eval.isel(time=tind), ds_test["t2m_tar"].isel(time=tind), plt_fname_exp,
             opt_plot={"title1": "downscaled T2m", "title2": "target T2m", "levels": np.arange(-3, 27., 1.)})

plt_fname_diff = "./plot_temp_diff"
diff_in_tar = ds_test["2t_in"].isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
diff_down_tar = y_pred_eval.isel(time=tind)-ds_test["t2m_tar"].isel(time=tind) + 273.15
create_plots(diff_in_tar, diff_down_tar, plt_fname_diff,
             opt_plot={"title1": "diff. input-target", "title2": "diff. downscaled-target",
                       "levels": np.arange(-3., 3.1, .2)})