# Compute metrics for different runs and plot them
##### author: Elizabeth A. Barnes, Randal J. Barnes and Mark DeMaria
##### version: v0.2.0

In [1]:
import datetime
import os
import pickle
import pprint
import time

import experiment_settings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import shash
from build_data import build_hurricane_data
import build_model
import model_diagnostics
from silence_tensorflow import silence_tensorflow
import prediction
from sklearn.neighbors import KernelDensity
import pandas as pd
from tqdm import tqdm
import imp

import warnings
warnings.filterwarnings("ignore")

silence_tensorflow()
dpiFig = 400

mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [2]:
__author__  = "Randal J Barnes and Elizabeth A. Barnes"
__version__ = "17 March 2022"

EXP_NAME_LIST = (
    # "intensity1_AL48",
    # "intensity2_AL72",
    # "intensity3_AL96",    
    # "intensity4_EPCP48",
    "intensity201_AL24",
    # "intensity5_EPCP24",    
    # "intensity5_EPCP72",
    # "intensity6_EPCP96",
    # "intensity41_EPCP48",
    # "intensity42_EPCP48",      
)

OVERWRITE_METRICS = False
DATA_PATH = "data/"
MODEL_PATH = "saved_models/"
METRIC_PATH = "saved_metrics/"
PREDICTION_PATH = "saved_predictions/"

In [3]:
RI_THRESH_DICT = {24: 30,
                  48: 55,
                  72: 65,
                 }

## Compute Metrics

In [4]:
imp.reload(model_diagnostics)

for exp_name in EXP_NAME_LIST:
    settings = experiment_settings.get_settings(exp_name)

    # set testing data
    if settings["test_condition"] == "leave-one-out":
        TESTING_YEARS_LIST = np.arange(2013,2021)
    elif settings["test_condition"] == "years":
        TESTING_YEARS_LIST = (np.copy(settings["years_test"]))
    else:
        raise NotImplementError('no such testing condition')
        
    for testing_years in TESTING_YEARS_LIST:        
        # set testing year
        settings["years_test"] = (testing_years,)
        
        
        for rng_seed in settings['rng_seed_list']:
            settings['rng_seed'] = rng_seed
            NETWORK_SEED_LIST = [settings["rng_seed"]]
            network_seed = NETWORK_SEED_LIST[0]
            tf.random.set_seed(network_seed)  # This sets the global random seed.    

            #----------------------------------------------------------------------------------------------------
            # get the data
            (
                data_summary,        
                x_train,
                onehot_train,
                x_val,
                onehot_val,
                x_test,
                onehot_test,        
                x_valtest,
                onehot_valtest,
                df_train,
                df_val,
                df_test,
                df_valtest,
            ) = build_hurricane_data(DATA_PATH, settings, verbose=0)

            #----------------------------------------------------------------------------------------------------
            # get the model
            # Make, compile, and train the model
            tf.keras.backend.clear_session()            
            model = build_model.make_model(
                settings,
                x_train,
                onehot_train,
                model_compile=False,
            )   
            model_name = (
                exp_name + "_" + 
                str(testing_years) + '_' +
                settings["uncertainty_type"] + '_' + 
                f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
            )

            try:
                model.load_weights(MODEL_PATH + model_name + "_weights.h5")
            except:
                print(model_name + ': model does not exist. skipping...')
                continue

            #----------------------------------------------------------------------------------------------------
            # check if the metric filename exists already
            metric_filename = PREDICTION_PATH + model_name + '_testingPredictions.csv'              
            if (os.path.exists(metric_filename) and OVERWRITE_METRICS==False):
                # print(metric_filename + ' exists. Skipping...')
                continue

            # compute the climatological errors
            obs_dev_cons_hist, OBS_DEV_BINS = model_diagnostics.compute_clim_errors(
                onehot=np.append(onehot_train[:,0],onehot_val[:,0]), 
                smooth = True,
            )                
            
            # get metrics and put into a dictionary
            pprint.pprint(model_name)
            
            shash_incs = np.arange(-160,161,1)
            shash_cpd = np.zeros((np.shape(x_test)[0],len(shash_incs)))
            shash_mean = np.zeros((np.shape(x_test)[0],))
            shash_med = np.zeros((np.shape(x_test)[0],))
            shash_mode = np.zeros((np.shape(x_test)[0],))
            shash_25p = np.zeros((np.shape(x_test)[0],))
            shash_75p = np.zeros((np.shape(x_test)[0],))
            shash_90p = np.zeros((np.shape(x_test)[0],))
            shash_pr_ri = np.zeros((np.shape(x_test)[0],))
            clim_pr_ri = np.zeros((np.shape(x_test)[0],))


            # loop through samples for shash calculation and get PDF for each sample
            for j in tqdm(range(0,np.shape(shash_cpd)[0])):
                mu_pred, sigma_pred, gamma_pred, tau_pred = prediction.params( x_test[np.newaxis,j], model )
                shash_cpd[j,:] = shash.prob(shash_incs, mu_pred, sigma_pred, gamma_pred, tau_pred)    
                shash_mean[j]  = shash.mean(mu_pred,sigma_pred,gamma_pred,tau_pred)
                shash_med[j]   = shash.median(mu_pred,sigma_pred,gamma_pred,tau_pred)

                shash_25p[j] = shash.quantile(.25,mu_pred,sigma_pred,gamma_pred,tau_pred)
                shash_75p[j] = shash.quantile(.75,mu_pred,sigma_pred,gamma_pred,tau_pred)
                shash_90p[j] = shash.quantile(.9,mu_pred,sigma_pred,gamma_pred,tau_pred)    

                i = np.argmax(shash_cpd[j,:])
                shash_mode[j]  = shash_incs[i]
                # try:
                #     ri_threshold = RI_THRESH_DICT[settings["leadtime"]]
                #     shash_pr_ri[j] = model_diagnostics.compute_pr_ri(shash_incs,shash_cpd[j,:], ri_threshold)
                #     clim_pr_ri[j] = model_diagnostics.compute_pr_ri(OBS_DEV_BINS,obs_dev_cons_hist, ri_threshold)
                # except:
                #     shash_pr_ri[j] = np.nan
            

            # add predictions to the data_frame
            df_predictions = df_test.copy()                      
            df_predictions["shash_median"] = shash_med
            df_predictions["shash_mean"] = shash_mean
            df_predictions["shash_mode"] = shash_mode
            df_predictions["shash_25p"] = shash_25p
            df_predictions["shash_75p"] = shash_75p
            df_predictions["shash_90p"] = shash_90p
            # df_predictions["shash_pr_ri"] = shash_pr_ri
            # df_predictions["clim_pr_ri"] = clim_pr_ri

            df_predictions["shash_error"] = shash_med - onehot_test[:,0]
            df_predictions["cons_error"] = 0.0 - onehot_test[:,0]
            df_predictions["shash_improvement"] = df_predictions["cons_error"].abs() - df_predictions["shash_error"].abs()
                
            # save the dataframe    
            df_predictions.to_csv(metric_filename)

'intensity201_AL24_2013_shash3_network_seed_416_rng_seed_416'


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:07<00:00, 15.32it/s]


'intensity201_AL24_2013_shash3_network_seed_739_rng_seed_739'


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:06<00:00, 17.56it/s]


'intensity201_AL24_2014_shash3_network_seed_416_rng_seed_416'


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 112/112 [00:06<00:00, 16.16it/s]


intensity201_AL24_2014_shash3_network_seed_739_rng_seed_739: model does not exist. skipping...
intensity201_AL24_2015_shash3_network_seed_416_rng_seed_416: model does not exist. skipping...
intensity201_AL24_2015_shash3_network_seed_739_rng_seed_739: model does not exist. skipping...
intensity201_AL24_2016_shash3_network_seed_416_rng_seed_416: model does not exist. skipping...
intensity201_AL24_2016_shash3_network_seed_739_rng_seed_739: model does not exist. skipping...
intensity201_AL24_2017_shash3_network_seed_416_rng_seed_416: model does not exist. skipping...
intensity201_AL24_2017_shash3_network_seed_739_rng_seed_739: model does not exist. skipping...
intensity201_AL24_2018_shash3_network_seed_416_rng_seed_416: model does not exist. skipping...
intensity201_AL24_2018_shash3_network_seed_739_rng_seed_739: model does not exist. skipping...
intensity201_AL24_2019_shash3_network_seed_416_rng_seed_416: model does not exist. skipping...
intensity201_AL24_2019_shash3_network_seed_739_rng

In [8]:
df_predictions[["ATCF","shash_median","shash_90p","shash_error","cons_error","shash_improvement"]]

Unnamed: 0,ATCF,shash_median,shash_90p,shash_error,cons_error,shash_improvement
0,AL03,6.519927,17.692007,-4.680073,-11.200000,6.519927
1,AL08,9.851115,33.852802,-17.348886,-27.200001,9.851115
2,AL01,-4.559035,4.928643,14.640966,19.200001,4.559035
3,AL06,-2.508804,7.616689,6.991196,9.500000,2.508804
4,AL03,2.623022,8.893994,6.423022,3.800000,-2.623022
...,...,...,...,...,...,...
107,AL06,-4.629972,6.393348,5.870028,10.500000,4.629972
108,AL01,3.013457,19.346493,-2.986543,-6.000000,3.013457
109,AL06,2.197300,19.427334,-8.302700,-10.500000,2.197300
110,AL04,4.051876,18.572666,2.251876,-1.800000,-0.451876
