# Detecting temperature targets
##### authors: Elizabeth A. Barnes and Noah Diffenbaugh
##### date: March 20, 2022

README:
This is the main training script for all TF models. Here are some tips for using this new re-factored code.

* ```experiment_settings.py``` is now your go-to place. It is something like a research log. You want to continue to copy and paste new experimental designs (with unique names e.g. ```exp23```) and this way you can always refer back to an experiment you ran before without having to change a bunch of parameters again. 

* If all goes well and we don't need more data, you should only be modifying the file called ```experiment_settings.py``` and this notebook (although plots.py might be changed too). 

* To train a set of moodels, you go into ```experiment_settings.py``` and make a new experiment (with a new name, e.g. ```exp1``` and then you specify that same name here in Cell 3 for ```EXP_NAME```.

* The parameter in settings called ```n_models```, will be more useful now. If you set this to a larger number, e.g. 20, it will train 20 models with the same experimental design but with different random training/validation/testing sets etc. You will then be able to analyze these models in another notebook.

* Other choices you have here (outside of the usual experiment settings) is whether to overwrite existing models with the name experiment name. Typically, you want ```OVERWRITE_MODEL = False``` so that the code will continue training new random seeds where you left off (rather than starting over again).

* Plots for model diagnostics are saved in the ```model_diagnostics``` directory. 

* Predictions for observations are saved in the ```saved_predictions``` directory, although you can always re-load the TF model and re-make the predictions in another notebook. But I thought this might be faster/easier.

* TF models and their meta data are saved in the ```saved_models``` directory.

* Once training is done, you can run the following to perform analysis and make/save plots for the paper.
** ```_analyze_models_vX.X.ipynb```
** ```_visualize_xai_vX.X.ipynb```

## Python stuff

In [None]:
import sys, imp, os

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import tensorflow as tf
import tensorflow_probability as tfp

import experiment_settings
import file_methods, plots, custom_metrics, network, data_processing

import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")  
print(f"tensorflow version = {tf.__version__}")  
print(f"tensorflow-probability version = {tfp.__version__}")  

## User Choices

In [None]:
EXP_NAME = 'exp15C_126'
OVERWRITE_MODEL = True

#-----------------------------------------------------

settings = experiment_settings.get_settings(EXP_NAME)
display(settings)

MODEL_DIRECTORY = 'saved_models/'        
PREDICTIONS_DIRECTORY = 'saved_predictions/'
DATA_DIRECTORY = 'data/'
DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'
FIGURE_DIRECTORY = 'figures/'

## Plotting functions

In [None]:
def plot_one_to_one_diagnostic():
    if settings['network_type'] == "shash2":
        top_pred_idx = 0
    else:
        top_pred_idx = None

    YEARS_UNIQUE = np.unique(y_yrs_train)
    predict_train = model.predict(x_train)[:,top_pred_idx].flatten()
    predict_val = model.predict(x_val)[:,top_pred_idx].flatten()
    predict_test = model.predict(x_test)[:,top_pred_idx].flatten()
    mae = np.mean(np.abs(predict_test-y_test[:]))
    
    #--------------------------------
    clr = ('tab:purple','tab:orange', 'tab:blue', 'tab:green', 'gold', 'brown','black','darkorange')
    plt.subplots(1,2,figsize=(15,6))

    plt.subplot(1,2,1)
    plt.plot(y_train, predict_train,'.',color='gray',alpha=.25, label='training')
    plt.plot(y_val, predict_val,'.', label='validation',color='gray',alpha=.75,)
    plt.plot(y_test, predict_test,'.', label='testing')    
    plt.plot(y_train,y_train,'--',color='fuchsia')
    plt.axvline(x=0,color='gray',linewidth=1)
    plt.axhline(y=0,color='gray',linewidth=1)
    plt.title('Testing MAE = ' + str(mae.round(2)) + ' years')
    plt.xlabel('true number of years until target is reached')
    plt.ylabel('predicted number of years until target is reached')
    plt.legend()


    plt.subplot(1,2,2)
    plt.plot(y_yrs_train, predict_train,'.',color='gray',alpha=.5, label='training')
    plt.title('Time to Target Year for ' + str(settings['target_temp']) + 'C using ssp' + str(settings['ssp']))
    plt.xlabel('year of map')
    plt.ylabel('predicted number of years until target is reached')
    plt.axhline(y=0, color='gray', linewidth=1)

    predict_val_mat = predict_val.reshape(N_GCMS,N_VAL,len(YEARS_UNIQUE))
    for i in np.arange(0,predict_val_mat.shape[0]):
        plt.plot(YEARS_UNIQUE, predict_val_mat[i,:,:].swapaxes(1,0),'.', label='validation', color=clr[i])
        plt.axvline(x=target_years[i],linestyle='--',color=clr[i])

## Initial housekeeping

In [None]:
## determine how many GCMs are being used
filenames = file_methods.get_cmip_filenames(settings, verbose=0)
N_GCMS = len(filenames)

# load observations for diagnostics plotting
da_obs, x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)
N_TRAIN, N_VAL, N_TEST, ALL_MEMBERS = data_processing.get_members(settings)

In [None]:
# define early stopping callback (cannot be done elsewhere)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                   patience=settings['patience'],
                                                   verbose=1,
                                                   mode='auto',
                                                   restore_best_weights=True)

## Train the network

In [None]:
rng = np.random.default_rng(settings["rng_seed"])

for iloop in np.arange(settings['n_models']):
    seed = rng.integers(low=1_000,high=10_000,size=1)[0]
    settings["seed"] = int(seed)
    tf.random.set_seed(settings["seed"])
    np.random.seed(settings["seed"])

    # get model name
    model_name = file_methods.get_model_name(settings)
    if os.path.exists(MODEL_DIRECTORY + model_name + "_model") and OVERWRITE_MODEL==False:
        print(model_name + 'exists. Skipping...')
        print("================================\n")
        continue    
    
    # get the data
    (x_train, 
     x_val, 
     x_test, 
     y_train, 
     y_val, 
     y_test, 
     onehot_train, 
     onehot_val, 
     onehot_test, 
     y_yrs_train, 
     y_yrs_val, 
     y_yrs_test, 
     target_years, 
     map_shape,
     settings) = data_processing.get_cmip_data(DATA_DIRECTORY, settings)

    #----------------------------------------        
    tf.keras.backend.clear_session()                
    model = network.compile_model(x_train, y_train, settings)
    history = model.fit(x_train, onehot_train, 
                        epochs=settings['n_epochs'], 
                        batch_size = settings['batch_size'], 
                        shuffle=True,
                        validation_data=[x_val, onehot_val],
                        callbacks=[early_stopping,],
                        verbose=0,                        
                       )
    #----------------------------------------
    # create predictions for observations with this model
    pred_obs = model.predict(x_obs)
    
    #----------------------------------------
    # save the tensorflow model and obs predictions
    if settings["save_model"]:
        file_methods.save_tf_model(model, model_name, MODEL_DIRECTORY, settings)
        file_methods.save_pred_obs(pred_obs, 
                                   PREDICTIONS_DIRECTORY+model_name + '_obs_predictions',
                                  )

    #----------------------------------------
    # create and save diagnostics plots
    plots.plot_metrics_panels(history,settings)
    plt.savefig(DIAGNOSTICS_DIRECTORY + model_name + '_metrics_diagnostic' + '.png', dpi=savefig_dpi)
    plt.show()             
    
    plot_one_to_one_diagnostic()
    plt.savefig(DIAGNOSTICS_DIRECTORY + model_name + '_one_to_one_diagnostic' + '.png', dpi=savefig_dpi)
    plt.show()   
    