# Detecting temperature targets
##### authors: Elizabeth A. Barnes and Noah Diffenbaugh
##### date: March 20, 2022


## Python stuff

In [1]:
import sys, imp, os

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import tensorflow as tf
import tensorflow_probability as tfp

import experiment_settings
import file_methods, plots, custom_metrics, network, data_processing

import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [2]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")  
print(f"tensorflow version = {tf.__version__}")  
print(f"tensorflow-probability version = {tfp.__version__}")  

python version = 3.9.10 | packaged by conda-forge | (main, Feb  1 2022, 21:27:43) 
[Clang 11.1.0 ]
numpy version = 1.22.2
xarray version = 2022.3.0
tensorflow version = 2.7.0
tensorflow-probability version = 0.15.0


## User Choices

In [3]:
EXP_NAME = 'exp0'
settings = experiment_settings.get_settings(EXP_NAME)
display(settings)

OVERWRITE_MODEL = False

MODEL_DIRECTORY = 'saved_models/'        
PREDICTIONS_DIRECTORY = 'saved_predictions/'
DATA_DIRECTORY = 'data/'
DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'
FIGURE_DIRECTORY = 'figures/'

{'save_model': True,
 'n_models': 3,
 'ssp': '370',
 'gcmsub': 'ALL',
 'obsdata': 'BEST',
 'target_temp': 1.5,
 'n_train_val_test': (7, 3, 0),
 'baseline_yr_bounds': (1850, 1899),
 'training_yr_bounds': (1970, 2100),
 'anomaly_yr_bounds': (1951, 1980),
 'anomalies': True,
 'remove_map_mean': False,
 'network_type': 'shash2',
 'hiddens': [10, 10],
 'dropout_rate': 0.0,
 'ridge_param': [1.0, 0.0],
 'learning_rate': 1e-05,
 'batch_size': 64,
 'rng_seed': 8889,
 'seed': None,
 'act_fun': ['relu', 'relu'],
 'n_epochs': 25000,
 'patience': 50,
 'exp_name': 'exp0'}

## Plotting functions

In [4]:
def plot_one_to_one_diagnostic():
    if settings['network_type'] == "shash2":
        top_pred_idx = 0
    else:
        top_pred_idx = None

    YEARS_UNIQUE = np.unique(y_yrs_train)
    predict_train = model.predict(x_train)[:,top_pred_idx].flatten()
    predict_val = model.predict(x_val)[:,top_pred_idx].flatten()
    mae = np.mean(np.abs(predict_val-y_val[:]))
    
    #--------------------------------
    clr = ('tab:purple','tab:orange', 'tab:blue', 'tab:green', 'gold', 'brown','black','darkorange')
    plt.subplots(1,2,figsize=(15,6))

    plt.subplot(1,2,1)
    plt.plot(y_train, predict_train,'.',color='gray',alpha=.5, label='training')
    plt.plot(y_val, predict_val,'.', label='validation')
    plt.plot(y_val,y_val,'--',color='fuchsia')
    plt.axvline(x=0,color='gray',linewidth=1)
    plt.axhline(y=0,color='gray',linewidth=1)
    plt.title('Validation MAE = ' + str(mae.round(2)) + ' years')
    plt.xlabel('true number of years until target is reached')
    plt.ylabel('predicted number of years until target is reached')
    plt.legend()


    plt.subplot(1,2,2)
    plt.plot(y_yrs_train, predict_train,'.',color='gray',alpha=.5, label='training')
    plt.title('Time to Target Year for ' + str(settings['target_temp']) + 'C using ssp' + str(settings['ssp']))
    plt.xlabel('year of map')
    plt.ylabel('predicted number of years until target is reached')
    plt.axhline(y=0, color='gray', linewidth=1)

    predict_val_mat = predict_val.reshape(N_GCMS,N_VAL,len(YEARS_UNIQUE))
    for i in np.arange(0,predict_val_mat.shape[0]):
        plt.plot(YEARS_UNIQUE, predict_val_mat[i,:,:].swapaxes(1,0),'.', label='validation', color=clr[i])
        plt.axvline(x=target_years[i],linestyle='--',color=clr[i])

## Initial housekeeping

In [5]:
## determine how many GCMs are being used
filenames = file_methods.get_cmip_filenames(settings, verbose=0)
N_GCMS = len(filenames)

# load observations for diagnostics plotting
x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)
N_TRAIN, N_VAL, N_TEST, ALL_MEMBERS = data_processing.get_members(settings)

observations: filling NaNs with zeros
np.shape(x_obs) = (172, 10368)


In [6]:
# define early stopping callback (cannot be done elsewhere)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                   patience=settings['patience'],
                                                   verbose=1,
                                                   mode='auto',
                                                   restore_best_weights=True)

## Train the network

In [None]:
rng = np.random.default_rng(settings["rng_seed"])
pred_obs_vec = np.zeros(shape=(settings['n_models'], x_obs.shape[0], 2))*np.nan

for iloop in np.arange(settings['n_models']):
    seed = rng.integers(low=1_000,high=10_000,size=1)[0]
    settings["seed"] = int(seed)
    tf.random.set_seed(settings["seed"])
    np.random.seed(settings["seed"])

    # get model name
    model_name = file_methods.get_model_name(settings)
    if os.path.exists(MODEL_DIRECTORY + model_name + "_model") and OVERWRITE_MODEL==False:
        print(model_name + 'exists. Skipping...')
        print("================================\n")
        continue    
    
    # get the data
    (x_train, 
     x_val, 
     x_test, 
     y_train, 
     y_val, 
     y_test, 
     onehot_train, 
     onehot_val, 
     onehot_test, 
     y_yrs_train, 
     y_yrs_val, 
     y_yrs_test, 
     target_years, 
     map_shape,
     settings) = data_processing.get_cmip_data(DATA_DIRECTORY,rng, settings)

    #----------------------------------------        
    tf.keras.backend.clear_session()                
    model = network.compile_model(x_train, y_train, settings)
    history = model.fit(x_train, onehot_train, 
                        epochs=settings['n_epochs'], 
                        batch_size = settings['batch_size'], 
                        shuffle=True,
                        validation_data=[x_val, onehot_val],
                        callbacks=[early_stopping,],
                        verbose=0,                        
                       )
    #----------------------------------------
    # create predictions for observations with this model
    pred_obs = model.predict(x_obs)
    pred_obs_vec[iloop,:,:pred_obs.shape[1]] = pred_obs
    
    #----------------------------------------
    # save the tensorflow model
    if settings["save_model"]:
        file_methods.save_tf_model(model, model_name, MODEL_DIRECTORY, settings)
        file_methods.save_pred_obs(pred_obs_vec, 
                                   PREDICTIONS_DIRECTORY+model_name[:model_name.rfind('_seed')] + '_obs_predictions',
                                  )

    #----------------------------------------
    # create and save diagnostics plots
    if settings["network_type"] == 'shash2':    
        plots.plot_metrics_panels(history,settings)
        plt.savefig(DIAGNOSTICS_DIRECTORY + model_name + '_metrics_diagnostic' + '.png', dpi=savefig_dpi)
        plt.show()             
    
    plot_one_to_one_diagnostic()
    plt.savefig(DIAGNOSTICS_DIRECTORY + model_name + '_one_to_one_diagnostic' + '.png', dpi=savefig_dpi)
    plt.show()   
    

exp0_seed1257exists. Skipping...

[5 4 2 8 3 0 6] [1 9 7] []
tas_Amon_historical_ssp370_CanESM5_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2011
tas_Amon_historical_ssp370_ACCESS-ESM1-5_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2035
tas_Amon_historical_ssp370_UKESM1-0-LL_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2024
tas_Amon_historical_ssp370_MIROC-ES2L_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2038
tas_Amon_historical_ssp370_GISS-E2-1-G_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2029
tas_Amon_historical_ssp370_IPSL-CM6A-LR_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2020
tas_Amon_historical_ssp370_CESM2-LE2-smbb_r1-10_ncecat_ann_mean_2pt5degree.nc
TARGET_YEAR = 2030
---------------------------
data_train.shape = (49, 131, 72, 144)
data_val.shape = (21, 131, 72, 144)
data_test.shape = (0, 131, 72, 144)
(6419, 10368) (6419,) (6419,)
(2751, 10368) (2751,) (2751,)
(0, 10368) (0,) (0,)


2022-03-20 16:14:37.349241: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 10368)]      0           []                               
                                                                                                  
 normalization (Normalization)  (None, 10368)        20737       ['input_1[0][0]']                
                                                                                                  
 dropout (Dropout)              (None, 10368)        0           ['normalization[0][0]']          
                                                                                                  
 dense (Dense)                  (None, 10)           103690      ['dropout[0][0]']                
                                                                                              