# Motivating example

Generates the figure 1 of the paper. Takes as input the set of illustration images from Fel et al [1](),[2]() and their corrupted counterparts, corrupted from [3](). 

Returns the predictions, and the consistency of the predictions. Then plots the predictinos and various explanability methods to show that traditional methods fail to provide insights on why the model failed to accurately predict the label under the given shift.

While the failure to provide informative explanation has been studied by earlier works [4](), our focus will be on explaining why a model fails to predict the label under given shifts, leveraging the Wavelet attribution method. 

In [None]:
# Libraries and imports
import sys
sys.path.append('../')

import os
import matplotlib.pyplot as plt
import pandas as pd
from utils import helpers, corruptions
import torchvision
from PIL import Image
from torchvision.models import resnet50
import torch
from spectral_sobol.torch_explainer import WaveletSobol, SobolAttributionMethod
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import json

In [None]:
# load the model and the data
device = 'cuda:1'
source = '../assets'
batch_size = 128
model = resnet50(pretrained = True).eval().to(device)

classes = { # dictionnary with the example images and labels
 'fox.png': 278,
 'snow_fox.png': 279,
 'polar_bear.png': 296,
 'leopard.png': 288,
 'fox1.jpg': 277,
 'fox2.jpg': 277,
 'sea_turtle.jpg': 33,
 'lynx.jpg': 287,
 'cat.jpg': 281,
 'otter.jpg': 360
}

# transforms

# misc transforms
resize_and_crop = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224)
])


# transforms
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

preprocessing = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    normalize,
])

# generate corrupted images
# generate and resize the set of corrupted images
# in this dictionnary, the 0th image is the source image (uncorrupted)
corrupted_images = {
    image_name : [resize_and_crop(im) for im in corruptions.generate_corruptions(os.path.join(source, image_name))] for image_name in classes.keys()
}

# or load the images

# corrupted_images = {}
# 
# for image_name in classes.keys():
# 
#     # get to the directory
#     destination = '../assets/corrupted/{}'.format(image_name[:-4])
#     items = os.listdir(destination)
# 
#     corrupted_images[image_name] = [Image.open(os.path.join(destination, it)).convert('RGB') for it in items]


## Inference 

Inference on some example images that are corrupted

In [None]:
results = {}

for image_name in corrupted_images.keys():

    batch = corrupted_images[image_name]

    x = torch.stack([
        preprocessing(im) for im in batch
    ])

    preds = helpers.evaluate_model_on_samples(x, model, batch_size)

    results[image_name] = {
        'preds' : preds,
        'label' : classes[image_name]
    }

# display the results
# convert the dictionnary as a dataframe
corruption_names = [
    'source',
    'motion-blur',            
    'defocus-blur',           
    'brightness',             
    'spatter',                
    'jpeg',                   
    'saturate',              
    'pixelate',               
    'impulse-noise',          
    'gaussian-noise',         
    'contrast',               
    'glass-blur',             
    'elastic-transformation', 
    'shot-noise',             
    'gaussian-blur',          
]


results_df = {
    k : [] for k in results.keys()
}

for k in results_df.keys():
    results_df[k].append([results[k]['label']])
    results_df[k].append(list(results[k]['preds']))

    results_df[k] = list(sum(results_df[k], []))

columns = corruption_names
columns.insert(0, 'label')
 
df = pd.DataFrame.from_dict(results_df, orient = 'index')
df.columns = columns
df

# load the dataframe
df = pd.read_csv('../assets/results_inference.csv', index_col=0)

# and re-generate the dictionnary

results = {}

for item in df.index:
    results[item] = {
        'preds' : df.loc[item][1:].values,
        'label' : classes[item]
    }

df

In [None]:
# load the dataframe
df = pd.read_csv('../assets/results_inference.csv', index_col=0)

# and re-generate the dictionnary

results = {}

for item in df.index:
    results[item] = {
        'preds' : df.loc[item][1:].values,
        'label' : classes[item]
    }

df

## Generate explanations across corruptions

We now consider a set of explanations and generate explanations for the different 

In [None]:
torch.cuda.empty_cache()

# explanations for an image
img_name = 'polar_bear.png'

# set up the attribution methods
# attribution methods
sobol = SobolAttributionMethod(model, batch_size = 128)
# wavelet 
wavelet = WaveletSobol(model, grid_size = 28, nb_design = 16, batch_size = 128, opt = {'approximation' : False})
# wb expplainers
target_layers = [model.layer4[-1]]
# Construct the CAM object once, and then re-use it on many images:
# cam = GradCAM(model=model.to(device), target_layers=target_layers, use_cuda=True)
campp = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=True)

# target images
#x = torch.stack([
#    preprocessing(im) for im in corrupted_images[img_name]
#])

source_index = 0
corruption_index = 1

target_images = corrupted_images[img_name]
target_images = [target_images[i] for i in [source_index, corruption_index]]

x = torch.stack([
    preprocessing(im) for im in target_images
])

# label
y = results[img_name]['preds'].astype(int)
# label for the wb
targets = [ClassifierOutputTarget(c) for c in y]

# compute the explanations
sobols = sobol(x,y)
print('Sobol complete')

# wavelet cam 
# set the option to remove the approximation coefficients from the computation
wavelets = wavelet(x,y)
print('Wavelet complete')
torch.cuda.empty_cache()

# wb
# cams = cam(input_tensor=x.to(device), targets=targets, aug_smooth = True)
cams_pp = campp(input_tensor=x.to(device), targets=targets, aug_smooth = True)
print('WB complete')

In [None]:
imagenet_dir = "../../data/ImageNet/"
classes_names = json.load(open(os.path.join(imagenet_dir,'classes-imagenet.json')))

# plot
fig, ax = plt.subplots(2,4, figsize = (16, 8))
plt.rcParams.update({'font.size': 17})

size = 224
levels = 3

source_index = 0
target_index = 1

perturbation = 'motion-blur'

pred_source = df.loc[img_name]["source"]
pred_target = df.loc[img_name][perturbation]

images = corrupted_images[img_name]

# source image and altered image

prediction = classes_names[str(pred_source.astype(int))].split(',')[1]
prediction_corrupted = classes_names[str(pred_target.astype(int))].split(',')[0]


ax[0,0].set_title('Source image \n Prediction : {}'.format(prediction))
ax[0,0].imshow(images[source_index])
ax[0,0].axis('off')
ax[1,0].set_title('Corrupted image \n Prediction : {}'.format(prediction_corrupted))
ax[1,0].imshow(images[corruption_index])
ax[1,0].axis('off')

# grad cam
ax[0,1].set_title('Grad-CAM ++ \n Source')
ax[0,1].imshow(images[source_index])
ax[0,1].imshow(cams_pp[0], cmap = "jet", alpha = 0.5)
ax[0,1].axis('off')
ax[1,1].set_title('Grad-CAM ++ \n Target')
ax[1,1].imshow(images[corruption_index])
ax[1,1].imshow(cams_pp[1], cmap = "jet", alpha = 0.5)
ax[1,1].axis('off')

# sobol
ax[0,2].set_title('Sobol attribution method \n Source')
ax[0,2].imshow(images[source_index])
ax[0,2].imshow(sobols[0], cmap = "jet", alpha = 0.5)
ax[0,2].axis('off')
ax[1,2].set_title('Sobol attribution method \n Corrupted')
ax[1,2].imshow(images[corruption_index])
ax[1,2].imshow(sobols[1], cmap = "jet", alpha = 0.5)
ax[1,2].axis('off')

# wavelet
# sobol
ax[0,3].set_title('Wavelet-CAM (ours) \n Source')
ax[0,3].imshow(wavelets[0], cmap = "hot")
ax[0,3].axis('off')
helpers.add_lines(size, levels, ax[0,3])
ax[1,3].set_title('Wavelet-CAM (ours)\n Corrupted')
ax[1,3].imshow(wavelets[1], cmap = "hot")
helpers.add_lines(size, levels, ax[1,3])
ax[1,3].axis('off')

fig.tight_layout()
plt.savefig('../figs/motivating_example.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(1,1)

ax.imshow(wavelets[0], cmap = "hot")
ax.axis('off')
helpers.add_lines(224,3, ax)

# plt.savefig('../figs/workflow-diagram/wcam-hot.pdf')
plt.show()