In [4]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import pathlib
import sys
from os.path import join
path_to_file = str(pathlib.Path().resolve())
dir_path = join(path_to_file, "../../")
sys.path.append(join(dir_path, "HelperFiles"))
import helper
retro_path = join(path_to_file, "..", "Results", "Retrospective")

alphas = [0.05, 0.1, 0.2]

import warnings
warnings.filterwarnings('ignore')

### Iterate through all datasets with both methods, at multiple alphas

Displayed matrix has same shape as Table 1 in paper, just only the ranking methods

In [2]:
datasets = ["census", "bank", "brca", "credit", "breast_cancer"]
methods = ['ss', 'kernelshap']
max_mat = np.empty((len(datasets), len(methods)*len(alphas)))
avg_mat = np.empty((len(datasets), len(methods)*len(alphas)))
for dataset_idx, dataset in enumerate(datasets):
    for method in methods:
        with open(join(retro_path, method+"_"+dataset), 'rb') as f:
            retro_results = pickle.load(f)
        shap_vals = retro_results["shap_vals"]
        N_verified_all = retro_results["N_verified"]
        N_pts, N_runs, N_alphas = N_verified_all.shape
        shap_vars = retro_results["shap_vars"]

        all_ranks = helper.shap_vals_to_ranks(shap_vals, abs=True)

        avg_shap = np.mean(shap_vals, axis=1)
        avg_ranks = np.array([helper.get_ranking(avg_shap[i], abs=True) for i in range(N_pts)])

        fwers = helper.calc_all_retro_fwers(N_verified_all, all_ranks, avg_ranks)
        max_fwers = np.round(np.nanmax(fwers, axis=1), 3)
        avg_fwers = np.round(np.nanmean(fwers, axis=1), 3)
        # print("Max:\t", max_fwers)
        col_start_idx = 0 if method=="ss" else 3
        max_mat[dataset_idx, col_start_idx:col_start_idx+3] = max_fwers
        avg_mat[dataset_idx, col_start_idx:col_start_idx+3] = avg_fwers
print(max_mat*100)
print(avg_mat)

[[ 2.  6. 10.  4.  4.  8.]
 [ 2.  4. 12.  2.  4.  8.]
 [ 2.  8. 10.  2.  2. 10.]
 [ 2.  8. 10.  2.  2.  6.]
 [ 2.  8. 10.  4.  6. 10.]]
[[0.004 0.011 0.031 0.003 0.005 0.026]
 [0.005 0.011 0.027 0.002 0.005 0.026]
 [0.004 0.012 0.026 0.001 0.005 0.017]
 [0.001 0.011 0.033 0.001 0.002 0.012]
 [0.003 0.007 0.023 0.003 0.005 0.011]]


# Stability of top K set

In [None]:
# Get rejection results
K = 5
max_mat_set = np.empty((len(datasets), len(methods)*len(alphas)))
avg_mat_set = np.empty((len(datasets), len(methods)*len(alphas)))
for dataset_idx, dataset in enumerate(datasets):
    for method in methods:
        with open(join(retro_path, method+"_"+dataset), 'rb') as f:
            retro_results = pickle.load(f)
        shap_vals = retro_results["shap_vals"]
        shap_vars = retro_results["shap_vars"]

        all_ranks = helper.shap_vals_to_ranks(shap_vals, abs=True)
        avg_shap = np.mean(shap_vals, axis=1)
        avg_ranks = np.array([helper.get_ranking(avg_shap[i], abs=True) for i in range(N_pts)])

        # skip_thresh = 0.2
        max_fwers = []
        avg_fwers = []
        for alpha in alphas:
            fwers_all = []
            for i in range(N_pts):
                num_false_rejections = 0
                true_top_K_set = np.sort(avg_ranks[i,:K])
                for j in range(N_runs):
                    ss_vals, ss_vars = shap_vals[i,j,:], shap_vars[i,j,:]
                    result = helper.test_top_k_set(ss_vals, ss_vars, K=K, alpha=alpha, abs=True)
                    if result=="reject":
                        est_top_K_set = np.sort(all_ranks[i,j,:K])
                        if not np.array_equal(true_top_K_set, est_top_K_set):
                            num_false_rejections += 1
                fwer = num_false_rejections/N_runs
                fwers_all.append(fwer)
            max_fwer = np.round(np.nanmax(fwers_all), 3)#.item()
            max_fwers.append(max_fwer)
            avg_fwer = np.round(np.nanmean(fwers_all), 3)#.item()
            avg_fwers.append(avg_fwer)
        np.array(max_fwers)

        col_start_idx = 0 if method=="ss" else 3
        max_mat_set[dataset_idx, col_start_idx:col_start_idx+3] = max_fwers
        avg_mat_set[dataset_idx, col_start_idx:col_start_idx+3] = avg_fwers


In [5]:
print(max_mat_set*100)
print(avg_mat_set*100)

[[ 4.  4. 10.  4.  6. 12.]
 [ 0.  2.  8.  0.  0.  2.]
 [ 2.  4.  8.  2.  4.  6.]
 [ 2.  6.  8.  2.  4. 10.]
 [ 0.  4.  6.  0.  0.  2.]]
[[0.2 0.7 1.7 0.2 0.7 2.1]
 [0.  0.3 1.3 0.  0.  0.1]
 [0.2 0.5 1.6 0.1 0.2 1.3]
 [0.2 0.7 2.1 0.1 0.5 1.3]
 [0.  0.2 0.5 0.  0.  0.1]]


### Great: Even in worst case, FWER is always controlled!