In [None]:
proj_dir='/path/to/main_project_folder/' # edit this line

import numpy as np
import random
import xarray as xr
import time
import h5py
import pandas as pd
from denseweight import DenseWeight
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.random import set_seed as tf_set_seed
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers.schedules import InverseTimeDecay
#from tensorflow.keras import optimizers
from tensorflow.compat.v1.keras import optimizers
from sklearn.model_selection import train_test_split
import tensorflow.compat.v1 as tf_compat_v1
from scipy.stats import loguniform
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from matplotlib import colors
import sys
sys.path.append(proj_dir)
from project_utils import parameters as param
from project_utils import load_region
from project_utils import prepare_inputs
from project_utils import utils as util
from project_utils import model_utils as mu
import analysis_fxns as fxns
import importlib
importlib.reload(fxns)
importlib.reload(param)
importlib.reload(prepare_inputs)
importlib.reload(util)
importlib.reload(mu)
importlib.reload(load_region)
import multiprocessing as mp

## set random seeds ##
np.random.seed(101)
random.seed(201)
tf_set_seed(333)
session_conf = tf_compat_v1.ConfigProto(device_count={'CPU': 24})
sess = tf_compat_v1.Session(config=session_conf)

# Define constants

In [None]:
ninputs = 1 # number of geospatial CNN input channels (i.e., GPH and SM)
w = 200 # window size for smoothing PDP curve
SNOWFREE = True # remove winter months from training data
ndense = 1
ncnn = 1
ncrp = 2
conv_filters = 8
dense_neurons = 32
verb = 0
curr_patience = 100
curr_loss = 'mean_squared_error'

region_list = ['northcentral_north_america', 
               'southcentral_north_america', 
               'southeastern_north_america', 
               'southwestern_europe', 
               'western_europe', 
               'central_europe', 
               'eastern_europe', 
               'northeastern_europe', 
               'northeastern_asia', 
               'southeastern_asia', 
               'northsouthern_south_america', 
               'southsouthern_south_america', 
               'southwestern_africa', 
               'southeastern_africa', 
               'southwestern_australia', 
               'southeastern_australia'
              ]
nlag_tpl = [1] # number of days to lag the soil moisture input behind the prediction day

# Train Convolutional Neural Networks and Calculate Partial Dependence Plots

In [None]:
for dset in ["ERA5", "NCEP"]:
    for nlag in nlag_tpl:   
        for ij, region_str in enumerate(region_list):
            hem, region_input_lat_bbox, region_input_lon_bbox, region_box_x, region_box_y, region_lat, region_lon, region_lon_EW, region_t62_lats, region_t62_lons = load_region.load_region_constants(region_str)        
            print('region: ',region_str)
            
            ####### load CNN hyperparameters ########  
            if dset == "ERA5":
                tts, activation_list, optimizer_list, lr_list, dw_list, decay_list, decay_steps, decay_rate, curr_batch, nepochs, reg_list = load_region.load_ERA5_region_cnn_hyperparams(region_str, shuffle_sm=False, nlag=nlag)
            elif dset == "NCEP":
                tts, activation_list, optimizer_list, lr_list, dw_list, decay_list, decay_steps, decay_rate, curr_batch, nepochs, reg_list = load_region.load_NCEP_region_cnn_hyperparams(region_str, shuffle_sm=False, nlag=nlag)
            
            x_dat_daily, y_dat, ind, caldays, time_vec = prepare_inputs.get_model_inputs(region_str, nlag, SNOWFREE, hemisphere=hem, dset=dset)
            time_vec = pd.to_datetime(time_vec).reset_index(drop=True)

            ######## Training / Testing Split ########    
            rs = tts[0]
            yrs = np.arange(1979, 2022) 
            yrs_train, yrs_eval = train_test_split(yrs, test_size=0.37, random_state=rs, shuffle=True) # splits yrs into 70% train, 30% test
            yrs_test, yrs_unseen = train_test_split(yrs_eval, test_size=0.50, random_state=rs+1, shuffle=True) # splits eval yrs into 2/3 test 1/3 unseen
            yrs_train = sorted(yrs_train)
            yrs_test = sorted(yrs_test)
            yrs_unseen = sorted(yrs_unseen)

            dw_alpha = 1.0
            dw = DenseWeight(alpha=dw_alpha)
            sample_weights = dw.fit(y_dat.values)

            x_train, x_test, x_unseen, \
                y_train, y_test, y_unseen, \
                ind_train, ind_test, ind_unseen, \
                sweight_train, sweight_test, sweight_unseen, \
                cday_train, cday_test, cday_unseen, \
                time_train, time_test, time_unseen = util.train_test_unseen_split_by_years(yrs_train, yrs_test, yrs_unseen,
                                                                           x_dat_daily, y_dat,
                                                                           ind, sample_weights, caldays, time_vec)

            print(np.shape(x_train), np.shape(x_test), np.shape(x_unseen), 
                  np.shape(y_train), np.shape(y_test), np.shape(y_unseen), 
                  np.shape(ind_train), np.shape(ind_test), np.shape(ind_unseen), 
                  np.shape(sweight_train), np.shape(sweight_test), np.shape(sweight_unseen), 
                  np.shape(cday_train), np.shape(cday_test), np.shape(cday_unseen))

            fig, ax = plt.subplots(1,1,figsize=(5,5))
            nbins = 50
            alph = 0.5
            ax.hist(y_train, label="train", bins=nbins, density=True, alpha=1)
            ax.hist(y_test, label="test", bins=nbins, density=True, alpha=alph)
            ax.hist(y_unseen, label="unseen", bins=nbins, density=True, alpha=alph/2)
            ax.set_ylabel('count')
            ax.set_xlabel('tmax (K)')
            ax.legend()
            ax.set_title(region_str+' temp distribution '+str(rs))
            plt.show()

            
            ######## Remove SM Layer ########    
            print("shapes with SM:", np.shape(x_train), np.shape(x_test), np.shape(x_unseen))
            print('removing SM input layer!')
            x_train = x_train[:,:,:,0]
            x_test = x_test[:,:,:,0]
            x_unseen = x_unseen[:,:,:,0]
            x_dat_daily = x_dat_daily[:,:,:,0]
            print("shapes without SM:", np.shape(x_train), np.shape(x_test), np.shape(x_unseen))
            

            ######## Set Training Constants ########    
            callback = EarlyStopping(monitor='val_loss', patience=curr_patience, restore_best_weights = True)
            curr_dw_alpha = dw_list[0]
            dw = DenseWeight(alpha=curr_dw_alpha)
            sample_weights = dw.fit(y_dat.values)
            sweight_train = sample_weights[ind_train]
            sweight_test = sample_weights[ind_test]
            curr_reg = reg_list[0]
            curr_act_func = activation_list[0]
            curr_lr = lr_list[0]
            dec_steps = decay_steps[0]
            lr_sched = None
            lr_sched = InverseTimeDecay(curr_lr, dec_step, decay_rate, staircase=True)
            optimizer_dict = {'RMSprop':optimizers.RMSprop(learning_rate=lr_sched)}
            op = optimizer_list[0]
            curr_opt = optimizer_dict[op] 
            rand_seed = 0
            
            print('loss:',curr_loss)
            print('denseweight:',curr_dw_alpha)
            print('l2 regularization:',curr_reg)
            print('activation function:',curr_act_func)
            print('learning rate:',curr_lr)
            print('lr decay steps:',dec_step)
            print('optimizer:',op,curr_opt)
            print('rand_seed:',rand_seed)
            
            # ensure training is reproducible
            np.random.seed(101+pp)
            random.seed(201+pp)
            tf_set_seed(333+pp)
            
            # build CNN
            model_daily = None
            model_daily = mu.build_model(lr = lr_sched, conv_filters=conv_filters, 
                                        dense_neurons=dense_neurons, dense_layers = ndense, 
                                        cnn_layers = ncnn, conv_relu_pool_layers = ncrp, 
                                        activity_reg = curr_reg, input_channels = ninputs, 
                                        loss_str=curr_loss, opt=curr_opt, 
                                        act_func=curr_act_func, 
                                        nlats=len(x_dat_daily[0,:,0,0]), nlons=len(x_dat_daily[0,0,:,0]))

            # train CNN
            history_daily = model_daily.fit({"stacked_input" : x_train, "calday": cday_train}, y_train, 
                                        batch_size = curr_batch,
                                        epochs = nepochs,
                                        sample_weight = sweight_train, 
                                        validation_data = ({"stacked_input" : x_test, "calday": cday_test}, y_test, sweight_test), 
                                        verbose = 0,
                                        callbacks = [callback])

            tmax_predictions_daily = model_daily.predict({"stacked_input" : x_dat_daily, "calday": caldays})[:,0]
            tmax_predictions_unseen = model_daily.predict({"stacked_input" : x_unseen, "calday": cday_unseen})[:,0]
            tmax_predictions_train = model_daily.predict({"stacked_input" : x_train, "calday": cday_train})[:,0]
            tmax_predictions_test = model_daily.predict({"stacked_input" : x_test, "calday": cday_test})[:,0]
            
            # Plot model skill #
            fxns.model_skill(region_str=region_str,
                            nlag=nlag,
                            tmax_predictions_unseen=tmax_predictions_unseen, 
                            tmax_predictions_train=tmax_predictions_train, 
                            tmax_predictions_test=tmax_predictions_test, 
                            y_unseen=y_unseen,
                            y_train=y_train,
                            y_test=y_test,
                            sweight_unseen=sweight_unseen,
                            sweight_train=sweight_train,
                            sweight_test=sweight_test,
                            dset=dset)  
            # save model weights #
            fxns.save_model(region_str=region_str,
                           nlag=nlag,
                           model=model_daily,
                           history=history_daily,
                           tmax_predictions=tmax_predictions_daily,
                           ind_test=ind_test,
                           ind_unseen=ind_unseen,
                           time_vec=time_vec,
                           y_dat=y_dat, 
                           shuffle_sm=False,
                           no_sm=True,
                           dset=dset) 
