# Load pre-trained models

In [None]:
import matplotlib
matplotlib.use('nbagg')
import matplotlib.pylab as plt

import numpy as np
import torch
import torchvision
import torchvision.models as models
from saliency_utilities import compute_saliency_for_methods
from methods.method_research_utilities import load_imagenet_saliency_data, load_imagenet_saliency_metric_eval_data, load_bgc_imagenet_saliency_data, load_cifar10_saliency_data
import copy
from datetime import datetime
import time
from methods.method_research_utilities import load_cifar10_saliency_data, post_process_maps
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline
import os
from methods.saliency_utilities import plot_maps_method_vertical, plot_maps_method_horizontal
from methods.captum_post_process import _normalize_image_attr

import captum
from captum.attr import IntegratedGradients, Occlusion, LayerGradCam, LayerAttribution
from captum.attr import visualization as viz
from matplotlib.colors import LinearSegmentedColormap


from methods.saliency_utilities import plot_maps_method_vertical, plot_maps_method_horizontal, plot_model_randomization_maps

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from matplotlib import gridspec
import seaborn as sns

# Some plotting defaults
sns.set_style('whitegrid', {'axes.grid': False})
SSIZE=10
MSIZE=12
BSIZE=14
plt.rc('font', size=SSIZE)
plt.rc('axes', titlesize=MSIZE)
plt.rc('axes', labelsize=MSIZE)
plt.rc('xtick', labelsize=MSIZE)
plt.rc('ytick', labelsize=MSIZE)
plt.rc('legend', fontsize=MSIZE)
plt.rc('figure', titlesize=MSIZE)
plt.rcParams['font.family'] = "sans-serif"


# Load, preprocess and prepare data for saliency computation

In [None]:
FinalData, Labels, dataLoaderSal, categories =  load_imagenet_saliency_data() #load_bgc_imagenet_saliency_data()

x_batch, y_batch = next(iter(dataLoaderSal))
print(x_batch.shape, y_batch.shape)

print(FinalData.shape)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Visualize the data

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataLoaderSal))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=None)

In [None]:
fig = plt.figure(figsize=(9,4))

for i in range(10):
    plt.subplot(2,5,i+1)

    inp = FinalData[i].numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    
    plt.imshow(inp, interpolation=None, aspect='auto', cmap=plt.cm.inferno) 
    plt.xticks([])
    plt.yticks([])

plt.show()

# Post-process Maps for visualization

In [None]:
name = 'inception_3'

method_titles = ["GD", "ONLY.IG", "ONLY.M", "GDAsc.IG","GDAsc.M", "M.GDAsc.IG", "M.GDAsc.M","Wt.P.IG", "Wt.P.M", "IG", "CaptIG", "L.IG"]

Dataset = {0: 'mnist', 1: 'fmnist', 2: 'cifar10', 3: 'imgnet'}
data = 3 # 
fname_common = "method_research_bgc_"+name
# visualize_maps(1, fname_common, 2)
list_of_saliency_dicts, titles = post_process_maps(data, fname_common, method_list=["GD", "ONLY.IG", "ONLY.M", "GDAsc", "M.GDAsc", "Wt.P"],\
                                                      random_seeds=list(range(0, 1)), viz=False, scale_flag=False)

### Normalize Saliency Maps

In [None]:
def normalize_saliency_maps(saliency_method_dict, sign='absolute_value'):
    saliency_maps_all_method = {} 
       
    for method_name, all_saliency_images in saliency_method_dict.items():
        normalized_saliency = []
        for sal_image in all_saliency_images:
            attribution = np.transpose(sal_image, (1,2,0))
            norm_attr = _normalize_image_attr(attribution, sign)
           
            normalized_saliency.append(norm_attr)
        
        normalized_saliency = np.stack(normalized_saliency, axis=0)
       
        saliency_maps_all_method[method_name] = normalized_saliency
    return saliency_maps_all_method

## Normalize image back to 0-1/0-255

In [None]:
def normalize_image_to_plot(img, data_name):
    
    if data_name == 'imgnet':
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])

    elif data_name == 'cifar10':
        mean=[x/255.0 for x in [125.3, 123.0, 113.9]]
        std=[x/255.0 for x in [63.0, 62.1, 66.7]]

    img = std * img + mean
    img = np.clip(img, 0, 1)
    
    return img

### Plotting test code

In [None]:
def plot_maps_method_vertical_testcode(
        images,
        name,
        data_name, 
        saliency_dict, 
        method_captions, 
        p, 
        range_to_display = np.asarray(range(95, 105, 1)), 
        fig_size = (10, 5),
        cm=None, 
        interp = 'none', 
        vis_min = 0.0, 
        vis_sign='positive',
        save = False,
        fname = "my_file"
    ):
    saliency_methods_total_id = len(saliency_dict)
    
    nrows = saliency_methods_total_id+1
    ncols = 10
    
    fig = plt.figure(figsize=fig_size)
    gs = gridspec.GridSpec(nrows, ncols,
                       wspace=0.0, hspace=0.0)

    for i in range(ncols):
        
        img = images[range_to_display[i]].numpy()
        
        if np.squeeze(img).ndim == 3:
            img = np.transpose(img, (1,2,0))
            
        img = normalize_image_to_plot(img, data_name)
        
        ax = plt.subplot(gs[0, i])
        
        ax.imshow(img)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        

    for method_id in range(saliency_methods_total_id):

        saliency = saliency_dict[method_id]

        for i in range(ncols):

            sample = saliency_dict[method_id][range_to_display[i]]
            ax = plt.subplot(gs[method_id + 1, i])

            ax.imshow(sample, interpolation=interp, 
                          vmin=vis_min,
                          vmax=1.0,
                          cmap=cm)
            ax.set_xticklabels([])
            ax.set_yticklabels([])


    for method_id in range(saliency_methods_total_id + 1):
        ax = plt.subplot(gs[method_id, 0])
        ax.set_ylabel(method_captions[method_id], fontsize=9)

    plt.show()
    
    dis_range = str(range_to_display[0])+'_'+str(range_to_display[-1])
    
    if save:
        path = os.path.join('./Plots/Real/', fname+'_'+interp+'_'+cm+'_'+str(p) )
        fig.savefig(path+'.svg', transparent=True, bbox_inches='tight', pad_inches=0)
        fig.savefig(path+'.pdf', format='pdf', dpi=300)
        print('Plots Saved...', path)
    plt.close(fig)


# Do visualization

In [None]:
model = models.inception_v3(pretrained=True)
model.to(device)
model.eval()

# Add the edge detector result

# with checkerboard background change, load the edge detector
cbg_cb_sobel_edges = np.load("./models and saliencies/saliency/method_research_imgNet_cbg_cb_sobel_edges.npy")

# sobel_edges = np.load("./models and saliencies/saliency/method_research_imgNet_sobel_edges.npy")

# sobel_edges = np.moveaxis(sobel_edges, 3, 1)
cbg_cb_sobel_edges = np.moveaxis(cbg_cb_sobel_edges, 3, 1)

print(np.min(cbg_cb_sobel_edges[0]), np.max(cbg_cb_sobel_edges[0]))

# print(cbg_cb_sobel_edges.shape)
saliency_dict = copy.deepcopy(list_of_saliency_dicts[0])
saliency_dict[9] = cbg_cb_sobel_edges
title_set = titles + ["EDGE.D"]

print(title_set)

# Read the imagenet categories
with open("imagenet_classes.txt", "r") as f:
    imgnet_categories = [s.strip() for s in f.readlines()]

name = 'inception_3'
# title_set = titles
# "BlWhRd"
colors = {'positive': 'Reds', 'absolute_value': 'bwr', 'all': LinearSegmentedColormap.from_list("RdWhGn", ["red", "white","green"])}
sign = 'absolute_value'
# print(np.min(saliency_dict[0]), np.max(saliency_dict[0]))

all_method_saliency = normalize_saliency_maps(saliency_dict, sign=sign)


In [None]:
def get_only_required_method_saliency(all_saliency_dict, method_list):
    required_saliency_dict = {}
    count = 0
    for k, v in all_saliency_dict.items():
        if k in method_list:
            required_saliency_dict[count] = all_saliency_dict[k]
            count += 1
    
    return required_saliency_dict
            

In [None]:
titles = ['Input', 'Grad', "Integrated\nGradients", "GGIG", "Edge\nDetector"]
required_saliency_dict = get_only_required_method_saliency(all_method_saliency, [0, 1, 6, 9])

cm_vs_min = {'bwr': -1.0, 'bwr_r': -1.0, 'coolwarm': -1.0, 'Reds': 0.0, 'gray': 0.0, 'inferno': 0.0, 'afmhot': 0.0}

for cm_value, min_value in cm_vs_min.items():
    
    plot_maps_method_vertical_testcode(FinalData, name, Dataset[data], 
                                       required_saliency_dict, titles, 0, 
                                       range_to_display = np.asarray(range(0, 10, 1)), 
                                       fig_size=(10,5), 
                                       cm=cm_value, interp='none',
                                       vis_min=min_value, vis_sign=sign, 
                                       save=True, fname='inception_bgc_interpolation_corrected')


# Post Process Model Randomization Test

In [None]:
name = 'inception_3'

method_titles = ["GD", "ONLY.IG", "ONLY.M", "GDAsc.IG","GDAsc.M", "M.GDAsc.IG", "M.GDAsc.M","Wt.P.IG", "Wt.P.M", "IG", "CaptIG", "L.IG", "Inp X GD", "GBP", "LRP"]

Dataset = {0: 'mnist', 1: 'fmnist', 2: 'cifar10', 3: 'imgnet'}
data = 3 # 
layer_randomization_order = ['normal',
                                 'fc',
                                 'Mixed_7c',
                                 'Mixed_7b',
                                 'Mixed_7a',
                                 'Mixed_6e',
                                 'Mixed_6d',
                                 'Mixed_6c',
                                 'Mixed_6b',
                                 'Mixed_6a',
                                 'Mixed_5d',
                                 'Mixed_5c',
                                 'Mixed_5b',
                                 'Conv2d_4a_3x3',
                                 'Conv2d_3b_1x1',
                                 'Conv2d_2b_3x3',
                                 'Conv2d_2a_3x3',
                                 'Conv2d_1a_3x3']

all_layer_all_methods_saliency = {}

for i, layer_name in enumerate(layer_randomization_order):
    fname_common = 'method_research_inception_revised_randomization_test_'+layer_name
    list_of_saliency_dicts, titles = post_process_maps(data, fname_common, method_list=["GD", "ONLY.IG", "M.GDAsc"],\
                                                          random_seeds=list(range(0, 1)), viz=False, scale_flag=False)
    all_layer_all_methods_saliency[layer_name] = list_of_saliency_dicts[0]

### Plot multiple samples from randomized model for each saliency method

In [None]:
def plot_multiple_per_method_randomization_test(
        images, 
        name,
        data_name, 
        saliency_dict, 
        method_id_no, 
        sample_list,
        p, 
        fig_size = (10, 5),
        cm=None, 
        interp = 'none',
        vis_min = 0.0, 
        vis_sign='positive', 
        save = False, 
        fname = 'temp'
    ):
    
    """
    This method plots cacaded randomization test results
    """
    
    # 'keys' contains the layer's information
    layers = list(saliency_dict.keys())
    
    no_of_layers = len(layers)
    
    # how many samples to show
    no_of_samples = len(sample_list)
    
    nrows = no_of_samples
    ncols = no_of_layers + 1
    
    fig = plt.figure(figsize=fig_size)
    gs = gridspec.GridSpec(nrows, ncols,
                       wspace=0.0, hspace=0.0)
    
    for y_axis, sample_no in enumerate(sample_list):
    
        img = images[sample_no].numpy()
        
        if np.squeeze(img).ndim == 3:
            img = np.transpose(img, (1,2,0))
            
        img = normalize_image_to_plot(img, data_name)
        
        ax = plt.subplot(gs[y_axis, 0])
        
        ax.imshow(img)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
        for count, (layer_name, all_method_saliency) in enumerate(saliency_dict.items()):

            all_method_saliency = normalize_saliency_maps(all_method_saliency, sign=vis_sign)

            sample = all_method_saliency[method_id_no][sample_no]
            
            ax = plt.subplot(gs[y_axis, count+1])
            
            ax.imshow(sample, interpolation=interp,
                          vmin=vis_min,
                          vmax=1.0,
                          cmap=cm)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            
            if y_axis == 0:
                ax = plt.subplot(gs[0, count+1])
                if "normal" in layers[count]:
                    ax.set_title('Original\nExplanation', fontsize=9, rotation=90, pad=12)
                    
                elif "fc" in layers[count]:
                    ax.set_title("Logits", fontsize=9, rotation=90, pad=12)
                else:
                    ax.set_title(layers[count], fontsize=9, rotation=90, pad=12)
    
    ax = plt.subplot(gs[0, 0])
    ax.set_title("Input", fontsize=9, rotation=90, pad=12)
    plt.show()

    if save:
        
        path = os.path.join('./Plots/Real/', fname+'_'+cm+'_'+str(p) )
        fig.savefig(path+'.svg', transparent=True, bbox_inches='tight', pad_inches=0)
#         fig.savefig(path+'.pdf', format='pdf', dpi=300)
        print('Plots Saved...', path)
        plt.close(fig)
     

### Plot comparative saliency maps for model randomization (multiple methods - one example)

In [None]:

from methods.captum_post_process import _normalize_image_attr

def plot_model_randomization_maps_test(
        images, 
        name,
        data_name, 
        saliency_dict, 
        method_captions, 
        sample_no,
        p, 
        fig_size = (10,4),
        cm=None, 
        interp = 'none',
        vis_min = 0.0,
        vis_sign='positive', 
        save = False,
        fname='temp'
    ):
    
    """
    This method plots cacaded randomization test results
    """
    
    # 'keys' contains the layer's information
    layers = list(saliency_dict.keys())
    
    no_of_layers = len(layers)
    
    # how many methods were used for post hoc analysis
    no_of_saliency_methods = len(saliency_dict[layers[0]])
    
    
    img = images[sample_no].numpy()
        
    if np.squeeze(img).ndim == 3:
        img = np.transpose(img, (1,2,0))
        
    img = normalize_image_to_plot(img, data_name)
    
    nrows = no_of_saliency_methods - 1
    ncols = no_of_layers + 1
    
    fig = plt.figure(figsize=fig_size)
    gs = gridspec.GridSpec(nrows, ncols,
                       wspace=0.0, hspace=0.0)
    
    
#     fig, axes = plt.subplots(nrows = id-1, ncols = no_of_layers + 1, figsize=(12,2.25))
    
    for row in range(no_of_saliency_methods - 1):
        
        ax = plt.subplot(gs[row, 0])
        
        ax.imshow(img)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        

    for count, (layer_name, all_method_saliency) in enumerate(saliency_dict.items()):

        all_method_saliency = normalize_saliency_maps(all_method_saliency, sign=vis_sign)
        
        y_axis = 0
        for method_id in range(no_of_saliency_methods):
            
            if method_captions[method_id] == 'GGIG_IG':
                continue
            
            saliency = all_method_saliency[method_id]
            sample = all_method_saliency[method_id][sample_no]
            
            ax = plt.subplot(gs[y_axis, count+1])
            
            ax.imshow(sample, interpolation=interp,
                          vmin=vis_min,
                          vmax=1.0,
                          cmap=cm)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            
            if y_axis == 0:
                ax = plt.subplot(gs[0, count+1])
                if "normal" in layers[count]:
                    ax.set_title('Original\nExplanation', fontsize=9, rotation=90, pad=12)

                elif "fc" in layers[count]:
                    ax.set_title("Logits", fontsize=9, rotation=90, pad=12)
                else:
                    ax.set_title(layers[count], fontsize=9, rotation=90, pad=12)
            y_axis += 1

        y_axis = 0
        for method_id in range(no_of_saliency_methods):
            if method_captions[method_id] == 'GGIG_IG':
                continue
            ax = plt.subplot(gs[y_axis, 0])
            ax.set_ylabel(method_captions[method_id], fontsize=9, rotation='horizontal', ha='right')
            y_axis += 1
            
    ax = plt.subplot(gs[0, 0])
    ax.set_title("Input", fontsize=9, rotation=90, pad=12)
    plt.show()

    if save:
        
        path = os.path.join('./Plots/Real/', fname+'_'+str(sample_no)+'_'+cm+'_'+str(p) )
        fig.savefig(path+'.svg', transparent=True, bbox_inches='tight', pad_inches=0)
#         fig.savefig(path+'.pdf', format='pdf', dpi=300)
        print('Plots Saved...', path)
        plt.close(fig)
     

# Do visualization (Model Randomization Test)

In [None]:

model = models.inception_v3(pretrained=True)
model.to(device)
model.eval()

title_set = titles

print(title_set)
layer='Mixed_5d'
name = 'revised_rand_test'

desired_title_set = ["Gradients", "Integrated\nGradients", "GGIG_IG", "Gemometrically\nGuided IG"]

# "BlWhRd"
colors = {'positive': 'Reds', 'absolute_value': 'inferno', 'all': LinearSegmentedColormap.from_list("RdWhGn", ["red", "white","green"])}
sign = 'absolute_value'

print(list(all_layer_all_methods_saliency.keys()))
print(len(list(all_layer_all_methods_saliency.keys())))

cm_vs_min = {'bwr': -1.0, 'bwr_r': -1.0, 'coolwarm': -1.0, 'Reds': 0.0, 'gray': 0.0, 'inferno': 0.0, 'afmhot': 0.0}

for cm_value, min_value in cm_vs_min.items():
    
    for img_id in [0, 2,3, 8, 10, 15]:
        plot_model_randomization_maps_test( 
                                  FinalData, 
                                  name, 
                                  Dataset[data], 
                                  all_layer_all_methods_saliency, 
                                  desired_title_set, 
                                  img_id,
                                  0, 
                                  fig_size = (9,1.5),
                                  cm=cm_value,
                                  interp = 'none',
                                  vis_min=min_value,
                                  vis_sign=sign, 
                                  save = True,
                                  fname="rand_test_interpolation_corrected"
                                )

### Multiple Sample visualization

In [None]:
model = models.inception_v3(pretrained=True)
model.to(device)
model.eval()

title_set = titles

print(title_set)

name = 'revised_rand_test_multiple_samples_for_paper'


# "BlWhRd"
colors = {'positive': 'Reds', 'absolute_value': 'inferno', 'all': LinearSegmentedColormap.from_list("RdWhGn", ["red", "white","green"])}
sign = 'absolute_value'

print(list(all_layer_all_methods_saliency.keys()))
print(len(list(all_layer_all_methods_saliency.keys())))

sample_list = [0, 1, 2, 3, 4, 5, 8, 10, 15, 16]

cm_vs_min = {'bwr': -1.0, 'bwr_r': -1.0, 'coolwarm': -1.0, 'Reds': 0.0, 'gray': 0.0, 'inferno': 0.0, 'afmhot': 0.0}

for cm_value, min_value in cm_vs_min.items():
    plot_multiple_per_method_randomization_test(
            FinalData, 
            name,
            Dataset[data], 
            all_layer_all_methods_saliency, 
            3, 
            sample_list,
            0, 
            fig_size = (11, 6), 
            cm=cm_value, 
            interp = 'none',
            vis_min = min_value, 
            vis_sign=sign, 
            save = True,
            fname="rand_test_multiple_samples_interpolation_corrected"
        )