## Prepare the training data:

Scripts for 
- Reading the datasets (model and reference)
- Aligning the datasets in time and calculating the mismatch (if needed).
- Masking the na values (also other masks if needed)
- Making perfect square/rectangular shaped canvas for UNet
- ID (data_unique_name) the training data and save it

Output:
- canvas_x, canvas_y, and canvas_m used in training the network saved as train_data.npz (canvas_x is the input data, canvas_y is the output and canvas_m are the weights for training)

What needs to be defined?
- Data pairs e.g. REFERENCE-MODEL or TSMP-COSMO
- Remap_type, which remapping method (e.g., remapbil)
- task_name (model-only, model-lag, temporal, spatiotemporal, spatial)
- Boundary parameters, what are the pixup, pixright, xpix, ypix?
- Number of channels
- Mapping method, direct mapping or mismatch mapping?
- Boxparameters for preparing the data for UNET.

Areas of improvement (to be developed):

- Dealing with negative values in MODEL. 
- Dealing with time- and coordinate-alignment issues.

In [4]:
from py_env_hpc import *

# Define the following:
model_data = ["HRES"] # TSMP must come first for calculating the mismatch correctly in ensembles!!!
reference_data = ["HSAF"]
task_name = "spatiotemporal"
mm = "MM"  # or DM
date_start="2020-10-01"
date_end="2021-09-30"
variable="pr"
mask_type="no_na"
laginensemble=False

# The following is defined automatically:
n_ensembles = len(model_data)
n_channels = Func.calculate_channels(n_ensembles, task_name, laginensemble=laginensemble)
if reference_data == ["COSMO_REA6"]:
    canvas_size = (400, 400) 
    topo_dir='/p/project/deepacf/kiste/patakchiyousefi1/IO/03-TOPOGRAPHY/EU-11-TOPO.npz'
    trim=True
    daily=True
if reference_data == ["HSAF"]:
    topo_dir='/p/project/deepacf/kiste/patakchiyousefi1/IO/03-TOPOGRAPHY/HSAF-TOPO.npz'
    canvas_size = (128, 256)
    trim=False
    daily=False
data_unique_name = f"train_data{'_daily' if daily else '_hourly'}.{variable}.{model_data}.{reference_data}.{mm}.{n_channels}.{'laginensemble' if laginensemble else ''}.{task_name}.{'.'.join(map(str, canvas_size))}.{date_start}.{date_end}.{mask_type}"
filename = f"{data_unique_name}.npz"
print(filename)

train_data_hourly.pr.['HRES'].['HSAF'].MM.6..spatiotemporal.128.256.2020-10-01.2021-09-30.no_na.npz


In [5]:
if filename not in os.listdir(TRAIN_FILES):

    # 1) Open the datasets:
    datasets = []
    for model in model_data:
        dataset = xarray.open_dataset(f"{ATMOS_DATA}/{model}_{variable}.nc")
        dataset = dataset[variable].sel(time=slice(date_start, date_end))
        datasets.append(dataset)

    REFERENCE = xarray.open_dataset(f"{ATMOS_DATA}/{reference_data[0]}_{variable}.nc")
    REFERENCE = REFERENCE[variable].sel(time=slice(date_start, date_end))
    
    # 2) Align time-wise and calculate the mismatch
    for i, model in enumerate(datasets):
        datasets[i]["time"] = datasets[i]["time"].astype(REFERENCE["time"].dtype)
        if reference_data == ["HSAF"]:
            REFERENCE=REFERENCE.where(REFERENCE['time'].isin(datasets[i]['time']),  drop=True)
            datasets[i]=datasets[i].where(datasets[i]['time'].isin(REFERENCE['time']),  drop=True)
        datasets[i], REFERENCE = xarray.align(datasets[i], REFERENCE, join="override")
    
    # Calculate calendar data according to REFERENCE (starting the calendar one day later)
    dayofyear=REFERENCE[1:, ...].time.dt.dayofyear.values
    dayofyear_resh = np.tile(dayofyear[:, np.newaxis, np.newaxis], (1, REFERENCE[1:, ...].shape[1], REFERENCE[1:, ...].shape[2]))
    yeardate=REFERENCE[1:, ...].time.dt.year.values
    yeardate_resh = np.tile(yeardate[:, np.newaxis, np.newaxis], (1, REFERENCE[1:, ...].shape[1], REFERENCE[1:, ...].shape[2]))
    CAL = np.stack((dayofyear_resh, yeardate_resh), axis=3)
    
    REFERENCE = REFERENCE.values[:, :, :, np.newaxis] # add new axis along ensemble dimension
    datasets = [dataset.values for dataset in datasets]
    MODEL = np.stack(datasets, axis=-1)
    if len(datasets)>1:
        TARGET = (MODEL[0] - REFERENCE) if (mm == "MM") else REFERENCE
    else:
        TARGET = (MODEL - REFERENCE) if (mm == "MM") else REFERENCE
    if MODEL.shape[0] < 1:
        print("The selected dates doesn't exist in the netcdf files!")
        
    # 3) prepare dim-wise:
    Y_TRAIN = TARGET[1:, ...]  # t
    X_TRAIN = MODEL[1:, ...] # t
    canvas_y = Func.make_canvas(Y_TRAIN, canvas_size, trim)
    canvas_y = np.nan_to_num(canvas_y, nan=-999) #fill values

    if mask_type == "no_na":
        canvas_m = np.zeros_like(canvas_y) #mask for na values (-999)
        canvas_m[canvas_y != -999] = 1.0

    if task_name == "model-lag":
        X_TRAIN_tminus = np.expand_dims(MODEL[variable].values, axis=3)[:-1, ...] # t-1
        X_TRAIN = np.concatenate ((X_TRAIN_tminus, X_TRAIN), axis=3)
        
    if task_name == "temporal":
        X_TRAIN = np.concatenate((X_TRAIN, CAL), axis=3)
    
    if task_name == "spatial":
        SPP = Func.spatiodataloader(topo_dir, X_TRAIN.shape)
        X_TRAIN = np.concatenate((X_TRAIN, SPP), axis=3)
        
    if task_name == "spatiotemporal":
        SPP = Func.spatiodataloader(topo_dir, X_TRAIN.shape)
        X_TRAIN = np.concatenate((X_TRAIN, CAL, SPP), axis=3)
    
    canvas_x = Func.make_canvas(X_TRAIN, canvas_size, trim)  
    np.savez(TRAIN_FILES + "/" + filename, canvas_x=canvas_x, canvas_y=canvas_y, canvas_m=canvas_m)
    print("data generated")
else:
    print("the data with the same unique name is already available")

data generated
