# Compute metrics for different runs and plot them
##### author: Elizabeth A. Barnes, Randal J. Barnes and Mark DeMaria
##### version: v0.1.0

```
conda create --name env-hurr-tfp python=3.9
conda activate env-hurr-tfp
pip install tensorflow==2.7.0
pip install tensorflow-probability==0.15.0
pip install --upgrade numpy scipy pandas statsmodels matplotlib seaborn 
pip install --upgrade palettable progressbar2 tabulate icecream flake8
pip install --upgrade keras-tuner sklearn
pip install --upgrade jupyterlab black isort jupyterlab_code_formatter
pip install silence-tensorflow
pip install tqdm
```

Use the command
```python -m pip freeze > requirements.txt```
to make a pip installation list.

In [1]:
import datetime
import os
import pickle
import pprint
import time

import experiment_settings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import shash
from build_data import build_hurricane_data
from build_model import build_shash_model, build_bnn_model, build_mcdrop_model
import model_diagnostics
from silence_tensorflow import silence_tensorflow
import prediction
from sklearn.neighbors import KernelDensity
import pandas as pd
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

silence_tensorflow()
dpiFig = 400

In [2]:
__author__ = "Randal J Barnes and Elizabeth A. Barnes"
__version__ = "20 January 2022"

EXP_NAME_LIST = (
                 # "intensity4_EPCP72",
                 # "intensity5_EPCP72",    
                 # "intensity14_EPCP72",# mcdrop 75% [100,50]  
                 # "intensity16_EPCP72",  # mcdrop 75% [300,200] 
                 # "intensity19_EPCP72",  # mcdrop 95% [300,200]     
                 # "intensity17_EPCP72",
                 # "intensity18_EPCP72",
                 # "intensity20_EPCP72",    
                 # "intensity21_EPCP72",        # mcdrop 75% [500,300] 
                 "intensity24_EPCP72",    
                 "intensity25_EPCP72",        
                 "intensity23_EPCP72",        
                 # "intensity1000_EPCP72",  
                 # "intensity1001_EPCP72",      
                )
# APPEND_NAME = '_mcdropComparison'

# EXP_NAME_LIST = (
#                  "intensity8_AL96",
#                  "intensity9_AL96",
#                  "intensity10_EPCP96",
#                  "intensity11_EPCP96",
#                  "intensity0_AL72",
#                  "intensity1_AL72",
#                  "intensity4_EPCP72",
#                  "intensity5_EPCP72",
#                  "intensity12_AL48",
#                  "intensity13_AL48",    
#                  "intensity2_EPCP48",
#                  "intensity3_EPCP48",
#                 )
# APPEND_NAME = ''


# EXP_NAME_LIST = (
#                  "intensity204_AL72",
#                  "intensity205_AL72",
#                  "intensity202_EPCP72",
#                  "intensity203_EPCP72",            
#                  "intensity206_AL48",
#                  "intensity207_AL48",
#                  "intensity200_EPCP48",
#                  "intensity201_EPCP48",    
#                 )
# APPEND_NAME = '_clusterExtrapolation'
# APPEND_NAME = '_clusterExtrapolationAllClusters'

APPEND_NAME = '_mcdrop_no2020'


OVERWRITE_METRICS = False
DATA_PATH = "data/"
MODEL_PATH = "saved_models/"
METRIC_PATH = "model_metrics/"
FIGURE_PATH = "figures/summary_plots/"

In [3]:
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

## Define get_metrics()

In [4]:
def get_metrics():
    tf.random.set_seed(network_seed)
    shash_incs = np.arange(-160,161,1)

    if settings["uncertainty_type"] == "bnn" or settings["uncertainty_type"] == "mcdrop":       
        # loop through runs for bnn calculation    
        runs = 5_000
        bins_plot = np.linspace(np.min(shash_incs), np.max(shash_incs), 1000)
        bnn_cpd = np.zeros((np.shape(x_eval)[0],runs))
        bnn_mode = np.zeros((np.shape(x_eval)[0],))

        for i in tqdm(range(0,runs)):
            if settings["uncertainty_type"] == "bnn":
                bnn_cpd[:,i] = np.reshape(model_bnn.predict(x_eval),np.shape(bnn_cpd)[0])
            elif settings["uncertainty_type"] == "mcdrop":
                # np.stack([model_mcdrop(x_eval,training=True) for sample in range (5_000)])
                bnn_cpd[:,i] = np.reshape(model_mcdrop(x_eval,training=True),np.shape(bnn_cpd)[0])                
            else:
                raise NotImplementedError
                
        bnn_mean = np.mean(bnn_cpd,axis=1)
        bnn_median = np.median(bnn_cpd,axis=1)

        for j in tqdm(range(0,np.shape(bnn_mode)[0])):
            kde = KernelDensity(kernel="gaussian", bandwidth=4.).fit(bnn_cpd[j,:].reshape(-1,1))
            log_dens = kde.score_samples(bins_plot.reshape(-1,1))
            i = np.argmax(log_dens)
            bnn_mode[j] = bins_plot[i]

        mean_error, median_error, mode_error = model_diagnostics.compute_errors(onehot_eval, bnn_mean, bnn_median, bnn_mode)         
        bins, hist_bnn, pit_D, EDp_bnn = model_diagnostics.compute_pit('bnn',onehot_eval, bnn_cpd)
        iqr_capture = model_diagnostics.compute_interquartile_capture('bnn',onehot_eval, bnn_cpd)
        iqr_error_spearman, iqr_error_pearson = model_diagnostics.compute_iqr_error_corr('bnn',
                                                                                          onehot_val=onehot_eval, 
                                                                                          bnn_cpd=bnn_cpd, 
                                                                                          pred_median=bnn_median,
                                                                                         )
        
    else:        
        shash_cpd = np.zeros((np.shape(x_eval)[0],len(shash_incs)))
        shash_mean = np.zeros((np.shape(x_eval)[0],))
        shash_med = np.zeros((np.shape(x_eval)[0],))
        shash_mode = np.zeros((np.shape(x_eval)[0],))

        # loop through samples for shash calculation and get PDF for each sample
        for j in tqdm(range(0,np.shape(shash_cpd)[0])):
            mu_pred, sigma_pred, gamma_pred, tau_pred = prediction.params( x_eval[np.newaxis,j], model_shash )
            shash_cpd[j,:] = shash.prob(shash_incs, mu_pred, sigma_pred, gamma_pred, tau_pred)    
            shash_mean[j]  = shash.mean(mu_pred,sigma_pred,gamma_pred,tau_pred)#np.sum(shash_cpd[j,:]*shash_incs)
            shash_med[j]   = shash.median(mu_pred,sigma_pred,gamma_pred,tau_pred)

            i = np.argmax(shash_cpd[j,:])
            shash_mode[j]  = shash_incs[i]

        mean_error, median_error, mode_error = model_diagnostics.compute_errors(onehot_eval, shash_mean, shash_med, shash_mode)    
        bins, hist_shash, pit_D, EDp_shash = model_diagnostics.compute_pit('shash',onehot_eval, x_val=x_eval,model_shash=model_shash)
        iqr_capture = model_diagnostics.compute_interquartile_capture('shash',onehot_eval, x_val=x_eval,model_shash=model_shash)
        iqr_error_spearman, iqr_error_pearson = model_diagnostics.compute_iqr_error_corr('shash',
                                                                                                onehot_val=onehot_eval,
                                                                                                pred_median=shash_med,
                                                                                                x_val=x_eval,
                                                                                                model_shash=model_shash,
                                                                                               )

    # by definition Consensus is a correction of zero
    cons_error = np.mean(np.abs(0.0 - onehot_eval[:,0]))
        
    return mean_error,median_error,mode_error,pit_D,iqr_capture,cons_error,iqr_error_spearman,iqr_error_pearson


## Evaluate the models

In [5]:
raise ValueError('do not compute metrics yet')

ValueError: do not compute metrics yet

In [None]:
# # y_predict = np.stack([model_mcdrop(x_eval,training=True) for sample in range (5_000)])
# # np.shape(y_predict)

# isample = 20
# plt.hist(y_predict[:,isample,0])
# plt.axvline(x=df_eval['OBDV'].to_numpy()[isample],linestyle='--')
# plt.show()

In [6]:
import imp
imp.reload(model_diagnostics)

for exp_name in EXP_NAME_LIST:
    settings = experiment_settings.get_settings(exp_name)
    RNG_SEED_LIST = np.copy(settings['rng_seed'])

    for rng_seed in RNG_SEED_LIST:
        settings['rng_seed'] = rng_seed
        NETWORK_SEED_LIST = [settings["rng_seed"]]
        network_seed = NETWORK_SEED_LIST[0]
        tf.random.set_seed(network_seed)  # This sets the global random seed.    

        #----------------------------------------------------------------------------------------------------
        # get the data
        (
            x_train,
            onehot_train,
            x_val,
            onehot_val,
            x_eval,
            onehot_eval,    
            data_summary,
            df_val,
            df_eval,
        ) = build_hurricane_data(DATA_PATH, settings, verbose=0)

        #----------------------------------------------------------------------------------------------------
        # get the model
        if settings["uncertainty_type"] == "bnn":       
            model_name_bnn = (
                exp_name + "_" + settings["uncertainty_type"] + '_' + f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
            )
            model_bnn = build_bnn_model(
                x_train,
                onehot_train,
                hiddens=settings["hiddens"],
                output_shape=onehot_train.shape[1],
                ridge_penalty=settings["ridge_param"],
                act_fun=settings["act_fun"],
            )
            try:
                model_bnn.load_weights(MODEL_PATH + model_name_bnn + "_weights.h5")
            except:
                print('tf model does not exist. skipping...')
                continue
            model_name = model_name_bnn
            
        elif settings["uncertainty_type"] == "mcdrop":       
            model_name_mcdrop = (
                exp_name + "_" + settings["uncertainty_type"] + '_' + f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
            )
            model_mcdrop = build_mcdrop_model(
                x_train,
                onehot_train,
                dropout_rate=settings["dropout_rate"],                
                hiddens=settings["hiddens"],
                output_shape=onehot_train.shape[1],
                ridge_penalty=settings["ridge_param"],
                act_fun=settings["act_fun"],
            )
            try:
                model_mcdrop.load_weights(MODEL_PATH + model_name_mcdrop + "_weights.h5")
            except:
                print('tf model does not exist. skipping...')
                continue
            model_name = model_name_mcdrop

        elif settings["uncertainty_type"][:5] == "shash": 
            model_name_shash = (
                exp_name + "_" + settings["uncertainty_type"] + '_' + f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
            )
            model_shash = build_shash_model(
                x_train,
                onehot_train,
                hiddens=settings["hiddens"],
                output_shape=onehot_train.shape[1],
                ridge_penalty=settings["ridge_param"],
                act_fun=settings["act_fun"],
            )
            try:
                model_shash.load_weights(MODEL_PATH + model_name_shash + "_weights.h5")
            except:
                print('tf model does not exist. skipping...')
                continue
            model_name = model_name_shash

        #----------------------------------------------------------------------------------------------------
        # check if the metric filename exists already
        if(APPEND_NAME == "_clusterExtrapolationAllClusters"):
            metric_filename = METRIC_PATH + model_name + '_metrics_allClusters.pkl'
        else:
            metric_filename = METRIC_PATH + model_name + '_metrics.pkl'  
            
        if (os.path.exists(metric_filename) and OVERWRITE_METRICS==False):
            # print(metric_filename + ' exists. Skipping...')
            continue
            
        # get metrics and put into a dictionary
        pprint.pprint(model_name)
        
        # if running out of cluster comparison, grab only the cluster that was not seen during training
        if((settings["train_condition"] in ('cluster','no_2020')) and APPEND_NAME != "_clusterExtrapolationAllClusters"):
            print('grabbing out of sample ' + 'CLUSTER')
            i_index = np.where(data_summary["cluster_eval"] == data_summary["cluster_out"])[0]
            x_eval = x_eval[i_index,:]
            onehot_eval = onehot_eval[i_index,:]
            df_eval=df_eval.iloc[i_index]       
        
        # compute the metrics
        mean_error, median_error, mode_error, pit_D, iqr_capture, cons_error, iqr_error_spearman,iqr_error_pearson = get_metrics()        
        d = {'uncertainty_type': settings["uncertainty_type"],
             'network_seed': network_seed,
             'rng_seed': settings['rng_seed'],
             'exp_name': exp_name,
             'mean_error': mean_error,
             'median_error': median_error,
             'mode_error': mode_error,
             'cons_error': cons_error,
             'pit_d': pit_D,
             'iqr_capture': iqr_capture,
             'iqr_error_spearman': iqr_error_spearman[0],
             'iqr_error_pearson': iqr_error_pearson[0],             
             'iqr_error_spearman_p': iqr_error_spearman[1],
             'iqr_error_pearson_p': iqr_error_pearson[1],             
        }
        pprint.pprint(d, width=80)  
        df = pd.DataFrame(data=d, index=[0])
        df.to_pickle(metric_filename)

'intensity24_EPCP72_shash3_network_seed_786_rng_seed_786'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 24.85it/s]


{'cons_error': 15.5199995,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.275,
 'iqr_error_pearson': 0.7409340969829505,
 'iqr_error_pearson_p': 3.925350237666821e-15,
 'iqr_error_spearman': 0.73328645100797,
 'iqr_error_spearman_p': 1.0384302082796517e-14,
 'mean_error': 14.852831913530826,
 'median_error': 14.761508359014988,
 'mode_error': 14.577499955892563,
 'network_seed': 786,
 'pit_d': 0.059686656591438236,
 'rng_seed': 786,
 'uncertainty_type': 'shash3'}
'intensity24_EPCP72_shash3_network_seed_311_rng_seed_311'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 23.95it/s]


{'cons_error': 15.520001,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.425,
 'iqr_error_pearson': 0.45327421744965934,
 'iqr_error_pearson_p': 2.4196308331190068e-05,
 'iqr_error_spearman': 0.4509845288326301,
 'iqr_error_spearman_p': 2.6903391093005504e-05,
 'mean_error': 14.193206411600112,
 'median_error': 14.131599602103233,
 'mode_error': 14.339999955892562,
 'network_seed': 311,
 'pit_d': 0.03400367512150567,
 'rng_seed': 311,
 'uncertainty_type': 'shash3'}
'intensity24_EPCP72_shash3_network_seed_888_rng_seed_888'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 24.31it/s]


{'cons_error': 15.5199995,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.475,
 'iqr_error_pearson': 0.12869667695010265,
 'iqr_error_pearson_p': 0.2552373294276671,
 'iqr_error_spearman': 0.28681770768153464,
 'iqr_error_spearman_p': 0.009896090170514852,
 'mean_error': 16.19446311891079,
 'median_error': 15.912954232096672,
 'mode_error': 15.357499957084656,
 'network_seed': 888,
 'pit_d': 0.03791435422849953,
 'rng_seed': 888,
 'uncertainty_type': 'shash3'}
'intensity24_EPCP72_shash3_network_seed_999_rng_seed_999'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 25.26it/s]


{'cons_error': 15.5199995,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.3625,
 'iqr_error_pearson': 0.6881001404676874,
 'iqr_error_pearson_p': 1.7624870738776882e-12,
 'iqr_error_spearman': 0.7192920768870137,
 'iqr_error_spearman_p': 5.6654789639802595e-14,
 'mean_error': 14.898902255296708,
 'median_error': 14.793109402060509,
 'mode_error': 14.599999952316285,
 'network_seed': 999,
 'pit_d': 0.05448621594103336,
 'rng_seed': 999,
 'uncertainty_type': 'shash3'}
'intensity24_EPCP72_shash3_network_seed_578_rng_seed_578'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 25.77it/s]


{'cons_error': 15.520001,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.3125,
 'iqr_error_pearson': 0.7714474772282929,
 'iqr_error_pearson_p': 5.608314438429943e-17,
 'iqr_error_spearman': 0.7738865447726209,
 'iqr_error_spearman_p': 3.882366392276879e-17,
 'mean_error': 14.669826517999173,
 'median_error': 14.54936280399561,
 'mode_error': 14.369999957084655,
 'network_seed': 578,
 'pit_d': 0.046770701006409746,
 'rng_seed': 578,
 'uncertainty_type': 'shash3'}
'intensity24_EPCP72_shash3_network_seed_331_rng_seed_331'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 25.51it/s]


{'cons_error': 15.5199995,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.3,
 'iqr_error_pearson': 0.8089877475533984,
 'iqr_error_pearson_p': 1.1094670946861355e-19,
 'iqr_error_spearman': 0.8019690576652602,
 'iqr_error_spearman_p': 3.9274915505865866e-19,
 'mean_error': 14.7242941737175,
 'median_error': 14.55645612180233,
 'mode_error': 14.342499947547912,
 'network_seed': 331,
 'pit_d': 0.06123721741106746,
 'rng_seed': 331,
 'uncertainty_type': 'shash3'}
'intensity24_EPCP72_shash3_network_seed_908_rng_seed_908'
grabbing out of sample CLUSTER


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:03<00:00, 25.51it/s]


{'cons_error': 15.5199995,
 'exp_name': 'intensity24_EPCP72',
 'iqr_capture': 0.25,
 'iqr_error_pearson': 0.7760033995072754,
 'iqr_error_pearson_p': 2.810888885879553e-17,
 'iqr_error_spearman': 0.7781762775433662,
 'iqr_error_spearman_p': 2.0104040768829403e-17,
 'mean_error': 15.726457534730434,
 'median_error': 15.550112229585647,
 'mode_error': 15.199999952316285,
 'network_seed': 908,
 'pit_d': 0.0707106496695263,
 'rng_seed': 908,
 'uncertainty_type': 'shash3'}
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model does not exist. skipping...
tf model 

In [None]:
raise ValueError('do not plot yet')

In [7]:
df_metrics = pd.DataFrame()

for exp_name in EXP_NAME_LIST:
    settings = experiment_settings.get_settings(exp_name)
    RNG_SEED_LIST = np.copy(settings['rng_seed'])
    
    for rng_seed in RNG_SEED_LIST:
        settings['rng_seed'] = rng_seed
        NETWORK_SEED_LIST = [settings["rng_seed"]]
        network_seed = NETWORK_SEED_LIST[0]
    
        model_name = (
                exp_name + "_" + settings["uncertainty_type"] + '_' + f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
        )
        
        # load the metric filename
        if(APPEND_NAME == "_clusterExtrapolationAllClusters"):
            metric_filename = METRIC_PATH + model_name + '_metrics_allClusters.pkl'
        else:
            metric_filename = METRIC_PATH + model_name + '_metrics.pkl'  
        if os.path.exists(metric_filename)==False:
            print(metric_filename + ' DOES NOT exist. Skipping...')
            continue
    
        # pprint.pprint(model_name)
        df = pd.read_pickle(metric_filename)
        df['basin_lead'] = exp_name[exp_name.rfind('_')+1:]
        df['mean_error_reduction'] = df['cons_error']-df['mean_error']
        df['median_error_reduction'] = df['cons_error']-df['median_error']        
        df['mode_error_reduction'] = df['cons_error']-df['mode_error']        
        df_metrics = pd.concat([df_metrics,df])
df_metrics

model_metrics/intensity24_EPCP72_shash3_network_seed_444_rng_seed_444_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_786_rng_seed_786_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_311_rng_seed_311_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_888_rng_seed_888_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_999_rng_seed_999_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_578_rng_seed_578_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_331_rng_seed_331_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_908_rng_seed_908_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity25_EPCP72_bnn_network_seed_444_rng_seed_444_metrics.pkl DOES NOT exist. Skipping...
model_metrics/intensity23_EPCP72_m

Unnamed: 0,uncertainty_type,network_seed,rng_seed,exp_name,mean_error,median_error,mode_error,cons_error,pit_d,iqr_capture,iqr_error_spearman,iqr_error_pearson,iqr_error_spearman_p,iqr_error_pearson_p,basin_lead,mean_error_reduction,median_error_reduction,mode_error_reduction
0,shash3,605,605,intensity24_EPCP72,14.981316,14.85844,14.645,15.520001,0.049371,0.325,0.703891,0.746417,3.269345e-13,1.913222e-15,EPCP72,0.538685,0.661561,0.875001
0,shash3,122,122,intensity24_EPCP72,14.135733,14.019756,14.05,15.52,0.049054,0.3625,0.793835,0.746973,1.5978690000000001e-18,1.776832e-15,EPCP72,1.384267,1.500244,1.47
0,shash3,786,786,intensity24_EPCP72,14.852832,14.761508,14.5775,15.52,0.059687,0.275,0.733286,0.740934,1.03843e-14,3.92535e-15,EPCP72,0.667168,0.758491,0.9425
0,shash3,311,311,intensity24_EPCP72,14.193206,14.1316,14.34,15.520001,0.034004,0.425,0.450985,0.453274,2.690339e-05,2.419631e-05,EPCP72,1.326795,1.388402,1.180001
0,shash3,888,888,intensity24_EPCP72,16.194463,15.912954,15.3575,15.52,0.037914,0.475,0.286818,0.128697,0.00989609,0.2552373,EPCP72,-0.674464,-0.392955,0.1625
0,shash3,999,999,intensity24_EPCP72,14.898902,14.793109,14.6,15.52,0.054486,0.3625,0.719292,0.6881,5.665479e-14,1.762487e-12,EPCP72,0.621097,0.72689,0.92
0,shash3,578,578,intensity24_EPCP72,14.669827,14.549363,14.37,15.520001,0.046771,0.3125,0.773887,0.771447,3.882366e-17,5.608314e-17,EPCP72,0.850175,0.970639,1.150001
0,shash3,331,331,intensity24_EPCP72,14.724294,14.556456,14.3425,15.52,0.061237,0.3,0.801969,0.808988,3.9274919999999996e-19,1.1094669999999998e-19,EPCP72,0.795705,0.963543,1.1775
0,shash3,908,908,intensity24_EPCP72,15.726458,15.550112,15.2,15.52,0.070711,0.25,0.778176,0.776003,2.0104040000000002e-17,2.8108890000000005e-17,EPCP72,-0.206458,-0.030113,0.32
0,bnn,605,605,intensity25_EPCP72,14.371846,14.356804,14.518611,15.520001,0.064469,0.4375,0.670394,0.67587,1.031588e-11,6.050739e-12,EPCP72,1.148156,1.163198,1.00139


In [8]:
import seaborn as sns
colors = ('#D95980','#284E60','#E1A730','#284E60')

x_axis_list = ("basin_lead", "exp_name")

for x_axis in x_axis_list:
    f, axs = plt.subplots(4, 2, figsize=(15,20))
    axs = axs.flatten()
    
    for imetric, metric in enumerate(('mean_error',
                                      'mean_error_reduction',
                                      'median_error',
                                      'median_error_reduction',                                  
                                      'iqr_error_spearman',
                                      'iqr_error_pearson',                                  
                                      # 'mode_error',
                                      # 'mode_error_reduction',                                  
                                      'pit_d',
                                      'iqr_capture')):
        ax = axs[imetric]
        g1 = sns.boxplot(x=x_axis, 
                         y=metric, 
                         hue="uncertainty_type",
                         data=df_metrics,
                         palette=colors,
                         boxprops={'alpha':.2,
                                   'edgecolor': 'white',
                                  },
                         fliersize=0,
                         ax=ax)
        g2 = sns.swarmplot(x=x_axis, 
                           y=metric, 
                           hue="uncertainty_type",
                           palette=colors,
                           data=df_metrics, 
                           dodge=True,
                           ax=ax)

        if(metric=='iqr_capture'):
            ax.axhline(y=0.5,linewidth=3,linestyle='--',color='gray')
            ax.set_ylim(0,1.0)
        if(metric=='pit_d'):
            ax.set_ylim(0,None)
        if(metric.find('reduction') > -1):
            ax.axhline(y=0.0,linewidth=3,linestyle='--',color='gray')
            ax.set_ylim(-4.,4.)
        if(metric.find('reduction') == -1 and metric.find('n_error') > -1):
            ax.set_ylim(0.,22.)
        if(metric.find('iqr_error')> -1):
            ax.set_ylim(-.3,1.)
            ax.axhline(y=0.0,linewidth=3,linestyle='--',color='gray')        


        ax.set_title(metric + APPEND_NAME)
        ax.legend(fontsize=10,frameon=True)
        ax.set_xticklabels(ax.get_xticklabels(),rotation = 30)

    plt.tight_layout()    
    plt.savefig(FIGURE_PATH + 'comparisonsMetrics' + APPEND_NAME + '_' + x_axis + '.png', dpi=dpiFig)    
    plt.close()
    # plt.show()