In [1]:
import pickle
import numpy as np
import os

import tensorflow as tf
from tensorflow.keras.layers import AveragePooling2D, UpSampling2D

In [2]:
#Select GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# Shapley Value Masks

In [3]:
# for method in ['fastshap', 'kernelshap', 'kernelshap_plus', 'deepshap']:
for method in ['fastshap']:
    
    ################ Load ################

    ### Load Shap Values
    
    if method == 'deepshap':
        method_dir = os.path.join('../', method) 
    elif method == 'fastshap':
#         method_dir = os.path.join('../', method, '20210519_16_09_16') # eff lambda = 0
        method_dir = os.path.join('../', method, '20210807_23_47_44') # eff lambda = 1.0
    else:
        method_dir = os.path.join('../', method, 'results') 
    

    with open(os.path.join(method_dir, 'shap_values.pkl'), 'rb') as f:
        shap_values = pickle.load(f)

    ### Load Labels and Model Predictions

    labels_path = '../images/labels.npy'
    labels = np.load(labels_path, allow_pickle=True)

    preds_path = '../images/predictions.npy'
    preds = np.load(preds_path, allow_pickle=True)

    ################# Select Shapley Values for the Predicted Class ################

    shap_values_select = []
    for i, yp in enumerate(preds):
        yp = yp.argmax()

        shap_values_select.append(shap_values[yp][i])

    shap_values_select = np.array(shap_values_select)
    if method == 'deepshap':
        # Sum for super pixel selections
        shap_values_select = shap_values_select.sum(3) # Sum Accross Channels
        shap_values_select = np.expand_dims(shap_values_select,-1)
        shap_values_select = AveragePooling2D(pool_size=(16,16))(shap_values_select) * (16*16) #Sum Pooling
        shap_values_select = UpSampling2D(size=(16,16))(shap_values_select).numpy()

    ################ Extract Selection Masks: (Remove Top 1%, 5%, 10%, 25%, 50%) ################

    shap_values_flat = shap_values_select.reshape(shap_values_select.shape[0], -1)
    masks = {}
    for p in [99, 95, 90, 85, 75, 50, 25, 15, 10, 5, 1]:
        thresholds = np.percentile(shap_values_flat, p, axis=1)

        masks[str(100-p)] = np.array([sv < tr for sv, tr in zip(shap_values_select, thresholds)]).astype(int) 
        
    ################ Save ################
    with open(os.path.join(method_dir, 'masks.pkl'), 'wb') as f:
        pickle.dump(masks, f)

In [4]:
# for method in ['fastshap', 'kernelshap', 'kernelshap_plus', 'deepshap']:
for method in ['fastshap']:
    
    ################ Load ################

    ### Load Shap Values
    
    if method == 'deepshap':
        method_dir = os.path.join('../', method) 
    elif method == 'fastshap':
#         method_dir = os.path.join('../', method, '20210519_16_09_16') # eff lambda = 0
        method_dir = os.path.join('../', method, '20210807_23_47_44') # eff lambda = 1.0
    else:
        method_dir = os.path.join('../', method, 'results') 
    

    with open(os.path.join(method_dir, 'shap_values.pkl'), 'rb') as f:
        shap_values = pickle.load(f)

    ### Load Labels and Model Predictions

    labels_path = '../images/labels.npy'
    labels = np.load(labels_path, allow_pickle=True)

    preds_path = '../images/predictions.npy'
    preds = np.load(preds_path, allow_pickle=True)

    ################# Select Shapley Values for the Predicted Class ################

    shap_values_select = []
    for i, yp in enumerate(preds):
        yp = yp.argmax()

        shap_values_select.append(shap_values[yp][i])

    shap_values_select = np.array(shap_values_select)
    if method == 'deepshap':
        # Sum for super pixel selections
        shap_values_select = shap_values_select.sum(3) # Sum Accross Channels
        shap_values_select = np.expand_dims(shap_values_select,-1)
        shap_values_select = AveragePooling2D(pool_size=(16,16))(shap_values_select) * (16*16) #Sum Pooling
        shap_values_select = UpSampling2D(size=(16,16))(shap_values_select).numpy()

    ################ Extract Selection Masks: (Remove Top 1%, 5%, 10%, 25%, 50%) ################

    shap_values_flat = shap_values_select.reshape(shap_values_select.shape[0], -1)
    masks = {}
    for p in [99, 95, 90, 85, 75, 50, 25, 15, 10, 5, 1]:
        thresholds = np.percentile(shap_values_flat, p, axis=1)

        masks[str(100-p)] = np.array([sv >= tr for sv, tr in zip(shap_values_select, thresholds)]).astype(int) 
        
    ################ Save ################
    with open(os.path.join(method_dir, 'masks_in.pkl'), 'wb') as f:
        pickle.dump(masks, f)

# Gradient Explanation Masks

In [10]:
# for method in ['gradcam', 'smoothgrad', 'integratedgradients', 'cxplain']:
for method in ['cxplain']:
    ################ Load ################

    ### Load Explanations
    if method == 'cxplain':
        method_dir = os.path.join('../', method, 'results')
    else:
        method_dir = os.path.join('../', method)
        
    explanations = np.load(os.path.join(method_dir, 'explanations.npy'), allow_pickle=True).astype('float32')
    
    ### Add Small Random Noise To Ensure 0s are Selected Randomly
    explanations += np.random.randn(*explanations.shape)*1e-8
    
    ### Sum for Super Pixel Selections
    if method != 'cxplain':
        explanations = np.expand_dims(explanations,-1)
        explanations = AveragePooling2D(pool_size=(16,16))(explanations) * (16*16) #Sum Pooling
        explanations = UpSampling2D(size=(16,16))(explanations).numpy()

    ################ Extract Selection Masks: (Remove Top 1%, 5%, 10%, 25%, 50%) ################

    explanations_flat = explanations.reshape(explanations.shape[0], -1)
    masks = {}
    for p in [99, 95, 90, 85, 75, 50, 25, 15, 10, 5, 1]:
        thresholds = np.percentile(explanations_flat, p, axis=1)

        masks[str(100-p)] = np.array([e < tr for e, tr in zip(explanations, thresholds)]).astype(int) 
        
    ################ Save ################
    with open(os.path.join(method_dir, 'masks.pkl'), 'wb') as f:
        pickle.dump(masks, f)

In [11]:
# for method in ['gradcam', 'smoothgrad', 'integratedgradients', 'cxplain']:
for method in ['cxplain']:
    ################ Load ################

    ### Load Explanations
    if method == 'cxplain':
        method_dir = os.path.join('../', method, 'results')
    else:
        method_dir = os.path.join('../', method)
        
    explanations = np.load(os.path.join(method_dir, 'explanations.npy'), allow_pickle=True).astype('float32')
    
    ### Add Small Random Noise To Ensure 0s are Selected Randomly
    explanations += np.random.randn(*explanations.shape)*1e-8
    
    ### Sum for Super Pixel Selections
    if method != 'cxplain':
        explanations = np.expand_dims(explanations,-1)
        explanations = AveragePooling2D(pool_size=(16,16))(explanations) * (16*16) #Sum Pooling
        explanations = UpSampling2D(size=(16,16))(explanations).numpy()

    ################ Extract Selection Masks: (Remove Top 1%, 5%, 10%, 25%, 50%) ################

    explanations_flat = explanations.reshape(explanations.shape[0], -1)
    masks = {}
    for p in [99, 95, 90, 85, 75, 50, 25, 15, 10, 5, 1]:
        thresholds = np.percentile(explanations_flat, p, axis=1)

        masks[str(100-p)] = np.array([e >= tr for e, tr in zip(explanations, thresholds)]).astype(int) 
        
    ################ Save ################
    with open(os.path.join(method_dir, 'masks_in.pkl'), 'wb') as f:
        pickle.dump(masks, f)