# Compute metrics for different runs and plot them
##### author: Elizabeth A. Barnes, Randal J. Barnes and Mark DeMaria
##### version: v0.2.0

In [1]:
import datetime
import os
import pickle
import pprint
import time

import experiment_settings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import shash_tfp
from build_data import build_hurricane_data
import build_model
import model_diagnostics
from silence_tensorflow import silence_tensorflow
import prediction
from sklearn.neighbors import KernelDensity
import pandas as pd
from tqdm import tqdm
import imp

import warnings
warnings.filterwarnings("ignore")

silence_tensorflow()
tf.config.set_visible_devices([], "GPU")  # turn-off tensorflow-metal if it is on
dpiFig = 400

mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [2]:
__author__  = "Randal J Barnes and Elizabeth A. Barnes"
__version__ = "16 September 2022"

EXP_NAME_LIST = (
                 "intensity201_AL24",
                 "intensity202_AL48",    
                 "intensity203_AL72",
                 "intensity204_AL96",    
                 "intensity205_AL120",

                 "intensity301_EPCP24",
                 "intensity302_EPCP48",
                 "intensity303_EPCP72",
                 "intensity304_EPCP96",
                 "intensity305_EPCP120",
                 )


OVERWRITE_METRICS = False
DATA_PATH = "data/"
MODEL_PATH = "saved_models/"
METRIC_PATH = "saved_metrics/"
PREDICTION_PATH = "saved_predictions/"

In [3]:
RI_THRESH_DICT = {24: 30,
                  48: 55,
                  72: 65,
                  96: None,
                  120: None,
                 }

## Compute Predictions for ALL of the data

In [None]:
imp.reload(model_diagnostics)

for exp_name in EXP_NAME_LIST:
    settings = experiment_settings.get_settings(exp_name)
    print(exp_name)

    # set testing data
    if settings["test_condition"] == "leave-one-out":
        TESTING_YEARS_LIST = np.arange(2013,2022)
    elif settings["test_condition"] == "years":
        TESTING_YEARS_LIST = (np.copy(settings["years_test"]))
    else:
        raise NotImplementedError('no such testing condition')
        
    for testing_years in TESTING_YEARS_LIST:        
        # set testing year
        settings["years_test"] = (testing_years,)
        
        
        for rng_seed in settings['rng_seed_list']:
            settings['rng_seed'] = rng_seed
            NETWORK_SEED_LIST = [settings["rng_seed"]]
            network_seed = NETWORK_SEED_LIST[0]
            tf.random.set_seed(network_seed)  # This sets the global random seed.    

            model_name = (
                exp_name + "_" + 
                str(testing_years) + '_' +
                settings["uncertainty_type"] + '_' + 
                f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
            )
            #----------------------------------------------------------------------------------------------------
            # check if the metric filename exists already
            metric_filename = PREDICTION_PATH + model_name + '_allPredictions.csv'
            if (os.path.exists(metric_filename) and OVERWRITE_METRICS==False):
                print(metric_filename + ' exists. Skipping...')
                continue
            
            #----------------------------------------------------------------------------------------------------
            # get the data
            (
                data_summary,        
                x_train,
                onehot_train,
                x_val,
                onehot_val,
                x_test,
                onehot_test,        
                x_valtest,
                onehot_valtest,
                df_train,
                df_val,
                df_test,
                df_valtest,
            ) = build_hurricane_data(DATA_PATH, settings, verbose=0)

            #----------------------------------------------------------------------------------------------------
            # get the model
            # Make, compile, and train the model
            tf.keras.backend.clear_session()            
            model = build_model.make_model(
                settings,
                x_train,
                onehot_train,
                model_compile=False,
            )   

            #----------------------------------------------------------------------------------------------------
            # load the model            
            try:
                model.load_weights(MODEL_PATH + model_name + "_weights.h5")
            except:
                print(model_name + ': model does not exist. skipping...')
                continue


            # compute the climatological errors
            obs_dev_cons_hist, OBS_DEV_BINS = model_diagnostics.compute_clim_errors(
                onehot=np.append(onehot_train[:,0],onehot_val[:,0]), 
                smooth = True,
            )                
            
            # get metrics and put into a dictionary
            pprint.pprint(model_name)

            # concatenate all input data together in a consistent order
            x_data = np.concatenate([x_train,x_val])
            x_data = np.concatenate([x_data,x_test])
            # print(np.shape(x_data), len(x_train)+len(x_val)+len(x_test))

            # concatenate all output data together in a consistent order
            onehot_data = np.concatenate([onehot_train,onehot_val])
            onehot_data = np.concatenate([onehot_data,onehot_test])
            # print(np.shape(onehot_data), len(onehot_train)+len(onehot_val)+len(onehot_test))

            # concatenate all dataframes together in a consistent order
            df_data = pd.concat([df_train,df_val])
            df_data = pd.concat([df_data,df_test])
            # print(np.shape(df_data), len(df_train)+len(df_val)+len(df_test))

            # get prediction metrics of interest
            SHASH_INCS = np.arange(-160,161,1)
            shash_cpd = np.zeros((np.shape(x_data)[0],len(SHASH_INCS)))
            shash_mu = np.zeros((np.shape(x_data)[0],))
            shash_sigma = np.zeros((np.shape(x_data)[0],))
            shash_gamma = np.zeros((np.shape(x_data)[0],))
            shash_tau = np.zeros((np.shape(x_data)[0],))
            shash_mean = np.zeros((np.shape(x_data)[0],))
            shash_med = np.zeros((np.shape(x_data)[0],))
            shash_mode = np.zeros((np.shape(x_data)[0],))
            shash_25p = np.zeros((np.shape(x_data)[0],))
            shash_75p = np.zeros((np.shape(x_data)[0],))
            shash_90p = np.zeros((np.shape(x_data)[0],))
            shash_pr_ri = np.zeros((np.shape(x_data)[0],))
            clim_pr_ri = np.zeros((np.shape(x_data)[0],))


            # loop through samples for shash calculation and get PDF for each sample
            # for j in tqdm(range(0,np.shape(shash_cpd)[0])):
            for j in range(0,np.shape(shash_cpd)[0]):
                mu_pred, sigma_pred, gamma_pred, tau_pred = prediction.params( x_data[np.newaxis,j], model )
                shash_mu[j]  = mu_pred
                shash_sigma[j]  = sigma_pred
                shash_gamma[j]  = gamma_pred
                shash_tau[j]  = tau_pred

                dist = shash_tfp.Shash(mu_pred, sigma_pred, gamma_pred, tau_pred)
                shash_cpd[j,:] = dist.prob(SHASH_INCS)
                shash_mean[j]  = dist.mean()
                shash_med[j]   = dist.median()

                shash_25p[j] = dist.quantile(.25)
                shash_75p[j] = dist.quantile(.75)
                shash_90p[j] = dist.quantile(.9) 

                i = np.argmax(shash_cpd[j,:])
                shash_mode[j]  = SHASH_INCS[i]
                
                try:
                    cons_intensity = df_test["VMXC"][j]
                    ri_threshold = df_test["VMAX0"][j] + RI_THRESH_DICT[settings["leadtime"]]
                    shash_pr_ri[j] = model_diagnostics.compute_pr_ri(SHASH_INCS+cons_intensity,shash_cpd[j,:], ri_threshold)
                    clim_pr_ri[j] = model_diagnostics.compute_pr_ri(OBS_DEV_BINS+cons_intensity,obs_dev_cons_hist, ri_threshold)
                except:
                    shash_pr_ri[j] = np.nan
                    clim_pr_ri[j] = np.nan
            

            # add predictions to the data_frame
            df_predictions = df_data.copy()
            df_predictions["shash_mu"] = shash_mu
            df_predictions["shash_sigma"] = shash_sigma
            df_predictions["shash_gamma"] = shash_gamma
            df_predictions["shash_tau"] = shash_tau
            df_predictions["shash_median"] = shash_med
            df_predictions["shash_mean"] = shash_mean
            df_predictions["shash_mode"] = shash_mode
            df_predictions["shash_25p"] = shash_25p
            df_predictions["shash_75p"] = shash_75p
            df_predictions["shash_90p"] = shash_90p
            df_predictions["shash_pr_ri"] = shash_pr_ri
            df_predictions["clim_pr_ri"] = clim_pr_ri

            df_predictions["shash_error"] = shash_med - onehot_data[:,0]
            df_predictions["cons_error"] = 0.0 - onehot_data[:,0]
            df_predictions["shash_improvement"] = df_predictions["cons_error"].abs() - df_predictions["shash_error"].abs()
                
            # save the dataframe    
            df_predictions.to_csv(metric_filename)

intensity201_AL24
saved_predictions/intensity201_AL24_2013_shash3_network_seed_222_rng_seed_222_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2013_shash3_network_seed_333_rng_seed_333_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2013_shash3_network_seed_416_rng_seed_416_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2013_shash3_network_seed_599_rng_seed_599_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2013_shash3_network_seed_739_rng_seed_739_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2014_shash3_network_seed_222_rng_seed_222_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2014_shash3_network_seed_333_rng_seed_333_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2014_shash3_network_seed_416_rng_seed_416_allPredictions.csv exists. Skipping...
saved_predictions/intensity201_AL24_2014_shash

In [None]:
error('here')

# Compute predictions for testing only

## Create one prediction file

In [None]:
df_bestval = pd.read_pickle(PREDICTION_PATH + "best_shash3_validation_seeds.pickle")

df_bestpred = pd.DataFrame()
for exp_name in EXP_NAME_LIST:
    settings = experiment_settings.get_settings(exp_name)

    # set testing data
    if settings["test_condition"] == "leave-one-out":
        TESTING_YEARS_LIST = np.arange(2013,2022)
    elif settings["test_condition"] == "years":
        TESTING_YEARS_LIST = (np.copy(settings["years_test"]))
    else:
        raise NotImplementError('no such testing condition')
        
    for testing_years in TESTING_YEARS_LIST:        
        # set testing year
        settings["years_test"] = (testing_years,)
        
        BEST_SEED = None
        try:
            BEST_SEED = df_bestval[(df_bestval["exp_name"]==exp_name) & (df_bestval["testing_years"] == testing_years)]["rng_seed"][0]
        except:
            print(BEST_SEED)
            continue
            
        for rng_seed in settings['rng_seed_list']:
            
            if rng_seed !=BEST_SEED:
                continue
            
            settings['rng_seed'] = rng_seed
            NETWORK_SEED_LIST = [settings["rng_seed"]]
            network_seed = NETWORK_SEED_LIST[0]
            tf.random.set_seed(network_seed)  # This sets the global random seed.    
            
            model_name = (
                exp_name + "_" + 
                str(testing_years) + '_' +
                settings["uncertainty_type"] + '_' + 
                f"network_seed_{network_seed}_rng_seed_{settings['rng_seed']}"
            )
            
            #----------------------------------------------------------------------------------------------------
            # check if the metric filename exists already
            metric_filename = PREDICTION_PATH + model_name + '_testingPredictions.csv'              
            if (os.path.exists(metric_filename) is False):
                continue
            pred_data = pd.read_csv(metric_filename)
            
            df_bestpred = df_bestpred.append(pred_data)
            
            
df_bestpred.to_csv(PREDICTION_PATH + "shash3_bestValTestingPredictions.csv")
print('number of rows = ' + str(len(df_bestpred)))
df_bestpred.head()
