


# When Explanations Lie: Why Many Modified BP Attributions fails


## Sanity Checks & Random Logits

This notebook produces the saliency figures and the ssim results.

In [None]:
# select gpu device
%env CUDA_VISIBLE_DEVICES=0

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow
import tensorflow as tf

import innvestigate
import matplotlib.pyplot as plt

import numpy as np
import PIL 
import copy
import json
import contextlib

import imp
import numpy as np
import os

from skimage.measure import compare_ssim 
import pickle
import datetime
from collections import OrderedDict
from IPython.display import IFrame, display

import keras
import keras.backend
import keras.models
from keras.applications.resnet50 import preprocess_input


import innvestigate
import innvestigate.applications.imagenet
import innvestigate.utils as iutils
import innvestigate.utils as iutils
import innvestigate.utils.visualizations as ivis
from innvestigate.analyzer.relevance_based.relevance_analyzer import LRP
from innvestigate.analyzer.base import AnalyzerNetworkBase, ReverseAnalyzerBase
from innvestigate.analyzer.deeptaylor import DeepTaylor
from innvestigate.analyzer import DeepLIFTWrapper

import warnings
import time
from tqdm.notebook import tqdm as tqdm_notebook
import seaborn as sns

import itertools
import matplotlib as mpl

from tensorflow.python.client import device_lib

from when_explanations_lie import *

import deeplift_resnet
from deeplift_resnet import monkey_patch_deeplift_neg_pos_mxts
from monkey_patch_lrp_resnet import custom_add_bn_rule, get_custom_rule

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
keras.backend.set_session(sess)
device_lib.list_local_devices()

In [None]:
# path to imagenet validation
host, = ! hostname

with open('imagenet_dir.json') as f:
    imagenet_dir = json.load(f)[host]

# path to examplary image
ex_image_path = "n01534433/ILSVRC2012_val_00015410.JPEG"
# number of images to run the evaluation
n_selected_imgs = 200

model_names = ['resnet50', 'vgg16']

os.makedirs('figures', exist_ok=True)

In [None]:
def load_model_and_meta(model_name, load_weights=True, clear_session=True):
    if clear_session:
        keras.backend.clear_session()
    if model_name in ['vgg16', 'resnet50']:
        model, innv_net, color_conversion = load_model(model_name, load_weights) 
        meta = ImageNetMeta(model, model_name, innv_net, n_selected_imgs, 
                            imagenet_dir, ex_image_path)
    elif model_name == 'cifar10':
        model, _, _ = load_model('cifar10', load_weights)
        meta = CIFAR10Meta(model, n_selected_imgs)
    else:
        raise ValueError()
    return model, meta
    

In [None]:
model_names

In [None]:
metas = OrderedDict()
for model_name in model_names:
    model, meta = load_model_and_meta(model_name)
    metas[model_name] = meta

In [None]:
meta.model_name

In [None]:
output_shapes = get_output_shapes(model)

print_output_shapes = False 
if print_output_shapes: 
    print("{:3}{:20}{:20}{}".format("l", "layer", "input_at_0", "output_shape"))
    for i in range(len(model.layers)):
        layer = model.get_layer(index=i)
        print("{:3}: {:20}  {:20}  {}".format(
            i, layer.name, str(layer.get_input_shape_at(0)), str(output_shapes[i])))
        #print("{:3}: {:20}  {:20}  {}".format(i, type(layer).__name__, layer.name, output_shapes[i]))

In [None]:
def hmap_postprocess_wrapper(name):
    return lambda x: heatmap_postprocess(name, x)

input_range = (meta.ex_image.min(), meta.ex_image.max())
analysers = get_analyser_params(input_range)

attr_names = [n for (n, _, _, _, _) in analysers]
    
hmap_postprocessing = {
    n: hmap_postprocess_wrapper(post_name) for n, _, post_name, _, _ in analysers
}

In [None]:
for i, name in enumerate(attr_names):
    style = mpl_styles[name]
    plt.plot(np.arange(10), [20-i] * 10, 
             #markersize=5,
             label=name + " m=" + style['marker'], **style)
    
plt.legend(bbox_to_anchor=(1, 1))

## Sanity Checks: Random Parameters & Logit Check

In [None]:
model, meta = load_model_and_meta('resnet50', clear_session=True)
model_cascading, _ = load_model_and_meta('resnet50', clear_session=False)
model_random, _ = load_model_and_meta('resnet50', 
                                         load_weights=False, clear_session=False)
model_cascading.set_weights(model.get_weights())
out = model.predict(meta.ex_image)
out_cascading = model_cascading.predict(meta.ex_image)
print("mean-l1 distance of the outputs of the trained model and when weights are from trained model [should be 0]:", np.abs(out_cascading - out).mean())

n_layers = len(model_random.layers)
copy_weights(model_cascading, model_random, range(n_layers - 3, n_layers))

out = model.predict(meta.ex_image)
out_cascading = model_cascading.predict(meta.ex_image)
print( "mean-l1 distance of the outputs of the trained model when the last 2 layers are random [should not be 0]:", np.abs(out_cascading - out).mean())

In [None]:
# use to select a specific cache dir
# hmap_output_dir = 'cache/2020-01-26T18:07:39.494420'


In [None]:
if 'hmap_output_dir' not in globals():
    hmap_output_dir = 'cache/' + datetime.datetime.now().isoformat()
    os.makedirs(hmap_output_dir)
    print("Created new output dir:", hmap_output_dir)

In [None]:
print("Saving heatmaps to: ", hmap_output_dir)

In [None]:
@contextlib.contextmanager
def ctx_analyzer(model, meta, innv_name, kwargs):
    if innv_name.startswith("pattern"):
        kwargs['patterns'] = meta.patterns

    if innv_name == 'deep_lift.wrapper':
        kwargs = copy.copy(kwargs)
        cross_mxts = kwargs.pop('cross_mxts', True)
        print("CROSS MIXTS", cross_mxts)
        with monkey_patch_deeplift_neg_pos_mxts(cross_mxts):
            analyzer = DeepLIFTWrapper(model, **kwargs)
            analyzer.create_analyzer_model()
            yield analyzer
    else:
        custom_rule = get_custom_rule(innv_name, kwargs)
        with custom_add_bn_rule(custom_rule):
            analyzer = innvestigate.create_analyzer(innv_name, model, **kwargs)
            analyzer.create_analyzer_model()
            yield analyzer

In [None]:
# heatmaps are saved in those dicts

selected_attr_names = attr_names
# selected_attr_names = ['DeepLIFT Abla.', ]

recreate_analyser = False
for model_name in tqdm.tqdm_notebook(['vgg16']):
#for model_name in tqdm.tqdm_notebook(model_names):
    model, meta = load_model_and_meta(model_name,  clear_session=True)
    input_range = (meta.ex_image.min(), meta.ex_image.max())
    analysers = get_analyser_params(input_range)
    
    for i, (attr_name, innv_name, _, excludes, analyser_kwargs) in enumerate(tqdm.tqdm_notebook(
        analysers, desc=model_name)):
        if attr_name not in selected_attr_names:
            continue
            
        hmap_original = OrderedDict()
        hmap_random_weights = OrderedDict()
        hmap_random_target = OrderedDict()
        
        if i % 1 == 0:
            # clear session from time to time to not OOM
            keras.backend.clear_session()

            model, innv_net, _ = load_model(model_name)
            model_cascading, _, _ = load_model(model_name)
            model_random, _, _ = load_model(model_name, load_weights=False)
            model_cascading.set_weights(model.get_weights())
            
        if "exclude_" + model_name in excludes:
            continue
        
        
        fname = hmap_output_dir + '/heatmap_{}_{}.pickle'.format(model_name, attr_name)
        if os.path.exists(fname):
            warnings.warn("File already exsists: " + fname)
            # continue
            
            
        cascading_heatmaps = {}
        cascading_outputs = {}
        model_cascading.set_weights(model.get_weights())
        kwargs_w_idx = copy.copy(analyser_kwargs)
        kwargs_w_idx['neuron_selection_mode'] = "index"
        
        original_idx = len(model.layers)
        with ctx_analyzer(model_cascading, meta, innv_name, kwargs_w_idx) as analyzer_cascading:
            for img_idx, (img_pp, target) in zip(meta.image_indices, meta.images):
                random_target = get_random_target(target)
                random_hmap = analyzer_cascading.analyze(img_pp, neuron_selection=random_target)[0]
                idx = model_name, attr_name, img_idx
                hmap_random_target[idx] = (random_target, random_hmap)
                
            selected_layers = [('original', original_idx)] +  [
                (name, meta.names.nice_to_idx(name)) 
                 for name in meta.randomization_layers[::-1]
            ]
            
            for layer_name, layer_idx in tqdm.tqdm_notebook(selected_layers, desc=attr_name):
                copy_weights(model_cascading, model_random, range(layer_idx, original_idx))
                if recreate_analyser or innv_name.startswith('deep_lift'):
                    keras.backend.clear_session()

                    model, innv_net, _ = load_model(model_name)
                    model_cascading, _, _ = load_model(model_name)
                    model_random, _, _ = load_model(model_name, load_weights=False)
                    copy_weights(model_cascading, model_random, range(layer_idx, original_idx))
                    
                    with ctx_analyzer(model_cascading, meta, innv_name, kwargs_w_idx) as analyzer_cascading:
                        pass

                for img_idx, (img_pp, target) in zip(meta.image_indices, meta.images):
                    hmap = analyzer_cascading.analyze(img_pp, neuron_selection=target)[0]
                    if layer_idx == original_idx:
                        hmap_original[model_name, attr_name, img_idx] = hmap
                        print('o', end='')
                    else:
                        hmap_random_weights[model_name, attr_name, img_idx, layer_idx] =  hmap

            #assert not os.path.exists(fname)
            with open(fname, 'wb') as f:
                pickle.dump((hmap_original, hmap_random_weights, hmap_random_target), f)

In [None]:
outpath = hmap_output_dir + '/heatmap_{}_{}.pickle'.format('vgg16', "PatternAttr.")
if not os.path.exists(outpath):
    warnings.warn("not found: " + outpath)
    raise
with open(outpath, 'rb') as f:
    hmap_original, hmap_random_weights, hmap_random_target = pickle.load(f)

In [None]:
! ls -lh {hmap_output_dir}

In [None]:
# computes the sanity check ssim scores

ssim_random_weights = OrderedDict()
l2_random_weights = OrderedDict()

last_idx = len(model.layers)
for model_name in model_names:
    for attr_name in attr_names:
        outpath = hmap_output_dir + '/heatmap_{}_{}.pickle'.format(model_name, attr_name)
        if not os.path.exists(outpath):
            warnings.warn("not found: " + outpath)
            continue
        with open(outpath, 'rb') as f:
            hmap_original, hmap_random_weights, hmap_random_target = pickle.load(f)
        
        for (model_name, name, img_idx, layer_idx), heatmap in tqdm_notebook(
            hmap_random_weights.items(), desc="{}.{}".format(model_name, attr_name)):
            original_heatmap = hmap_original[model_name, name, img_idx]
            postprocess = hmap_postprocessing[name]
            original_heatmap = postprocess(original_heatmap)
            heatmap = postprocess(heatmap)
            
            if attr_name == 'RectGrad':
                percentile = 100
            else:
                percentile = 99.5
            ssim_random_weights[model_name, name, img_idx, layer_idx] = ssim_flipped(
                heatmap, original_heatmap, percentile=percentile)
            # l2_random_weights[model_name, name, img_idx, layer_idx] = l2_flipped(heatmap, original_heatmap)

In [None]:
# computes the random target ssim scores
ssim_random_target = OrderedDict()

for model_name in model_names:
    for attr_name in attr_names:
        outpath = hmap_output_dir + '/heatmap_{}_{}.pickle'.format(model_name, attr_name)
        if not os.path.exists(outpath):
            warnings.warn("not found: " + outpath)
            continue
        with open(outpath, 'rb') as f:
            hmap_original, hmap_random_weights, hmap_random_target = pickle.load(f)
        for (model_name, attr_name, img_idx), (_, hmap_random) in tqdm.tqdm_notebook(
            hmap_random_target.items()):
            if (model_name, attr_name) not in ssim_random_target:
                ssim_random_target[model_name, attr_name] = []

            postprocess = hmap_postprocessing[attr_name]
            hmap = postprocess(hmap_original[model_name, attr_name, img_idx])
            hmap_random = postprocess(hmap_random)
            
            if attr_name == 'RectGrad':
                percentile = 100
            else:
                percentile = 99.5
            ssim_random_target[model_name, attr_name].append(
                ssim_flipped(hmap, hmap_random, percentile=percentile))

In [None]:
for model_name in model_names:
    with sns.axes_style('ticks', {"axes.grid": True, 'font.family': 'serif'}):
        fig, ax = plt.subplots(1, 1, figsize=(3.9, 2.3), squeeze=True)


        xlabels =  [n for (m, n) in ssim_random_target.keys() if m == model_name]
        bars = ax.boxplot([ssim_random_target[model_name, n] for n in xlabels]) 
        ax.set_ylabel('SSIM')
        #ax.set_xticks(np.arange(len(xlabels)))
        ax.set_xticklabels(xlabels, rotation=90)
        ax.set_ylim(-0.05, 1.05)
        
        os.makedirs('figures/sanity_checks/', exist_ok=True)
        figpath = 'figures/sanity_checks/random-logit-boxplot-{}.pdf'.format(model_name)
        fig.savefig(figpath,  bbox_inches='tight', pad_inches=0)
        display(IFrame(figpath, 800, 500))

In [None]:
selected_layers = [
    (name, meta.names.nice_to_idx(name))
     for name in meta.randomization_layers[::-1]
]
selected_layers

In [None]:
from when_explanations_lie import mpl_styles

metrics = [('SSIM', ssim_random_weights)]

ssim_reduce = 'median'
confidence_intervals = True
confidence_percentile = 99.5

for model_name in model_names:
    
    meta = metas[model_name]
    selected_layers = [
        (name, meta.names.nice_to_idx(name))
         for name in meta.randomization_layers[::-1]
    ]
    print(selected_layers)
    
    with sns.axes_style("ticks", {"axes.grid": True, 'font.family': 'serif'}):
        fig, axes = plt.subplots(1, len(metrics), figsize=(4.5 * len(metrics), 3.5), squeeze=False)
        axes = axes[0]
        for ax, (ylabel, metric) in zip(axes, metrics): 
            for (name, _, _, excludes, _) in analysers:
                if 'exclude_' + model_name in excludes:
                    continue
                    
                metric_per_layer = []
                layer_idx = selected_layers[-1][1]
                if (model_name, name, meta.image_indices[0], layer_idx) not in metric:
                    warnings.warn("cound not find: " + str((model_name, name, meta.image_indices[0], layer_idx)))
                    continue
                lower_conf = []
                upper_conf = []
                for (_, layer_idx) in selected_layers[::-1]:
                    metric_per_layer.append(
                        [metric[model_name, name, img_idx, layer_idx] for img_idx in meta.image_indices]
                    )
                        
                    if confidence_intervals:
                        vals = np.array(metric_per_layer[-1])
                        ridx = np.random.choice(len(vals), (10000, len(vals)), replace=True)
                        resample = vals[ridx]
                        stats = np.median(resample, 1)
                        lower_conf.append(np.percentile(stats, 100 - confidence_percentile))
                        upper_conf.append(np.percentile(stats, confidence_percentile))

                metric_per_layer = np.array(metric_per_layer)

                if ssim_reduce == 'mean':
                    ssims_reduced = metric_per_layer.mean(1)
                elif ssim_reduce == 'median':
                    ssims_reduced = np.median(metric_per_layer, 1)

                ticks = np.arange(len(ssims_reduced))
                ax.plot(ticks, ssims_reduced[::-1], label=name, **mpl_styles[name])
                ax.fill_between(ticks, lower_conf[::-1], upper_conf[::-1], 
                                color=mpl_styles[name]['color'],
                                alpha=0.25
                               )
                #ax.plot(ticks, lower_conf, color=linestyles[name]['color'])
                #ax.plot(ticks, upper_conf, color=linestyles[name]['color'])

            xlabels = [layer_name for layer_name, _ in selected_layers] 
            ax.set_ylim([0, 1.05])
            ax.set_xticks(np.arange(len(xlabels)))
            ax.set_xticklabels(xlabels, rotation=90)
            ax.set_ylabel(ylabel)
        axes[-1].legend(bbox_to_anchor=(1.0, 1.00), labelspacing=0.33)
        plt.savefig('figures/sanity_checks/ssim-random-weights-{}.pdf'.format(model_name),  bbox_inches='tight', pad_inches=0)
        plt.show()
        display(IFrame('figures/sanity_checks/ssim-random-weights-{}.pdf'.format(model_name), 800, 500))

In [None]:
def plot_heatmap_grid(heatmaps, cols, row_labels=[], column_labels=[], 
                      fig_path=None, figsize=None, labelpad=45):
    mpl.rcParams['font.family'] = 'serif'
    rows = len(heatmaps) // cols
    
    if figsize is None:
        figsize = (cols, rows)
    fig, axes = plt.subplots(rows, cols, figsize=figsize, squeeze=False)
    fontsize = 9
    plt.subplots_adjust(wspace=0.05, hspace=0.05, top=1, bottom=0, left=0, right=1)
    for label, ax in zip(row_labels, axes[:, 0]):
        ax.set_ylabel(label, fontsize=fontsize + 1, labelpad=labelpad, rotation=0)
        
    print(axes.shape, column_labels, row_labels)
    for label, ax in zip(column_labels, axes[0, :]):
        ax.set_title(label, fontsize=fontsize)
        
        
    for ax, heatmap in zip(axes.flatten(), heatmaps):
        ax.imshow(heatmap, cmap='seismic', vmin=-1, vmax=1)
        ax.set_xticks([])
        ax.set_yticks([])

    #plt.tight_layout()
    if fig_path is not None:
        plt.savefig(fig_path, bbox_inches='tight', pad_inches=0, dpi=120)

In [None]:
# def normalize_neg(x):
#     vmax = np.percentile(x, 99)
#     vmin = np.percentile(x, 1)
#     vmax
#     x_pos = x * (x > 0)
#     x_neg = x * (x < 0)
#     
#     x_pos = x_pos / vmax
#     x_neg = - x_neg / vmin
#     return np.clip(x_pos + x_neg, -1, 1)


def load_examplary_heatmap():
    hmap_loaded = OrderedDict()
    for model_name in model_names:
        meta = metas[model_name]
        rnd_layers = meta.randomization_layers[::-1]
        for (attr_name, _, _, excludes, _) in tqdm.tqdm_notebook(analysers):
            if 'exclude_' + model_name in excludes:
                print(attr_name)
                continue
            try:
                outpath = hmap_output_dir + '/heatmap_{}_{}.pickle'.format(model_name, attr_name)
                if not os.path.exists(outpath):
                    warnings.warn("not found: " + outpath)
                    continue
                with open(outpath, 'rb') as f:
                    hmap_original, hmap_random_weights, hmap_random_target = pickle.load(f)
                if attr_name in ['GuidedBP', 'Deconv']:
                    postp = hmap_postprocessing[attr_name]
                else:
                    postp = hmap_postprocessing[attr_name]
                
                postp = lambda x: x
                for img_idx in meta.image_indices[:1]:
                    hmap_loaded[model_name, attr_name, 'image'] =norm_image(meta.images[0][0][0])
                    hmap_loaded[model_name, attr_name, 'original'] = hmap_original[
                        model_name, attr_name, img_idx]
                    for layer_name in rnd_layers:
                        layer_idx = meta.names.nice_to_idx(layer_name)
                        hmap_loaded[model_name, attr_name, layer_name] = hmap_random_weights[
                            model_name, attr_name, img_idx, layer_idx]
            except KeyError as e:
                print(e)
                pass
    return hmap_loaded

hmap_examplary = load_examplary_heatmap()

In [None]:
os.makedirs('figures/sanity_checks', exist_ok=True)

In [None]:
def normalize_visual_equal(x, percentile=99):
    """
    for visualization we normalize pos and neg attribution separatly.
    """
    
    if (x > 0).all() or (x < 0).all():
        # special case
        x_abs = np.abs(x)
        vmax = np.percentile(x_abs, percentile)
        return np.sign(x.mean()) * x_abs / vmax
    
    vmax = np.percentile(x, percentile)
    vmin = np.percentile(x, 100 - percentile)
    
    x_pos = x * (x >= 0)
    x_neg = x * (x < 0)
    
    absmax = max(np.abs(vmax), np.abs(vmin))
    if np.abs(vmax) > 0:
        x_pos = x_pos / absmax
    if np.abs(vmin) > 0:
        x_neg = x_neg / absmax
    return np.clip(x_pos + x_neg, -1, 1)

def postprocess_sanity(attr_name, hmap, visual=True):
    hmap_sum = hmap_postprocessing[attr_name](hmap)
    if attr_name == "RectGrad":
        percentile = 100
    else:
        percentile = 99.5
    return normalize_sanity(hmap_sum, percentile) 

def postprocess_visual(attr_name, hmap, visual=True):
    mean = [103.939, 116.779, 123.68]
    hmap_sum = hmap_postprocessing[attr_name](hmap)
    if attr_name == "RectGrad":
        percentile = 100
    else:
        percentile = 99.5
    
    
    if False and attr_name in ['GuidedBP', "DeepLIFT Abla.", "PatternNet", "Deconv"]:
        return image(hmap)#[..., ::-1]
    else:
        return normalize_visual_equal(hmap_sum, percentile) 

In [None]:
attr_names

In [None]:
# figure 1
selected_analysers = ['GuidedBP',
 'RectGrad',
 'DTD',
 'LRP $\\alpha1\\beta0$',
 'LRP $\\alpha2\\beta1$',                      
 'PatternAttr.',
 'DeepLIFT Resc.',
 'Gradient',
]

selected_layers = [
    "image",
    "original",
    "fc3",
    "conv5_3",
    "conv4_1",
    "conv2_1",
    "conv1_1",
]

selected_layers
selected_hmaps = []

for attr_name in selected_analysers:
    for layer_name in selected_layers:
        hmap = hmap_examplary['vgg16', attr_name, layer_name]
        if layer_name != 'image':
            hmap = postprocess_visual(attr_name, hmap)
        selected_hmaps.append(hmap)
    
plot_heatmap_grid(
    selected_hmaps, len(selected_layers), row_labels=selected_analysers, 
    column_labels=selected_layers,
    figsize=(3.9, 0.55*len(selected_analysers) + 0.1),
    fig_path='figures/sanity_checks/heatmap_grid_figure1.pdf',
    labelpad=45,
)

In [None]:
display(IFrame('figures/sanity_checks/heatmap_grid_figure1.pdf', width=1000, height=600))

In [None]:
metas['resnet50'].randomization_layers

In [None]:
selected_analysers = [
     'LRP CMP $\\alpha1\\beta0$',
     'LRP CMP $\\alpha2\\beta1$'
]


for model_name in model_names:
    
    selected_hmaps = []
    selected_layers = {
        'vgg16': [
            "image",
            "original",
            "fc3",
            "conv5_3",
            "conv5_1",
            "conv4_1",
            "conv1_1",
        ],
        'resnet50': [
            "image",
            "original",
            "dense",
            "block5_2",
            "block4_1",
            "block3_3",
            "conv1",
        ]}[model_name]
    
    for attr_name in selected_analysers:
        for layer_name in selected_layers:
            hmap = hmap_examplary[model_name, attr_name, layer_name]

            if layer_name != 'image':
                hmap = postprocess_visual(attr_name, hmap)
            selected_hmaps.append(hmap)

    outname = 'figures/sanity_checks/heatmap_grid_{}_lrp_cmp.pdf'.format(model_name)
    print(len(selected_hmaps))
    plot_heatmap_grid(
        selected_hmaps, len(selected_layers), row_labels=selected_analysers, 
        column_labels=selected_layers,
        figsize=(3.99, 1.1),
        fig_path=outname
    )
    display(IFrame(outname, width=1000, height=600))

In [None]:
for selected_model_name in model_names[::-1]:
    print(selected_model_name)
    attr_for_model = OrderedDict()
    layer_names = OrderedDict()
    hmap_plot = []
    for (model_name, attr_name, layer_name), hmap in hmap_examplary.items():
        if model_name != selected_model_name:
            continue
        attr_for_model[attr_name] = attr_name
        layer_names[layer_name] = layer_name
        if layer_name != 'image':
            if attr_name in ['GuidedBP', "Deconv", "DeepLIFT Abla.", "PatternNet"]:
                hmap = image(hmap) #[..., [2, 1, 0]]
            else:
                hmap = postprocess_visual(attr_name, hmap)
        hmap_plot.append(hmap)
    
    #hmap_plot = [normalize_visual(h) for h in hmap_plot]
    plot_heatmap_grid(
        hmap_plot, len(layer_names), 
        row_labels=attr_for_model.keys(), 
        column_labels=layer_names.keys(),
        figsize=(0.6*len(layer_names), 0.6*len(attr_for_model)),
        #figsize=(1*len(layer_names), 1*len(attr_for_model)),
        fig_path='figures/sanity_checks/heatmap_image_grid_{}.pdf'.format(selected_model_name)
    )

In [None]:
display(IFrame('figures/sanity_checks/heatmap_image_grid_vgg16.pdf', width=1000, height=600))
display(IFrame('figures/sanity_checks/heatmap_image_grid_resnet50.pdf', width=1000, height=600))

In [None]:
for selected_model_name in model_names[::-1]:
    print(selected_model_name)
    attr_for_model = OrderedDict()
    layer_names = OrderedDict()
    hmap_plot = []
    for (model_name, attr_name, layer_name), hmap in hmap_examplary.items():
        if model_name != selected_model_name:
            continue
        attr_for_model[attr_name] = attr_name
        layer_names[layer_name] = layer_name
        if layer_name != 'image':
            hmap = postprocess_visual(attr_name, hmap)
        hmap_plot.append(hmap)
    
    #hmap_plot = [normalize_visual(h) for h in hmap_plot]
    plot_heatmap_grid(
        hmap_plot, len(layer_names), 
        row_labels=attr_for_model.keys(), 
        column_labels=layer_names.keys(),
        figsize=(0.6*len(layer_names), 0.6*len(attr_for_model)),
        #figsize=(1*len(layer_names), 1*len(attr_for_model)),
        fig_path='figures/sanity_checks/heatmap_visual_grid_{}.pdf'.format(selected_model_name)
    )

In [None]:
display(IFrame('figures/sanity_checks/heatmap_visual_grid_vgg16.pdf', width=1000, height=600))
display(IFrame('figures/sanity_checks/heatmap_visual_grid_resnet50.pdf', width=1000, height=600))