In [6]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import pathlib
import sys
from os.path import join
path_to_file = str(pathlib.Path().resolve())
dir_path = join(path_to_file, "../../")
sys.path.append(join(dir_path, "HelperFiles"))
import helper
retro_path = join(path_to_file, "..", "Results", "Retrospective")

alphas = [0.05, 0.1, 0.2]

### Iterate through all datasets with both methods, at multiple alphas

Display with same shape as Table 1 in paper

In [9]:
datasets = ["census", "bank", "brca", "breast_cancer", "credit"]
methods = ['ss', 'kernelshap']
max_mat = np.empty((len(datasets), len(methods)*len(alphas)))
avg_mat = np.empty((len(datasets), len(methods)*len(alphas)))
for i, dataset in enumerate(datasets):
    for method in methods:
        with open(join(retro_path, method+"_"+dataset), 'rb') as f:
            retro_results = pickle.load(f)
        shap_vals = retro_results["shap_vals"]
        N_verified_all = retro_results["N_verified"]
        N_pts, N_runs, N_alphas = N_verified_all.shape
        shap_vars = retro_results["shap_vars"]

        all_ranks = helper.shap_vals_to_ranks(shap_vals, abs=True)

        avg_shap = np.mean(shap_vals, axis=1)
        avg_ranks = np.array([helper.get_ranking(avg_shap[i], abs=True) for i in range(N_pts)])

        fwers = helper.calc_all_retro_fwers(N_verified_all, all_ranks, avg_ranks)
        max_fwers = np.round(np.nanmax(fwers, axis=1), 3)
        avg_fwers = np.round(np.nanmean(fwers, axis=1), 3)
        # print("Max:\t", max_fwers)
        col_start_idx = 0 if method=="ss" else 3
        max_mat[i, col_start_idx:col_start_idx+3] = max_fwers
        avg_mat[i, col_start_idx:col_start_idx+3] = avg_fwers
print(max_mat)
print(avg_mat)

[[0.02 0.02 0.1  0.04 0.04 0.08]
 [0.02 0.04 0.12 0.02 0.04 0.08]
 [0.02 0.08 0.1  0.02 0.02 0.1 ]
 [0.02 0.08 0.1  0.04 0.06 0.1 ]
 [0.02 0.08 0.1  0.02 0.02 0.06]]
[[0.004 0.008 0.021 0.003 0.005 0.026]
 [0.005 0.011 0.027 0.002 0.005 0.026]
 [0.004 0.012 0.026 0.001 0.005 0.017]
 [0.003 0.007 0.023 0.003 0.005 0.011]
 [0.001 0.011 0.033 0.001 0.002 0.012]]


### Great: Even in worst case, FWER is essentially always controlled