# Effects of natural image corruptions on a model's prediction

This notebooks allows the replication of the results of the section 4.1. of the paper. We show that across backbones, under standard training, a corruption affects the salient scales (i.e. the important regions in the space-scale representation). In turn, the image is not longer matched with its class's typical scale profile, but rather with the scale profile of another class, resulting in a missclassification. 

We firs

In [1]:
# Library imports
import sys
sys.path.append('../')

import os
from utils import helpers, corruptions
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2

## Set-up: model backbone, and source data

The requirements is to have a set of WCAMs precomputed from ImageNet validation set. These WCAMs can be computed from the script `generate_wcams.py`. When necessary, model weights are accessible on the public repository attached with this article. Backbones include the following:
- CNN backbones: ResNet50, VGG16 (weights: taken from `torchvision`)
- Transformer backbones: ViT-B16 and ViT-B32 (weights: taken from `torchvision`)
- Self-Supervised backbones: Dino (ViT-B16, weights: taken from `torch hub`)

In [None]:
# download our folder from the Zenodo repository

In [4]:
# load the data

backbone = "resnet" # backbones : 'resnet', 'vgg', 'vitb16', 'vitb32', 'dino'
root_dir = '../../data/wcams-imagenet' # root: where wcams are stored
imagenet_dir = '../../data/ImageNet/' # imagenet : where iamgenet is located

wcams_dir = os.path.join(root_dir, backbone) # folder containing the wcams of the corresponding backbone
                                                         # for each backbone, 5000 wcams were computed.
labels_dir = os.path.join(imagenet_dir, "val.txt") # txt file with the labels of the imagenet validation set

# create a dataframe with the labels of each image included in the wcam directory
labels_raw = open(labels_dir).read().split('\n')
labels_dict = {r[:28] : int(r[29:]) for r in labels_raw if r[:10] == 'ILSVRC2012'}
labels_true = pd.DataFrame.from_dict(labels_dict, orient = "index").reset_index()
labels_true.columns = ['name', 'label']

targets = os.listdir(wcams_dir) # list of images
labels_wcam = labels_true[labels_true['name'].isin(targets)] # keep only the labels of the images

# auxiliary file: the name of the imagenet classes 
classes_names = json.load(open(os.path.join(imagenet_dir,'classes-imagenet.json')))

## Computation of the class-wise scale profiles

For each class, we compute the scale profiles. These profiles are obtained by averaging the WCAMS computed for each class. We then plot them.

In [None]:
# computation
# for the example, we plot the average WCAMs for 16 classes
num_classes = 16

classes = helpers.compute_average_classes(labels_wcam, wcams_dir)

# plot

fig, ax = plt.subplots(4,4, figsize = (16,16))
input_shape = (224,224)

np.random.seed(42)
examples = np.random.choice(list(classes.keys()), num_classes, replace = False)

row = -1

for i, e in enumerate(examples):
    
    if i % 4 == 0:
        row += 1
        
    label_class = classes_names[str(e)].split(',')
        
    ax[row, i%4].set_title('Class : {} ({})'.format(label_class[0], e))
    
    # upsample the WCAMs for better visualization
    # change the colorbar 
    upscaled = cv2.resize(classes[e], input_shape)
    im = ax[row, i%4].imshow(upscaled, cmap = 'jet')
    helpers.add_lines(input_shape[0], 3, ax[row, i%4])

    fig.colorbar(im, ax=ax[row, i%4], shrink=0.8)

plt.suptitle('Average wavelet attributions for various ImageNet classes')
fig.tight_layout()
# plt.savefig('figs/classes-wcams.pdf')
plt.show()

## Computation of the WCAMs of corrupted images

Now that we have the average WCAM for the dataset classes, we consider a set of held out samples (i.e., coming from the remaining validation images). We apply a corruption to these samples and compute the WCAM across the set of corruptions. We then plot the WCAM of the predicted class (retrieved from the previous section) and the WCAM of the (corrupted) image. We show that the prediction is consistent with the pattern on the WCAM of the target class. 

If the model incorrectly classifies an image as belonging to class 42 while its true class is 999, then the WCAM will look like the average WCAM of the class 42. In particular, the <i> salient </i> scales are similar. 

In [None]:
# generate a list of corrupted images

# regenerate the dataframe
complete_labels = pd.DataFrame.from_dict(items, orient = 'index').reset_index()
complete_labels.columns = ['name', 'label']

# restrict to the the remaining samples 
# and to the targeted labels
complete_labels = complete_labels[~complete_labels['name'].isin(targets)]
complete_labels = complete_labels[complete_labels['label'].isin(examples)]

sample_size = 50
np.random.seed(42)
example_images = np.random.choice(complete_labels['name'].values, sample_size)

# generate the corruptions for these images
corrupted_images = [corruptions.generate_corruptions(os.path.join(source_dir,im)) for im in example_images]

In [None]:
# inference 
device = "cuda:2"
model = resnet50(pretrained = True).to(device)


# transforms
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

preprocessing = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize,
])

results = {}

for image, corruptions in zip(example_images, corrupted_images):

    label = complete_labels[complete_labels['name'] == image]['label'].values.item()
    x = torch.stack([
        preprocessing(im.convert('RGB')) for im in corruptions
    ]).to(device)
    
    preds = model(x).cpu().detach().numpy()
    preds = np.argmax(preds, axis = 1)
    
    
    if label == preds[0]:
        results[image] = (label, preds)


In [None]:
# compute the WCAM for the selected set of images

batch_size = 128
grid_size = 28
nb_design = 2
opt = {"approximation" : False, 'size' : grid_size}

wavelet = WaveletSobol(model, grid_size = grid_size, nb_design = nb_design, batch_size = batch_size, opt = opt)

corruption_indices = [np.where(example_images == image)[0].item() for image in results.keys()]

wcams = {}

for image, index in zip(results.keys(), corruption_indices):
        
    images = corrupted_images[index]
    x = torch.stack([
        preprocessing(im.convert('RGB')) for im in images        
    ]).to(device)
    
    y = np.array(results[image][1])
    
    wcams[image] = wavelet(x,y)
        

In [None]:
images_list = list(wcams.keys())
index = 0

plt_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
])


# get the predictions 
preds = results[images_list[index]][1]
img_wcams = wcams[images_list[index]]

source_images = corrupted_images[index]

class_wcams = []

for p in preds:
    if p not in classes.keys():
        class_wcams.append(np.nan * np.ones((28,28)))
    else:
        class_wcams.append(classes[p])

# get the original image

fig, ax = plt.subplots(2,3, figsize = (16, 8))

ax[0,0].set_title('Original image')
ax[0,0].imshow(plt_transform(source_images[0]))

# true class, wcam and class wcam
ax[0,1].set_title('Image WCAM - class : {}'.format(preds[0]))
ax[0,1].imshow(cv2.resize(img_wcams[0], input_shape), cmap = 'jet')
ax[0,2].set_title('Class average WCAM')
ax[0,2].imshow(cv2.resize(class_wcams[0], input_shape), cmap = 'jet')
utils.add_lines(224, 3, ax[0,2])
utils.add_lines(224, 3, ax[0,1])

# true class, wcam and class wcam
k = 2

ax[1,0].set_title('Corrupted image')
ax[1,0].imshow(plt_transform(source_images[k]))

ax[1,1].set_title('Image WCAM - class : {}'.format(preds[k]))
ax[1,1].imshow(cv2.resize(img_wcams[k], input_shape), cmap = 'jet')
ax[1,2].set_title('Class average WCAM')
ax[1,2].imshow(cv2.resize(class_wcams[k], input_shape), cmap = 'jet')
utils.add_lines(224, 3, ax[1,2])
utils.add_lines(224, 3, ax[1,1])

plt.suptitle('WCAM for the original and corrupted image')
plt.savefig('figs/class-swap-example-3.pdf')
plt.show()