## Generate the Explainability Output

In this notebook, we compute the explainability output used in the paper. The methods include:
1. Accumulated Local Effects (ALE)  
2. SHAP (Shapley Additive Explanations)
3. SAGE (Shapley Additive Global Explanations)
4. Grouped SAGE


#### Import python packages (internal and third party)

In [1]:
import sys, os 
from os.path import dirname
path = dirname(dirname(os.getcwd()))
sys.path.insert(0, path)
sys.path.insert(0, '/home/monte.flora/python_packages/scikit-explain')

import skexplain 
from skexplain.common.importance_utils import to_skexplain_importance
from src.io.io import load_data_and_model
from src.common.util import subsampler, normalize_importance, compute_sage

import pickle
import shap
import itertools
import numpy as np

#### Setting the user constants (paths, parameters, etc)

In [None]:
# Constants. 
N_BOOTSTRAP = 10
N_BINS = 30
EVALUATION_FN = 'norm_aupdc'

BASE_PATH = '/work/mflora/explainability_work/'

RESULTS_PATH = os.path.join(BASE_PATH,    'results')
DATA_BASE_PATH = os.path.join(BASE_PATH,  'datasets')
MODEL_BASE_PATH = os.path.join(BASE_PATH, 'models')

In [1]:
# Compute ALE 
def compute_ale(explainer, dataset, option, est_name, **kwargs):
    features = kwargs['X'].columns
    n_jobs = len(X.columns)
    ale = explainer.ale(features='all', n_bootstrap=N_BOOTSTRAP, n_bins=N_BINS, n_jobs=n_jobs)
    # Save the raw ALE and ALE-variance rankings results for paper 1 
    explainer.save(os.path.join(RESULTS_PATH, f'ale_{dataset}_{option}.nc'), ale)


# ## 6. SHAP Values
# Compute SHAP (Approx. Owen Values)
# The default explainer is the PermutationExplainer. The PermutationExplainer uses a 
# simple forward- and backward-permutation scheme to compute the SHAP values. 
# The SHAP documentation claims this method is exact for 2nd order interactions.
# For the maskers, we are using correlations and as such we are computing 
# approximate Owen values. 

# Check if each SHAP example can be ran in parallel. 
def compute_shap(explainer, dataset, option, est_name, **kwargs):
    X = kwargs['X']
    features = kwargs['X'].columns
    results = explainer.local_attributions('shap', 
                                       shap_kws={'masker' : 
                                      shap.maskers.Partition(X, max_samples=50, 
                                                             clustering="correlation"), 
                                     'algorithm' : 'permutation'})


    shap_rank = to_skexplain_importance(results[f'shap_values__{est_name}'].values, 
                                     estimator_name=est_name, 
                                     feature_names=features, 
                                     method ='shap_sum', 
                                     normalize=False    
                                       )

    # Sum the SHAP values for each feature and then save results. 
    explainer.save(os.path.join(RESULTS_PATH, f'shap_{dataset}_{option}.nc'), results)
    explainer.save(os.path.join(RESULTS_PATH, f'shap_rank_{dataset}_{option}.nc'), shap_rank)


# ## 7. SAGE Values
# Compute SAGE
def compute_sage_(explainer, dataset, option, est_name, **kwargs):
    estimator = explainer.estimators[est_name]
    
    X = explainer.X
    y = explainer.y
    X_orig = kwargs['X']
    
    features = kwargs['X'].columns
    sage_values = compute_sage(estimator, X.values, y, X_orig, n_jobs = X.shape[1])
    sage_rank = to_skexplain_importance(sage_values,
                                     estimator_name=est_name, 
                                     feature_names=features, 
                                     method = 'sage', 
                                     normalize=False  
                                       )

    # Sum the SAGE values for each feature and then save results. 
    explainer.save(os.path.join(RESULTS_PATH, f'sage_{dataset}_{option}.nc'), sage_values)
    #explainer.save(os.path.join(RESULTS_PATH, f'new_sage_rank_{dataset}_{option}.nc'), sage_rank)


# Compute Grouped SAGE
def compute_group_sage(explainer, dataset, option, est_name, **kwargs):
    
    X = explainer.X
    feature_groups = kwargs['groups']
    # Group indices
    groups = []
    cols = list(X.columns)
    features = []
    for key, group in feature_groups.items():
        ind_list = []
        for feature in group:
            ind_list.append(cols.index(feature))
        groups.append(ind_list)
        features.append(key)  
    
    estimator = explainer.estimators[est_name]
    
    y = explainer.y
    X_orig = kwargs['X']
    
    sage_values = compute_sage(estimator, X.values, y, X_orig, groups=groups)
    sage_rank = to_skexplain_importance(sage_values,
                                     estimator_name=est_name, 
                                     feature_names=features, 
                                     method = 'sage', 
                                     normalize=False  
                                       )

    # Sum the SAGE values for each feature and then save results. 
    explainer.save(os.path.join(RESULTS_PATH, f'sage_{dataset}_{option}.nc'), sage_values)
    #explainer.save(os.path.join(RESULTS_PATH, f'new_grouped_sage_rank_{dataset}_{option}.nc'), sage_rank)
        
DATASETS = ['new_severe_wind', 'lightning', 'road_surface']
global_methods = [compute_sage_, compute_group_sage,]
local_methods = [compute_shap]

# TODO: perhaps make the global size as a fraction of the total dataset? 
# Implement a try and expect? 

GLOBAL_SIZE = 50000
LOCAL_SIZE = 2500

for dataset in DATASETS:
    
    # Load model and data.
    model, X, y, groups = load_data_and_model(dataset, DATA_BASE_PATH, MODEL_BASE_PATH, 
                                     return_groups=True)
    
    # Subsample the dataset. 
    X_sub, y_sub = subsampler(X,y, GLOBAL_SIZE)
     
    X_local, y_local = subsampler(X_sub, y_sub, LOCAL_SIZE)
    local_explainer = skexplain.ExplainToolkit(model, X_local, y_local)
    for method in local_methods:
        print(method)
        method(local_explainer, dataset, option, est_name, X=X, model=model)

    # Initialize the explainer. 
    explainer = skexplain.ExplainToolkit(model, X_sub, y_sub)    
    for method in global_methods:
        print(method)
        method(explainer, dataset, option, est_name, X=X, model=model, groups=groups)

(61426, 91)
<function compute_group_sage at 0x1459168235e0>


  0%|          | 0/1 [00:00<?, ?it/s]