# Wasserstein Globalness Example

This notebook loads a number of sample explanations from CIFAR10 and calculates Wasserstein Globalness on the explanations.

In [8]:
import os
import sys

import numpy as np
import pandas as pd
import pickle

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

sys.path.append('../')

def load_dict(path):
    with open(path, 'rb') as handle:
        dictionary = pickle.load(handle)
    return dictionary


def merge_dict(dict_list):
    '''
    merges a list of dictionaries into a single dictionary
    '''
    output = {}
    for d in dict_list:
        for k, v in d.items():  
            output.setdefault(k, []).append(v)
    return output

## Load saved CIFAR10 samples

In [4]:

samples = load_dict('./samples/cifar10_samples_0.pkl')

## Load saved sample explanations

Sample explanations are generated using SmoothGrad applied to a Resnet18 model. We calculate explanations for $\sigma \in \{100, 10, 1, 0.1, 0.01, 0\}$, where $\sigma$ represents the smoothing parameter in SmoothGrad.

In [5]:
sigma_list = [100, 10, 1, 0.1, 0.01, 0]
sigma_list = list(map(float, sigma_list))
exp = {}
for sigma in sigma_list:
    exp[sigma] = load_dict('./samples/cifar10_smoothgrad_%s_0.pkl' % str(sigma))


tmp = [exp, samples]
output_dict = merge_dict(tmp)
print(output_dict.keys())

dict_keys([100.0, 10.0, 1.0, 0.1, 0.01, 0.0, 'test_images', 'test_labels'])


Calculate Wasserstein Globalness values for the sample explanations

In [13]:
from utils.locality_utilities import wasserstein_globalness
n_unif = 10000 # number of uniform samples used to approximate U_\mathcal{E}
n_projections = 500 # number of projections for Sliced Wasserstein Distance solver
l2_bound = 47.402 # radius k of U_\mathcal{E}

for std in [100, 10, 1, 0.1, 0.01, 0]:
    exps = output_dict[std][0]['explanations']
    glob = wasserstein_globalness(exps, n_unif=n_unif,l2_bound = l2_bound, n_projections = n_projections)
    print(r'Sigma %s: %s' % (str(std), str(glob)))

$\sigma$ 100: 0.9804279710103189
$\sigma$ 10: 0.9803831638893189
$\sigma$ 1: 0.9569426544963487
$\sigma$ 0.1: 0.7681273049491278
$\sigma$ 0.01: 0.7091159546892561
$\sigma$ 0: 0.7003042706571309
