# Load all the required packages

In [None]:
import pixellib
import cv2
from cv2 import GaussianBlur

import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
import timeit
from scipy import interpolate
from scipy.stats import entropy

from sklearn import metrics
from functools import partial

from PIL import Image
from methods.saliency_recent_real_metrics import add_random_pixels, interpolate_missing_pixels, \
                generate_saliency_focused_images_prev, generate_revised_saliency_focused_images, interpolate_img, calculate_webp_size

from methods.method_research_utilities import load_cifar10_saliency_data, post_process_maps
from methods.method_research_utilities import load_imagenet_saliency_data, load_imagenet_saliency_metric_eval_data, load_bgc_imagenet_saliency_data, load_cifar10_saliency_data, load_imgnet_val_data

from methods.captum_post_process import _normalize_image_attr

import os
import sys
import copy
from datetime import datetime
import time

from matplotlib.colors import LinearSegmentedColormap
from methods.saliency_utilities import plot_maps_method_vertical, plot_maps_method_horizontal

import io, os
import skimage.io
import skimage.filters
from skimage import color

from math import log, e
import seaborn as sns
import pandas as pd

%config InlineBackend.figure_formats = ['svg']
%matplotlib inline

from matplotlib import gridspec
import seaborn as sns

# Some plotting defaults
sns.set_style('whitegrid', {'axes.grid': False})
SSIZE=10
MSIZE=12
BSIZE=14
plt.rc('font', size=SSIZE)
plt.rc('axes', titlesize=MSIZE)
plt.rc('axes', labelsize=MSIZE)
plt.rc('xtick', labelsize=MSIZE)
plt.rc('ytick', labelsize=MSIZE)
plt.rc('legend', fontsize=MSIZE)
plt.rc('figure', titlesize=MSIZE)
plt.rcParams['font.family'] = "sans-serif"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Prepare the images for entropy calculation. 

It requires unnormalized and original images. This module properly resizes images to $224 \times 224$ size and pixel values are kept in range $[0-255]$

In [None]:
def get_unnormalized_images(images, target_labels):
    
    transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor()])

    unnormalized_images = []
    for image in images:
        if isinstance(image, np.ndarray) and image.shape[0]==image.shape[1]:
            img_tensor = torch.from_numpy(image)
            img_tensor = img_tensor.permute(2, 0, 1)
        else:
            img_tensor = transform(image)
            img_tensor = img_tensor*255
            
        img = img_tensor.to(int)
        unnormalized_images.append(img)

    unnormalized_images = torch.stack(unnormalized_images, dim=0)

    unnormalized_img_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(unnormalized_images, target_labels), batch_size = unnormalized_images.shape[0], shuffle=False)
    return unnormalized_images, unnormalized_img_loader
    

# Normalize Images
- argument images are already resized and within 0-255
- output images are within 0-1 and z-scored

In [None]:
def normalize_images(images, target_labels, samples_per_batch=None):
    
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    normalized_images = []
    labels = []
    
    for image, label in zip(images, target_labels):
        img_tensor = transform(image)
        normalized_images.append(img_tensor)
        labels.append(label)
        
    normalized_images = torch.stack(normalized_images, dim=0)
    labels = torch.stack(labels, dim=0)
    labels = labels.long()
    
    if samples_per_batch is None:
        normalized_img_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(normalized_images, labels), batch_size = normalized_images.shape[0], shuffle=False)
    
    else:
        normalized_img_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(normalized_images, labels), batch_size = samples_per_batch, shuffle=False)
        
    return normalized_images, normalized_img_loader

## Plot Images/Saliency Maps (for one method at a time)

In [None]:
def plot_images_or_maps(images, labels=None, categories=None, nrows=3, ncols=3, samples_to_show=list(range(100)), plot_type='images', save=False):
    
    
    fig,axes=plt.subplots(nrows=nrows,ncols=ncols,figsize=(12,9),sharex=True,sharey=True)
    
    for i, ax in enumerate(axes.flat):
        
        img = images[samples_to_show[i]]
        if not isinstance(img, np.ndarray):
            img = img.numpy()
        
        if img.shape[0] <=3:
            img = np.transpose(img, (1,2,0))
            
        if plot_type == 'images':
            ax.imshow(img, interpolation=None, aspect='equal')
        else:
            ax.imshow(img, cmap='Reds', vmin=0, vmax=1)
        
        ax.axes.xaxis.set_ticks([])
        ax.axes.yaxis.set_ticks([])

        if labels is not None:
            ax.set_title(categories[labels[i].item()], fontsize=6, y=0.95)
        
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.1, hspace=0.2)
    plt.show()
    
    path = os.path.join('./Plots/Real/', "method_research_"+ plot_type)

    if save:
        fig.savefig(path+'.pdf', format='pdf', dpi=300)
        #     fig.savefig(path+'.svg', transparent=True, bbox_inches='tight', pad_inches=0, dpi=300)
        print('Plots Saved...', path)
        
    plt.close(fig)

## Load CAT vs DOG vs BIRD samples        

In [None]:
def load_cat_dog_etc_data():
    target = torch.as_tensor(13)
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    data = []
    labels = []

    filename1 = './ImageNet/collection/my_junco.jpg'
    filename2 = './ImageNet/collection/cat_dog_1.jpeg'
    filename3 = './ImageNet/collection/cat_dog_2.jpeg'
    
    files = [filename1, filename2, filename3]
    
    for filename in files:

        input_image = Image.open(filename)
       
        input_tensor = preprocess(input_image)
        input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
        print(input_batch.shape)

        data.append(input_tensor)
        labels.append(target)

    data = torch.stack(data, dim=0)
    labels = torch.stack(labels, dim=0)
    labels = labels.long()
    print(data.shape, labels.shape)

    all_img_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(data, labels), batch_size = data.shape[0], shuffle=False)
    
    dataLoaderSal = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(data, labels), batch_size = 1, shuffle=False)

    # Read the categories
    with open("imagenet_classes.txt", "r") as f:
        classes = [s.strip() for s in f.readlines()]
    
    return data, labels, all_img_loader, dataLoaderSal, classes


In [None]:
normalized_images, targets, all_img_loader, dataLoaderSal, categories = load_cat_dog_etc_data() 
# plot_images_or_maps(x_batch, labels=targets, categories=categories, nrows=1, ncols=2, samples_to_show=[0, 2])

## Visualize the data

In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(all_img_loader))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

# imshow(out, title=None)
imshow(normalized_images[2])    

## Checking model predictions

In [None]:
since = time.time()


name = 'inception_3'
model = models.inception_v3(pretrained=True)
model.to(device)
model.eval()

print(normalized_images.shape)

FinalData = normalized_images.to(device)
Labels = targets.to(device)

outputs = model(FinalData)
_, preds = torch.max(outputs, 1)

print(preds)

print(preds.shape)

for pred in preds:
    pred = pred.item()

    print(categories[pred])

for idx, item in enumerate(categories):
    if 'cat' in item.split(' '):
        print(item, idx)

## Test the unnormalized but resized image generation module

In [None]:
matplotlib.rcParams.update(matplotlib.rcParamsDefault)

# Load Data
images, normalized_2, target_labels, dataLoaderSal, categories = load_imgnet_val_data() 

# These are the images we used for perturbation test
img_ids = [111, 114, 115, 122, 193]
raw_imgs = []
imgs = []
targets = []

for id in img_ids:
    raw_imgs.append(images[id])
    imgs.append(normalized_images[id])
    targets.append(target_labels[id])
    
imgs = torch.stack(imgs, axis=0)
targets = torch.stack(targets, axis=0)

print('Dataset Shape {}, {}'.format(imgs.shape, targets.shape))

unnormalized_images, unnormalized_img_loader = get_unnormalized_images(raw_imgs, targets)
print(normalized_images.shape)
print(unnormalized_images.shape)
x_batch, y_batch = next(iter(unnormalized_img_loader))
image_sample = x_batch[2].numpy()
print(image_sample.shape)
print(np.min(image_sample), np.max(image_sample))

# plot_images_or_maps(x_batch, labels=targets, categories=categories, nrows=2, ncols=2, samples_to_show=[0, 2, 3, 4])


#### Set the computational device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Generate softmax scores

This function is to generate softmax scores on the **original** or the **saliency-focused** images

In [None]:
def find_softmax_scores(model, normalized_images, pred_indices=None):
        
    with torch.no_grad():
        outputs = model(normalized_images)
        scores = F.softmax(outputs, dim=1)

        if pred_indices is None:
            best_prob_scores, pred_indices = torch.max(scores, dim=1)
        else:
            best_prob_scores = scores.gather(1, pred_indices.view(-1,1))
            best_prob_scores = torch.squeeze(best_prob_scores)
                
        return scores, best_prob_scores, pred_indices
    

class MyDataset(Dataset):
    def __init__(self, X):
        self.data = X
        
    def __getitem__(self, index):
        x = self.data[index]
        return x
    
    def __len__(self):
        return len(self.data)
    
def find_revised_softmax_scores(model, normalized_images, b_size=101, device=device):
    
    print("Original Data Shape:", normalized_images.shape)
    test_loader = torch.utils.data.DataLoader(MyDataset(normalized_images), batch_size=b_size, shuffle=False)
    
    all_scores = []
    all_best_prob_scores = []
    all_predictions = []
    with torch.no_grad():
        
        for (i, images) in enumerate(test_loader):
            
            if i % 10 == 0:
                print('Generating scores for {}-th batch'.format(i))
                print('Image Shape: {}'.format(images.shape))
            images = images.to(device)
            outputs = model(images)
            scores = F.softmax(outputs, dim=1).detach().cpu()
            all_scores.append(scores)
            
            best_prob_scores, pred_indices = torch.max(scores, dim=1)
            all_best_prob_scores.append(best_prob_scores)
            all_predictions.append(pred_indices)
            
    all_scores = torch.stack(all_scores, axis=0)
    all_scores = all_scores.squeeze(0)
    all_best_prob_scores = torch.stack(all_best_prob_scores, axis=0)
    all_best_prob_scores = all_best_prob_scores.squeeze(0)
    all_predictions = torch.stack(all_predictions, axis=0)
    all_predictions = all_predictions.squeeze(0)
    
    print(all_scores.shape)
    print(all_best_prob_scores.shape)
    print(all_predictions.shape)
        
    return all_scores.numpy(), all_best_prob_scores.numpy(), all_predictions
    

## Test the softmax score generation function as defined above

- Define the model
- Load the weights
- Generate the images
- Call the softmax score generation function

In [None]:
pretrained_models = {'Resnet18': models.resnet18, 'Resnet34': models.resnet34, 'resnet50': models.resnet50,
                     'Resnet101': models.resnet101, 'inception_3': models.inception_v3}

def load_pretrained_model(model_name):
    
    # Load the model
    print("Loading pre-trained", model_name, "model")
    model = pretrained_models[model_name](pretrained=True)
    return model

## post-process saliency maps

In [None]:
Dataset = {0: 'mnist', 1: 'fmnist', 2: 'cifar10', 3: 'imgnet'}
method_titles = ["GD", "ONLY.IG", "ONLY.M", "GDAsc.IG","GDAsc.M", "M.GDAsc.IG", "M.GDAsc.M","Wt.P.IG", "Wt.P.M", "IG", "CaptIG", "L.IG"]

def post_process_saliency_maps(dataset_id, model_name=None, methods=None, saliency_path_prefix=None):
    
    name = 'inception_3' if model_name is None else model_name
    methods_used = ["GD", "ONLY.IG", "ONLY.M", "GDAsc", "M.GDAsc", "Wt.P"] if methods is None else methods
    fname_common = "method_research_"+Dataset[dataset_id]+"_valSet_metricEval_"+name if saliency_path_prefix is None else saliency_path_prefix

    list_of_saliency_dicts, titles = post_process_maps(dataset_id, fname_common, method_list=methods_used,\
                                                          random_seeds=list(range(0, 1)), viz=False, scale_flag=False)
    return list_of_saliency_dicts[0], titles

### Post-process saliency maps for all perturbations and methods 

In [None]:
method_list = ["GD", "ONLY.IG", "ONLY.M", "M.GDAsc"]
name = 'Resnet101'

all_randomizations_saliency_dict = {}
for threshold in np.linspace(1, 100, 100):
    prefix = 'method_research_'+name+'_perturbation_test_'+ str(int(threshold))

    method_saliency_dict, title_set = post_process_saliency_maps(3, model_name=name, methods=method_list, saliency_path_prefix=prefix) # arg: dataset_id
    all_randomizations_saliency_dict[int(threshold)] = method_saliency_dict

## Save and load processed saliency dictionary

In [None]:

def change_keys_and_save_saliency_dict(method_saliency_dict, title_set, fname="process_saliency_for_metricEval"):
    
    saliency_dict_for_save = {}
    for key1,key2 in zip(method_saliency_dict.keys(), title_set):
        saliency_dict_for_save[key2] = method_saliency_dict[key1]

    np.savez(fname+".npz", **saliency_dict_for_save)

def load_saliency_dict_and_rename_keys(path=None):

    saliency_dict = np.load("process_saliency_for_metricEval.npz") if path is None else np.load(path)

    new_saliency_dict = {}

    for method_name, new_int_id in zip(saliency_dict.files, list(range(len(saliency_dict.files)))):
        new_saliency_dict[new_int_id] = saliency_dict[method_name]
    
    return new_saliency_dict, saliency_dict.files

# method_saliency_dict, title_set = load_saliency_dict_and_rename_keys()
# print(method_saliency_dict.keys())
# print(title_set)

## Do visualization of the comparable maps across several methods

In [None]:
name = 'Resnet101_perturbation_test'

def visualize_maps(name, display_range=[0, 10], f_size=(16, 16)):

    # "BlWhRd"
    colors = {'positive': 'Reds', 'absolute_value': 'Reds', 'all': LinearSegmentedColormap.from_list("RdWhGn", ["red", "white","green"])}
    sign = 'absolute_value'

    plot_maps_method_horizontal(model, normalized_images, target_labels, categories, name, \
                                'imgnet', method_saliency_dict, title_set, 0, \
                                range_to_display=np.asarray(range(display_range[0], display_range[1], 1)), fig_size=f_size, cm=colors[sign], vis_sign=sign)

    plot_maps_method_vertical(model, normalized_images, target_labels, categories, name, \
                                'imgnet', method_saliency_dict, title_set, 0, \
                                range_to_display=np.asarray(range(display_range[0], display_range[1], 1)), fig_size=f_size, cm=colors[sign], vis_sign=sign)
    


# visualize_maps(name, display_range=[100, 110], f_size=(16, 16))
    

## Normalize saliency maps 

Provide the processed saliency maps (not normalized and not channel collapsed) of all methods in a dictionary structure. This method returns a new dictionary of normalized maps for all methods. 

- The output saliency maps do not have channel dimension (i.e. now it is 2D for images)
- attribution values are within 0-1 for absolute masks.

In [None]:
def normalize_saliency_maps(saliency_method_dict, sign='absolute_value'):
    saliency_maps_all_method = {} 
       
    for method_name, all_saliency_images in saliency_method_dict.items():
        normalized_saliency = []
        for sal_image in all_saliency_images:
            attribution = np.transpose(sal_image, (1,2,0))
            norm_attr = _normalize_image_attr(attribution, sign)
           
            normalized_saliency.append(norm_attr)
        
        normalized_saliency = np.stack(normalized_saliency, axis=0)
       
        saliency_maps_all_method[method_name] = normalized_saliency
    return saliency_maps_all_method

In [None]:
print(title_set)
normalized_saliency_maps_all_method = {}
for threshold in np.linspace(1, 100, 100):
    
    method_saliency_dict = all_randomizations_saliency_dict[int(threshold)] 
    normalized_saliency_maps_all_method[int(threshold)] = normalize_saliency_maps(method_saliency_dict, sign='absolute_value')

    change_keys_and_save_saliency_dict(normalized_saliency_maps_all_method[int(threshold)], title_set, fname="./MetricEvalEntropyMaps/normalized_saliency_perturbation_test_"+str(int(threshold)))

### Load from disk if the normalized saliency maps are already saved there

In [None]:
normalized_maps_all_method = {}
for threshold in np.linspace(1, 100, 100):
    normalized_maps_all_method[int(threshold)], title_set = load_saliency_dict_and_rename_keys(path="./MetricEvalEntropyMaps/normalized_saliency_perturbation_test_"+str(int(threshold))+".npz")
    
    if int(threshold) == 1:
        print(normalized_maps_all_method[int(threshold)].keys())
        print(title_set)

## Plot the comparative mask and blurred images for a group of methods

In [None]:
def plot_comparable_saliency_focused_mages(images, 
                                           original_saliency,
                                           original_img_id,
                                           saliency_all_methods,
                                           method_ids,
                                           method_names, 
                                           sample_id, 
                                           after=False, 
                                           save=False, 
                                           fname_hint=None, 
                                           fig_size=(12,9),
                                           interp = 'none',
                                           cm = 'afmhot',
                                           vis_min = 0.0
                                          ):
    
    ncols, nrows = 11, len(method_ids)
    
    fig = plt.figure(figsize=fig_size)
    gs = gridspec.GridSpec(nrows, ncols,
                       wspace=0.0, hspace=0.0)
    
    for i in range(nrows):
        
        img = images[sample_id].numpy()
        
        if np.squeeze(img).ndim == 3:
            img = np.transpose(img, (1,2,0))
        
        ax = plt.subplot(gs[i, 0])
        
        ax.imshow(img, interpolation='none')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        if ax.get_subplotspec().is_first_row():
                ax.set_title("Input", fontsize=12)
        
    for i in range(nrows):
        
        main_image = images[sample_id].numpy()
        main_image = np.transpose(main_image, (1,2,0))
        
        img_saliency_mask = original_saliency[i][original_img_id]
        
        new_mask = get_thresholded_saliency_mask_numpy(img_saliency_mask, 0.1)

        # create 0/1 mask              
        mask_3d = np.stack((new_mask,new_mask,new_mask),axis=2)

        saliency_img = np.where(mask_3d==1, main_image, int(np.mean(main_image)))
        saliency_img_only = mask_3d*main_image  # No interpolation on the updates

        ax = plt.subplot(gs[i, 1])
        
        ax.imshow(saliency_img, interpolation=interp, vmin=vis_min, vmax=1.0, cmap=cm)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
        if ax.get_subplotspec().is_first_row():
                ax.set_title("Original\nExplanation", fontsize=12)
        
    
    # The first item in the method list is "Input", so start from the second 
    
        
    for i in range(ncols-2):
        
        perturbation_threshold_saliency_for_all_methods = saliency_all_methods[int(10*i+1)]
        
        for method_id, method_name in zip(range(len(method_ids[:])), method_names[1:]):

            saliency = perturbation_threshold_saliency_for_all_methods[method_id]

            main_image = images[sample_id].numpy()
            main_image = np.transpose(main_image, (1,2,0))

            img_saliency_mask = saliency[sample_id]

            new_mask = get_thresholded_saliency_mask_numpy(img_saliency_mask, 0.1)

            # create 0/1 mask              
            mask_3d = np.stack((new_mask,new_mask,new_mask),axis=2)

            saliency_img = np.where(mask_3d==1, main_image, int(np.mean(main_image)))
            saliency_img_only = mask_3d*main_image  # No interpolation on the updates

            ax = plt.subplot(gs[method_id, i+2])

            ax.imshow(saliency_img, interpolation=interp, vmin=vis_min, vmax=1.0, cmap=cm)

            ax.set_xticklabels([])
            ax.set_yticklabels([])
            if ax.get_subplotspec().is_first_row():
                ax.set_title("{}%".format(10*i + 1), fontsize=12)
        
    for method_id in range(nrows):
        ax = plt.subplot(gs[method_id, 0])
        print(method_names[method_id])
        ax.set_ylabel(method_names[method_id+1], fontsize=12)
    
    plt.show()
    
    if save:
        path = os.path.join('./Plots/Real/', fname_hint+"_imageID_"+str(original_img_id))
#         fig.savefig(path+'.pdf', format='pdf', dpi=300)
        fig.savefig(path+'.svg', transparent=True, bbox_inches='tight', pad_inches=0, dpi=300)
        print('Plots Saved...', path)
    plt.close(fig)

### Get only required saliency dictionary

In [None]:
def get_only_required_method_saliency(all_saliency_dict, method_list):
    required_saliency_dict = {}
    count = 0
    for k, v in all_saliency_dict.items():
        if k in method_list:
            required_saliency_dict[count] = all_saliency_dict[k]
            count += 1
    
    return required_saliency_dict

### Load val images original saliency

In [None]:
val_normalized_maps_all_method, title_loaded = load_saliency_dict_and_rename_keys(path="./MetricEvalEntropyMaps/normalized_saliency_for_valSet_metricEval_revised_with_edge_detector.npz")
print(val_normalized_maps_all_method.keys())
print(title_loaded)

In [None]:
val_required_saliency_dict = get_only_required_method_saliency(val_normalized_maps_all_method, [0, 1, 6])
print(val_required_saliency_dict.keys())

#### Visualizing model perturbation images across the methods

In [None]:
def get_thresholded_saliency_mask_numpy(saliency_mask, threshold_percent):
    
    thresholded_saliency_mask = np.zeros_like(saliency_mask.flatten())
    no_of_required_salient_values = int(saliency_mask.flatten().shape[0]*threshold_percent)
    
    if no_of_required_salient_values > 0:
        topk_salient_indices = np.argpartition(saliency_mask.flatten(), \
                                           -no_of_required_salient_values)[-no_of_required_salient_values:]
        thresholded_saliency_mask[topk_salient_indices] = 1
    
    thresholded_saliency_mask = thresholded_saliency_mask.reshape(saliency_mask.shape)
    
    return thresholded_saliency_mask

def get_masked_image(saliency_mask, main_img, threshold_p=0.1):
    mask = get_thresholded_saliency_mask_numpy(saliency_mask, threshold_p)
    mask_3d = np.stack((mask,mask,mask),axis=2)
    saliency_img_only = mask_3d*main_img
    
    return saliency_img_only



method_list = {0 : "Gradients", 1 : "Integrated\nGradients", 2 : "IG_Max", \
               3 : "GGIG_IG", \
               4 : "GGIG"}

desired_methods = [0, 1, 4]

samples_to_show= [111, 114, 115, 122, 193]

saliency_all_methods = {}
method_captions = ["Input"]
for sal_method_id in desired_methods:
    
    sal_method_new_name = method_list[sal_method_id]
    
    for threshold in np.linspace(1, 100, 100):
        
        all_image_saliency_masks = normalized_maps_all_method[int(threshold)]
        saliency_all_methods[int(threshold)]= all_image_saliency_masks
    method_captions.append(sal_method_new_name)

# matplotlib.rcParams.update(matplotlib.rcParamsDefault)

for sample_id, original_img_id in enumerate(samples_to_show):
    plot_comparable_saliency_focused_mages(x_batch, 
                                               val_required_saliency_dict,
                                               original_img_id,
                                               saliency_all_methods, 
                                               desired_methods, 
                                               method_captions,
                                               sample_id, 
                                               after=False, 
                                               save=True, 
                                               fname_hint="perturbation_test", 
                                               fig_size=(12,5),
                                               interp = 'none',
                                               cm = 'afmhot',
                                               vis_min = 0.0
                                        )


### Plot the normalized softmax scores of the perturbations

In [None]:
pallette_1 = ['#b2182b','#ef8a62','#fddbc7','#f7f7f7','#d1e5f0','#67a9cf','#2166ac']
pallette_2 = ['#d73027','#f46d43','#fdae61','#fee090','#e0f3f8','#abd9e9','#74add1','#4575b4']
pallette_3 = ['#8c510a','#bf812d','#2166ac','#80cdc1','#35978f','#01665e', '#4575b4']

sns.set(style="whitegrid", font_scale=1.0)
img_id = 4
basename = os.path.join('models and saliencies', 'saliency')

used_legends = ['GRAD', 'IG', 'GGIG']

for img_id in [0,1, 2, 3, 4]:
    
    fig = plt.figure(figsize=(5,5))

    for method in ['Grad', 'LocalIGAll', "MultiGradAsc"]:

        score_path = os.path.join(basename, 'method_research_perturbation_test_all_scores_'+method+'_'+str(0)+'.npy')
        comparative_prob_scores = np.load(score_path)
        print("All comparative softmax scores saved here: {}".format(score_path))
        print(comparative_prob_scores.shape)

        img_score = comparative_prob_scores[:, img_id] 
        img_score /= img_score[0]
        img_score = np.clip(img_score, 0, 1)
        plt.plot(np.linspace(0, 100, 101), img_score)

    plt.legend(used_legends)
    plt.xlabel("Perturbation Scale (%)", fontsize=12)

    y_label = "Normalized Softmax Score"
    plt.ylabel(y_label, fontsize=12)
    plt.show()

    path = os.path.join('./Plots/Real/', "perturbation_test_"+categories[targets[img_id]]+"_softmax_scores")
    print(path)
    fig.savefig(path+'.svg', transparent=True, bbox_inches='tight', pad_inches=0, dpi=300)
    # fig.savefig(path+'.pdf', format='pdf', dpi=300)
    print('Plots Saved...', path)