In [1]:
%config Completer.use_jedi = False

In [2]:
import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf

# Predict whole history

- Given $t_{-n}, ..., t_{-1}$ inputs, predict outputs at $t_{-n}, ..., t_{-1}$
- concatanates (put pictures side by side) two features
- For a given month output, 16th of that month until 16th of the month before inputs correspond

In [None]:
models = ['GFDL-ESM4','IPSL-CM6A-LR','MPI-ESM1-2-HR']  # models for temp, prec, LAI
time_len = 10  # for gen_data_card_history(), how long each training sample should be, in months

def gen_data_card_history():  
    model = np.random.choice(np.array(models))  # which of 3 models to choose from
    
    # MONTHLY PICK
    start_year = np.random.randint(1850,2014+1) # randomly select a start year of a time slice
    start_month = np.random.randint(1,12+1)
    end_year = start_year + ((start_month+time_len-1) // 12)
    end_month = (start_month+time_len) % 12
    if end_month == 0:
        end_month = 12
    month_index_start = (start_year-1850)*12 + start_month  # convert date into index with 01-1850 as 0
    month_index_end = month_index_start + time_len
#     print("index:", month_index_start, month_index_end)
#     print(model,start_year, start_month, end_year, end_month)
    
    # select appropriate time slices
    temp = xr.open_mfdataset('data/near_surface_air_temperature/historical/{}/*.nc'.format(model))
    temp = temp.tas.loc["{}-{}-16".format(start_year, start_month):"{}-{}-16".format(end_year, end_month)]  
    
    prec = xr.open_mfdataset('data/precipitation_flux/historical/{}/*.nc'.format(model))
    prec = prec.pr.loc["{}-{}-16".format(start_year, start_month):"{}-{}-16".format(end_year, end_month)]  
    
    # predict only single time step
    lai = xr.open_mfdataset('data/leaf_area_index/historical/{}/*.nc'.format(model))
    lai = np.array(lai.lai)[month_index_end]
    # lai = np.array(lai.lai)[month_index_start:month_index_end]  # predict whole history
    
    npp_files = glob.glob('data/net_primary_production_on_land/historical/**/*.nc', recursive=True) 
    npp = xr.open_mfdataset(np.random.choice(np.array(npp_files)))
    npp = np.array(npp.npp)[month_index_end]
    # npp = np.array(npp.npp)[month_index_start:month_index_end]
                
    # concatanate data
    inputs = np.array(xr.concat((temp,prec), dim='lat'))  # two maps next to each other
    outputs = np.concatenate((lai,npp), axis=-2)
    
    yield(inputs, outputs)

In [None]:
ds = tf.data.Dataset.from_generator(generator=gen_data_card, output_types=(tf.float32, tf.float32))
for inputs, outputs in ds.take(1):
    print(inputs.shape, outputs.shape)

# Predict single timeslice

- Given $t_{-n}, ..., t_{-1}$ inputs, predict outputs at $t_{-1}$
- Give two features as an extra dimension
- Given an output month, use fixed length of days input

In [3]:
import time

In [13]:
data_models = ['GFDL-ESM4','IPSL-CM6A-LR','MPI-ESM1-2-HR']  # models for temp, prec, LAI
day_len = 300  # for gen_data_card()

dmodel = np.random.choice(np.array(data_models))  # choose a model
# load all data externally to save computationtime
temp_ds = np.array(xr.open_mfdataset('data/near_surface_air_temperature/historical/{}/*.nc'.format(dmodel)).tas)
prec_ds = np.array(xr.open_mfdataset('data/precipitation_flux/historical/{}/*.nc'.format(dmodel)).pr)
npp_files = glob.glob('data/net_primary_production_on_land/historical/**/*.nc', recursive=True) 
npp_ds = np.array(xr.open_mfdataset(np.random.choice(np.array(npp_files))).npp)
lai_ds = xr.open_mfdataset('data/leaf_area_index/historical/{}/*.nc'.format(dmodel))

max_month = 1978
min_month = day_len//32

def gen_data_card():
    while True:
        start = time.process_time()
        timer = time.process_time()
        output_month_i = np.random.randint(min_month, max_month)  # y_pred timepoint in int
        print('0',time.process_time() - timer)
        timer = time.process_time()
        
        try:
            endstamp = lai_ds.indexes['time'].to_datetimeindex()[output_month_i]  # cfttimeindex to datetime
        except:
            endstamp = lai_ds.indexes['time'][output_month_i]
        output_day_i = (endstamp - pd.Timestamp('1850-01-01T12')).days  # output is i-th day in int
#         print(time.process_time() - timer)
#         timer = time.process_time()

        # continue with month-based metrics
        lai = np.array(lai_ds.lai)[output_month_i]
#         print('2',time.process_time() - timer)
#         timer = time.process_time()
        npp = npp_ds[output_month_i]
#         print('5',time.process_time() - timer)
#         timer = time.process_time()

        # day-based metrics
        temp = temp_ds[output_day_i-day_len:output_day_i]
#         print('6',time.process_time() - timer)
#         timer = time.process_time()
        prec = prec_ds[output_day_i-day_len:output_day_i]
#         print('7',time.process_time() - timer)
#         timer = time.process_time()

        inputs = np.stack((temp,prec), axis=-1)  # two features
#         print('8',time.process_time() - timer)
#         timer = time.process_time()
        outputs = np.stack((lai,npp), axis=-1)
#         print('9',time.process_time() - timer)
#         timer = time.process_time()
#         print('total', time.process_time() - start)

        yield (inputs, outputs)

In [14]:
mygen = gen_data_card()
next(mygen)

0 7.566200000042045e-05
1 4.375999992589641e-06
0.07601951699999177


  endstamp = lai_ds.indexes['time'].to_datetimeindex()[output_month_i]  # cfttimeindex to datetime


2 0.15587636499999746
3 4.760999999575688e-06
4 3.6250000050586095e-06
5 2.130199999328397e-05
6 2.2557999997729894e-05
7 1.5669000006823808e-05
8 0.006922856000002753
9 0.001188813000013056
total 0.24578256200000226


(array([[[[2.2519220e+02, 0.0000000e+00],
          [2.2591081e+02, 0.0000000e+00],
          [2.2620091e+02, 0.0000000e+00],
          ...,
          [2.2185773e+02, 0.0000000e+00],
          [2.2363916e+02, 0.0000000e+00],
          [2.2444591e+02, 0.0000000e+00]],
 
         [[2.5490581e+02, 0.0000000e+00],
          [2.5207217e+02, 0.0000000e+00],
          [2.5094653e+02, 0.0000000e+00],
          ...,
          [2.4697598e+02, 0.0000000e+00],
          [2.5160922e+02, 1.6463675e-06],
          [2.5381007e+02, 0.0000000e+00]],
 
         [[2.4678241e+02, 4.8839438e-06],
          [2.4748117e+02, 4.5087650e-06],
          [2.4793744e+02, 5.4810830e-06],
          ...,
          [2.5446584e+02, 6.0210050e-06],
          [2.4167412e+02, 3.3332858e-06],
          [2.4167220e+02, 2.9108071e-06]],
 
         ...,
 
         [[2.7282968e+02, 5.6363365e-06],
          [2.7246149e+02, 4.2714510e-06],
          [2.7206039e+02, 4.3346445e-06],
          ...,
          [2.7193304e+02, 5.62797