In [80]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import pathlib
import sys
from os.path import join
path_to_file = str(pathlib.Path().resolve())
dir_path = join(path_to_file, "../../")
sys.path.append(join(dir_path, "HelperFiles"))
from helper import *
results_path = join(dir_path, "Experiments", "Results", "Retrospective")

In [89]:
def calc_retro_fwer(GTranks, rankings, nVerified, alphaIdx):
    nStable = np.sum(nVerified[:,alphaIdx] > 0)
    N_runs, _ = rankings.shape
    if nStable <= 0.05*N_runs: # Majority unverified
        return None
    prop_stable = 0
    # Number of runs with at least one stable rank
    for runIdx in range(N_runs):
        nVerif = nVerified[runIdx,alphaIdx]
        if nVerif > 0:
            stableRanks = rankings[runIdx,:nVerif]
            was_stable = np.array_equal(stableRanks, GTranks[:nVerif])
            prop_stable += was_stable
    prop_stable /= nStable
    fwer = 1 - prop_stable
    return round(fwer,3)

def calc_avg_fwers(all_fwers):
    # Given N_alphas x N_pts numpy array (3x10) of fwers
    N_alphas, N_pts = all_fwers.shape
    avg_fwers = []
    for i in range(N_alphas):
        fwers = []
        for j in range(N_pts):
            if all_fwers[i,j] is not None:
                fwers.append(all_fwers[i,j])
        avg_fwers.append(np.mean(fwers))
    return np.round(avg_fwers, 3)

def calc_prop_controlled(all_fwers, alphas):
    # Given N_alphas x N_pts numpy array (3x10) of fwers
    N_alphas, N_pts = all_fwers.shape
    prop_controlled = []
    for i in range(N_alphas):
        controlled = []
        for j in range(N_pts):
            if all_fwers[i,j] is not None:
                controlled.append(all_fwers[i,j] <= alphas[i])
        prop_controlled.append(np.mean(controlled))
    return np.round(prop_controlled, 3)

def calc_all_fwers(verif, ranks, avgRanks):
    fwers_all = []
    N_pts, N_runs, N_alphas = verif.shape
    for alphaIdx in range(N_alphas):
        fwers = []
        for ptIdx in range(N_pts):
            GTranks = avgRanks[ptIdx]
            fwer = calc_retro_fwer(GTranks, ranks[ptIdx], verif[ptIdx], alphaIdx)
            fwers.append(fwer)
        fwers_all.append(fwers)
    return np.array(fwers_all)
    
alphas = [0.05, 0.1, 0.2]

In [90]:
dataset = "bank"
ssTitle = "ss_" + dataset
kshapTitle = "kernelshap_" + dataset
with open(join(results_path, ssTitle+"_N_verified"), "rb") as fp:
    ssVerif = pickle.load(fp)
with open(join(results_path, kshapTitle+"_N_verified"), "rb") as fp:
    kshapVerif = pickle.load(fp)

with open(join(results_path, ssTitle+"_shap_vals"), "rb") as fp:
    ssVals = pickle.load(fp)
    ssRanks = shap_vals_to_ranks(ssVals, abs=True)
with open(join(results_path, kshapTitle+"_shap_vals"), "rb") as fp:
    kshapVals = pickle.load(fp)
    kshapRanks = shap_vals_to_ranks(kshapVals, abs=True)
N_pts, N_runs, N_alphas = ssVerif.shape

print("# Ranks Verified, SS:", np.median(ssVerif[:,:,1], axis=1))
print("KernelSHAP:", np.median(kshapVerif[:,:,1], axis=1))

# Ranks Verified, SS: [2. 2. 1. 0. 2. 2. 2. 0. 2. 0.]
KernelSHAP: [1. 0. 0. 0. 0. 0. 0. 0. 1. 0.]


### KernelSHAP is verifying fewer ranks than Shapley Sampling. Is it also more unstable, or is the variance/test just bad?
- YES: it's considerably more unstable. Fuck.
- How the fuck is this happening?????

and how different are its rankings?
- The rankings are decent.

In [91]:
print(np.sum(np.var(ssVals, axis=1)))
print(np.sum(np.var(kshapVals, axis=1)))
for i in range(3):
    print(ssRanks[i,0, :5])
    print(kshapRanks[i,0, :5])
    print("#"*10)

0.002466453691209922
0.009328117455284862
[11  8 15  6  9]
[11  8 15  6  4]
##########
[11  8 15  9  6]
[11  8  6 15  4]
##########
[ 8  6  3 15  9]
[ 8 15  6  3 14]
##########


In [92]:
avgSS, avgKshap = np.mean(ssVals, axis=1), np.mean(kshapVals, axis=1)
avgSSRanks = np.array([get_ranking(avgSS[i]) for i in range(N_pts)])
avgKshapRanks = np.array([get_ranking(avgKshap[i]) for i in range(N_pts)])
    
ssFwers_all = calc_all_fwers(ssVerif, ssRanks, avgSSRanks)
kshapFwers_all = calc_all_fwers(kshapVerif, kshapRanks, avgKshapRanks)

print(calc_avg_fwers(ssFwers_all))
print(calc_avg_fwers(kshapFwers_all))

print(calc_prop_controlled(ssFwers_all, alphas))
print(calc_prop_controlled(kshapFwers_all, alphas))


[0.002 0.011 0.028]
[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]


bank good; brca good; breast_cancer good; census good; credit good
- all good - but very sparse

Whatever, let's make a table

In [93]:
datasets = ["census", "bank", "brca", "breast_cancer", "credit"]
props_controlled_ss, props_controlled_kshap = [], []
fwers_ss, fwers_kshap = [], []
for dataset in datasets:
    ssTitle = "ss_" + dataset
    kshapTitle = "kernelshap_" + dataset
    with open(join(results_path, ssTitle+"_N_verified"), "rb") as fp:
        ssVerif = pickle.load(fp)
    with open(join(results_path, kshapTitle+"_N_verified"), "rb") as fp:
        kshapVerif = pickle.load(fp)

    with open(join(results_path, ssTitle+"_shap_vals"), "rb") as fp:
        ssVals = pickle.load(fp)
        ssRanks = shap_vals_to_ranks(ssVals, abs=True)
    with open(join(results_path, kshapTitle+"_shap_vals"), "rb") as fp:
        kshapVals = pickle.load(fp)
        kshapRanks = shap_vals_to_ranks(kshapVals, abs=True)
    N_pts, N_runs, N_alphas = ssVerif.shape

    avgSS, avgKshap = np.mean(ssVals, axis=1), np.mean(kshapVals, axis=1)
    avgSSRanks = np.array([get_ranking(avgSS[i]) for i in range(N_pts)])
    avgKshapRanks = np.array([get_ranking(avgKshap[i]) for i in range(N_pts)])
        
    ssFwers_all = calc_all_fwers(ssVerif, ssRanks, avgSSRanks)
    kshapFwers_all = calc_all_fwers(kshapVerif, kshapRanks, avgKshapRanks)

    fwers_ss.append(calc_avg_fwers(ssFwers_all))
    fwers_kshap.append(calc_avg_fwers(kshapFwers_all))

    props_controlled_ss.append(calc_prop_controlled(ssFwers_all, alphas))
    props_controlled_kshap.append(calc_prop_controlled(kshapFwers_all, alphas))

fwers = np.array([fwers_ss, fwers_kshap])
props_controlled = np.array([props_controlled_ss, props_controlled_kshap])
print(np.round(fwers*100))
print("#"*10)
print(props_controlled)


[[[ 0.  1.  2.]
  [ 0.  1.  3.]
  [ 1. 10. 16.]
  [ 2.  2.  4.]
  [ 1.  3.  6.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]]
##########
[[[1.  1.  1. ]
  [1.  1.  1. ]
  [1.  0.6 0.6]
  [0.9 1.  1. ]
  [0.9 0.9 1. ]]

 [[1.  1.  1. ]
  [1.  1.  1. ]
  [1.  1.  1. ]
  [1.  1.  1. ]
  [1.  1.  1. ]]]
