# When Explanations Lie: Why Modified BP Attribution fails

This notebook produces the cosine similaries of the relevance vectors.

In [None]:
# uncomment to install install packages
# !pip install tensorflow-gpu==1.13.1
# !pip install innvestigate seaborn tqdm deeplift

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow
import tensorflow as tf
import warnings

import innvestigate
import matplotlib.pyplot as plt

import numpy as np
import PIL 
import copy
import contextlib
import datetime

import imp
import numpy as np
import os

from skimage.measure import compare_ssim 
import pickle
from collections import OrderedDict
from IPython.display import IFrame, display

import keras
import keras.backend
import keras.models


import innvestigate
import innvestigate.applications.imagenet
import innvestigate.utils as iutils
import innvestigate.utils as iutils
import innvestigate.utils.visualizations as ivis
from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP
from innvestigate.analyzer.base import AnalyzerNetworkBase, ReverseAnalyzerBase
from innvestigate.analyzer.deeptaylor import DeepTaylor
from innvestigate.analyzer import DeepLIFTWrapper

import time
import tqdm

import seaborn as sns

import itertools
import matplotlib as mpl
from when_explanations_lie import *
from monkey_patch_lrp_resnet import custom_add_bn_rule, get_custom_rule
import deeplift_resnet  
from deeplift_resnet import DeepLiftRelevanceReplacer
from tensorflow.python.client import device_lib

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
keras.backend.set_session(sess)
device_lib.list_local_devices()

In [None]:
# path to imagenet validation

host = ! hostname
host = host[0]

imagenet_dir = {
 "morty": "/mnt/ssd/data/imagenet/imagenet-raw",
 "snuffles": "/srv/public/leonsixt/data/imagenet",
}[host]
#imagenet_val_dir = "/home/leonsixt/tmp/imagenet/imagenet-raw/validation/"
# path to examplary image
ex_image_path = "n01534433/ILSVRC2012_val_00015410.JPEG"
# number of images to run the evaluation
#n_selected_imgs = 10
n_selected_imgs = 200

load_weights = True
model_names = ['cifar10', 'vgg16', 'resnet50']

In [None]:
#! ls -l cache/csc_200_2020-01-24T10:02:02.563084
        
cache_dir = 'cache/csc_200_2020-01-26T22:19:11.601426'

In [None]:


if 'cache_dir' not in globals():
    cache_dir = 'cache/csc_200_' + datetime.datetime.utcnow().isoformat()
    os.makedirs(cache_dir)
    
print("results will be saved in: ", cache_dir)

In [None]:
def load_model_and_meta(model_name, load_weights=True, clear_session=True):
    if clear_session:
        keras.backend.clear_session()
    if model_name in ['vgg16', 'resnet50']:
        model, innv_net, color_conversion = load_model(model_name, load_weights) 
        meta = ImageNetMeta(model, model_name, innv_net, n_selected_imgs, 
                            imagenet_dir, ex_image_path)
    elif model_name == 'cifar10':
        model, _, _ = load_model('cifar10', load_weights)
        meta = CIFAR10Meta(model, n_selected_imgs)
    else:
        raise ValueError()
    return model, meta
    

In [None]:
model, meta = load_model_and_meta('vgg16', load_weights=True)

In [None]:
def hmap_postprocess_wrapper(name):
    return lambda x: heatmap_postprocess(name, x)

input_range = (meta.ex_image.min(), meta.ex_image.max())
analysers = get_analyser_params(input_range)

attr_names = [n for (n, _, _, _, _) in analysers]
    
hmap_postprocessing = {
    n: hmap_postprocess_wrapper(post_name) for n, _, post_name, _, _ in analysers
}

In [None]:
bins = np.linspace(0, 0.9, 10).tolist() + [0.99, 0.999, 0.9999, 1]

In [None]:
def parse_reversed(hidden):
    return [h[1] for h in hidden[1:]]


dead_neuron_mask = {}

for model_name in model_names:
    model, meta = load_model_and_meta(model_name)
    analyser = innvestigate.create_analyzer(
        "gradient", model, reverse_keep_tensors=True)
    
    analyser.analyze(np.concatenate([img for (img, _) in meta.images[:20]], 0))
    
    grad_hidden = parse_reversed(analyser._reversed_tensors) 
    dead_neuron_mask[model_name] = [(0 == np.mean(g, 0, keepdims=True)).all(-1, keepdims=True) for g in grad_hidden]

In [None]:
for model_name in model_names:
    plt.title(model_name + " - inactive neurons")
    plt.plot([(m.sum(-1) / m.shape[-1] > 0.999999).mean() 
              for m in dead_neuron_mask[model_name]])
    plt.show()

In [None]:
def cosine_similarities_from_relevances(relevance_per_layers):
    cos_sims = []
    for rel_per_layer in relevance_per_layers:
        rel_per_layer = [conv_as_matrix(r[None]) for r in rel_per_layer]

        cos_sims.append(pairwise_cosine_similarity(rel_per_layer).flatten())
    return cos_sims

In [None]:
def conv_as_matrix(x):
    if len(x.shape) == 2:
        return x
    if len(x.shape) == 3:
        x = x[None]
    b, h, w, c = x.shape
    return np.reshape(x, (b*h*w, c))



In [None]:
def create_replacement_class(analyser_cls):
    assert issubclass(analyser_cls, ReverseAnalyzerBase)
    class ReplaceBackward(analyser_cls):
        def __init__(self, model, *args, **kwargs):
            kwargs['reverse_keep_tensors'] = True
            #kwargs['reverse_verbose'] = True
            super().__init__(model, *args, **kwargs)
        
        def _create_analysis(self, *args, **kwargs):
            outputs, relevances_per_layer = super()._create_analysis(*args, **kwargs)
            # self._relevances_per_layer = relevances_per_layer[::-1]
            # 
            # self._not_associated = set()
            # 
            #     
            # self._id_to_reversed_tensor = OrderedDict()
            # for tens, info in self._reversed_tensors_raw.items():
            #     self._id_to_reversed_tensor[info['id']] = info['final_tensor']
       # 
    # 
            # self._layer_to_id_reversed_output = OrderedDict(
            #     [reversed(i) for i in self._id_reversed_output_to_layer.items()])
        # 
            # self._layer_to_id_reversed_input = OrderedDict()
            # for _id, layers in self._id_reversed_input_to_layers.items():
            #     for layer in layers:
            #         self._layer_to_id_reversed_input[layer] = _id
            return outputs, relevances_per_layer
        
        def _get_layer_idx(self, name):
            layer = self._model.get_layer(name=name)
            return self._model.layers.index(layer)
        
        def get_relevances(self, input_value, relevance_value,  
                           set_layer, output_layers):
            """
            return relevance values
            """
            sess = keras.backend.get_session()
            inp = self._analyzer_model.inputs[0]
            
            def parse_input_output(desc):
                if type(desc) == tuple:
                    layer_name, input_or_output = desc
                else:
                    layer_name = desc
                    input_or_output = 'output'

                if type(input_or_output) == str:
                    input_or_output = (input_or_output, 0)
                    return layer_name, input_or_output
            
            def get_rel_tensor(layer_name, input_or_output):
                layer = self._model.get_layer(name=layer_name)
                if input_or_output[0] == 'input':
                    if type(layer.input) != list:
                        forward_tens = layer.input
                    else:
                        forward_tens = layer.input[input_or_output[1]]
                else:
                    if type(layer.output) != list:
                        forward_tens = layer.output
                    else:
                        forward_tens = layer.output[input_or_output[1]]
                    
                return self._reversed_tensors_raw[forward_tens]['final_tensor']
            
            set_layer_name, set_input_or_output = parse_input_output(set_layer)
            
            output_layers = [parse_input_output(n) for n in output_layers]
            rel_tensor = get_rel_tensor(set_layer_name, set_input_or_output)
            
            output_rel_tensors = [get_rel_tensor(*o) for o in output_layers]
        
            output_relevances = sess.run(
                output_rel_tensors,
                feed_dict={ 
                    inp: input_value,
                    rel_tensor: relevance_value
                })
            return output_relevances 
        
    return ReplaceBackward 


def get_replacement_analyser(model, analyser_cls, **kwargs):
    if type(analyser_cls) == str:
        analyser_cls = innvestigate.analyzer.analyzers[analyser_cls]
    replacement_cls = create_replacement_class(analyser_cls)
    
    return replacement_cls(model, **kwargs)

In [None]:
debug = False
if debug:
    for model_name in model_names:
        print(model_name)
        model, meta = load_model_and_meta(model_name)
        from innvestigate.analyzer import GuidedBackprop
        gb_repl = create_replacement_class(GuidedBackprop)(model)
        gb_repl.create_analyzer_model()

        for layer_name in meta.csc_replacement_layers:
            layer = gb_repl._model.get_layer(name=layer_name)

            relv_shape = layer.output.shape.as_list()

            layer_idx = gb_repl._model.layers.index(layer)

            n = 2
            relvs = gb_repl.get_relevances(
                input_value=np.repeat(meta.ex_image, n, axis=0), 
                relevance_value=np.random.normal(size=[n, ] + relv_shape[1:]),
                set_layer=layer_name, 
                output_layers=[(model.layers[0].name, 'output')],
            )

In [None]:
class DeepLiftRelevanceReplacer:
    def __init__(self, deeplift_wrapper):
        self.deeplift_wrapper = deeplift_wrapper
        if not hasattr(self.deeplift_wrapper, "_deep_lift_func"): 
            self.deeplift_wrapper._create_deep_lift_func()
        self.model = self.deeplift_wrapper._deeplift_model
        self.layers = list(self.model._name_to_layer.values())
        self.layer_names = list(self.model._name_to_layer.keys())
        self.input_layer = self.layers[0]
       
    def _get_layer_idx(self, name):
        deeplift_name = name + '_0'
        layer_names = list(self.model._name_to_layer.keys())
        return layer_names.index(deeplift_name)
    
    def get_relevances(self, input_value,  relevance_value,
                       set_layer, output_layers, reference=None):
        
        def parse_input_output(desc):
            if type(desc) == tuple:
                layer_name, input_or_output = desc
            else:
                layer_name = desc
                input_or_output = 'output'

            if type(input_or_output) == str:
                input_or_output = (input_or_output, 0)
                return layer_name, input_or_output
            
        def run_single(single_image, single_relevance_value, single_reference):
            sess = keras.backend.get_session()
            return sess.run(
                [self.layers[idx]._target_contrib_vars 
                 for idx in selected_layer_idxs], 
                feed_dict={
                    self.input_layer.get_activation_vars(): single_image,
                    self.input_layer.get_reference_vars(): single_reference,
                    changed_layer._pos_mxts: single_relevance_value,
                    changed_layer._neg_mxts: single_relevance_value, 
                })
            
        set_layer_idx = self._get_layer_idx(set_layer)
        changed_layer = self.layers[set_layer_idx]
        selected_layer_idxs = [self._get_layer_idx(name) for name in output_layers]
        
        if reference is None:
            reference = np.zeros_like(input_value)
            
        self.layers[-1].set_active()
        
        aggregated_contribs = [[] for _ in selected_layer_idxs]
        for i in range(len(input_value)):
            contribs = run_single(
                input_value[i:i+1],
                relevance_value[i:i+1],
                reference[i:i+1],
            )
            for i, cont in enumerate(contribs):
                aggregated_contribs[i].append(cont)
                
        self.layers[-1].set_inactive()
        
        return [np.concatenate(contrib) for contrib in aggregated_contribs]

In [None]:
from deeplift_resnet import monkey_patch_deeplift_neg_pos_mxts

debug = False
if debug:
    for model_name in model_names[1:2]:
        print(model_name)
        model, meta = load_model_and_meta(model_name)
        with monkey_patch_deeplift_neg_pos_mxts(cross_mxts=False):
            dp_lift = DeepLIFTWrapper(model)
            deeplift_csc = DeepLiftRelevanceReplacer(dp_lift)
            n = 2
            for layer_name in meta.csc_replacement_layers[:1]:

                layer = model.get_layer(name=layer_name)

                relv_shape = layer.output.shape.as_list()

                relvs = deeplift_csc.get_relevances(
                    input_value=np.repeat(meta.ex_image, n, axis=0), 
                    set_layer=layer_name, 
                    relevance_value=np.random.normal(size=[n, ] + relv_shape[1:]),
                    output_layers=[meta.names.to_raw(n) for n in meta.names.nice_names()],
                )
                print([np.median(c) for c in cosine_similarities_from_relevances(relvs)])
                # for i in range(len(relvs[0])):
                #     plt.imshow(relvs[0][i].sum(-1))
                #     plt.colorbar()
                #     plt.show()
                
                plt.imshow(relvs[0][0].sum(-1))
                plt.colorbar()
                plt.show()
                plt.imshow(relvs[0][1].sum(-1))
                plt.colorbar()

In [None]:
for label, innv_name, _, excludes, kwargs in analysers:
    print(label, innv_name, get_custom_rule(innv_name, kwargs))

In [None]:
from when_explanations_lie import mpl_styles

input_range = (meta.ex_image.min(), meta.ex_image.max())
analysers = get_analyser_params(input_range)

attr_names = [n for (n, _, _, _, _) in analysers]
print(attr_names)

for i, (name, style) in enumerate(mpl_styles.items()):
    assert name in attr_names
    plt.plot(np.arange(10), [20-i] * 10, 
             #markersize=5,
             label=name, #+ " m=" + style['marker'], 
             **style)
    
plt.legend(bbox_to_anchor=(1, 1))

In [None]:
#replacement_layers= {'vgg16': ['fc3'], 'resnet50': ['dense']}

In [None]:
list(enumerate(attr_names))

In [None]:
#model_names = ['vgg16', 'resnet50', 'cifar10']

In [None]:
from deeplift_resnet import monkey_patch_deeplift_neg_pos_mxts

In [None]:
def get_replacement_analyser(model, analyser_cls, **kwargs):
    if type(analyser_cls) == str:
        analyser_cls = innvestigate.analyzer.analyzers[analyser_cls]
    replacement_cls = create_replacement_class(analyser_cls)
    
    return replacement_cls(model, **kwargs)

@contextlib.contextmanager
def ctx_replacement_analyzer(model, meta, innv_name, kwargs):
    if innv_name.startswith("pattern"):
        kwargs['patterns'] = meta.patterns

    if innv_name == 'deep_lift.wrapper':
        kwargs = copy.copy(kwargs)
        cross_mxts = kwargs.pop('cross_mxts', True)
        with monkey_patch_deeplift_neg_pos_mxts(cross_mxts):
            analyser = DeepLIFTWrapper(model, **kwargs)
            repl_analyser = DeepLiftRelevanceReplacer(analyser)
            yield repl_analyser
    else:
        custom_rule = get_custom_rule(innv_name, kwargs)
        with custom_add_bn_rule(custom_rule):
            repl_analyser = get_replacement_analyser(
                model, innv_name, **kwargs)
            repl_analyser.create_analyzer_model()
            yield repl_analyser

In [None]:
attr_names

In [None]:
%pdb off
# replacement_layer_indices = [22]
n_sampled_v = 5

cos_sim_histograms = OrderedDict()
cos_mean = OrderedDict()
selected_percentiles = [0, 1, 5, 10, 20, 50, 80, 95, 99, 100]
cos_sim_percentiles = OrderedDict()
override_results = False

selected_attr_names = attr_names
selected_attr_names = [
    'PatternNet',
    'DeepLIFT Abla.',
]

for model_name in model_names:
    model, meta = load_model_and_meta(model_name, load_weights, clear_session=True)
    input_range = (meta.ex_image.min(), meta.ex_image.max())
    analysers = get_analyser_params(input_range)
        
    for attr_name, innv_name, _, excludes, kwargs in tqdm.tqdm_notebook(analysers[::-1]):
        if attr_name not in selected_attr_names:
            continue
            
        if 'exclude_cos_sim' in excludes:
            continue
        if 'exclude_' + model_name in excludes:
            continue
        
        fname = os.path.join(cache_dir, "csc_{}_{}.pickle".format(model_name, attr_name))
        if os.path.exists(fname) and not override_results:
            warnings.warn("Results already exists at: \n"
                          "{}\nUse override_results=True to replace".format(fname))
            continue
        model, meta = load_model_and_meta(model_name, load_weights, clear_session=True)
        
        selected_layers = [meta.names.to_raw(nice_name) 
                           for nice_name in meta.names.nice_names()]
        
        with ctx_replacement_analyzer(model, meta, innv_name, kwargs) as repl_analyser:
            for repl_layer_raw in meta.csc_replacement_layers:
                repl_shape = model.get_layer(name=repl_layer_raw).output_shape

                cos_per_img = OrderedDict()

                lower_layers = list(itertools.takewhile(lambda n: n != repl_layer_raw, selected_layers))
                relevance_layers = lower_layers + [repl_layer_raw]
                for img_idx, (img, _) in tqdm.tqdm_notebook(
                    zip(meta.image_indices, meta.images),  
                    desc="[{}.{}] {}".format(model_name, meta.names.to_nice(repl_layer_raw), attr_name), 
                    total=len(meta.images)): 
                    channels = repl_shape[-1]

                    img_tiled = np.repeat(img, n_sampled_v, axis=0)
                    random_relevance = np.random.normal(size=(n_sampled_v, ) + repl_shape[1:]) 

                    relevances = repl_analyser.get_relevances(
                        img_tiled, random_relevance, 
                        set_layer=repl_layer_raw, 
                        output_layers=lower_layers)
            
                    cos_sim = cosine_similarities_from_relevances(
                        relevances + [random_relevance])
                    for layer_raw, cs_for_layer in zip(relevance_layers, cos_sim):
                        cos_per_img[model_name, layer_raw, img_idx] = np.abs(cs_for_layer)

                median_for_label = []
                percentile_for_label = OrderedDict([(p, []) for p in selected_percentiles])
                
                mean_idx = (attr_name, model_name, repl_layer_raw)
                cos_mean[mean_idx] = []
                for layer_raw in relevance_layers:
                    cos_per_layer = np.concatenate([cos_per_img[model_name, layer_raw, img_idx]  
                                                    for img_idx in meta.image_indices])
                    cos_per_layer = cos_per_layer.flatten()

                    # we filter nans as they appear when the gradients are zero
                    cos_mean[mean_idx].append(np.nanmean(cos_per_layer))

                    cos_per_layer = cos_per_layer[~np.isnan(cos_per_layer)]

                    if len(cos_per_layer) == 0:
                        raise ValueError()
                        #import pdb
                        #pdb.set_trace()
                        cos_per_layer = np.array([np.nan])


                    perc_values = np.percentile(cos_per_layer,  selected_percentiles)
                    for p, val in zip(selected_percentiles, perc_values):
                        percentile_for_label[p].append(val)

                    if layer_raw in meta.csc_histogram_layers:
                        if len(cos_per_layer) > 50000:
                            ridx = np.random.choice(len(cos_per_layer), 50000, replace=False)
                            cos_per_layer_sel = cos_per_layer[ridx]
                        else:
                            cos_per_layer_sel = cos_per_layer

                        idx = (attr_name, model_name, repl_layer_raw, layer_raw)
                        cos_sim_histograms[idx] = np.histogram(cos_per_layer_sel, bins)


                cos_mean[mean_idx] = np.array(cos_mean[mean_idx])
                for p, values in percentile_for_label.items():
                    cos_sim_percentiles[attr_name, model_name, repl_layer_raw, p] = np.array(values)
                
        assert not os.path.exists(fname)
        with open(fname, 'wb') as f:
            def filter_dict(dictionary):
                return OrderedDict([
                    (k, v) for k, v in dictionary.items()
                    if k[0] == attr_name and k[1] == model_name])
            
            m = filter_dict(cos_mean)
            perc = filter_dict(cos_sim_percentiles)
            hist = filter_dict(cos_sim_histograms)
            print('saving at {}, m{}, cs{}, h{}'.format(fname, len(m), len(perc), len(hist)))
            pickle.dump((m, perc, hist), f)

In [None]:
save_results = False
if save_results:
    with open(os.path.join(cache_dir, 'all.pickle'), 'wb') as f:
        pickle.dump((cos_mean, cos_sim_percentiles, cos_sim_histograms), f)

In [None]:
! ls -l  'cache/csc_200_2020-01-26T22:19:11.601426'

In [None]:
cache_dir

In [None]:
load_results = True
# cache_dir = 'cache/csc_200_2020-01-26T22:19:11.601426'

if load_results:
    if input("Do you really want to replace the current results?") != "y":
        raise Exception()
    cos_sim_histograms = OrderedDict()
    cos_mean = OrderedDict()
    cos_sim_percentiles = OrderedDict()

    for filename in tqdm.tqdm_notebook(os.listdir(cache_dir)):
        with open(os.path.join(cache_dir, filename), 'rb') as f:
            mean, prec, hist  = pickle.load(f)
            cos_mean.update(mean)
            cos_sim_percentiles.update(prec)
            cos_sim_histograms.update(hist)

In [None]:
def cosine_similarity(U, V):
    v_norm = V / np.linalg.norm(V, axis=0, keepdims=True)
    u_norm = U / np.linalg.norm(U, axis=0, keepdims=True)
    return v_norm.T @ u_norm

def get_sample_cos_sim_per_layer(output_shapes):
    values = []
    for layer_idx, shp in output_shapes.items():
        ch = shp[-1]
        n_samples = 1000
        u = np.random.normal(size=(ch, n_samples))
        v = np.random.normal(size=(ch, n_samples))
        cos = cosine_similarity(v, u)
        mask = np.tri(cos.shape[0])
        values.append(np.median(np.abs(cos[mask == 1])))
    return np.array(values)
        

In [None]:
cos_sim_baseline = {}

model_metas = {}
for model_name in model_names:
    model, meta = load_model_and_meta(model_name)
    model_metas[model_name] = meta
    output_shapes = get_output_shapes(model)
    print(len(output_shapes))
    cos_sim_baseline[model_name] = get_sample_cos_sim_per_layer(output_shapes)

In [None]:
model_names

In [None]:
cos_sim_baseline['vgg16'].shape, cos_sim_baseline['resnet50'].shape

In [None]:
list(cos_sim_percentiles.keys())[:2]

In [None]:
model_names_in_cos_sim = set()
attr_names_in_cos_sim = set()
for (label, model_name, replacement_layer, percentile), values in cos_sim_percentiles.items():
    if model_name == 'vgg16':
        attr_names_in_cos_sim.add(label)
    model_names_in_cos_sim.add(model_name)

In [None]:
model_names_in_cos_sim, attr_names_in_cos_sim

In [None]:
model_names = ['vgg16', 'cifar10', 'resnet50']


In [None]:
model_metas['cifar10'].csc_replacement_layers

model_metas['cifar10'].csc_replacement_layers


csc_shown_layers = {}
csc_shown_layers['cifar10'] = ['input', 'conv1', 'conv2', 'pool2', 
                               'conv3', 'conv4', 'pool4', 'fc5', 'fc6']

cifar_nice = model_metas['cifar10'].names.nice_names()
csc_shown_layers['cifar10'] = [(cifar_nice.index(name), name) for name in csc_shown_layers['cifar10']]
# csc_shown_layers['cifar10'] = 
csc_shown_layers['resnet50'] = list(enumerate(model_metas['resnet50'].names.nice_names()))
csc_shown_layers['vgg16'] = list(enumerate(model_metas['vgg16'].names.nice_names()))

In [None]:
csc_shown_layers['cifar10']

In [None]:
def draw_order(attr_name):
    if attr_name.startswith("LRP CMP"):
        return 1
    elif attr_name.startswith("Pattern"):
        return 2
    else:
        return 0

analysers = get_analyser_params([0, 1])

In [None]:
meta.model_name

In [None]:
global_save = True

def plot_convergence(meta, replacement_layer, 
                     metrics, include_cos_sim_baseline=False, log=False, save=True, save_marker=None):
    def handle_log(data):
        if log:
            return 1 - data
        else:
            return data
    
    legend = OrderedDict()
    model_name = meta.model_name
    repl_idx = meta.names.raw_to_idx(replacement_layer)
    start_layer = meta.n_layers - repl_idx 


    selected_layers = [name for _, name in csc_shown_layers[model_name]
                       if meta.names.nice_to_idx(name) <= repl_idx][::-1]
    selected_idx = [idx for idx, name in csc_shown_layers[model_name]
                       if meta.names.nice_to_idx(name) <= repl_idx][::-1]
    layer_idx = np.array([meta.names.nice_to_idx(name) 
                          for name in selected_layers])

    plt.figure(figsize=(max(3, len(layer_idx) / 3.8), 2.7)) #3.5))


    for i, (label, cos_sim_per_label) in enumerate(metrics):
        cos_sim_per_label = cos_sim_per_label[selected_idx] #[::-1]

        style = copy.copy(mpl_styles[label])
        plt.plot(-0.5 + np.arange(len(cos_sim_per_label)), handle_log(cos_sim_per_label), 
                 label=label, zorder=draw_order(label), **style)

        if label not in legend:
            legend[label] = mpl_styles[label]

    # Random Cos Similarity
    # Cos Similarity Base.
    if include_cos_sim_baseline:
        label='Cos Similarity BL'
        style = {'color': (0.25, 0.25, 0.25)}
        plt.plot(-0.5 + np.arange(len(layer_idx)), cos_sim_baseline[model_name][layer_idx], 
                 # label='Cos. Sim. Baseline', 
                 label=label,
                 **style)
        if label not in legend:
            legend[label] = style

    #plt.legend(bbox_to_anchor=(1, 1))
    plt.ylabel('cosine similarity')
    plt.xticks(np.arange(len(selected_layers)), selected_layers, rotation=90)
    #plt.ylim(-0.05, 1.05)
    plt.grid('on', alpha=0.35) #, axis="y")
    if log:
        plt.yscale('log')
        print(plt.ylim())
        ymin, ymax = plt.ylim()
        ymin = max(ymin, 1e-8)
        plt.ylim(ymin, 1.05)
        locs, labels = plt.yticks()
        locs = [l for l in locs if ymin < l < 1]
        plt.yticks(locs, labels=["1 - {:.0e}".format(l).replace("0", "") for l in locs])
        #            labels=["0", "1 - 1e-2", "1 - 1e-4", "1 - 1e-6", "1 - 1e-8"])
        plt.gca().invert_yaxis()
    else:
        plt.ylim(-0.05, 1.05)
    log_str = "log" if log else "linear"
    
    if save and global_save:
        marker = model_name
        if save_marker is not None:
            marker += "_" + save_marker
        outdir = "./figures/cosine_similarity/{}/".format(model_name)
        os.makedirs(outdir, exist_ok=True)
        fname = os.path.join(outdir, "{}_layer_{}_{}.pdf".format(
            marker, repl_idx, log_str))
        plt.savefig(fname, bbox_inches='tight', pad_inches=0.01)
        plt.show()
        plt.close()
        return legend, fname
    
    plt.show()
    plt.close()
    return legend, None

In [None]:
meta.names.nice_names()

In [None]:
attr_names

In [None]:
from when_explanations_lie import mpl_styles

def filter_metric(metric, model_name, attr_names, replacement_layer):
    metrics = []
    for attr_name in attr_names:
        try:
            idx = (attr_name, model_name, replacement_layer)
            metrics.append((attr_name, metric[idx]))
        except KeyError:
            warnings.warn("not found: " + str(idx))
            continue
    return metrics

meta = model_metas['cifar10']
plot_convergence(meta, 'fc6', filter_metric(cos_mean, 'cifar10', attr_names, 'fc6'), log=True)


meta = model_metas['vgg16']
plot_convergence(meta, 'dense_1', filter_metric(cos_mean, 'vgg16', attr_names, 'dense_1'))

In [None]:
set([p for (l, m, r, p), v in cos_sim_percentiles.items()])

In [None]:
len(cos_mean.keys())

In [None]:

from when_explanations_lie import mpl_styles


def plot_legend(legend, marker='all'):
    plt.figure(figsize=(2.5, 3))
    for label, style in legend.items():
        plt.plot([], label=label, alpha=1, **legend[label])

    plt.axis('off')
    plt.legend(loc='center', labelspacing=0.33)
    plt.savefig("./figures/cosine_similarity/cos_sim_legend_{}.pdf".format(marker),
                bbox_inches='tight', pad_inches=0.02,
               )
    plt.show()

    
os.makedirs('figures/cosine_similarity', exist_ok=True)

cos_sim_median = {(l, m, r) : v for (l, m, r, p), v in cos_sim_percentiles.items()
                  if p == 50}

metric = cos_sim_median

selected_attr_names = [n for n in attr_names if not n.startswith("LRP CMP")]
lrp_cmp_names = [
    'LRP $\\alpha1\\beta0$',
    'LRP $\\alpha2\\beta1$',
    'LRP-z',
    'LRP CMP $\\alpha1\\beta0$',
    'LRP CMP $\\alpha2\\beta1$',
    'Gradient'
]

figpaths = []
for log in [True, False]:
    include_cos_sim_baseline = not log
    
    for model_name in model_names:
        meta = model_metas[model_name]
        replacement_layer = meta.csc_replacement_layers[0]

        metrics = filter_metric(metric, model_name, lrp_cmp_names, replacement_layer)
        legend, figpath = plot_convergence(meta, replacement_layer, metrics,  
                                           log=log, save_marker='lrp_cmp')
        print(legend)
        plot_legend(legend, model_name + "_lrp_cmp")
        figpaths.append(figpath)

    all_model_legend = OrderedDict()
    for model_name in model_names:
        meta = model_metas[model_name]
        print("meta", meta.model_name)
        for replacement_layer in meta.csc_replacement_layers:
            metrics = filter_metric(metric, model_name, selected_attr_names, replacement_layer)
            legend, figpath = plot_convergence(meta, replacement_layer, metrics, log=log)
            figpaths.append(figpath)
            all_model_legend.update(legend)
        plot_legend(legend, model_name)

    plot_legend(all_model_legend, "all")

In [None]:
figpaths

In [None]:
list(cos_mean.keys())[0]

In [None]:
from IPython.display import IFrame, display

In [None]:
legend

In [None]:
figures = ! ls figures/cosine_similarity/
print(figures)
for figure in figures:
    display(IFrame("figures/cosine_similarity/" + figure, 800, 600))

In [None]:
def select_keys(x, select_key_values):
    for keys, val in x.items():
        should_yield = True
        for i, sel_key_value in enumerate(select_key_values):
            if sel_key_value is None:
                continue
            if sel_key_value != keys[i]:
                should_yield = False
        if should_yield:
            yield keys, val

In [None]:
# list(select_keys(cos_sim_histograms, ['Gradient', 'vgg16', 'dense_1', None]))

In [None]:

def sort_hist(x):
    (attr_name, model_name, repl_layer, layer_name) = x[0]
    return  attr_names.index(attr_name)

prev_layer_name = None
n_plots = 0



def plot_histograms(histograms):
    legend = OrderedDict()
    attr_counts = []
    labels = []
    for (attr_name, model_name, repl_layer, layer_name), (counts, bins) in sorted(
        histograms.items(), key=sort_hist, reverse=True):

        lower_09 = counts[bins[:-1] < 0.9].sum()
        #print(attr_name, counts.sum())
        # print(attr_name, repl_layer, layer_name)
        counts_collapsed = np.concatenate([lower_09[None], counts[bins[:-1] >= 0.9]])
        bins_int = np.arange(len(counts_collapsed) + 1)
        attr_counts.append(counts_collapsed)
        labels.append(attr_name)
        
    plt.figure(figsize=(3., 2.5))
    
    
    color = [mpl_styles[l]['color'] for l in labels] 
    _, _, patches = plt.hist(
        [bins_int[:-1]] * len(attr_counts), bins_int, 
        weights=attr_counts, stacked=True, label=labels,  
        color=color, 
        rwidth=0.9,
        edgecolor='black', linewidth=1.2,
    )
    xticks = ["{:.4g}".format(b) for b in [0] + bins[bins >= 0.9].tolist()]
    plt.xticks(bins_int, xticks, rotation=90)
    print(bins_int, xticks)
    
    hatches = ['////',  '...', '\\\\\\\\',  'xxxx','OO', 'xxx', '**'] * 4
    color_to_hatch = {}
    for label, patch_set, color in reversed(list(zip(labels, patches, color))):
        if color not in color_to_hatch:
            hatch = None # hatches[0]
            color_to_hatch[color] = 0
        else:
            hatch = hatches[color_to_hatch[color]]
            color_to_hatch[color] += 1
            
        for patch in patch_set.patches:
            patch.set_hatch(hatch)
        legend[label] = {'hatch': hatch, 'color': color}
    plt.yticks([])
    #plt.legend(bbox_to_anchor=[1, 1])
    sns.despine(left=True)
    return legend
    
legend = OrderedDict()
for model_name in model_names:
    meta = model_metas[model_name]
    repl_layer = meta.csc_replacement_layers[0]
    hist_layer = meta.csc_histogram_layers[-2]
    hists = OrderedDict(select_keys(
        cos_sim_histograms, [None, model_name, repl_layer, hist_layer]))
    
    legend.update(plot_histograms(hists))
    outdir = 'figures/csc_hists/'
    os.makedirs(outdir, exist_ok=True)
    
    figpath =os.path.join(outdir, "hist_{}_repl_{}_hist_{}.pdf".format(
        model_name, 
        meta.names.to_nice(repl_layer),
        meta.names.to_nice(hist_layer)))
    plt.savefig(figpath, bbox_inches='tight', pad_inches=0.01)
    print(figpath)
    plt.show()

In [None]:
plt.figure(figsize=(2.5, 3))
for label in attr_names:
    if label not in legend:
        continue
    plt.hist([], label=label, alpha=1, **legend[label])

plt.axis('off')
plt.legend(loc='center', fontsize='medium', frameon=True, handlelength=3, labelspacing=0.33)

fname = 'figures/csc_hists/csc_hist_legend.pdf'
plt.savefig(fname, bbox_inches='tight', pad_inches=0.01)
print(fname)
plt.show()

In [None]:
load_weights

In [None]:
keras.backend.clear_session()
model, innv_net, _ = load_model('resnet50', load_weights=load_weights)
model_output_shapes = get_output_shapes(model)

selected_layers = [layer_names[model_name].to_raw(nice_name) 
                   for nice_name in layer_names[model_name].nice_names()]

In [None]:
analyser = DeepLIFTWrapper(model, **kwargs)
repl_analyser = DeepLiftRelevanceReplacer(analyser)

In [None]:
repl_analyser.

In [None]:
bins

In [None]:

hist[1]

In [None]:
counts