# Compare across random seeds
##### authors: Elizabeth A. Barnes and Noah Diffenbaugh
##### date: March 25, 2022


## Python stuff

In [None]:
%%javascript
require(
        ["notebook/js/outputarea"],
        function (oa) {
            oa.OutputArea.auto_scroll_threshold = -1;
            console.log("Setting auto_scroll_threshold to -1");
        });

%%javascript
require("notebook/js/notebook").Notebook.prototype.scroll_to_bottom = function () {}

In [None]:
import sys, os, copy, tqdm
import importlib as imp

import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import tensorflow as tf
import tensorflow_probability as tfp
import silence_tensorflow
silence_tensorflow

import scipy.stats as stats
import seaborn as sns
from tqdm import tqdm

import experiment_settings
import file_methods, plots, data_processing, custom_metrics, network

import matplotlib as mpl
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

import warnings
warnings.filterwarnings("ignore")

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")  
print(f"tensorflow version = {tf.__version__}")  
print(f"tensorflow-probability version = {tfp.__version__}")  

## User Choices

In [None]:
EXP_NAME_VEC = (
                'exp11C_370','exp15C_370','exp20C_370','exp15C_126','exp20C_126','exp11C_126',
                'exp11C_245','exp15C_245','exp20C_245',
                'exp0','exp1','exp2','exp3','exp4','exp5',    # ridge sweep w/ hiddens = [10,10]
                'exp10','exp11','exp12','exp13','exp14',      # hiddens sweep w/ ridge = 10.0
                'exp20','exp21','exp22','exp23',              # hiddens sweep w/ ridge = 5.0                
                'exp30','exp31','exp32','exp33',              # ridge sweep w/ hiddens = [25,25]
                'exp15C_370_uniform','exp20C_370_uniform','exp15C_126_uniform','exp20C_126_uniform', 
                'exp20C_126_force','exp20C_126_extended','exp20C_126_max','exp20C_126_all7','exp20C_126_all7_b',
                'exp15C_126_all10',
                'exp20C_126_all7_baseAnoms',
                'exp19C_126_all7','exp19C_126_all7_smooth','exp20C_126_smooth',
                'exp15C_126_noM6','exp15C_126_test','exp15C_126_noSH',
                'exp15C_370_smooth','exp13C_126','exp13C_370',
                'exp15C_126_nohigh10','exp15C_126_nohigh7','exp15C_126_nohigh5',
                'exp15C_126_smooth_nohigh10','exp15C_126_smooth_nohigh7','exp15C_126_smooth_nohigh5',
)

LOOP_THROUGH_EXP = True
SAVE_FILE = True
LOAD_METRICS = True
OVERWRITE = False

#-------------------------------------------------------

MODEL_DIRECTORY = 'saved_models/'        
PREDICTIONS_DIRECTORY = 'saved_predictions/'
DATA_DIRECTORY = 'data/'
DIAGNOSTICS_DIRECTORY = 'model_diagnostics/'
FIGURE_DIRECTORY = 'figures/'

## Plotting Functions

In [None]:
FS = 10
palette=("tab:gray","tab:purple","tab:orange","tab:blue","tab:red","tab:green","tab:pink","tab:brown","tab:olive")

### for white background...
# plt.rc('text',usetex=True)
plt.rc('text',usetex=False)
# plt.rc('font',**{'family':'sans-serif','sans-serif':['Avant Garde']}) 
plt.rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 
plt.rc('savefig',facecolor='white')
plt.rc('axes',facecolor='white')
plt.rc('axes',labelcolor='dimgrey')
plt.rc('axes',labelcolor='dimgrey')
plt.rc('xtick',color='dimgrey')
plt.rc('ytick',color='dimgrey')
################################  
################################  
def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 5))
        else:
            spine.set_color('none')  
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        ax.yaxis.set_ticks([])
    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        ax.xaxis.set_ticks([]) 

def format_spines(ax):
    adjust_spines(ax, ['left', 'bottom'])
    ax.spines['top'].set_color('none')
    ax.spines['right'].set_color('none')
    ax.spines['left'].set_color('dimgrey')
    ax.spines['bottom'].set_color('dimgrey')
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.tick_params('both',length=4,width=2,which='major',color='dimgrey')
#     ax.yaxis.grid(zorder=1,color='dimgrey',alpha=0.35)    
    

## Get obs

In [None]:
# load observations for diagnostics plotting
settings = experiment_settings.get_settings("exp0")
da_obs, x_obs, global_mean_obs = data_processing.get_observations(DATA_DIRECTORY, settings)

settings["obsdata"] = 'GISS'
da_obs_giss, x_obs_giss, global_mean_obs_giss = data_processing.get_observations(DATA_DIRECTORY, settings)

## Analyze CMIP results across random seeds

In [None]:
df_metrics = pd.DataFrame()
if LOAD_METRICS == True:
    df_metrics = pd.read_pickle(PREDICTIONS_DIRECTORY + "df_random_seed.pickle")

    
if LOOP_THROUGH_EXP == True:
    for exp_name in EXP_NAME_VEC:
        settings = experiment_settings.get_settings(exp_name)
        rng = np.random.default_rng(settings["rng_seed"])    
        print(exp_name)

        for iloop in np.arange(settings['n_models']):
            seed = rng.integers(low=1_000,high=10_000,size=1)[0]
            settings["seed"] = int(seed)
            tf.random.set_seed(settings["seed"])
            np.random.seed(settings["seed"])

            # check if entry exists
            if LOAD_METRICS == True:
                entry = df_metrics[(df_metrics["exp_name"]==settings["exp_name"]) & (df_metrics["seed"]==settings["seed"])]
                if OVERWRITE == True:
                    print('removing entry: ')
                    display(entry)
                    df_metrics=df_metrics.drop(index=entry.index,)                    
                elif (len(entry) > 0):
                    continue
            
            # get model name
            model_name = file_methods.get_model_name(settings)
            if os.path.exists(MODEL_DIRECTORY + model_name + "_model") == False:               
                continue
            model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)
            # get the data
            (x_train, 
             x_val, 
             x_test, 
             y_train, 
             y_val, 
             y_test, 
             onehot_train, 
             onehot_val, 
             onehot_test, 
             y_yrs_train, 
             y_yrs_val, 
             y_yrs_test, 
             target_years, 
             map_shape,
             settings) = data_processing.get_cmip_data(DATA_DIRECTORY, settings, verbose=0)

            #----------------------------------------        
            # make predictions for observations and cmip results
            pred_train = model.predict(x_train)
            pred_val = model.predict(x_val)
            pred_test = model.predict(x_test) 
            pred_obs = model.predict(x_obs)
            pred_obs_giss = model.predict(x_obs_giss)

            #----------------------------------------        
            # compute metrics to compare
            error_val = np.mean(np.abs(pred_val[:,0] - onehot_val[:,0]))
            error_test = np.mean(np.abs(pred_test[:,0] - onehot_test[:,0]))    
            __, __, d_val, __ = custom_metrics.compute_pit(onehot_val, x_data=x_val, model_shash = model)
            __, __, d_test, __ = custom_metrics.compute_pit(onehot_test, x_data=x_test, model_shash = model)    
            __, __, d_valtest, __ = custom_metrics.compute_pit(np.append(onehot_val,onehot_test,axis=0), x_data=np.append(x_val,x_test,axis=0), model_shash = model)    
            loss_val = network.RegressLossExpSigma(onehot_val,pred_val).numpy()
            loss_test = network.RegressLossExpSigma(onehot_test,pred_test).numpy()
            
            d = {}
            d["exp_name"] = settings["exp_name"]
            d["seed"] = settings["seed"]
            d["hiddens"] = str(settings["hiddens"])
            d["ridge_param"] = settings["ridge_param"][0]        
            d["error_val"] = error_val
            d["error_test"] = error_test
            d["loss_val"] = loss_val
            d["loss_test"] = loss_test            
            d["d_val"] = d_val
            d["d_test"] = d_test
            d["d_valtest"] = d_valtest
            d["best_2021_mu"] = pred_obs[-1][0]
            d["best_2021_sigma"] = pred_obs[-1][1]            
            d["giss_2021_mu"] = pred_obs_giss[-1][0]
            d["giss_2021_sigma"] = pred_obs_giss[-1][1]            

            df = pd.DataFrame(d, index=[0])
            df_metrics = pd.concat([df_metrics,df])

    # there should NOT be any duplicates        
    df_duplicated = df_metrics.duplicated(subset=("exp_name","seed"))
    if(len(df_duplicated[df_duplicated==True]) > 0):
        display(df_duplicated)
        raise ValueError('there are duplicated entries!')
    df_metrics = df_metrics.drop_duplicates(ignore_index=True, keep="last", subset=("exp_name","seed"))        
    
    if SAVE_FILE:
        df_metrics.to_pickle(PREDICTIONS_DIRECTORY + "df_random_seed_rev2.pickle")
        
    display(df_metrics)

In [None]:
error('here')

## Random seeds for obs

In [None]:
# PLOT ACROSS SSPs and TARGETS
EXP_FOR_PLOTTING = ('exp11C_370','exp11C_245','exp11C_126','exp15C_370','exp15C_245','exp15C_126','exp20C_370','exp20C_245','exp20C_126')
clr_order = [0,2,1,3,4,5,6,7,8,]
x_labels = EXP_FOR_PLOTTING
#------------------------------------------------------------
fig, ax = plt.subplots(1,1,figsize=(6.5,2.75))

for obs_type in ('best',):
    for iexp,exp_name in enumerate(EXP_FOR_PLOTTING):
        iplot = np.where(df_metrics["exp_name"]==exp_name)[0]

        if obs_type=='giss':
            alpha = 0.3
            shift_extra = .05
            clr = np.array(palette)[clr_order][iexp]
        else:
            alpha = 1.0
            shift_extra = 0.
            clr = np.array(palette)[clr_order][iexp]
        
        ax.errorbar(np.ones(iplot.shape)*iexp+np.linspace(-.4,.4,len(iplot))+shift_extra,
                df_metrics.iloc[iplot][obs_type + "_2021_mu"]+2021,
                yerr=df_metrics.iloc[iplot][obs_type + "_2021_sigma"],
                color=clr,
                marker='.',
                linestyle='',
                elinewidth=.25,
                markersize=2,
                alpha=alpha,
               )
        
        # plot the text above the bars
        if obs_type=='best':
            max_y_value = np.max(df_metrics.iloc[iplot][obs_type + "_2021_mu"]+df_metrics.iloc[iplot][obs_type + "_2021_sigma"])
            if exp_name=='exp15C_370':
                add_val = 3
            elif exp_name=='exp20C_126' or exp_name=='exp15C_126':
                add_val = -1
            else:
                add_val = 2
            text_name = 'SSP'+exp_name[7]+'-'+exp_name[8] + '.' + exp_name[9] + '\n' +exp_name[3] + '.' + exp_name[4]+'C'

            plt.text(iexp,
                     max_y_value+add_val+2021,
                     text_name,
                     fontsize=FS*0.8,
                     color=np.array(palette)[clr_order][iexp],
                     horizontalalignment='center',
                    )


ax.set_ylabel('year threshold is reached')
ax.set_title('Observations 2021 Predicted $\mu \pm \sigma$')
ax.set_xlabel(None)
format_spines(ax)
ax.set_xticks(np.arange(0,len(x_labels)),'', fontsize=FS*0.8,rotation=45)
ax.set_yticks(np.arange(1950,2100,10),np.arange(1950,2100,10).round())
ax.set_ylim(-10+2021,60+2021)
plt.grid(which='major',axis='y',linewidth=.25,linestyle='--',alpha=.5)

plt.tight_layout()
# plots.savefig(FIGURE_DIRECTORY + 'obs_BEST_GISS' + '_params_ssp_target_comparison',dpi=savefig_dpi)
plots.savefig(FIGURE_DIRECTORY + 'obs_BEST' + '_params_ssp_target_comparison',dpi=savefig_dpi)
plt.show()


In [None]:

norm_incs = np.arange(-80,80,1)
#------------------------------------------------------------
fig, axs = plt.subplots(3,1,figsize=(4.9,2.75*3))

PLOT_SEED = 2247
obs_type = 'best'


for thresh in (1.1, 1.5, 2.0):
    if thresh==1.1:
        EXP_FOR_PLOTTING = ('exp11C_370','exp11C_245','exp11C_126')   
        ax = axs[0]
        thresh_text = '1.1C'
        text_x = 2040
    elif thresh==1.5:
        EXP_FOR_PLOTTING = ('exp15C_370','exp15C_245','exp15C_126')        
        ax = axs[1]
        thresh_text = '1.5C'        
        text_x = 2050        
    elif thresh==2.0:
        EXP_FOR_PLOTTING = ('exp20C_370','exp20C_245','exp20C_126')        
        ax = axs[2]       
        thresh_text = '2.0C'        
        text_x = 2068        
    else:
        raise ValueError('no such threshold')
    for iexp,exp_name in enumerate(EXP_FOR_PLOTTING):
        iplot = np.where((df_metrics["exp_name"]==exp_name) & (df_metrics["seed"]==PLOT_SEED))[0]

        mu_pred = df_metrics.iloc[iplot][obs_type + "_2021_mu"].values[0]
        sigma_pred = df_metrics.iloc[iplot][obs_type + "_2021_sigma"].values[0]
        norm_dist = tfp.distributions.Normal(mu_pred,sigma_pred)
        norm_perc_low = norm_dist.quantile(.25).numpy()   
        norm_perc_high = norm_dist.quantile(.75).numpy()      
        norm_perc_med = norm_dist.quantile(.5).numpy()      
        norm_cpd = norm_dist.prob(norm_incs)

        if(df_metrics.iloc[iplot]["exp_name"].values[0][-3:]=='370'):
            clr = "tab:red"
            # text_x = 2040
            text_y = .05
            ssp_text = 'SSP3-7.0'
        elif(df_metrics.iloc[iplot]["exp_name"].values[0][-3:]=='245'):
            clr = "tab:purple"
            # text_x = 2050
            text_y = .035        
            ssp_text = 'SSP2-4.5'        
        else:
            clr = "tab:blue"
            # text_x = 2070
            text_y = .02        
            ssp_text = 'SSP1-2.6'        

        ax.plot(norm_incs+2021,
                 norm_cpd,
                 color=clr,
                 linewidth=2.5,
                )

    #         # plot the text above the bars
        text_name = ssp_text + '\n' + str(int(np.round(mu_pred+2021))) + ' (' + str(int(np.round(mu_pred+2021-sigma_pred))) + ' to ' + str(int(np.round(mu_pred+2021+sigma_pred))) + ')'
        ax.text(text_x,
                 text_y,
                 text_name,
                 fontsize=FS*0.8,
                 color=clr,
                 horizontalalignment='left',
                )
    ax.text(1998,
             .1,
             thresh_text + ' threshold',
             fontsize=FS,
             horizontalalignment="left",
             verticalalignment="top",
             color='k', 
             # weight="bold",
            )    

    ax.set_xlabel('year')
    format_spines(ax)
    ax.set_xlim(-25+2021,70+2021)
    ax.set_ylim(-0.001,.1)

plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + 'obs_2021PDF_allSSPs',dpi=savefig_dpi)
plt.show()


## Plots across hyperparameters

### Plots across obs predictions

In [None]:
# PLOT ACROSS RIDGE CHOICES
EXP_FOR_PLOTTING = ('exp12','exp30','exp31','exp32','exp33')
df_metrics_plot = df_metrics[df_metrics["exp_name"].isin(EXP_FOR_PLOTTING)]
df_metrics_plot = df_metrics_plot.sort_values("ridge_param")
clr_order = [0,0,0,2,0,0,0,0,]
print('PARAMETER CHECK: ' + str(df_metrics_plot["hiddens"].unique()))
#------------------------------------------------------------
fig, axs = plt.subplots(1,2,figsize=(7,2.75))

ax = axs[0]
sns.swarmplot(x="exp_name",
              y="best_2021_mu",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax = ax,
            )
ax.set_ylabel('$\mu$')
ax.set_title('Obs. 2021 predicted $\mu$')
ax.set_ylim(0.0,25.0)
format_spines(ax)
ax.set_xlabel('ridge parameter',fontsize=FS)
x_labels = df_metrics_plot["ridge_param"].unique()
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)

ax = axs[1]
sns.swarmplot(x="exp_name",
              y="best_2021_sigma",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('Obs. 2021 predicted $\sigma$',fontsize=FS*1.2)
ax.set_ylim(0,None)
format_spines(ax)
ax.set_ylabel('$\sigma$')
ax.set_xlabel('ridge parameter',fontsize=FS)
x_labels = df_metrics_plot["ridge_param"].unique()
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)


plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + 'obsBEST_params_ridge_comparison',dpi=savefig_dpi)
plt.show()


In [None]:
# PLOT ACROSS HIDDEN CHOICES

EXP_FOR_PLOTTING = ('exp5','exp10','exp11','exp12','exp13','exp14')
# EXP_FOR_PLOTTING = ('exp4','exp20','exp21','exp22','exp23',)
df_metrics_plot = df_metrics[df_metrics["exp_name"].isin(EXP_FOR_PLOTTING)]
df_metrics_plot = df_metrics_plot.sort_values("hiddens")
x_labels = df_metrics_plot["hiddens"].unique()
x_labels[x_labels=='[2]'] = '[2]\nlinear'
clr_order = [0,0,0,2,0,0,0,0,]
print('PARAMETER CHECK: ' + str(df_metrics_plot["ridge_param"].unique()))
#------------------------------------------------------------
fig, axs = plt.subplots(1,2,figsize=(7,2.75))

ax = axs[0]
sns.swarmplot(x="exp_name",
              y="best_2021_mu",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax = ax,
            )
ax.set_ylabel('$\mu$')
ax.set_title('Obs. 2021 predicted $\mu$')
ax.set_ylim(0.0,20.0)
format_spines(ax)
ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)

ax = axs[1]
sns.swarmplot(x="exp_name",
              y="best_2021_sigma",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('Obs. 2021 predicted $\sigma$',fontsize=FS*1.2)
ax.set_ylim(2,None)
format_spines(ax)
ax.set_ylabel('$\sigma$')
ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)


plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + 'obsBEST_params_hiddens_comparison',dpi=savefig_dpi)
plt.show()


### Error and PIT

In [None]:
# PLOT ACROSS RIDGE CHOICES
EXP_FOR_PLOTTING = ('exp12','exp30','exp31','exp32','exp33')
df_metrics_plot = df_metrics[df_metrics["exp_name"].isin(EXP_FOR_PLOTTING)]
df_metrics_plot = df_metrics_plot.sort_values("ridge_param")
clr_order = [0,0,0,2,0,0,0,0,]

print('PARAMETER CHECK: ' + str(df_metrics_plot["hiddens"].unique()))
#------------------------------------------------------------
fig, axs = plt.subplots(1,3,figsize=(8.5,2.5))

ax = axs[0]
sns.boxplot(x="exp_name",
            y="error_val",
            palette=np.array(palette)[clr_order],
            data=df_metrics_plot,
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            whis=100., 
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="error_test",
              palette=np.array(palette)[clr_order],
              data=df_metrics_plot,
              size=2.5,
              ax = ax,
            )
ax.set_ylabel('error (years)')
ax.set_title('Mean Absolute Error')
ax.set_ylim(2.0,5.0)
format_spines(ax)
ax.set_xlabel('ridge parameter',fontsize=FS)
x_labels = df_metrics_plot["ridge_param"].unique()
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)

ax = axs[1]
sns.boxplot(x="exp_name",
            y="d_val",
            data=df_metrics_plot,
            palette=np.array(palette)[clr_order],
            whis=100., 
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="d_test",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('PIT D Metric',fontsize=FS*1.2)
ax.set_ylim(0,.055)
format_spines(ax)
ax.set_xlabel('ridge parameter',fontsize=FS)
x_labels = df_metrics_plot["ridge_param"].unique()
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)

ax = axs[2]
sns.boxplot(x="exp_name",
            y="loss_val",
            data=df_metrics_plot,
            palette=np.array(palette)[clr_order],
            whis=100., 
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="loss_test",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('Loss',fontsize=FS*1.2)
ax.set_ylim(2.,5.0)
format_spines(ax)
ax.set_xlabel('ridge parameter',fontsize=FS)
x_labels = df_metrics_plot["ridge_param"].unique()
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)



plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + 'cmip6_metrics_ridge_comparison',dpi=savefig_dpi)
plt.show()


In [None]:
# PLOT ACROSS HIDDEN CHOICES
EXP_FOR_PLOTTING = ('exp5','exp10','exp11','exp12','exp13','exp14')
df_metrics_plot = df_metrics[df_metrics["exp_name"].isin(EXP_FOR_PLOTTING)]
df_metrics_plot = df_metrics_plot.sort_values("hiddens")
x_labels = df_metrics_plot["hiddens"].unique()
x_labels[x_labels=='[2]'] = '[2]\nlinear'
clr_order = [0,0,0,2,0,0,0,0,]

print('PARAMETER CHECK: ' + str(df_metrics_plot["ridge_param"].unique()))
#------------------------------------------------------------
fig, axs = plt.subplots(1,3,figsize=(8.5,2.5))

ax = axs[0]
sns.boxplot(x="exp_name",
            y="error_val",
            data=df_metrics_plot,
            palette=np.array(palette)[clr_order],
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            whis=100., 
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="error_test",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax = ax,
            )
ax.set_ylabel('error (years)')
ax.set_title('Mean Absolute Error')
ax.set_ylim(2.0,6.0)
format_spines(ax)
ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)

ax = axs[1]
sns.boxplot(x="exp_name",
            y="d_val",
            data=df_metrics_plot,
            palette=np.array(palette)[clr_order],
            whis=100., 
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="d_test",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('PIT D Metric',fontsize=FS*1.2)
ax.set_ylim(0,.08)
format_spines(ax)
ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)

ax = axs[2]
sns.boxplot(x="exp_name",
            y="loss_val",
            data=df_metrics_plot,
            palette=np.array(palette)[clr_order],
            whis=100., 
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="loss_test",
              data=df_metrics_plot,
              palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('Loss',fontsize=FS*1.2)
ax.set_ylim(2.0,5.0)
format_spines(ax)
ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)


plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + 'cmip6_metrics_hiddens_comparison',dpi=savefig_dpi)
plt.show()


## Plot all hyperparameter experiments

In [None]:
# PLOT ACROSS ALL EXPERIMENTS
EXP_FOR_PLOTTING = EXP_NAME_VEC
df_metrics_plot = df_metrics[df_metrics["exp_name"].isin(EXP_FOR_PLOTTING)]
x_labels = df_metrics_plot["exp_name"].unique()
clr_order = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,]
#------------------------------------------------------------
fig, axs = plt.subplots(1,3,figsize=(11,3.))

ax = axs[0]
sns.boxplot(x="exp_name",
            y="error_val",
            data=df_metrics_plot,
            # palette=np.array(palette)[clr_order],
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            whis=100., 
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="error_test",
              data=df_metrics_plot,
              # palette=np.array(palette)[clr_order],
              size=2.5,
              ax = ax,
            )
ax.set_yticks(np.arange(0,10,.5))
ax.set_ylabel('error (years)')
ax.set_title('Mean Absolute Error')
ax.set_ylim(2.5,5.5)
format_spines(ax)
# ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)
ax.grid(alpha=.3)

ax = axs[1]
sns.boxplot(x="exp_name",
            y="d_val",
            data=df_metrics_plot,
            # palette=np.array(palette)[clr_order],
            whis=100., 
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="d_test",
              data=df_metrics_plot,
              # palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('PIT D Metric',fontsize=FS*1.2)
ax.set_ylim(0,.06)
format_spines(ax)
# ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)
ax.grid(alpha=.3)

ax = axs[2]
sns.boxplot(x="exp_name",
            y="loss_val",
            data=df_metrics_plot,
            # palette=np.array(palette)[clr_order],
            whis=100., 
            boxprops=dict(alpha=.3, edgecolor='gray',linewidth=1.),
            whiskerprops=dict(color='gray',linewidth=1.),
            medianprops=dict(color='gray',linewidth=1.),
            capprops=dict(color='gray',linewidth=1.),            
            ax = ax,
            )
sns.swarmplot(x="exp_name",
              y="loss_test",
              data=df_metrics_plot,
              # palette=np.array(palette)[clr_order],
              size=2.5,
              ax=ax,
            )
ax.set_title('Loss',fontsize=FS*1.2)
ax.set_ylim(2.0,5.0)
format_spines(ax)
# ax.set_xlabel('hidden layers x nodes',fontsize=FS)
ax.set_xticks(np.arange(0,len(x_labels)),x_labels, fontsize=FS*0.8,rotation=45)
ax.grid(alpha=.3)

plt.tight_layout()
plots.savefig(FIGURE_DIRECTORY + 'cmip6_metrics_all_comparison',dpi=savefig_dpi)
plt.show()


## Explore the dataframe

In [None]:
EXP_NAME = 'exp12'
df = df_metrics[df_metrics["exp_name"]==EXP_NAME]
PLOT_SEED = df_metrics.iloc[df['loss_test'].idxmin()]["seed"]
# display(df_metrics.iloc[df['loss_test'].idxmin()])
display(df.sort_values("loss_val"))
# df['loss_val'].idxmax()
# display(df_metrics[df_metrics["exp_name"]=="exp4"].sort_values("error_val").head())
# PLOT_SEED = 1257


In [None]:
# display(df_metrics[df_metrics["exp_name"]=="exp4"].sort_values("error_val").head())
# display(df_metrics[df_metrics["exp_name"]=="exp4"].sort_values("d_val").head())
# display(df_metrics[df_metrics["exp_name"]=="exp4"].sort_values("error_test").head())