# When Explanations Lie: Why Modified BP Attribution fails

This notebook produces the heatmap figures and the ssim results.

In [None]:
# uncomment to install install packages
#!pip install tensorflow-gpu==1.13.1
#!pip install innvestigate seaborn tqdm

In [None]:
#%env CUDA_VISIBLE_DEVICES=1

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow
import tensorflow as tf

import innvestigate
import matplotlib.pyplot as plt

import numpy as np
import PIL 
import copy
import contextlib

import imp
import numpy as np
import os

from skimage.measure import compare_ssim 
import pickle
from collections import OrderedDict
from IPython.display import IFrame, display

import keras
import keras.backend
import keras.models


import innvestigate
import innvestigate.applications.imagenet
import innvestigate.utils as iutils
import innvestigate.utils as iutils
import innvestigate.utils.visualizations as ivis
from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP
from innvestigate.analyzer.base import AnalyzerNetworkBase, ReverseAnalyzerBase
from innvestigate.analyzer.deeptaylor import DeepTaylor

import time
import tqdm

import seaborn as sns

import itertools
import matplotlib as mpl

from tensorflow.python.client import device_lib

from utils import *

In [None]:
def _prepare_model(self, model):
    return super(DeepTaylor, self)._prepare_model(model)

# otherwise DTD does not work on negative outputs
DeepTaylor._prepare_model = _prepare_model

In [None]:
# device_lib.list_local_devices()

In [None]:
# path to imagenet validation
# imagenet_val_dir = "/home/leonsixt/tmp/imagenet/imagenet-raw/validation"
imagenet_val_dir = "/mnt/ssd/data/imagenet/imagenet-raw/validation"
# path to examplary image
ex_image_path = "n01534433/ILSVRC2012_val_00015410.JPEG"
# number of images to run the evaluation
n_selected_imgs = 10

model_names = ['resnet50', 'vgg16']

assert os.path.exists(imagenet_val_dir)
os.makedirs('figures', exist_ok=True)

In [None]:

model, innv_net, color_conversion = load_model('vgg16')
ex_image_vgg, ex_target, val_images, selected_img_idxs = load_val_images(
    innv_net, imagenet_val_dir, ex_image_path, n_selected_imgs)

keras.backend.clear_session()
model, innv_net, color_conversion = load_model('resnet50')
ex_image, ex_target, val_images, selected_img_idxs = load_val_images(
    innv_net, imagenet_val_dir, ex_image_path, n_selected_imgs)


assert ((ex_image - ex_image_vgg) == 0).all()

nice_layer_names = get_nice_layer_names(model)

In [None]:
n_layers = {'vgg16': 23, 'resnet50': 177}

randomization_layers = {
    'vgg16': ["conv1_1", "conv2_1",  "conv3_1", "conv4_1",  "conv4_3",  "conv5_1", "conv5_3", "fc1", "fc3"],
    'resnet50': ['conv1', 'block2_2', 'block3_1', 'block3_3', 'block4_1', 'block4_6', 'block5_2', 'dense'],
}



In [None]:
output_shapes = get_output_shapes(model)

print_output_shapes = False 
if print_output_shapes: 
    print("{:3}{:20}{:20}{}".format("l", "layer", "input_at_0", "output_shape"))
    for i in range(len(model.layers)):
        layer = model.get_layer(index=i)
        print("{:3}: {:20}  {:20}  {}".format(
            i, layer.name, str(layer.get_input_shape_at(0)), str(output_shapes[i])))
        #print("{:3}: {:20}  {:20}  {}".format(i, type(layer).__name__, layer.name, output_shapes[i]))

In [None]:
def hmap_postprocess_wrapper(name):
    return lambda x: heatmap_postprocess(name, x)

input_range = (ex_image.min(), ex_image.max())
analysers = get_analyser_params(input_range)

attr_names = [n for (n, _, _, _, _) in analysers]
    
hmap_postprocessing = {
    n: hmap_postprocess_wrapper(post_name) for n, _, post_name, _, _ in analysers
}

In [None]:
for i, name in enumerate(attr_names):
    style = mpl_styles[name]
    plt.plot(np.arange(10), [20-i] * 10, 
             #markersize=5,
             label=name + " m=" + style['marker'], **style)
    
plt.legend(bbox_to_anchor=(1, 1))

## Sanity Checks: Random Parameters & Logit Check

In [None]:
model_cascading, _, _ = load_model('resnet50')
model_random, _, _ = load_model('resnet50', load_weights=False)
model_cascading.set_weights(model.get_weights())
out = model.predict(ex_image)
out_cascading = model_cascading.predict(ex_image)
print("mean-l1 distance of the outputs of the trained model and when weights are from trained model [should be 0]:", np.abs(out_cascading - out).mean())

n_layers = len(model_random.layers)
copy_weights(model_cascading, model_random, range(n_layers - 3, n_layers))

out = model.predict(ex_image)
out_cascading = model_cascading.predict(ex_image)
print( "mean-l1 distance of the outputs of the trained model when the last 2 layers are random [should not be 0]:", np.abs(out_cascading - out).mean())

In [None]:
# heatmaps are saved in those dicts
hmap_original = OrderedDict()
hmap_random_weights = OrderedDict()
hmap_random_target = OrderedDict()

recreate_analyser = False
for model_name in tqdm.tqdm_notebook(model_names[::-1]):
    get_layer_idx = lambda layer_name: get_layer_idx_full(
        model_name, nice_layer_names, layer_name)
    
    for i, (attr_name, innv_name, _, excludes, analyser_kwargs) in enumerate(tqdm.tqdm_notebook(
        analysers, desc=model_name)):
        if i % 5 == 0:
            # clear session from time to time to not OOM
            keras.backend.clear_session()

            model, innv_net, _ = load_model(model_name)
            model_cascading, _, _ = load_model(model_name)
            model_random, _, _ = load_model(model_name, load_weights=False)
            model_cascading.set_weights(model.get_weights())
            
        if "exclude_" + model_name in excludes:
            continue
        if innv_name == 'pattern.attribution':
            analyser_kwargs['patterns'] = innv_net['patterns']

        cascading_heatmaps = {}
        cascading_outputs = {}
        model_cascading.set_weights(model.get_weights())

        original_idx = len(model.layers)

        analyzer_cascading = innvestigate.create_analyzer(
            innv_name, model_cascading, 
            neuron_selection_mode="index", **analyser_kwargs)

        for img_idx, (img_pp, target) in zip(selected_img_idxs, val_images):
            random_target = get_random_target(target)
            hmap_random_target[model_name, attr_name, img_idx] = (
                random_target, analyzer_cascading.analyze(img_pp, neuron_selection=random_target)[0])
        selected_layers = [('original', original_idx)] +  [
            (name, get_layer_idx(name)) 
             for name in randomization_layers[model_name][::-1]
        ]
        for layer_name, layer_idx in tqdm.tqdm_notebook(selected_layers, desc=attr_name):
            copy_weights(model_cascading, model_random, range(layer_idx, original_idx))
            if recreate_analyser:
                analyzer_cascading = create_analyzer(
                    analyser, model_cascading, 
                    neuron_selection_mode="index",  **analyser_kwargs)

            for img_idx, (img_pp, target) in zip(selected_img_idxs, val_images):
                hmap = analyzer_cascading.analyze(img_pp, neuron_selection=target)[0]
                if layer_idx == original_idx:
                    hmap_original[model_name, attr_name, img_idx] = hmap
                else:
                    hmap_random_weights[model_name, attr_name, img_idx, layer_idx] =  hmap

In [None]:
dump_heatmaps = False

if dump_heatmaps:
    with open('heatmaps_v4.pickle', 'wb') as f:
        pickle.dump((hmap_original, hmap_random_weights, hmap_random_target, analysers), f)

In [None]:
# print size of dump
! ls -lh 'heatmaps.pickle'

In [None]:
load_heatmaps = False
if load_heatmaps:
    with open('heatmaps.pickle', 'rb') as f:
        hmap_original, hmap_random_weights, hmap_random_target, analysers = pickle.load(f)

In [None]:
#attr_names = sorted(set([n for (n, _) in hmap_original.keys()]))
attr_names

In [None]:
def get_sorting(n):
    attr_names_sorting = {
     'DTD': -1,
     'GuidedBP': -1,
     'SmoothGrad': -1,
     '$\\alpha=1, \\beta=0$-LRP': 1,
     '$\\alpha=2, \\beta=1$-LRP': 2,
     '$\\alpha=100, \\beta=99$-LRP': 3,
    }
    if n in attr_names_sorting:
        return attr_names_sorting[n]
    elif "epsilon" in n:
        return 0
    elif "cmp-LRP" in n:
        return 10
    else:
        return 5
    
attr_names = sorted(attr_names, key=lambda x: (get_sorting(x), x))
attr_names

In [None]:
ssim = OrderedDict()
l2_random_weights = OrderedDict()

last_idx = len(model.layers)
for (name, img_idx, layer_idx), heatmap in tqdm.tqdm_notebook(hmap_random_weights.items()):
    original_heatmap = hmap_original[(name, img_idx)]
    postprocess = hmap_postprocesisng[name]
    original_heatmap = postprocess(original_heatmap)
    heatmap = postprocess(heatmap)
    ssim[(name, img_idx, layer_idx)] = ssim_flipped(heatmap, original_heatmap)
    l2_random_weights[(name, img_idx, layer_idx)] = l2_flipped(heatmap, original_heatmap)

In [None]:
ssim_random_target = OrderedDict()
for (name, img_idx), (_, hmap_random) in tqdm.tqdm_notebook(hmap_random_target.items()):
    if name not in ssim_random_target:
        ssim_random_target[name] = []
        
    postprocess = hmap_postprocesisng[name]
    
    hmap = hmap_original[name, img_idx]
    original_heatmap = postprocess(original_heatmap)
    hmap = postprocess(hmap)
    hmap_random = postprocess(hmap_random)
    ssim_random_target[name].append(
        ssim_flipped(hmap, hmap_random))

In [None]:
with sns.axes_style('whitegrid', {"axes.grid": True, 'font.family': 'serif'}):
    fig, ax = plt.subplots(1, 1, figsize=(4, 3.0), squeeze=True)
    mean_ssim = [np.mean(s) for s in ssim_random_target.values()]
    
    names = ssim_random_target.keys()
    bars = ax.bar(names, mean_ssim, 
           color=[linestyles[name]['color'] for name in attr_names])
    
    xlabels = ssim_random_target.keys()
    ax.set_ylabel('SSIM')
    ax.set_xticks(np.arange(len(xlabels)))
    ax.set_xticklabels(xlabels, rotation=90)
    
    
    fig.savefig('check-random-logit.pdf',  bbox_inches='tight', pad_inches=0)
    display(IFrame('check-random-logit.pdf', 800, 500))

In [None]:
ssim_random_target.keys()

In [None]:
with sns.axes_style('ticks', {"axes.grid": True, 'font.family': 'serif'}):
    fig, ax = plt.subplots(1, 1, figsize=(4, 3.0), squeeze=True)
    
    
    xlabels = attr_names
    bars = ax.boxplot([ssim_random_target[n] for n in attr_names]) 
    ax.set_ylabel('SSIM')
    #ax.set_xticks(np.arange(len(xlabels)))
    ax.set_xticklabels(xlabels, rotation=90)
    ax.set_ylim(-0.05, 1.05)
    
    fig.savefig('check-random-logit-boxplot.pdf',  bbox_inches='tight', pad_inches=0)
    display(IFrame('check-random-logit-boxplot.pdf', 800, 500))

In [None]:
ssim_reduce = 'median'
confidence_intervals = True
confidence_percentile = 99.5
with sns.axes_style("ticks", {"axes.grid": True, 'font.family': 'serif'}):
    # metrics = [('SSIM', ssim), ('MSE', l2_random_weights)]
    metrics = [('SSIM', ssim)]
    fig, axes = plt.subplots(1, len(metrics), figsize=(4.5 * len(metrics), 3.5), squeeze=False)
    axes = axes[0]
    for ax, (ylabel, metric) in zip(axes, metrics): 
        for name in attr_names:
            metric_per_layer = []
            
            lower_conf = []
            upper_conf = []
            for (_, layer_idx) in selected_layers[::-1]:
                metric_per_layer.append(
                    [metric[(name, img_idx, layer_idx)] for img_idx in selected_img_idxs]
                )
                if confidence_intervals:
                    vals = np.array(metric_per_layer[-1])
                    ridx = np.random.choice(len(vals), (10000, len(vals)), replace=True)
                    resample = vals[ridx]
                    stats = np.median(resample, 1)
                    lower_conf.append(np.percentile(stats, 100 - confidence_percentile))
                    upper_conf.append(np.percentile(stats, confidence_percentile))
                    
            metric_per_layer = np.array(metric_per_layer)
                
            if ssim_reduce == 'mean':
                ssims_reduced = metric_per_layer.mean(1)
            elif ssim_reduce == 'median':
                ssims_reduced = np.median(metric_per_layer, 1)
            
            ticks = np.arange(len(ssims_reduced))
            ax.plot(ticks, ssims_reduced, label=name, **linestyles[name])
            ax.fill_between(ticks, lower_conf, upper_conf, 
                            color=linestyles[name]['color'],
                            alpha=0.25
                           )
            #ax.plot(ticks, lower_conf, color=linestyles[name]['color'])
            #ax.plot(ticks, upper_conf, color=linestyles[name]['color'])
            
        xlabels = [layer_name for layer_name, _ in selected_layers[::-1]] 
        ax.set_ylim([0, 1.05])
        ax.set_xticks(np.arange(len(xlabels)))
        ax.set_xticklabels(xlabels, rotation=90)
        ax.set_ylabel(ylabel)
    axes[-1].legend(bbox_to_anchor=(1.0, 1.00))
    plt.savefig('check-random-weights.pdf',  bbox_inches='tight', pad_inches=0)
    plt.show()
    display(IFrame('check-random-weights.pdf', 800, 500))

In [None]:
def plot_heatmap_grid(heatmaps, cols, row_labels=[], column_labels=[], fig_path=None):
    mpl.rcParams['font.family'] = 'serif'
    rows = len(heatmaps) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols, rows), squeeze=False)
    fontsize = 12
    plt.subplots_adjust(wspace=0.05, hspace=0.05, top=1, bottom=0, left=0, right=1)
    for label, ax in zip(row_labels, axes[:, 0]):
        ax.set_ylabel(label, fontsize=fontsize + 1, labelpad=55, rotation=0)
        
    print(axes.shape, column_labels, row_labels)
    for label, ax in zip(column_labels, axes[0, :]):
        ax.set_title(label, fontsize=fontsize)
        
        
    for ax, heatmap in zip(axes.flatten(), heatmaps):
        ax.imshow(heatmap, cmap='seismic', vmin=-1, vmax=1)
        ax.set_xticks([])
        ax.set_yticks([])

    #plt.tight_layout()
    if fig_path is not None:
        plt.savefig(fig_path, bbox_inches='tight', pad_inches=0, dpi=150)

In [None]:
def normalize_neg(x):
    vmax = np.percentile(x, 99)
    vmin = np.percentile(x, 1)
    vmax
    x_pos = x * (x > 0)
    x_neg = x * (x < 0)
    
    x_pos = x_pos / vmax
    x_neg = - x_neg / vmin
    return np.clip(x_pos + x_neg, -1, 1)

for model_name in model_names:
    hmap_plot = []
    rnd_layers = randomization_layers[model_name][::-1]
    attr_for_model = []
    for (attr_name, _, _, excludes, _) in analysers:
        if 'exclude_' + model_name in excludes:
            print(attr_name)
            continue
        try:
            if attr_name in ['GuidedBP', 'Deconv']:
                postp = hmap_postprocess_wrapper('sum')
            else:
                postp = hmap_postprocessing[attr_name]

            for img_idx in selected_img_idxs[:1]:
                hmap_plot.append(norm_image(val_images[0][0][0]))
                hmap_plot.append(normalize_neg(postp(hmap_original[model_name, attr_name, img_idx])))
                for layer_name in rnd_layers:
                    layer_idx = get_layer_idx_full(model_name, nice_layer_names, layer_name)
                    hmap_plot.append(normalize_neg(postp(hmap_random_weights[model_name, attr_name, img_idx, layer_idx])))
            attr_for_model.append(attr_name)
        except KeyError as e:
            print(e)
            pass
    
    plot_heatmap_grid(
        hmap_plot, 2+len(rnd_layers), row_labels=attr_for_model, 
        column_labels=['input', 'original'] + rnd_layers,
        fig_path='figures/heatmap_grid_{}.pdf'.format(model_name)
    )

In [None]:
display(IFrame('figures/heatmap_grid_vgg16.pdf', width=1000, height=600))
display(IFrame('figures/heatmap_grid_resnet50.pdf', width=1000, height=600))

In [None]:
attr_name = 'LRP-$\\alpha=5, \\beta=4$'
attr_name = 'PatternAttr.'
hmap = hmap_random_weights['vgg16', attr_name, 673, 1]