# Benchmarking the application "Downscaling of 2m temperature from IFS HRES with a U-Net"

In [None]:
!pip install climetlab==0.8.14
!pip install climetlab-maelstrom-downscaling==0.1.0

In [None]:
import os, sys
import time
import tensorflow.keras as keras
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.utils as ku
sys.path += ["../handle_data/", "../models", "../postprocess/"]
from handle_data_unet import *
from unet_model import build_unet, get_lr_scheduler
import xarray as xr
import datetime as dt

In [None]:
datadir = "/p/project/deepacf/maelstrom/data/downscaling_unet/"

data_obj = HandleUnetData(datadir, "train")
data_obj.append_data("val")
data_obj.append_data("test")

In [None]:
# set daytime for which downsclaing model is trained (i.e. either 0 or 12)
hour = 12    

# preprocess data for training
int_data, tart_data, opt_norm = data_obj.normalize("train", daytime=12)
inv_data, tarv_data = data_obj.normalize("val", daytime=hour, opt_norm=opt_norm)

print(data_obj.timing)
print(data_obj.data_info["memory_datasets"])
print(data_obj.data_info["nsamples"])

In [None]:
from tensorflow.python.client import device_lib

device_lib.list_local_devices()

In [None]:
import tensorflow.keras.utils as ku
shape_in = (96, 128, 3)

if "login" in data_obj.host:
    unet_model = build_unet(shape_in, z_branch=True)
    ku.plot_model(unet_model, to_file=os.path.join(os.getcwd(), "unet_downscaling_model.png"), show_shapes=True)

In [None]:
# define class for creating timer callback
class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.epoch_times = []

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, epoch, logs={}):
        self.epoch_times.append(time.time() - self.epoch_time_start)
        
z_branch = True                    # flag if additionally training on surface elevation is performed

lr_scheduler, time_tracker = get_lr_scheduler(), TimeHistory()
# create callbacks
callback_list = [lr_scheduler, time_tracker]

In [None]:
# build, compile and train the model
nepochs = 70
unet_model = build_unet(shape_in, z_branch=z_branch)
if z_branch:
    unet_model.compile(optimizer=Adam(learning_rate=5*10**(-4)),
                   loss={"output_temp": "mae", "output_z": "mae"}, 
                   loss_weights={"output_temp": 1.0, "output_z": 1.0})
    
    history = unet_model.fit(x=int_data.values, y={"output_temp": tart_data.isel(variable=0).values,
                                                   "output_z": tart_data.isel(variable=1).values},
                             batch_size=32, epochs=nepochs, callbacks=callback_list, 
                             validation_data=(inv_data.values, {"output_temp": tarv_data.isel(variable=0).values,
                                                                "output_z": tarv_data.isel(variable=1).values}))
else:
    unet_model.compile(optimizer=Adam(learning_rate=5*10**(-4)), loss="mae")

    history = unet_model.fit(x=int_data.values, y=tart_data.isel(variable=0).values, batch_size=32,
                             epochs=nepochs, callbacks=callback_list,
                             validation_data=(inv_data.values, tarv_data.isel(variable=0).values))

In [None]:
epoch_times = time_tracker.epoch_times

print(history.history["output_temp_loss"][-1])
print(history.history["val_output_temp_loss"][-1])

print("Total training time: {0:.2f}s".format(np.sum(epoch_times)))
print("Max. time per epoch: {0:.4f}s, min. time per epoch: {1:.4f}s".format(np.amax(epoch_times), np.amin(epoch_times)))

In [None]:
# preprocess the test data first
inte_data, tarte_data = data_obj.preprocess_data("test", daytime=hour, opt_norm=opt_norm)

# generate the downscaled fields
y_pred_test = unet_model.predict(inte_data.values, verbose=1)
y_pred_val = unet_model.predict(inv_data.values, verbose=1)

In [None]:
comparison_type = "validation"            # change here to switch between validation and testing data
if comparison_type == "validation":
  y_pred = y_pred_val
  ds_ref = data_obj.data["val"].sel(time=dt.time(hour))
  var_ref = tarv_data.isel(variable=0)
elif comparison_type == "test":
  y_pred = y_pred_test
  ds_ref = data_obj.data["test"].sel(time=dt.time(hour))
  var_ref = tarte_data.isel(variable=0)
else:
  ValueError("Unknown comparison_type '{0}' chosen.".format(comparison_type))

if np.ndim(y_pred) == 5:                # cropping necessary if z_branch is True (two output channels)
  y_pred = y_pred[0]
else:
  pass

print(np.abs(np.squeeze(y_pred) - var_ref).mean(dim=["lat", "lon"]).mean())

In [None]:
# get some relevant information from the original dataset, ...
coords = var_ref.squeeze().coords
dims = var_ref.squeeze().dims

# denomralize...
y_pred_trans = np.squeeze(y_pred)*opt_norm["std_tar"].squeeze().values + opt_norm["mu_tar"].squeeze().values
# and make xarray DataArray 
y_pred_trans = xr.DataArray(y_pred_trans, coords=coords, dims=dims, name="t2m_downscaled")

In [None]:
mse = ((y_pred_trans - ds_ref["t2m_tar"])**2).mean(dim=["lat", "lon"])

print("MSE of downscaled 2m temperature: {0:.3f} K**2 (+/-{1:.3f} K**2)".format(mse.mean().values, mse.std().values))

As we see, the model has learned to recover a lot of details resulting mainly from the topography. Especially over the Alpes, but also over the the German low mountain ranges, the differences have become smaller and less structured. It is also noted that the differences near the coast (e.g. at the Baltic Sea) have become smaller. <br>
However, some systematic features are still visible, the differences can stilll be as large as 3 K and especially in the Alps, the differences are somehow 'blurry'. Thus, there is still room for further improvement. 
These improvements will not only pertain the model architecture, but will also target to engulf more meteorological variables. The latter will also enable the network to generalize with respect to daytime and season. Note, that this has not been done yet, since we trained the U-net with data between April and September at 12 UTC only.
 