# Data augmented dataset for 2m temperature downscaing with U-net

So far, the dataset used for training the U-net in the first deliverable is fairly small (only ~300 MB of training data).
However, for benchmarking the [HPC-systems]() a larger dataset is considered to be more reasonable.
For the sake of the related deliverable, the dataset is therefore augmented as follows:
- Add more daytimes (e.g. 10-16 UTC) instead of choosing on one daytime (e.g. 12 UTC)
- Perform simple data augmentation by flipping along the geographical axis (latitude and longitude)

In total, this increases the number of samples by a factor of 7x4=28. The complete dataset should the comprise 25.620 samples.

As a prepartory step, the preprocessing with the Python-script `preprocess_downscaling_data.py` must be performed in which the original IFS HRES data is processed (with lead times between 0 and 11 hours).
For this purpose, set-up `preprocess_ifs_hres_data_template.sh` accordingly and run the resulting runscript.

In [None]:
import os
import glob
import numpy as np
import xarray as xr
import pandas as pd
import datetime as dt

In [None]:
# Set-up data-directory and load all merged netCDf-files of complete dataset
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ifs_hres/preprocessed/netcdf_data/workdir"
data_all = xr.open_mfdataset(os.path.join(datadir, "sfc_*_merged.nc"))

In [None]:
# Slice data to daytimes of interest (i.e. between 10 and 17 UTC)
daytimes = list(range(10, 17))

data_all_sub = data_all.sel(time=data_all.time.dt.hour.isin(daytimes)).load()
times= data_all_sub["time"]

In [None]:
# Perform data augmentation

# generate data-array with flipped latitude axis
data_all_sub_invlat = data_all_sub.reindex(lat=data_all_sub.lat[::-1])
data_all_sub_invlat.coords["lat"] = data_all_sub["lat"]
# generate data-array with flipped longitude axis
data_all_sub_invlon = data_all_sub.reindex(lon=data_all_sub.lon[::-1])
data_all_sub_invlon.coords["lon"] = data_all_sub["lon"]
# generate data-array with flipped latitude and longitude axis
data_all_sub_invlatlon = data_all_sub_invlat.reindex(lon=data_all_sub_invlat.lon[::-1])
data_all_sub_invlatlon.coords["lon"] = data_all_sub["lon"]

In [None]:
# check for correctness of data augmentation 
var="t2m_in"

assert data_all_sub_invlat[var][0,-6,5].values == data_all_sub[var][0,5,5].values, \
       "Latitude flipping did not work as expected. Check previous flipping method."
assert data_all_sub_invlon[var][0, 5, -6].values == data_all_sub[var][0, 5, 5].values, \
       "Longitude flipping did not work as expected. Check previous flipping method."
assert data_all_sub_invlatlon[var][0, -6, -6].values == data_all_sub[var][0, 5, 5].values, \
       "Latitude-Longitude flipping did not work as expected. Check previous flipping method."

print("Checks for data augmentation have been passed successfully.")

In [None]:
# manipulate time-coordinates of flipped datasets for later merging
data_all_sub_invlat.coords["time"] = pd.to_datetime(times.values) + dt.timedelta(minutes=1)
data_all_sub_invlon.coords["time"] = pd.to_datetime(times.values) + dt.timedelta(minutes=2)
data_all_sub_invlatlon.coords["time"] = pd.to_datetime(times.values) + dt.timedelta(minutes=3)

In [None]:
# concatenate dataset along time-axis
ds_aug = xr.concat([data_all_sub, data_all_sub_invlat, data_all_sub_invlon, data_all_sub_invlatlon], dim="time")
# print to check dimensions
print(ds_aug)

Now that we have the complete, augmented dataset. Let's spilt up into training, validataion and testing data.
The former comprises all data for the years 2016 to 2019, i.e. four years. The latter two subsets use data from 2020 with the months May, July and August for the validation and April, June and September for the test dataset, respectively.

In [None]:
yr_train = list(range(2016,2020)) 
yr_val= yr_test = [2020]
mo_val = [5, 7, 8]
mo_test = [4, 6, 9]

ds_train = ds_aug.sel(time=ds_aug.time.dt.year.isin(yr_train))
ds_val = ds_aug.sel(time=(ds_aug.time.dt.month.isin(mo_val) & ds_aug.time.dt.year.isin(yr_val)))
ds_test = ds_aug.sel(time=(ds_aug.time.dt.month.isin(mo_test) & ds_aug.time.dt.year.isin(yr_test)))

# just a check taht we get all data
assert len(ds_test["time"]) + len(ds_val["time"]) + len(ds_train["time"]) == len(ds_aug["time"]), \
   "Not all samples ({0:d}) from the augmented dataset have been used.".format(len(ds_aug["time"]))

In [None]:
# finally, write augmented dataset to netCDF-file
datafile_train = os.path.join(datadir, "maelstrom-downscaling_train_aug.nc")
datafile_val = datafile_train.replace("_train_", "_val_")
datafile_test = datafile_train.replace("_train_", "_test_")

data_dict = {datafile_train: ds_train, datafile_val: ds_val, datafile_test: ds_test}

for dfile, ds in data_dict.items():
    if os.path.exists(dfile):
        print("File '{0}' already exists. Remove the file if you would like to create a new one.".format(datafile_augmented))
    else:
        print("Write augmented dataset to file '{0}'...".format(dfile))
        ds.to_netcdf(dfile)

In [None]:
print(ds_test)