# Detecting temperature targets
##### authors: Elizabeth A. Barnes and Noah Diffenbaugh
##### date: March 20, 2022


## Python stuff

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pickle
import sys
import scipy.stats as stats
import json

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Input, Dropout, Softmax
from tensorflow.keras import optimizers
from tensorflow.keras import regularizers
from tensorflow import keras

import experiment_settings
import file_methods

import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")  
print(f"tensorflow version = {tf.__version__}")  
print(f"tensorflow-probability version = {tfp.__version__}")  

## User Choices

In [None]:
EXP_NAME = 'exp0'
exp_settings = experiment_settings.get_settings(EXP_NAME)
display(exp_settings)


MODEL_DIRECTORY = 'saved_models/'        
FILE_DIRECTORY = 'saved_files/'
DATA_DIRECTORY = 'data/'

## Get the data and process it

In [None]:
filenames = file_methods.get_cmip_filenames(exp_settings)
N_GCMS = len(filenames)

## Load the observations

In [None]:
x_obs, global_mean_obs = data_processing.get_observations(exp_settings)
N_TRAIN, N_VAL, N_TEST, ALL_MEMBERS = data_processing.get_members(exp_settings)

## Network and XAI functions

In [None]:
# define early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                   patience=exp_settings['patience'],
                                                   verbose=1,
                                                   mode='auto',
                                                   restore_best_weights=True)


## Train the network

In [None]:
rng = np.random.default_rng(exp_settings["rng_seed"])
pred_obs_vec = np.zeros(shape=(exp_settings['n_models'], x_obs.shape[0], 2))*np.nan

for iloop in np.arange(exp_settings['n_models']):
    SEED = rng.integers(low=1_000,high=10_000,size=1)[0]    
    tf.random.set_seed(SEED)
    np.random.seed(SEED)

    train_members = rng.choice(ALL_MEMBERS, size=N_TRAIN, replace=False)
    val_members   = rng.choice(np.setdiff1d(ALL_MEMBERS,train_members), size=N_VAL, replace=False)
    test_members  = rng.choice(np.setdiff1d(ALL_MEMBERS,np.append(train_members[:],val_members)), size=N_TEST, replace=False)
    print(train_members, val_members, test_members)

    (x_train, 
     x_val, 
     x_test, 
     y_train, 
     y_val, 
     y_test, 
     onehot_train, 
     onehot_val, 
     onehot_test, 
     y_yrs_train, 
     y_yrs_val, 
     y_yrs_test, 
     target_years, 
     map_shape) = data_processing.get_cmip_data(exp_settings)

    #----------------------------------------        
    tf.keras.backend.clear_session()                
    model = compile_model()
    history = model.fit(x_train, onehot_train, 
                        epochs=exp_settings['n_epochs'], 
                        verbose=exp_settings['verbosity'],
                        batch_size = exp_settings['batch_size'], 
                        shuffle=True,
                        validation_data=[x_val, onehot_val],
                        callbacks=[early_stopping,],
                       )
    #----------------------------------------
    # create predictions for observations with this model
    pred_obs = model.predict(x_obs)
    pred_obs_vec[iloop,:,:pred_obs.shape[1]] = pred_obs
    
    #----------------------------------------
    # save the tensorflow model
    model_name = files.get_model_name()
    if exp_settings["save_model"]:
        save_tf_model(model, model_name)
        save_pred_obs(pred_obs_vec, model_name[:model_name.rfind('_seed')] + '_pred_obs')
    
    #----------------------------------------
    if exp_settings["show_plots"]:
       
        plt.plot(history.history['loss'], label='loss')
        plt.plot(history.history['val_loss'], label='val_loss')
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.legend()
        plt.show()
        
        

## Make loss plots following training

In [None]:
if exp_settings["network_type"] == 'shash2':
    try:
        imin = len(history.history['custom_mae'])
        plt.subplots(figsize=(20,4))

        plt.subplot(1,4,1)
        plot_metrics(history,'loss')
        plt.ylim(0,10.)

        plt.subplot(1,4,2)
        plot_metrics(history,'custom_mae')
        plt.ylim(0,10)

        plt.subplot(1,4,3)
        plot_metrics(history,'interquartile_capture')

        plt.subplot(1,4,4)
        plot_metrics(history,'sign_test')

        plt.show()
    except:
        print('metrics were not saved')

In [None]:
if exp_settings['network_type'] == "shash2":
    top_pred_idx = 0
else:
    top_pred_idx = None
    
YEARS_UNIQUE = np.unique(y_yrs_train)
predict_train = model.predict(x_train)[:,top_pred_idx].flatten()
predict_val = model.predict(x_val)[:,top_pred_idx].flatten()
mae = np.mean(np.abs(predict_val-y_val[:]))

In [None]:
clr = ('tab:purple','tab:orange', 'tab:blue', 'tab:green', 'gold', 'brown','black','darkorange')
#--------------------------------
plt.subplots(1,2,figsize=(15,6))

plt.subplot(1,2,1)
plt.plot(y_train, predict_train,'.',color='gray',alpha=.5, label='training')
plt.plot(y_val, predict_val,'.', label='validation')
plt.plot(y_val,y_val,'--',color='fuchsia')
plt.axvline(x=0,color='gray',linewidth=1)
plt.axhline(y=0,color='gray',linewidth=1)
plt.title('Validation MAE = ' + str(mae.round(2)) + ' years')
plt.xlabel('true number of years until target is reached')
plt.ylabel('predicted number of years until target is reached')
plt.legend()


plt.subplot(1,2,2)
plt.plot(y_yrs_train, predict_train,'.',color='gray',alpha=.5, label='training')
plt.title('Time to Target Year for ' + str(exp_settings['target_temp']) + 'C using ssp' + str(exp_settings['ssp']))
plt.xlabel('year of map')
plt.ylabel('predicted number of years until target is reached')
plt.axhline(y=0, color='gray', linewidth=1)

predict_val_mat = predict_val.reshape(N_GCMS,N_VAL,len(YEARS_UNIQUE))
for i in np.arange(0,predict_val_mat.shape[0]):
    plt.plot(YEARS_UNIQUE, predict_val_mat[i,:,:].swapaxes(1,0),'.', label='validation', color=clr[i])
    plt.axvline(x=target_years[i],linestyle='--',color=clr[i])
if IN_COLAB==False:
    pass
    # plt.savefig('figures/initial_result_seed' + str(SEED) + '.png', dpi=savefig_dpi)
plt.show()

## Predict observations

In [None]:
y_predict_obs = model.predict(x_obs)[:,top_pred_idx].flatten()

iy = np.where(da_obs['time.year'].values >= 2001)[0]
x = da_obs['time.year'].values[iy]
y = y_predict_obs[iy]
linear_model = stats.linregress(x=x,y=y)

#--------------------------------
i_year = np.where(y_predict_obs < 0)[0]
plt.figure(figsize=(10,3))
plt.subplot(1,2,1)
plt.plot(da_obs['time.year'], y_predict_obs, '.r')
plt.plot(x, linear_model.slope*x+linear_model.intercept, '--k', alpha=.5, linewidth=2)
plt.xlabel('year of map')
plt.ylabel('predicted number of years \nuntil target is reached')
plt.title(exp_settings["obsdata"] + ' Target Year for ' + str(exp_settings['target_temp']) + 'C using ssp' + str(exp_settings['ssp']) + ' = ' + 
          str(np.round(2021+y_predict_obs[-1],1)) +
          ' (' + str(y_predict_obs[-1].round()) + ' years)'+
          '\n slope = ' + str(linear_model.slope.round(2)) 
         )
plt.xlim(1970,2025)
plt.ylim(-10,80)
plt.axhline(y=0,color='gray')

#--------------------------------
plt.subplot(1,2,2)
global_mean_obs.plot(linewidth=2,label='data',color="tab:orange")
plt.title(exp_settings["obsdata"] + ' Observations')

plt.tight_layout()
plt.show()


In [None]:
n_plot = 4
grads_obs = get_gradients(x_obs[-n_plot:,:],top_pred_idx=top_pred_idx).numpy()*x_obs[-n_plot:,:]
print('np.shape(grads_obs) = ' + str(np.shape(grads_obs)))
grads_obs_mean = np.mean(grads_obs,axis=0)
plt.figure(figsize=(10,4))
plot_map(grads_obs_mean.reshape((map_shape[0],map_shape[1])), 
         clim = (-.02,.02),
         title = 'Observations ' + str(2021-n_plot+1) + '-' + str(2021) + ': Gradient x Input',
        )

In [None]:
import scipy.stats as stats
import seaborn as sns

if exp_settings["network_type"] == 'shash2':

    clr_choice = 'orange'
    y_predict_obs = model.predict(x_obs)

    iy = np.where(da_obs['time.year'].values >= 2011)[0]
    x = da_obs['time.year'].values[iy]
    y = y_predict_obs[iy,0]
    linear_model = stats.linregress(x=x,y=y)

    #--------------------------------
    norm_incs = np.arange(-80,80,1)
    mu_pred = y_predict_obs[:,0]
    sigma_pred = y_predict_obs[:,1]
    norm_dist = tfp.distributions.Normal(mu_pred,sigma_pred)
    norm_perc_low = norm_dist.quantile(.25).numpy()   
    norm_perc_high = norm_dist.quantile(.75).numpy()      
    norm_perc_med = norm_dist.quantile(.5).numpy()      
    norm_cpd = norm_dist[-1].prob(norm_incs)
    y_predict_obs = norm_perc_med
    
    print('2021 prediction = ' + str(mu_pred[-1]) + ' (' + str(norm_perc_low[-1]) + ' to ' + str(norm_perc_high[-1]) + ')')

    #------------------------------------------------------------
    ax = plt.subplots(1,2,figsize=(16,4))
    years = np.arange(1850,2022)

    plt.subplot(1,2,1)
    for iyear in np.arange(0,y_predict_obs.shape[0]):
        min_val = norm_perc_low[iyear]
        max_val = norm_perc_high[iyear]

        if(years[iyear]==2021):
            clr = clr_choice
        else:
            clr = 'gray'
        plt.plot((years[iyear],years[iyear]),(min_val, max_val),
                 linestyle='-',
                 linewidth=4,
                 color=clr,
                )

    plt.plot(x,x*linear_model.slope+linear_model.intercept,'--', color='black')

    plt.xlim(1970.5,2023)    
    plt.ylim(-10,80)
    plt.ylabel('years until target')
    plt.xlabel('year')
    plt.title(exp_settings["obsdata"] + ' predictions for ' + str(exp_settings['target_temp']) + 'C using ssp' + exp_settings["ssp"] + ' (norm)\n slope=' + str(linear_model.slope.round(2)))

    plt.subplot(1,2,2)
    plt.plot(norm_incs,norm_cpd,
             linewidth=5,
             color=clr_choice,
            )

    k = np.argmin(np.abs(norm_perc_low[-1]-norm_incs))
    plt.plot((norm_perc_low[-1],norm_perc_low[-1]),(0,norm_cpd[k]),'--',color=clr_choice)
    k = np.argmin(np.abs(norm_perc_high[-1]-norm_incs))
    plt.plot((norm_perc_high[-1],norm_perc_high[-1]),(0,norm_cpd[k]),'--',color=clr_choice)

    plt.xlabel('years until target')
    plt.title('Predictions for ' + exp_settings["obsdata"] + ' Observations under SSP' + exp_settings["ssp"] + '\nYear = 2021')

    if exp_settings["target_temp"] == 1.1:
        plt.xlim(-20,20)
    elif exp_settings["target_temp"] == 1.5:
        plt.xlim(-10,40)
    elif exp_settings["target_temp"] == 2.0:
        plt.xlim(-10,70)
    else:
        plt.xlim(-70,70)

    plt.tight_layout()
    plt.show()

## Explainability via Input * Gradient and Integrated Gradients
We will use two attribution explainaiblity methods called Input * Gradient and Integrated Gradients to make heatmaps of regions of the input that act as explanations for the network's prediction.

* https://keras.io/examples/vision/integrated_gradients/
* https://distill.pub/2020/attribution-baselines/

In [None]:
#=========================================
# Define the samples you want to explain
rng = np.random.default_rng(45)
isubsample = rng.choice(np.arange(0,x_val.shape[0]),
                        size = 500,
                        replace = False,
                       )

inputs = np.copy(x_val[isubsample,:])
targets = np.copy(y_val[isubsample])
yrs = np.copy(y_yrs_val[isubsample])
preds = model.predict(inputs)

#=========================================
#---------------------------------------
# Gradient x Input
#---------------------------------------
# compute the multiplication of gradient * inputs
# and reshape into a map of latitude x longitude

grads = get_gradients(inputs,top_pred_idx).numpy()
grad_x_input = grads * inputs
grad_x_input = grad_x_input.reshape((len(targets),map_shape[0],map_shape[1]))
print(np.shape(grad_x_input))

#---------------------------------------
# Integrated Gradients
#---------------------------------------
baseline_mean = np.mean(x_train,axis=0)*0.    
print('shape(baseline_mean) = ' + str(np.shape(baseline_mean)))
print('model.predict(baseline_mean) = ' + str(model.predict(baseline_mean[np.newaxis,:])))

igrad = get_integrated_gradients(inputs, baseline=baseline_mean,top_pred_idx=top_pred_idx)
integrated_gradients = igrad.numpy().reshape((len(targets),map_shape[0],map_shape[1]))
print(np.shape(integrated_gradients))

In [None]:

# plot the resulting heatmaps for a subset of samples
# based on their label
plot_list = (40, 20, 10, 0)
NCOL = 4
plt.subplots(len(plot_list),NCOL,figsize=(35,5*len(plot_list)))

for irow,min_range in enumerate(plot_list):
        
    max_range = min_range + 5
    isamples = np.where((targets >= min_range) & (targets <= max_range))[0]
    igrad_mean = np.mean(integrated_gradients[isamples,:,:],axis=0)
    grad_x_input_mean = np.mean(grad_x_input[isamples,:,:],axis=0)
    grad_mean = np.mean(grads[isamples,:],axis=0).reshape((map_shape[0],map_shape[1]))
    x_inputs_mean = np.mean(inputs[isamples,:],axis=0).reshape((map_shape[0],map_shape[1]))
    x_inputs_mean = x_inputs_mean - baseline_mean.reshape((map_shape[0],map_shape[1]))
    #------------------------------------------------------------------
    
    text = (
            "\n"
            + f"  label_range    = {min_range}-{max_range} yrs.\n"                    
            + f"  n_samples      = {len(isamples)}\n"
    )    
    #------------------------------------------------------------------    
    
    # plot average input map
    plt.subplot(len(plot_list),NCOL,irow*NCOL+1)
    plot_map(x_inputs_mean, 
             text=text,
             clim=(-5,5),
             cmap='RdBu_r',
             title = 'Temperature anomaly from Baseline',
            )
    #------------------------------------------------------------------
    # plot explainability of gradient (saliency)
    plt.subplot(len(plot_list),NCOL,irow*NCOL+2)
    plot_map(grad_mean, 
             text=text,             
             clim=(-0.02, .02), 
             title = 'Gradient (Saliency)',
            )
    
    #------------------------------------------------------------------
    # plot explainability of input x gradient
    plt.subplot(len(plot_list),NCOL,irow*NCOL+3)
    plot_map(grad_x_input_mean, 
             text=text,
             clim=(-.02,.02),
             title = 'Gradient x Input',
            )

    #------------------------------------------------------------------
    # plot explainability of integrated gradients
    plt.subplot(len(plot_list),NCOL,irow*NCOL+4)
    plot_map(igrad_mean, 
             text=text,             
             clim=(-.02,.02), 
             title = 'Integrated Gradients',
            )

plt.tight_layout()   
# plt.savefig('figures/xai_grid_' + str(min_range) +'-' + str(max_range) + '_baseline_' + str(BASELINE) + '.png', dpi=savefig_dpi)
plt.show()

