# Spectral profile of model predictions

This notebooks displays images computed from the script `spectral_profile.py`. We display WCAMs obtained across various training methods and model backbones to show that the sensitivity to distribution shifts can be explained by a shift in the important regions for the prediction, as highlighted by the WCAM. Existing attribution methods cannot provide such qualitative evidence while robustness analyses have only shown that robust methods are quantitatively less subject to distribution shifts, but did not explicited whether it was because the model behave qualitatively differently. 

In this work, we provide evidence that the attitude of deep learining models with respect to the scale decomposition of an image is the same across (1) model baselines and (2) training approaches. Differences in robustness between those methods are only quantitative. 

In [1]:
import sys
sys.path.append('../')

# Libraries 
import os
import matplotlib.pyplot as plt
from utils import helpers

In [None]:
# data directory

root = '../../data/spectral-attribution-outputs/'

perturbations = os.listdir(root)

# get the name of the images to be plotted for each perturbation
images = {
    perturbation : os.listdir(os.path.join(root, perturbation)) for perturbation in perturbations
}



In [None]:
# plot of the baseline case 
save = None
cases = ['baseline']

image_name = images['corruptions'][0]

directory = os.path.join(os.path.join(root, perturbation), image_name)

plot_wcams(directory, cases, save = save)

In [None]:
# plot of the adversarial and SIN cases
save = None
cases = ['adv', 'sin']

image_name = images['corruptions'][0]

directory = os.path.join(os.path.join(root, perturbation), image_name)

plot_wcams(directory, cases, save = save)

In [2]:
def plot_wcams(directory, cases, opt = None, save = None):
    """
    plots the wcam of the image image_name for the cases 
    passed as input in the cases inputed in the list

    opt is a set of optional parameters if one wants to specify 
    precise indices 
    """

    # open the source and target images
    source = Image.open(os.path.join(directory, "source.png")).convert('RGB')

    if opt is not None and "index" in opt.keys():
        index = opt['index']

    else: # TODO. Consider a random index in the set (altered_1, ... altered_n)
        index = 1 

    target = Image.open(os.path.join(directory, "altered_{}.png".format(index)))

    # open the wcams for the cases passed as input

    if len(cases) == 1: # only one case to consider, so 2x2 plot

        destination = os.path.join(directory, cases[0])

        # open the source and target
        source_wcam = Image.open(os.path.join(destination, "wcam_source.png")).convert('L')
        target_wcam = Image.open(os.path.join(destination, "wcam_target_{}.png".format(index))).convert('L')

        generate_plot([source, target], [source_wcam], [target_wcam], cases, save = save)
    
    else:
        
        source_wcams = []
        target_wcams = []

        for case in cases:

            destination = os.path.join(directory, case)

            source_wcam = Image.open(os.path.join(destination, "wcam_source.png")).convert('L')
            target_wcam = Image.open(os.path.join(destination, "wcam_target_{}.png".format(index))).convert('L')

            source_wcams.append(source_wcam)
            target_wcams.append(target_wcam)

            generate_plot([source, target], source_wcams, target_wcams, cases, save = save)

    return None


def generate_plot(images, source_wcams, target_wcams, cases, save = None):
    """
    generates a plot with the images and the wcams
    """

    levels = 3
    size = 224

    n_cols = 1 + len(source_wcams)

    fig, ax = plt.subplots(2, n_cols, figsize = (4 * n_cols, 8))
    # add options here

    # first column : images
    source, target = images
    ax[0,0].imshow(source)
    ax[0,0].axis('off')
    ax[0,0].set_title("Source image")

    ax[1,0].imshow(target)
    ax[1,0].axis('off')
    ax[1,0].set_title("Altered image")

    # subsequent columns: wcams
    for i, (source_wcam, target_wcam) in enumerate(zip(source_wcams, target_wcams)):

        title = '{} base \n WCAM on the altered image'.format(cases[i])

        ax[0, i + 1].imshow(source_wcam, cmap = 'jet')
        ax[0, i + 1].axis('off')
        ax[0, i + 1].set_title(title)
        helpers.add_lines(size, levels, ax[0, i+1])


        ax[1, i + 1].imshow(target_wcam, cmap = 'jet')
        ax[1, i + 1].axis('off')
        ax[1, i + 1].set_title('WCAM on the altered image')
        helpers.add_lines(size, levels, ax[0, i+1])

    fig.tight_layout()
    
    if save is not None:

        plt.savefig(save)
    
    plt.show()
    return None
