In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf

# Predict single timeslice

- Given $t_{-n}, ..., t_{-1}$ inputs, predict outputs at $t_{-1}$
- Give two features as an extra dimension
- Given an output month, use fixed length of days input

In [78]:
data_models = ['GFDL-ESM4','IPSL-CM6A-LR','MPI-ESM1-2-HR']  # models for temp, prec, LAI
day_len = 300  # for gen_data_card()
batch_size = 32
# num_cores = 8

# choose a model
dmodel = np.random.choice(np.array(data_models)) 
# load all data externally to save computationtime
temp_ds = np.array(xr.open_mfdataset('data/near_surface_air_temperature/historical/{}/*.nc'.format(dmodel)).tas)
prec_ds = np.array(xr.open_mfdataset('data/precipitation_flux/historical/{}/*.nc'.format(dmodel)).pr)
npp_files = glob.glob('data/net_primary_production_on_land/historical/**/*.nc', recursive=True) 
npp_ds = np.array(xr.open_mfdataset(np.random.choice(np.array(npp_files))).npp)
lai_ds = xr.open_mfdataset('data/leaf_area_index/historical/{}/*.nc'.format(dmodel))

# define range for month index
max_month = 1978
min_month = day_len//28


def gen_data_card():
    while True:
        # array to append to
        endstamp = []
        output_day_i = np.zeros(batch_size)
        lai = np.zeros((batch_size, npp_ds.shape[1], npp_ds.shape[2]))  # batch, lat, lon
        npp = np.zeros((batch_size, npp_ds.shape[1], npp_ds.shape[2]))
        temp = np.zeros((batch_size, day_len, npp_ds.shape[1], npp_ds.shape[2]))  # batch, time, lat, lon
        prec = np.zeros((batch_size, day_len, npp_ds.shape[1], npp_ds.shape[2]))
        
        # index of output in month
        output_month_i = np.random.randint(min_month, max_month, size=batch_size)  # y_pred timepoint in int

        # convert output index to timestamp
        try:
            for i in range(batch_size):
                endstamp.append(lai_ds.indexes['time'].to_datetimeindex()[output_month_i[i]])  # cfttimeindex to datetime               
        except:
            for i in range(batch_size):
                endstamp.append(lai_ds.indexes['time'][output_month_i[i]])

        # convert output month index to day index
        for i in range(batch_size):
            output_day_i[i] = (endstamp[i] - pd.Timestamp('1850-01-01T12')).days  # output is i-th day in int
        output_day_i = np.int_(output_day_i)
#         # use joblib parrallelization, but somehow slower?
#         my_f = lambda x: (x - pd.Timestamp('1850-01-01T12')).days
#         sub_ary = joblib.Parallel(n_jobs=num_cores)(              
#                       delayed(my_f)(endstamp[i])
#                       for i in range(batch_size)) 
#         output_day_i = np.int_(np.stack(sub_ary, axis=0))

        # save month-based time slices
        lainp = np.array(lai_ds.lai)
        for i in range(batch_size):
            lai[i] = lainp[output_month_i[i]]
        for i in range(batch_size):
            npp[i] = npp_ds[output_month_i[i]]

        # day-based metrics
        for i in range(batch_size):
            temp[i] = temp_ds[output_day_i[i]-day_len:output_day_i[i]]
        for i in range(batch_size):
            prec[i] = prec_ds[output_day_i[i]-day_len:output_day_i[i]]

        # merge features
        inputs = np.stack((temp,prec), axis=-1)  # two features
        outputs = np.stack((lai,npp), axis=-1)
        outputs = np.nan_to_num(outputs)

        yield (inputs, outputs)

In [79]:
mygen = gen_data_card()
a = next(mygen)
a[0].shape

total 1.6816175480000197


(32, 300, 36, 72, 2)