# Analyze models
##### authors: Elizabeth A. Barnes and Noah Diffenbaugh
##### date: March 20, 2022


## Python stuff

In [None]:
import sys, imp, os, copy

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import tensorflow as tf
import tensorflow_probability as tfp

import scipy.stats as stats
import seaborn as sns

import experiment_settings
import file_methods, plots, data_processing

from scipy.signal import savgol_filter

import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")  
print(f"tensorflow version = {tf.__version__}")  
print(f"tensorflow-probability version = {tfp.__version__}")  

## User Choices

In [None]:
EXP_NAME = 'exp20C_126'#'exp20C_126_all7'#'exp20C_126'#'exp20C_126'#'exp15C_370_uniform' #'exp15C_126_uniform'#'exp20C_126'
#-------------------------------------------------------

settings = experiment_settings.get_settings(EXP_NAME)
# display(settings)

MODEL_DIRECTORY = 'saved_models/'        
PREDICTIONS_DIRECTORY = 'saved_predictions/'
DATA_DIRECTORY = 'data/'
DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'
FIGURE_DIRECTORY = 'figures/'

## Get seed to show in plot
You need to first run compare_random_seeds.ipynb to ensure the data/stats on your experiments were saved in the df_random_seed.pickle file.

In [None]:
df_metrics = pd.read_pickle(PREDICTIONS_DIRECTORY + "df_random_seed.pickle")
df = df_metrics[df_metrics["exp_name"]==EXP_NAME]
PLOT_SEED = df_metrics.iloc[df['loss_val'].idxmin()]["seed"]
PLOT_SEED = 2247
print(PLOT_SEED)
display(df)

## Plotting Functions

In [None]:
FS = 10

### for white background...
# plt.rc('text',usetex=True)
plt.rc('text',usetex=False)
# plt.rc('font',**{'family':'sans-serif','sans-serif':['Avant Garde']}) 
plt.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 
plt.rc('savefig',facecolor='white')
plt.rc('axes',facecolor='white')
plt.rc('axes',labelcolor='dimgrey')
plt.rc('axes',labelcolor='dimgrey')
plt.rc('xtick',color='dimgrey')
plt.rc('ytick',color='dimgrey')
################################  
################################  
def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 5))
        else:
            spine.set_color('none')  
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        ax.yaxis.set_ticks([])
    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        ax.xaxis.set_ticks([]) 

def format_spines(ax):
    adjust_spines(ax, ['left', 'bottom'])
    ax.spines['top'].set_color('none')
    ax.spines['right'].set_color('none')
    ax.spines['left'].set_color('dimgrey')
    ax.spines['bottom'].set_color('dimgrey')
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.tick_params('both',length=4,width=2,which='major',color='dimgrey')
#     ax.yaxis.grid(zorder=1,color='dimgrey',alpha=0.35)    
    

In [None]:
model_name_plot = EXP_NAME + '_' + str(PLOT_SEED)

In [None]:
imp.reload(file_methods)
imp.reload(data_processing)

rng = np.random.default_rng(settings["rng_seed"])
settings["seed"] = PLOT_SEED

# get model name
model_name = file_methods.get_model_name(settings)

# load the model
model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)

settings_new = settings
settings_new["gcmsub"] = "OOS"
settings_new["n_train_val_test"] = (0, 0, 5)
# get the data
(x_train, 
 x_val, 
 x_test, 
 y_train, 
 y_val, 
 y_test, 
 onehot_train, 
 onehot_val, 
 onehot_test, 
 y_yrs_train, 
 y_yrs_val, 
 y_yrs_test, 
 target_years, 
 map_shape,
 settings) = data_processing.get_cmip_data(DATA_DIRECTORY, settings)   

       

In [None]:
filenames = file_methods.get_cmip_filenames(settings_new, verbose=0)
N_GCMS = len(filenames)
N_MEMBERS = settings["n_train_val_test"][-1]
target_list = []

# loop through the models and plot
clr = ('lawngreen','tab:green','chocolate')
fig,axs = plt.subplots(1,2,figsize=(3*2.5,2.25))


#---------------------------------------------
plt.subplot(1,2,1)
for imodel in np.arange(0,3):
    f = filenames[imodel]
    print(f)
    da = file_methods.get_netcdf_da(DATA_DIRECTORY + f)
    f_labels, f_years, f_target_year = data_processing.get_labels(da, settings_new,)

    # compute global mean
    global_mean = data_processing.compute_global_mean(da)
    baseline_mean = global_mean.sel(time=slice(str(settings["baseline_yr_bounds"][0]),str(settings["baseline_yr_bounds"][1]))).mean('time')
    global_mean_anomalies = global_mean - baseline_mean
    if settings["smooth"] == True:
        mean_curve = savgol_filter(np.mean(global_mean_anomalies,axis=0), 15, 3)
    else:
        mean_curve = np.mean(global_mean_anomalies,axis=0)
        
    
    
    # plot the members
    plt.plot(f_years, 
             np.swapaxes(global_mean_anomalies.to_numpy(),1,0), 
             color='gray',
             linewidth=.5,
             alpha=.3,
             zorder=1,
            )
    # plot ensemble mean
    plt.plot(f_years, 
             mean_curve, 
             color=clr[imodel],
             linewidth=1.,
             alpha=1.,
             zorder=4,
            )
    
    #plot the year
    target_list.append(f_target_year)
    if(f_target_year != 2100):
        plt.axvline(x=f_target_year,
                    color=clr[imodel],
                    linewidth=1.,
                    alpha=1.,
                    linestyle='--',
                   )                
    
# plt.title('Global Mean Temperatures for SSP'+ str(settings["ssp"]),fontsize=12)
plt.xlabel('year',fontsize=FS)
plt.ylabel('temperature anomaly',fontsize=FS)
plt.xticks(np.arange(1850,2150,50),np.arange(1850,2150,50))

plt.ylim(-.4,2.9)
plt.axhline(y=0, color='black', linewidth=0.5)
plt.axhline(y=1.1, color='gray', linewidth=1.0, linestyle='--')
plt.axhline(y=1.5, color='gray', linewidth=1.0, linestyle='--')
plt.axhline(y=2.0, color='gray', linewidth=1.0, linestyle='--')

plt.text(1850,
         2.0,
         str(settings["target_temp"]) + "C\nSSP" + settings["ssp"][0] + '-' + settings["ssp"][1] + '.' + settings["ssp"][-1],
         fontsize=FS,
         horizontalalignment="left",
         verticalalignment="bottom",
         color='k', 
         weight="bold",
        )

format_spines(plt.gca())


#--------------------------------------
plt.subplot(1,2,2)
# plot the predictions for 8 members
YEARS_UNIQUE = np.unique(y_yrs_test)

miroc_pred = model.predict(x_test)
mu_pred = miroc_pred[:,0].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))
sigma_pred = miroc_pred[:,1].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))

for imodel in np.arange(0,3):
    print(filenames[imodel])
    iy = np.where(YEARS_UNIQUE==2021)[0]
    print(np.mean(mu_pred[imodel,:,:].swapaxes(1,0),axis=1)[iy],np.mean(sigma_pred[imodel,:,:].swapaxes(1,0),axis=1)[iy])
    plt.plot(YEARS_UNIQUE,mu_pred[imodel,:,:].swapaxes(1,0),color=clr[imodel],linewidth=1.,alpha=.25)
    plt.plot(YEARS_UNIQUE,np.mean(mu_pred[imodel,:,:].swapaxes(1,0),axis=1),
             color=clr[imodel],
             linewidth=2.,
             alpha=.75,
             # label=label1,
            )

    if(target_list[imodel] != 2100):
        plt.axvline(x=target_list[imodel],
                    color=clr[imodel],
                    linewidth=1.,
                    alpha=1.,
                    linestyle='--',
                   )                

plt.legend(frameon=False)    
# plt.ylim(-17,30)    
plt.ylim(-27,65)    
plt.xlim(2020,2100)
ax = plt.gca()
format_spines(ax)
plt.ylabel('predicted years\nuntil ' + str(settings["target_temp"]) + 'C threshold')
# plt.title(model_name_plot)

plt.axhline(y=0, color='black', linewidth=0.5)

plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + model_name_plot + '_OOS_inference', dpi=savefig_dpi)
plt.show()    
    

In [None]:
# from scipy.signal import savgol_filter

# raw_values = np.mean(global_mean,axis=0)
# baseline_mean = raw_values.sel(time=slice(str(settings["baseline_yr_bounds"][0]),str(settings["baseline_yr_bounds"][1]))).mean('time')
# iwarmer = np.where(raw_values > baseline_mean.values+settings["target_temp"])[0]
# target_year = raw_values["time"].values[iwarmer[0]].year
# plt.plot(raw_values["time.year"],raw_values)

# smoothed_values = np.mean(global_mean,axis=0)
# smoothed_values = savgol_filter(smoothed_values, 15, 3) # window size 51, polynomial order 3

# # poly = np.poly1d(np.polyfit(raw_values["time.year"],smoothed_values,deg=10))
# # smoothed_values = poly(raw_values["time.year"])
# baseline_mean = raw_values.sel(time=slice(str(settings["baseline_yr_bounds"][0]),str(settings["baseline_yr_bounds"][1]))).mean('time')
# iwarmer = np.where(smoothed_values > baseline_mean.values+settings["target_temp"])[0]
# target_year = raw_values["time"].values[iwarmer[0]].year
# plt.plot(raw_values["time.year"],smoothed_values)

# print(target_year)

In [None]:
# filenames = file_methods.get_cmip_filenames(settings_new, verbose=0)
# N_GCMS = len(filenames)
# N_MEMBERS = settings["n_train_val_test"][-1]
# target_list = []

# # loop through the models and plot
# clr = ('fuchsia','tab:green','tab:blue','gold','tab:purple','tab:orange','k')
# fig,axs = plt.subplots(1,2,figsize=(3*2.5,2.25))


# #---------------------------------------------
# plt.subplot(1,2,1)
# for imodel in np.arange(0,3):
#     f = filenames[imodel]
#     print(f)
#     da = file_methods.get_netcdf_da(DATA_DIRECTORY + f)
#     f_labels, f_years, f_target_year = data_processing.get_labels(da, settings_new,)

#     # compute global mean
#     global_mean = data_processing.compute_global_mean(da)
#     baseline_mean = global_mean.sel(time=slice(str(settings["baseline_yr_bounds"][0]),str(settings["baseline_yr_bounds"][1]))).mean('time')
#     global_mean_anomalies = global_mean - baseline_mean
    
#     # plot the members
#     plt.plot(f_years, 
#              np.swapaxes(global_mean_anomalies.to_numpy(),1,0), 
#              color='gray',
#              linewidth=.5,
#              alpha=.3,
#              zorder=1,
#             )
#     # plot ensemble mean
#     print('max temp = ' + str(np.round(np.max(np.mean(global_mean_anomalies,axis=0)).values,2)))
#     print(np.round((np.mean(global_mean_anomalies,axis=0)).values,2))
#     year_list = da["time.year"].values
#     print('argmax   = ' + str(year_list[np.argmax(np.mean(global_mean_anomalies,axis=0).values)]))
#     plt.plot(f_years, 
#              np.mean(global_mean_anomalies,axis=0), 
#              color=clr[imodel],
#              linewidth=1.,
#              alpha=1.,
#              zorder=4,
#             )
    
#     #plot the year
#     target_list.append(f_target_year)
#     if(f_target_year != 2100):
#         plt.axvline(x=f_target_year,
#                     color=clr[imodel],
#                     linewidth=1.,
#                     alpha=1.,
#                     linestyle='--',
#                    )                
    
# # plt.title('Global Mean Temperatures for SSP'+ str(settings["ssp"]),fontsize=12)
# plt.xlabel('year',fontsize=FS)
# plt.ylabel('temperature anomaly',fontsize=FS)
# plt.xticks(np.arange(1850,2150,50),np.arange(1850,2150,50))

# plt.ylim(-.4,2.9)
# plt.axhline(y=0, color='black', linewidth=0.5)
# plt.axhline(y=1.1, color='gray', linewidth=1.0, linestyle='--')
# plt.axhline(y=1.5, color='gray', linewidth=1.0, linestyle='--')
# plt.axhline(y=2.0, color='gray', linewidth=1.0, linestyle='--')

# plt.text(1850,
#          2.0,
#          str(settings["target_temp"]) + "C\nSSP" + settings["ssp"][0] + '-' + settings["ssp"][1] + '.' + settings["ssp"][-1],
#          fontsize=FS,
#          horizontalalignment="left",
#          verticalalignment="bottom",
#          color='k', 
#          weight="bold",
#         )

# format_spines(plt.gca())


# #--------------------------------------
# plt.subplot(1,2,2)
# # plot the predictions for 8 members
# YEARS_UNIQUE = np.unique(y_yrs_test)

# miroc_pred = model.predict(x_test)
# mu_pred = miroc_pred[:,0].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))
# sigma_pred = miroc_pred[:,1].reshape(N_GCMS,N_MEMBERS,len(YEARS_UNIQUE))

# for imodel in np.arange(0,3):
#     print(filenames[imodel])
#     # plt.plot(YEARS_UNIQUE,mu_pred[imodel,:,:].swapaxes(1,0),color=clr[imodel],linewidth=1.,alpha=.25)
#     plt.errorbar(YEARS_UNIQUE,np.mean(mu_pred[imodel,:,:].swapaxes(1,0),axis=1),yerr=np.mean(sigma_pred[imodel,:,:].swapaxes(1,0),axis=1),
#              color=clr[imodel],
#              linewidth=2.,
#              alpha=.75,
#              # label=label1,
#             )

#     if(target_list[imodel] != 2100):
#         plt.axvline(x=target_list[imodel],
#                     color=clr[imodel],
#                     linewidth=1.,
#                     alpha=1.,
#                     linestyle='--',
#                    )                

# plt.legend(frameon=False)    
# # plt.ylim(-17,30)    
# plt.ylim(-27,65)    
# plt.xlim(2020,2100)
# ax = plt.gca()
# format_spines(ax)
# plt.ylabel('predicted years\nuntil ' + str(settings["target_temp"]) + 'C threshold')
# # plt.title(model_name_plot)

# plt.axhline(y=0, color='black', linewidth=0.5)

# plt.tight_layout()
# # plots.savefig(FIGURE_DIRECTORY + model_name_plot + '_OOS_inference', dpi=savefig_dpi)
# plt.show()    
    

In [None]:
year_plot = 2021
i = np.where(YEARS_UNIQUE==year_plot)[0]
YEARS_UNIQUE[i]
np.mean(mu_pred[:,:,i],axis=1)+year_plot

In [None]:
# load BEST observations for diagnostics plotting
da_obs, x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)

# load GISS observations
settings["obsdata"] = 'GISS'
da_obs_giss, x_obs_giss, global_mean_obs_giss = data_processing.get_observations(DATA_DIRECTORY, settings)


pred_obs = model.predict(x_obs)

#### 

In [None]:
PLOT_YEAR = 2021
iyear = np.where(YEARS_UNIQUE==PLOT_YEAR)[0]
norm_incs = np.arange(-80,80,1)

plt.figure(figsize=(3,2))
norm_dist = tfp.distributions.Normal(pred_obs[-1,0],pred_obs[-1,1])
norm_cpd = norm_dist.prob(norm_incs)
plt.plot(norm_incs+PLOT_YEAR,
         norm_cpd,
         color='tab:orange',
         linewidth=2.,
         alpha=1.,
         zorder=50,
        )

for imodel in np.arange(0,N_GCMS):
    norm_cpd_mean = np.zeros(len(norm_incs))
    
    for ens in np.arange(0,mu_pred.shape[1]):
        norm_dist = tfp.distributions.Normal(mu_pred[imodel,ens,iyear],sigma_pred[imodel,ens,iyear])
        norm_cpd = norm_dist.prob(norm_incs)
        plt.plot(norm_incs+PLOT_YEAR,
                 norm_cpd,
                 color=clr[imodel],
                 linewidth=.75,
                 alpha=.25,
                )
    
plt.text(2021,
         .15,
         str(settings["target_temp"]) + "C\nSSP" + settings["ssp"][0] + '-' + settings["ssp"][1] + '.' + settings["ssp"][-1],
         fontsize=FS,
         horizontalalignment="left",
         verticalalignment="top",
         color='k', 
         # weight="bold",
        )    

plt.xlim(2020,2100)
plt.yticks(np.arange(0,.25,.02),np.arange(0,.25,.02).round(2))
plt.ylim(-0.001,.15)
format_spines(plt.gca())
    

In [None]:
PLOT_YEAR = 2021
iyear = np.where(YEARS_UNIQUE==PLOT_YEAR)[0]
norm_incs = np.arange(-80,80,1)

plt.figure(figsize=(3,2))
norm_dist = tfp.distributions.Normal(pred_obs[-1,0],pred_obs[-1,1])
norm_cpd = norm_dist.prob(norm_incs)
plt.plot(norm_incs+PLOT_YEAR,
         norm_cpd,
         color='tab:orange',
         linewidth=2.,
         alpha=1.,
         zorder=100,
        )

for imodel in (0,1):
    ens = 7
    norm_dist = tfp.distributions.Normal(mu_pred[imodel,ens,iyear],sigma_pred[imodel,ens,iyear])
    norm_cpd = norm_dist.prob(norm_incs)
    plt.plot(norm_incs+PLOT_YEAR,
             norm_cpd,
             color=clr[imodel],
             linewidth=2.,
             alpha=1.,
             zorder=100,
            )
    
plt.text(2021,
         .15,
         str(settings["target_temp"]) + "C\nSSP" + settings["ssp"][0] + '-' + settings["ssp"][1] + '.' + settings["ssp"][-1],
         fontsize=FS,
         horizontalalignment="left",
         verticalalignment="top",
         color='k', 
         # weight="bold",
        )    

plt.xlim(2020,2100)
plt.yticks(np.arange(0,.25,.05),np.arange(0,.25,.05).round(2))
plt.ylim(-0.001,.15)
format_spines(plt.gca())
    