# When Explanations Lie: Why Modified BP Attribution fails

This notebook produces the cosine similaries of the relevance vectors.

In [None]:
# uncomment to install install packages
# !pip install tensorflow-gpu==1.13.1
# !pip install innvestigate seaborn tqdm deeplift

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow
import tensorflow as tf
import warnings

import innvestigate
import matplotlib.pyplot as plt

import numpy as np
import PIL 
import copy
import contextlib

import imp
import numpy as np
import os

from skimage.measure import compare_ssim 
import pickle
from collections import OrderedDict
from IPython.display import IFrame, display

import keras
import keras.backend
import keras.models


import innvestigate
import innvestigate.applications.imagenet
import innvestigate.utils as iutils
import innvestigate.utils as iutils
import innvestigate.utils.visualizations as ivis
from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP
from innvestigate.analyzer.base import AnalyzerNetworkBase, ReverseAnalyzerBase
from innvestigate.analyzer.deeptaylor import DeepTaylor

import time
import tqdm

import seaborn as sns

import itertools
import matplotlib as mpl
from when_explanations_lie import *
from monkey_patch_lrp_resnet import custom_add_bn_rule
import deeplift_resnet  
from tensorflow.python.client import device_lib

In [None]:
device_lib.list_local_devices()

In [None]:
# path to imagenet validation
imagenet_val_dir = "/mnt/ssd/data/imagenet/imagenet-raw/validation"
#imagenet_val_dir = "/home/leonsixt/tmp/imagenet/imagenet-raw/validation/"
# path to examplary image
ex_image_path = "n01534433/ILSVRC2012_val_00015410.JPEG"
# number of images to run the evaluation
#n_selected_imgs = 200
n_selected_imgs = 10

load_weights = True
model_names = ['resnet50', 'vgg16']

In [None]:
keras.backend.clear_session()
model, innv_net, color_conversion = load_model('vgg16', load_weights)
ex_image_vgg, ex_target, val_images, selected_img_idxs = load_val_images(
    innv_net, imagenet_val_dir, ex_image_path, n_selected_imgs)

keras.backend.clear_session()
model, innv_net, color_conversion = load_model('resnet50', load_weights)
ex_image, ex_target, val_images, selected_img_idxs = load_val_images(
    innv_net, imagenet_val_dir, ex_image_path, n_selected_imgs)


assert ((ex_image - ex_image_vgg) == 0).all()

nice_layer_names = get_nice_layer_names(model)

In [None]:
n_layers = {'vgg16': 22, 'resnet50': 177}

replacement_layers = {
    'vgg16':  ['fc3', 'fc1', 'conv4_3', 'conv3_3', 'conv2_2'],
    'resnet50': ['dense', 'block5_1', 'block4_2', 'block3_4', 'block3_2', 'block2_2'],
}

output_shapes = get_output_shapes(model)

In [None]:
def hmap_postprocess_wrapper(name):
    return lambda x: heatmap_postprocess(name, x)

input_range = (ex_image.min(), ex_image.max())
analysers = get_analyser_params(input_range)

attr_names = [n for (n, _, _, _, _) in analysers]
    
hmap_postprocessing = {
    n: hmap_postprocess_wrapper(post_name) for n, _, post_name, _, _ in analysers
}

In [None]:
model_names

In [None]:
bins = np.linspace(0, 0.9, 10).tolist() + [0.99, 0.999, 0.9999, 1]

In [None]:
n_layers

In [None]:
def parse_reversed(hidden):
    return [h[1] for h in hidden[1:]]


dead_neuron_mask = {}

for model_name in model_names:
    keras.backend.clear_session()
    model, innv_net, _ = load_model(model_name, load_weights=True)
    analyser = innvestigate.create_analyzer(
        "gradient", model, reverse_keep_tensors=True)
    
    analyser.analyze(np.concatenate([img for (img, _) in val_images[:20]], 0))
    
    grad_hidden = parse_reversed(analyser._reversed_tensors) 
    dead_neuron_mask[model_name] = [(0 == np.mean(g, 0, keepdims=True)).all(-1, keepdims=True) for g in grad_hidden]

In [None]:
for model_name in model_names:
    plt.title(model_name + " - active neurons")
    plt.plot([(m.sum(-1) / m.shape[-1] > 0.999999).mean() for m in dead_neuron_mask[model_name]])
    plt.show()

In [None]:
nice_layer_names

In [None]:
histogram_layers = copy.deepcopy(replacement_layers)
histogram_layers['vgg16'].extend(['conv1_1', 'input'])
histogram_layers['resnet50'].extend(['conv2_1a', 'input'])

In [None]:

histogram_layers_idx = OrderedDict()
for model_name in model_names:
    histogram_layers_idx[model_name] = []
    for layer_name in histogram_layers[model_name]:
        idx = get_layer_idx_full(model_name, nice_layer_names, layer_name)
        histogram_layers_idx[model_name].append(idx) 
histogram_layers_idx

In [None]:
len(dead_neuron_mask['vgg16'])

In [None]:
replacement_layers

In [None]:
from innvestigate.analyzer import GuidedBackprop

In [None]:

keras.backend.clear_session()
model_name = 'vgg16'
model, innv_net, color_conversion = load_model(model_name, load_weights)
ex_image, ex_target, val_images, selected_img_idxs = load_val_images(
    innv_net, imagenet_val_dir, ex_image_path, n_selected_imgs)

In [None]:
def create_replacement_class(analyser_cls):
    assert issubclass(analyser_cls, ReverseAnalyzerBase)
    class ReplaceBackward(analyser_cls):
        def __init__(self, model, *args, **kwargs):
            kwargs['reverse_keep_tensors'] = True
            super().__init__(model, *args, **kwargs)
        
        def _create_analysis(self, *args, **kwargs):
            outputs, relevances_per_layer = super()._create_analysis(*args, **kwargs)
            self._relevances_per_layer = relevances_per_layer[::-1]
            return outputs, relevances_per_layer
        
        def _get_layer_idx(self, name):
            layer = self._model.get_layer(name='dense_2')
            return self._model.layers.index(layer)
        
        def get_relevances(self, input_value, relevance_value,  
                           set_layer, selected_layers):
            """
            return relevance values
            """
            sess = keras.backend.get_session()
            inp = self._analyzer_model.inputs[0]
            set_layer_idx = self._get_layer_idx(set_layer)
            selected_layer_idxs = [
                self._get_layer_idx(n) for n in selected_layers]
            rel_tensor = self._relevances_per_layer[set_layer_idx]
            
            return sess.run(
                [self._relevances_per_layer[i] for i in selected_layer_idxs],
                feed_dict={ 
                    inp: input_value,
                    rel_tensor: relevance_value
           })
        
    return ReplaceBackward 

In [None]:
def get_replacement_analyser(model, analyser_cls, **kwargs):
    if type(analyser_cls) == str:
        analyser_cls = innvestigate.analyzer.analyzers[analyser_cls]
    replacement_cls = create_replacement_class(analyser_cls)
    
    return replacement_cls(model, **kwargs)

In [None]:
class DeepLiftRelevanceReplacer:
    def __init__(self, deeplift_wrapper):
        self.deeplift_wrapper = deeplift_wrapper
        self.model = self.deeplift_wrapper._deeplift_model
        self.layers = list(self.model._name_to_layer.values())
        self.layer_names = list(self.model._name_to_layer.keys())
        self.input_layer = self.layers[0]
       
    def _get_layer_idx(self, name):
        deeplift_name = name + '_0'
        layer_names = list(self.model._name_to_layer.keys())
        return layer_names.index(deeplift_name)
    
    def get_relevances(self, input_value,  relevance_value,
                       set_layer, selected_layers, reference=None):
        def run_single(single_image, single_relevance_value, single_reference):
            sess = keras.backend.get_session()
            return sess.run(
                [self.layers[idx]._target_contrib_vars 
                 for idx in selected_layer_idxs], 
                feed_dict={
                    self.input_layer.get_activation_vars(): single_image,
                    self.input_layer.get_reference_vars(): single_reference,
                    changed_layer._pos_mxts: single_relevance_value,
                    changed_layer._neg_mxts: single_relevance_value, 
                })
            
        set_layer_idx = self._get_layer_idx(set_layer)
        changed_layer = self.layers[set_layer_idx]
        selected_layer_idxs = [self._get_layer_idx(name) for name in selected_layers]
        print(selected_layer_idxs)
        if reference is None:
            reference = np.zeros_like(input_value)
            
        self.layers[-1].set_active()
        
        aggregated_contribs = [[] for _ in selected_layer_idxs]
        for i in range(len(input_value)):
            contribs = run_single(
                input_value[i:i+1],
                relevance_value[i:i+1],
                reference[i:i+1],
            )
            for i, cont in enumerate(contribs):
                print(cont.shape)
                aggregated_contribs[i].append(cont)
                
        self.layers[-1].set_inactive()
        
        return [np.concatenate(contrib) for contrib in aggregated_contribs]

In [None]:

gb_repl = create_replacement_class(GuidedBackprop)(model)
gb_repl.create_analyzer_model()

layer = gb_repl._model.get_layer(name='dense_2')
print(layer, layer.name, layer.weights[0].shape)
gb_repl._model.layers.index(layer)

n = 2
relvs = gb_repl.get_relevances(
    input_value=np.repeat(ex_image, n, axis=0), 
    relevance_value=np.random.normal(size=(n, 1000)),
    set_layer="dense_2", 
    selected_layers=[model.layers[1].name],
)

In [None]:

deeplift_csc = DeepLiftRelevanceReplacer(dp_lift)

n = 2
relvs = deeplift_csc.get_relevances(
    input_value=np.repeat(ex_image, n, axis=0), 
    set_layer="dense_2", 
    relevance_value=np.random.normal(size=(n, 1000)),
    selected_layers=[model.layers[1].name, model.layers[-3].name],
)
print(relvs[0].shape, model.layers[1].name, len(relvs))

for i in range(len(relvs[0])):
    plt.imshow(relvs[0][i].sum(-1))
    plt.colorbar()
    plt.show()

In [None]:
from innvestigate.analyzer.relevance_based.relevance_rule import AlphaBetaRule

def alpha_beta_wrapper(alpha, beta):
    class AlphaBetaRuleWrapper(AlphaBetaRule):
        def __init__(self, layer, state, bias=True, copy_weights=False):
            super(AlphaBetaRuleWrapper, self).__init__(layer, state, alpha=alpha, beta=beta, 
                             bias=bias, copy_weights=copy_weights)
            
        def __repr__(self):
            return "AlphaBetaRuleWrapper(alpha={}, beta={})".format(self._alpha, self._beta)
        
    return AlphaBetaRuleWrapper

def get_custom_rule(innv_name, kwargs):
    if innv_name == 'lrp.alpha_beta':
        return alpha_beta_wrapper(kwargs['alpha'], kwargs['beta'])
    elif innv_name == 'lrp.sequential_preset_a':
        return alpha_beta_wrapper(1, 0)
    elif innv_name == 'lrp.sequential_preset_b':
        return alpha_beta_wrapper(2, 1)
        
for label, innv_name, _, excludes, kwargs in analysers:
    print(innv_name, get_custom_rule(innv_name, kwargs))

In [None]:
%pdb off
x = np.linspace(0, 2)
plt.plot(x, np.sin(x), label="LRP ${\\alpha1\\beta0}$")
plt.legend()

In [None]:

input_range = (ex_image.min(), ex_image.max())
analysers = get_analyser_params(input_range)

attr_names = [n for (n, _, _, _, _) in analysers]


mpl_styles = OrderedDict([
    ('GuidedBP',                   {'marker': '$G$', 'color': colors[0]}),
    ('Deconv',                     {'marker': '$V$', 'color': colors[1]}),
    ('LRP-z',                      {'marker': 'D',   'color': colors[2]}),
    ('DTD',                        {'marker': '$T$', 'color': colors[3]}),
    ('PatternAttr.',               {'marker': '$P$', 'color': colors[4]}),
    ('LRP $\\alpha1\\beta0$',      {'marker': '<',   'color': colors[0]}),
    ('LRP $\\alpha2\\beta1$',      {'marker': '>',   'color': colors[1]}),
    ('LRP $\\alpha5\\beta4$',      {'marker': '^',   'color': colors[2]}),
    ('LRP CMP $\\alpha1\\beta0$',  {'marker': 's',   'color': colors[3]}),
    ('LRP CMP $\\alpha2\\beta1$',  {'marker': 'P',   'color': colors[4]}),
    ('DeepLift R.Can.',            {'marker': '$D$',   'color': colors[0]}),
    ('DeepLift Resc.',             {'marker': '$D$',   'color': colors[1]}),
    ('SmoothGrad',                 {'marker': 'o',   'color': colors[2]}),
    ('Gradient',                   {'marker': 'v',   'color': 'black'}),
])

for i, (name, style) in enumerate(mpl_styles.items()):
    assert name in attr_names
    plt.plot(np.arange(10), [20-i] * 10, 
             #markersize=5,
             label=name, #+ " m=" + style['marker'], 
             **style)
    
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
replacement_layers= {'vgg16': ['fc3'], 'resnet50': ['dense']}

In [None]:
model_names, replacement_layers

In [None]:
%pdb off
# replacement_layer_indices = [22]
n_sampled_v = 5

cos_sim_histograms = OrderedDict()
cos_mean = OrderedDict()
selected_percentiles = [0, 1, 5, 10, 20, 50, 100]
cos_sim_percentiles = OrderedDict()

for label, innv_name, _, excludes, kwargs in tqdm.tqdm_notebook(analysers):
    if 'exclude_cos_sim' in excludes:
        continue
    for model_name in model_names[:1]:
        if 'exclude_' + model_name in excludes:
            continue
        keras.backend.clear_session()
        model, innv_net, _ = load_model(model_name, load_weights=load_weights)
        model_output_shapes = get_output_shapes(model)
            
        selected_layers = [model.layers[idx].name 
                           for idx in nice_layer_names[model_name].keys() ]
        if innv_name == "pattern.attribution":
            kwargs['patterns'] = innv_net['patterns']

        if innv_name == 'deeplift.wrapper':
            repl_analyser = DeepLiftRelevanceReplacer(analyser)
        else:
            custom_rule = get_custom_rule(innv_name, kwargs)
            with custom_add_bn_rule(custom_rule):
                repl_analyser = get_replacement_analyser(
                    model, innv_name, **kwargs)
                repl_analyser.create_analyzer_model()
        for repl_layer_nice in replacement_layers[model_name]:
            replacement_layer_idx = get_layer_idx_full(
                model_name, nice_layer_names, repl_layer_nice)
            repl_shape = model_output_shapes[replacement_layer_idx]
            repl_layer_raw = model.layers[replacement_layer_idx].name
            cos_per_img = OrderedDict()
            print(repl_shape)
            print("selected", list(itertools.takewhile(lambda l: l != repl_layer_raw, selected_layers)))
            
            lower_layers = itertools.takewhile(lambda n: n != repl_layer_raw, selected_layers)
            for img_idx, (img, _) in tqdm.tqdm_notebook(zip(selected_img_idxs, val_images), 
                desc="[{}.{}] {}".format(model_name, repl_layer_nice, label)):
                channels = repl_shape[-1]
                if label == "$\\alpha=100, \\beta=99$-LRP":
                    # a=100,b=99 sufferes numerical instabilities with std = 1
                    std = 1 / np.sqrt(channels)
                else:
                    std = 1
                
                img_tiled = np.repeat(img, n_sampled_v, axis=0)
                random_relevance = std*np.random.normal(size=(n_sampled_v, ) + repl_shape[1:]) 
                
                relevances = repl_analyser.get_relevances(
                    img_tiled, random_relevance, repl_layer_raw, 
                    lower_layers)
                    
                cos_sim = cosine_similarities_from_relevances(relevances)
                
                for layer_raw, cs_for_layer in zip(lower_layers, cos_sim):
                    # we filter 0 cosine similarites as they only appear practically when the gradients are zero
                    cos_per_img[model_name, layer_raw, img_idx] = np.abs(cs_for_layer)

            median_for_label = []
            percentile_for_label = OrderedDict([(p, []) for p in selected_percentiles])
            for layer_raw in lower_layers:
                cos_per_layer = np.concatenate([cos_per_img[model_name, layer_raw, img_idx]  for img_idx in selected_img_idxs])
                cos_per_layer = cos_per_layer.flatten()

                idx = (label, model_name, repl_layer_raw,  layer_idx)
                cos_mean[idx] = np.mean(cos_per_layer)

                perc_values = np.percentile(cos_per_layer,  selected_percentiles)
                for p, val in zip(selected_percentiles, perc_values):
                    percentile_for_label[p].append(val)

                if layer_idx in histogram_layers_idx[model_name]:

                    if len(cos_per_layer) > 50000:
                        ridx = np.random.choice(len(cos_per_layer), 50000, replace=False)
                        cos_per_layer_sel = cos_per_layer[ridx]
                    else:
                        cos_per_layer_sel = cos_per_layer

                    cos_sim_histograms[idx] = np.histogram(cos_per_layer_sel, bins)


            for p, values in percentile_for_label.items():
                cos_sim_percentiles[label, model_name, replacement_layer_idx, p] = np.array(values)

In [None]:
repl_analyser.get_relevances()

In [None]:
with keras.backend.get_session().as_default():
    print(1 - tf.losses.cosine_distance([0, 0], [0, 0], 0).eval())

In [None]:
outs[0].shape, outs[-4].shape

In [None]:
save_results = False
if save_results:
    os.makedirs('cache', exist_ok=True)
    with open('cache/cos_sim_with_hist_random_weights.pickle', 'wb') as f:
        pickle.dump((cos_sim_percentiles, cos_sim_histograms ), f)

In [None]:
load_results = False
if load_results:
    os.makedirs('cache', exist_ok=True)
    with open('cache/cos_sim_with_hist.pickle', 'rb') as f:
        cos_sim_percentiles, cos_sim_histograms = pickle.load(f)

In [None]:
def cosine_similarity(U, V):
    v_norm = V / np.linalg.norm(V, axis=0, keepdims=True)
    u_norm = U / np.linalg.norm(U, axis=0, keepdims=True)
    return v_norm.T @ u_norm

def get_sample_cos_sim_per_layer(output_shapes):
    values = []
    for layer_idx, shp in output_shapes.items():
        ch = shp[-1]
        n_samples = 1000
        u = np.random.normal(size=(ch, n_samples))
        v = np.random.normal(size=(ch, n_samples))
        cos = cosine_similarity(v, u)
        mask = np.tri(cos.shape[0])
        values.append(np.median(np.abs(cos[mask == 1])))
    return np.array(values)
        

In [None]:
cos_sim_baseline = {}

for model_name in model_names:
    keras.backend.clear_session()
    model, _, _ = load_model(model_name)
    output_shapes = get_output_shapes(model)
    print(len(output_shapes))
    cos_sim_baseline[model_name] = get_sample_cos_sim_per_layer(output_shapes)

In [None]:
model_names

In [None]:
cos_sim_baseline['vgg16'].shape, cos_sim_baseline['resnet50'].shape

In [None]:
legend = OrderedDict()

os.makedirs('figures/cosine_similarity', exist_ok=True)
for model_name in model_names[::-1]:
    for replacement_layer in replacement_layers[model_name]:
        repl_idx = get_layer_idx_full(model_name, nice_layer_names, replacement_layer)
        start_layer = n_layers[model_name] - repl_idx 
        
        layer_names = [name for idx, name in nice_layer_names[model_name].items()
                       if idx <= repl_idx][::-1]
        layer_idx = np.array([idx for idx, name in nice_layer_names[model_name].items()
                       if idx < repl_idx][::-1])
        
        print(layer_idx, repl_idx, start_layer)
        #layer_idx = layer_idxs
        
        plt.figure(figsize=(max(3, len(layer_idx) / 4), 3.5))
        
        for i, (label, _, _, _, _) in enumerate(analysers):
            idx = (label, model_name, repl_idx, 50)
            if idx not in cos_sim_percentiles:
                warnings.warn("not found: " + str(idx))
                continue
            print(len(cos_sim_percentiles[idx]))
            cos_sim_per_label = cos_sim_percentiles[idx][layer_idx]
            
            #cos_sim_per_label = []
            #for lidx in layer_idx:
            #    cos_sim_per_label.append(cos_mean[label, model_name, repl_idx, lidx])
            # try:
            #     cos_sim_per_label = cos_sim_percentiles[idx][layer_idx]
            # except IndexError:
            #     cos_sim_per_label = (cos_sim_baseline[model_name][layer_idx[:1]].tolist() +
            #                          cos_sim_percentiles[idx][layer_idx[1:]].tolist())
                
            plt.plot(0.5 + np.arange(len(cos_sim_per_label)), cos_sim_per_label, label=label, **mpl_styles[label])
            
            if label not in legend:
                legend[label] = mpl_styles[label]
            
        # Random Cos Similarity
        # Cos Similarity Base.
        label='Cos Similarity BL'
        style = {'color': (0.25, 0.25, 0.25)}
        plt.plot(0.5 + np.arange(len(layer_idx)), cos_sim_baseline[model_name][layer_idx], 
                 # label='Cos. Sim. Baseline', 
                 label=label,
                 **style)
        if label not in legend:
            legend[label] = style
        
        #plt.legend(bbox_to_anchor=(1, 1))
        plt.ylabel('cosine similarity')
        plt.xticks(np.arange(len(layer_names)), layer_names, rotation=90)
        plt.ylim(-0.05, 1.05)
        plt.grid('on', alpha=0.35) #, axis="y")
        plt.savefig("./figures/cosine_similarity/{}_layer_{}.pdf".format(model_name, repl_idx),  
                    bbox_inches='tight', pad_inches=0)
        plt.show()
        plt.close()

In [None]:
list(cos_mean.keys())[0]

In [None]:
plt.figure(figsize=(2.5, 3))
for label, style in legend.items():
    plt.plot([], label=label, alpha=1, **style)

plt.axis('off')
plt.legend(loc='center')
plt.savefig("./figures/cos_sim_legend.pdf",
            bbox_inches='tight', pad_inches=0.02)

In [None]:
display(IFrame("./figures/cos_sim_legend.pdf", 800, 600))

In [None]:
for attr_name, model_name, layer_idx, percentile in cos_sim_percentiles.keys():
    if attr_name == 'GuidedBP' and model_name == 'resnet50':
        print(attr_name, model_name, layer_idx, percentile)
    

In [None]:
cos_sim_histograms.keys()

In [None]:


attr_counts = []
labels = []
for (attr_name, model_name, repl_layer, layer_idx), (counts, bins) in cos_sim_histograms.items():
    if layer_idx != 7:
        #print(layer_idx)
        continue
    lower_09 = counts[bins[:-1] < 0.9].sum()
    print(attr_name, counts.sum())
    counts_collapsed = np.concatenate([lower_09[None], counts[bins[:-1] >= 0.9]])
    bins_int = np.arange(len(counts_collapsed) + 1)
    attr_counts.append(counts_collapsed)
    labels.append(attr_name)
plt.hist([bins_int[:-1]] * len(attr_counts), bins_int, 
         weights=attr_counts, stacked=True, label=labels)
plt.xticks(bins_int, ["{:.4g}".format(b) for b in [0] + bins[bins >= 0.9].tolist()], rotation=0)
plt.legend()

In [None]:
bins

In [None]:

hist[1]

In [None]:
counts