In [None]:
!git clone https://github.com/AI4LIFE-GROUP/OpenXAI.git
!cd OpenXAI/ && pip install -e .

In [None]:
# !pip install git+https://github.com/AI4LIFE-GROUP/OpenXAI

In [None]:
# Utils
import torch
import numpy as np
import pickle

# ML models
from openxai.LoadModel import LoadModel

# Data loaders
from openxai.dataloader import return_loaders

# Explanation models
from openxai.Explainer import Explainer

# Evaluation methods
from openxai.evaluator import Evaluator

# Perturbation methods required for the computation of the relative stability metrics
from openxai.explainers.catalog.perturbation_methods import NormalPerturbation
from openxai.explainers.catalog.perturbation_methods import NewDiscrete_NormalPerturbation

In [None]:
# COMPAS. The dataset has criminal records and demographics features for 18,876 defendants who
# got released on bail at the U.S state courts during 1990-2009. The task is to classify defendants into
# bail (i.e., unlikely to commit a violent crime if released) vs. no bail (i.e., likely to commit a violent
# crime) [36].

In [None]:
# Choose the model and the data set you wish to generate explanations for
data_loader_batch_size = 32
# data_name = 'german' # must be one of ['heloc', 'adult', 'german', 'compas']
data_name = 'compas' # must be one of ['heloc', 'adult', 'german', 'compas']
# data_name = 'heloc' # must be one of ['heloc', 'adult', 'german', 'compas']
# data_name = 'adult' # must be one of ['heloc', 'adult', 'german', 'compas']
# data_name = 'synthetic' # must be one of ['heloc', 'adult', 'german', 'compas']
model_name = 'lr'    # must be one of ['lr', 'ann']

### (1) Data Loaders

In [None]:
# Get training and test loaders
loader_train, loader_test = return_loaders(data_name=data_name,
                                           download=True,
                                           batch_size=data_loader_batch_size)
data_iter = iter(loader_test)
inputs, labels = data_iter.next()
labels = labels.type(torch.int64)

In [None]:
# get full training data set
data_all = torch.FloatTensor(loader_train.dataset.data)

In [None]:
data_all.shape

torch.Size([4937, 7])

### (2) Load a pretrained ML model

In [None]:
# Load pretrained ml model
model = LoadModel(data_name=data_name,
                  ml_model=model_name,
                  pretrained=True)

### (3) Choose an explanation method

### (4) Choose an evaluation metric

In [None]:
def generate_mask(explanation, top_k):
    mask_indices = torch.topk(explanation, top_k).indices
    mask = torch.zeros(explanation.shape) > 10
    for i in mask_indices:
        mask[i] = True
    return mask

In [None]:
# Perturbation class parameters
perturbation_mean = 0.0
perturbation_std = 0.10
perturbation_flip_percentage = 0.03
if data_name == 'compas':
    feature_types = ['c', 'd', 'c', 'c', 'd', 'd', 'd']
# Adult feature types
elif data_name == 'adult':
    feature_types = ['c'] * 6 + ['d'] * 7

# Gaussian feature types
elif data_name == 'synthetic':
    feature_types = ['c'] * 20
# Heloc feature types
elif data_name == 'heloc':
    feature_types = ['c'] * 23
elif data_name == 'german':
    feature_types = pickle.load(open('./data/German_Credit_Data/german-feature-metadata.p', 'rb'))

In [None]:
# Perturbation methods
if data_name == 'german':
    # use special perturbation class
    perturbation = NewDiscrete_NormalPerturbation("tabular",
                                                  mean=perturbation_mean,
                                                  std_dev=perturbation_std,
                                                  flip_percentage=perturbation_flip_percentage)

else:
    perturbation = NormalPerturbation("tabular",
                                      mean=perturbation_mean,
                                      std_dev=perturbation_std,
                                      flip_percentage=perturbation_flip_percentage)

In [None]:
def get_evaluator(model, inputs, labels, explanation, explainer, index=0):
  input_dict = dict()
  
  # inputs and models
  input_dict['x'] = inputs[index].reshape(-1)
  input_dict['input_data'] = inputs
  input_dict['explainer'] = explainer
  input_dict['explanation_x'] = explanation[index,:].flatten()
  input_dict['model'] = model
  
  # perturbation method used for the stability metric
  input_dict['perturbation'] = perturbation
  input_dict['perturb_method'] = perturbation
  input_dict['perturb_max_distance'] = 0.4
  input_dict['feature_metadata'] = feature_types
  input_dict['p_norm'] = 2
  input_dict['eval_metric'] = None
  
  # true label, predicted label, and masks
  input_dict['top_k'] = 3
  input_dict['y'] = labels[index].detach().item()
  input_dict['y_pred'] = torch.max(model(inputs[index].unsqueeze(0).float()), 1).indices.detach().item()
  input_dict['mask'] = generate_mask(input_dict['explanation_x'].reshape(-1), input_dict['top_k'])
  
  # required for the representation stability measure
  input_dict['L_map'] = model
  
  evaluator = Evaluator(input_dict,
                        inputs=inputs,
                        labels=labels, 
                        model=model, 
                        explainer=explainer)
  return evaluator

In [None]:
from collections import defaultdict

def evaluate_on_inputs_batch(explanation, explainer, model, inputs, labels):
  metrics_dict = defaultdict(list)
  
  for index in range(inputs.shape[0]):
      evaluator = get_evaluator(model, inputs, labels, explanation, explainer, index=index)
  
      metrics_list = []
      # evaluate rank correlation
      metrics_list.append(('RC:', evaluator.evaluate(metric='RC')))
  
      # evaluate feature agreement
      metrics_list.append(('FA:', evaluator.evaluate(metric='FA')))
  
      # evaluate rank agreement
      metrics_list.append(('RA:', evaluator.evaluate(metric='RA')))
  
      # evaluate sign agreement
      metrics_list.append(('SA:', evaluator.evaluate(metric='SA')))
  
      # evaluate signed rankcorrelation
      metrics_list.append(('SRA:', evaluator.evaluate(metric='SRA')))
  
      # evaluate prediction gap on umportant features
      metrics_list.append(('PGU:', evaluator.evaluate(metric='PGU')))
      
      # evaluate prediction gap on important features
      metrics_list.append(('PGI:', evaluator.evaluate(metric='PGI')))
  
      for metric_name, metric_val in metrics_list:
        # ugly fix for stupid metrics...
        if isinstance(metric_val, tuple):
          metric_val = metric_val[1]
        metrics_dict[metric_name].append(metric_val)
  
  # print(metrics_dict)
  
  return metrics_dict

In [None]:
import scipy.special
import numpy as np
import itertools
from tqdm import tqdm


def powerset(iterable):
    s = list(iterable)
    return itertools.chain.from_iterable(
        itertools.combinations(s, r) for r in range(len(s) + 1)
    )


def shapley_kernel(M, s):
    if s == 0 or s == M:
        return 10000  # approximation of inf with some large weight
    return (M - 1) / (scipy.special.binom(M, s) * s * (M - s))


def get_weights(X_data, s, x, sigma=0.4):
    sample_cov_s = np.linalg.pinv(np.cov(X_data[:, s], rowvar=False))
    D_s = np.einsum(
        "ji,ji->j", np.dot(x[s] - X_data[:, s], sample_cov_s), x[s] - X_data[:, s]
    )
    D_s = np.sqrt(D_s / len(s))
    w_s = np.exp(-np.square(D_s) / (2 * (sigma**2)))
    # w_s shape: number of examples in ds
    return w_s


def get_weighted_mean(w_s, s, f, x, reference, top_k_weights):
    # sort - pretty slow
    w_s_sort_idx = np.argsort(w_s)[::-1]
    wp_sum, w_sum = 0.0, 0.0
    ws_to_iterate = w_s_sort_idx
    if top_k_weights:
      ws_to_iterate = ws_to_iterate[:top_k_weights]
    for idx in ws_to_iterate:
        x_eval = reference[idx].copy()
        x_eval[s] = x[s]
        wp_sum += w_s[idx] * f(x_eval.reshape(1, -1))
        w_sum += w_s[idx]
    return wp_sum / w_sum


def kernel_shapr(f, x, reference, M, sigma, top_k_weights=None):
    X = np.zeros((2**M, M + 1))
    X[:, -1] = 1
    weights = np.zeros(2**M)
    y = np.zeros(2**M)

    for i, s in enumerate(powerset(range(M))):
        s = list(s)
        w_s = np.ones(len(reference))
        if len(s) > 1:
            w_s = get_weights(reference, s, x, sigma)
        X[i, s] = 1
        weights[i] = shapley_kernel(M, len(s))
        y[i] = get_weighted_mean(w_s, s, f, x, reference, top_k_weights=top_k_weights)

    tmp = np.linalg.inv(np.dot(np.dot(X.T, np.diag(weights)), X))
    return np.dot(tmp, np.dot(np.dot(X.T, np.diag(weights)), y))

def brute_kernel_shap(f, x, reference, M):
    # f - model
    # x - input vector
    # reference - dummy values for subsets that we don't choose
    # M - n. features
    X = np.zeros((2**M, M + 1))
    X[:, -1] = 1
    weights = np.zeros(2**M)
    V = np.zeros((2**M, M))
    for i in range(2**M):
        V[i, :] = reference

    for i, s in enumerate(powerset(range(M))):
        s = list(s)
        V[i, s] = x[s]
        X[i, s] = 1
        weights[i] = shapley_kernel(M, len(s))
    y = f(V)
    tmp = np.linalg.inv(np.dot(np.dot(X.T, np.diag(weights)), X))
    return np.dot(tmp, np.dot(np.dot(X.T, np.diag(weights)), y))


class ShapR:
    def __init__(self, f, X, **kwargs):
        self.f = f
        self.X = X
        self.M = X.shape[1]
        self.sigma = 0.4 if "sigma" not in kwargs else kwargs["sigma"]
        self.top_k_weights = None if "top_k_weights" not in kwargs else kwargs["top_k_weights"]

    def explain(self, x):
        phi = np.zeros((x.shape[0], self.M + 1))
        for idx, x in tqdm(enumerate(x)):
            phi[idx] = kernel_shapr(self.f, x, self.X, self.M, self.sigma, self.top_k_weights)
        self.expected_values = phi[:, -1]
        shap_values = phi[:, :-1]
        return shap_values


In [None]:
from openxai.explainers.api import Explainer

In [None]:
import xgboost
import shap
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F

class SHAPRExplainer(Explainer):
    
    '''
    param: model: model object
    param: data: pandas data frame or numpy array
    param: link: str, 'identity' or 'logit'
    param: feature_perturbation: str, 'tree_path_dependent' or 'interventional'
    '''
    
    def __init__(self, model, data: torch.FloatTensor, domain: str = 'non_deep', link: str = 'identity',
                 function_class: str = 'non_tree', feature_perturbation: str = 'interventional',
                 top_k_weights=None):
        super().__init__(model)
        self.data = data.numpy()
        self.domain = domain

        f = lambda x: self.model(torch.FloatTensor(x))[:, 1]
        self.shapr = ShapR(f, data.numpy(), top_k_weights=top_k_weights)


    def get_explanation(self, data_x: torch.FloatTensor, label = None) -> torch.FloatTensor:
        '''
        Returns SHAP values as the explaination of the decision made for the input data (data_x)
        :param data_x: data samples to explain decision for
        :return: SHAP values [dim (shap_vals) == dim (data_x)]
        '''
        
        data_x = data_x.numpy()
        
        shap_vals = self.shapr.explain(data_x)
        return torch.FloatTensor(shap_vals)


In [None]:
import xgboost
import shap
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from shap import KernelExplainer
from shap import DeepExplainer

class SHAPExplainer(Explainer):
    '''
    param: model: model object
    param: data: pandas data frame or numpy array
    param: link: str, 'identity' or 'logit'
    param: feature_perturbation: str, 'tree_path_dependent' or 'interventional'
    '''
    
    def __init__(self, model, data: torch.FloatTensor, domain: str = 'non_deep', link: str = 'identity',
                 function_class: str = 'non_tree', feature_perturbation: str = 'interventional'):
        super().__init__(model)
        self.data = data.numpy()
        self.domain = domain
        f = lambda x: model(torch.FloatTensor(x))[:, 1].detach().numpy()
        self.explainer = shap.KernelExplainer(f, self.data, link=link)

    def get_explanation(self, data_x: torch.FloatTensor, label = None) -> torch.FloatTensor:
        '''
        Returns SHAP values as the explaination of the decision made for the input data (data_x)
        :param data_x: data samples to explain decision for
        :return: SHAP values [dim (shap_vals) == dim (data_x)]
        '''
        
        data_x = data_x.numpy()
        
        shap_vals = self.explainer.shap_values(data_x)
        return torch.FloatTensor(shap_vals)


In [None]:
from openxai.explainers import RandomBaseline
from functools import partial

N_TRAINING_EXAMPLES = 128

explainer_training_data = data_all[:N_TRAINING_EXAMPLES]
validation_inputs = inputs
validation_labels = labels

explainers = {
    'random': RandomBaseline,
    'shap': SHAPExplainer,
    'shapr': SHAPRExplainer,
    'shapr_top5': partial(SHAPRExplainer, top_k_weights=5),
    'shapr_top10': partial(SHAPRExplainer, top_k_weights=10),
}

In [None]:
data_iter = iter(loader_test)

N_EVAL_BATCHES = 5

metrics_per_explainer = defaultdict(list)
for _, (validation_inputs, validation_labels) in zip(range(N_EVAL_BATCHES), data_iter):
  
  for explainer_name, explainer in explainers.items():
    print(f'Evaluating explainer: {explainer_name}')
  
    if explainer_name != 'random':
      explainer = explainer(model=model, data=explainer_training_data)
    else:
      explainer = explainer(model=model)
    explanation = explainer.get_explanation(validation_inputs.float(), label=validation_labels)
  
    metrics = evaluate_on_inputs_batch(explanation, explainer, model, validation_inputs, validation_labels)
    # print(metrics)
    metrics_per_explainer[explainer_name].append(metrics)

  
for explainer_name in explainers.keys():
  cumulative_metrics = defaultdict(list)
  for metrics_batch in metrics_per_explainer[explainer_name]:
    for metric_name, metric_vals in metrics_batch.items():
      cumulative_metrics[metric_name] += metric_vals
  res = {k: np.mean(v) for k,v in cumulative_metrics.items()}
  print(f'{explainer_name} Results: ', res)

Evaluating explainer: random




Evaluating explainer: shap


  0%|          | 0/32 [00:00<?, ?it/s]

Evaluating explainer: shapr


32it [00:44,  1.38s/it]


Evaluating explainer: shapr_top5


32it [00:04,  7.03it/s]


Evaluating explainer: shapr_top10


32it [00:06,  4.69it/s]


Evaluating explainer: random




Evaluating explainer: shap


  0%|          | 0/32 [00:00<?, ?it/s]

Evaluating explainer: shapr


32it [00:43,  1.35s/it]


Evaluating explainer: shapr_top5


32it [00:04,  7.00it/s]


Evaluating explainer: shapr_top10


32it [00:06,  5.23it/s]


Evaluating explainer: random




Evaluating explainer: shap


  0%|          | 0/32 [00:00<?, ?it/s]

Evaluating explainer: shapr


32it [00:43,  1.36s/it]


Evaluating explainer: shapr_top5


32it [00:04,  6.98it/s]


Evaluating explainer: shapr_top10


32it [00:06,  5.13it/s]


Evaluating explainer: random




Evaluating explainer: shap


  0%|          | 0/32 [00:00<?, ?it/s]

Evaluating explainer: shapr


32it [00:46,  1.44s/it]


Evaluating explainer: shapr_top5


32it [00:04,  6.94it/s]


Evaluating explainer: shapr_top10


32it [00:06,  5.20it/s]


Evaluating explainer: random




Evaluating explainer: shap


  0%|          | 0/32 [00:00<?, ?it/s]

Evaluating explainer: shapr


32it [00:43,  1.35s/it]


Evaluating explainer: shapr_top5


32it [00:04,  6.97it/s]


Evaluating explainer: shapr_top10


32it [00:06,  5.13it/s]


random Results:  {'RC:': 0.05223214285714285, 'FA:': 0.45833333333333337, 'RA:': 0.15208333333333332, 'SA:': 0.2333333333333333, 'SRA:': 0.07499999999999998, 'PGU:': 0.04975014, 'PGI:': 0.05772712}
shap Results:  {'RC:': 0.6200892857142857, 'FA:': 0.5541666666666666, 'RA:': 0.3333333333333333, 'SA:': 0.3, 'SRA:': 0.18541666666666665, 'PGU:': 0.054703783, 'PGI:': 0.0582888}
shapr Results:  {'RC:': 0.5234375, 'FA:': 0.50625, 'RA:': 0.2666666666666667, 'SA:': 0.3, 'SRA:': 0.18124999999999997, 'PGU:': 0.055929255, 'PGI:': 0.05762993}
shapr_top5 Results:  {'RC:': 0.5254464285714285, 'FA:': 0.51875, 'RA:': 0.25416666666666665, 'SA:': 0.3083333333333333, 'SRA:': 0.1833333333333333, 'PGU:': 0.054112118, 'PGI:': 0.058240663}
shapr_top10 Results:  {'RC:': 0.5462053571428571, 'FA:': 0.5333333333333333, 'RA:': 0.3, 'SA:': 0.2895833333333333, 'SRA:': 0.18541666666666665, 'PGU:': 0.054317106, 'PGI:': 0.057669174}
