# Understanding how distribution shifts affect image classification models

This notebooks displays images computed from the script `spectral_profile.py`. We display WCAMs obtained across various training methods and model backbones to show that the sensitivity to distribution shifts can be explained by a shift in the important regions for the prediction, as highlighted by the WCAM. Existing attribution methods cannot provide such qualitative evidence while robustness analyses have only shown that robust methods are quantitatively less subject to distribution shifts, but did not explicited whether it was because the model behave qualitatively differently. 

In this work, we provide evidence that the attitude of deep learining models with respect to the scale decomposition of an image is the same across (1) model baselines and (2) training approaches. Differences in robustness between those methods are only quantitative. 

In [None]:
import sys
sys.path.append('../')

# Libraries 
import os
from utils import helpers
import numpy as np

## Mains plots

### Figure 4

In [None]:
import torchvision
from torchvision.models import resnet50
from utils import corruptions
from spectral_sobol.torch_explainer import WaveletSobol
import torch

device = 'cuda'
source = '../assets'
model = resnet50(pretrained = True).eval().to(device)
batch_size = 128

classes = { # dictionnary with the example images and labels
 'fox.png': 278,
 'snow_fox.png': 279,
 'polar_bear.png': 296,
 'leopard.png': 288,
 'fox1.jpg': 277,
 'fox2.jpg': 277,
 'sea_turtle.jpg': 33,
 'lynx.jpg': 287,
 'cat.jpg': 281,
 'otter.jpg': 360
}

# transforms

# misc transforms
resize_and_crop = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224)
])


# transforms
normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

preprocessing = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    normalize,
])

# generate corrupted images
# generate and resize the set of corrupted images
# in this dictionnary, the 0th image is the source image (uncorrupted)
corrupted_images = {
    image_name : [resize_and_crop(im) for im in corruptions.generate_corruptions(os.path.join(source, image_name))] for image_name in classes.keys()
}

name = 'sea_turtle.jpg'
images = corrupted_images[name]
x = torch.stack([
    preprocessing(im) for im in images
])

preds = helpers.evaluate_model_on_samples(x, model, batch_size)


In [None]:
preds

In [None]:
images = corrupted_images[name]
indices = [0, 3, 7]
selected_images = [images[i] for i in indices]
x = torch.stack([
    preprocessing(im) for im in selected_images
])

wavelet = WaveletSobol(model, grid_size = 28, nb_design = 4, batch_size = 128, opt = {'approximation' : False})
explanations = wavelet(x, np.array([33,33,33]).astype(np.uint8))

In [None]:
fig, ax = plt.subplots(1,3, figsize = (12,4))
ax[0].imshow(explanations[0], cmap = 'jet')
ax[1].imshow(explanations[1], cmap = 'jet')
ax[2].imshow(explanations[2], cmap = 'jet')


In [None]:
import json
imagenet_dir = '../../data/ImageNet/'

classes_names = json.load(open(os.path.join(imagenet_dir,'classes-imagenet.json')))


In [None]:
classes_names[str(preds[7].astype(int))].split(',')

In [None]:
import matplotlib.pyplot as plt


fig, ax = plt.subplots(2,3, figsize = (14,8))
plt.rcParams.update({'font.size': 17})


prediction = classes_names[str(preds[0].astype(int))].split(',')[0]
ax[0,0].set_title('Prediction : {}'.format(prediction))
ax[0,0].imshow(selected_images[0])

prediction = classes_names[str(preds[3].astype(int))].split(',')[0]
ax[0,1].set_title('Prediction : {}'.format(prediction))
ax[0,1].imshow(selected_images[1])

prediction = classes_names[str(preds[7].astype(int))].split(',')[0]
ax[0,2].set_title('Prediction : {}'.format(prediction))
ax[0,2].imshow(selected_images[2])


ax[1,0].set_title('WCAM')
ax[1,0].imshow(explanations[0], cmap = 'hot')
helpers.add_lines(224, 3, ax[1,0])

ax[1,1].set_title('WCAM')
ax[1,1].imshow(explanations[1], cmap = 'hot')
helpers.add_lines(224, 3, ax[1,1])

ax[1,2].set_title('WCAM')
ax[1,2].imshow(explanations[2], cmap = 'hot')
helpers.add_lines(224, 3, ax[1,2])

ax[0,0].axis('off')
ax[0,1].axis('off')
ax[0,2].axis('off')
ax[1,0].axis('off')
ax[1,1].axis('off')
ax[1,2].axis('off')

# plt.savefig('../figs/corruptions_baseline.pdf')

plt.show()

### Figure 5

In [None]:
model_names = ['sin', 'adv']

models = [helpers.load_model(model_name, device) for model_name in model_names]

corrupted_images = {
    image_name : [resize_and_crop(im) for im in corruptions.generate_corruptions(os.path.join(source, image_name))] for image_name in classes.keys()
}

classes = { # dictionnary with the example images and labels
 'fox.png': 278,
 'snow_fox.png': 279,
 'polar_bear.png': 296,
 'leopard.png': 288,
 'fox1.jpg': 277,
 'fox2.jpg': 277,
 'sea_turtle.jpg': 33,
 'lynx.jpg': 287,
 'cat.jpg': 281,
 'otter.jpg': 360
}

name = 'fox.png'
images = corrupted_images[name]
x = torch.stack([
    preprocessing(im) for im in images
])

preds = {model_name : helpers.evaluate_model_on_samples(x, model, batch_size) for model_name, model in zip(model_names, models)}

In [None]:
outputs = {}

images = corrupted_images[name]
selected_images = [images[i] for i in [0,8]]
x = torch.stack([
    preprocessing(im) for im in images
])

for model_name, model in zip(model_names, models):

    wavelet = WaveletSobol(model, grid_size = 28, nb_design = 4, batch_size = 128, opt = {'approximation' : False})
    
    predictions = [278]
    predictions.append(preds[model_name][8])
    
    explanations = wavelet(x, np.array(predictions).astype(np.uint8))

    outputs[model_name] = explanations

In [None]:
fig, ax = plt.subplots(2,3, figsize = (14,8))
plt.rcParams.update({'font.size': 17})

ax[0,0].imshow(selected_images[0])
ax[0,0].axis('off')
ax[0,0].set_title('Source image\n Label : {}'.format(classes_names[str(278)]).split(',')[0])

ax[0,1].set_title('SIN \n Pred : {}'.format(classes_names[str(278)].split(',')[0]))
ax[0,1].imshow(outputs['sin'][0], cmap = 'hot')
helpers.add_lines(224, 3, ax[0,1])

ax[1,1].set_title('Pred : {}'.format(classes_names[str(int(preds['sin'][8]))].split(',')[0]))
ax[1,1].imshow(outputs['sin'][1], cmap = 'hot')
helpers.add_lines(224, 3, ax[1,1])

ax[0,2].set_title('Adv \n Pred : {}'.format(classes_names[str(278)].split(',')[0]))
ax[0,2].imshow(outputs['adv'][0], cmap = 'hot')
helpers.add_lines(224, 3, ax[0,2])

ax[1,2].set_title('Pred : {}'.format(classes_names[str(int(preds['adv'][8]))].split(',')[0]))
ax[1,2].imshow(outputs['adv'][1], cmap = 'hot')
helpers.add_lines(224, 3, ax[1,2])



ax[0,1].axis('off')
ax[0,2].axis('off')
ax[1,1].axis('off')
ax[1,2].axis('off')

fig.delaxes(ax[1,0])
plt.savefig('../figs/alternative-models.pdf')
plt.show()

## Appendices

In [None]:
# data directory

root = '../../data/spectral-attribution-outputs/'
perturbations = ['corruptions']

# get the name of the images to be plotted for each perturbation
images = {
    perturbation : os.listdir(os.path.join(root, perturbation)) for perturbation in perturbations
}

In the cell below you can generate plots on the model of those depicted on figure 4 in the main paper.

In [None]:
save = None 
perturbation = "corruptions"
image_name = images[perturbation][0]
device = 'cuda:1'

case = 'baseline'
labels = ['source', 'corrupted, unaffected', "corrupted, affected"]

np.random.seed(95)
directory = os.path.join(os.path.join(root, perturbation), image_name)


model = helpers.load_model(case, device)
helpers.plot_stable_and_unstable_prediction(image_name, model, directory, perturbation, case, labels, save = save)

In [None]:
np.random.seed(95)

# Loop 

# plot of the baseline case 
save = '../figs/figure-4/{}_{}_{}.pdf'
perturbation = "corruptions"
device = 'cuda:1'
labels = ['source', 'corrupted, unaffected', "corrupted, affected"]


cases = ['augmix']#, 'pixmix', "adv", "adv_free", 'fast_adv', "sin"]


for case in cases:
    print('Case ............. {}'.format(case))
    
    model = helpers.load_model(case, device)
    for i, image_name in enumerate(images[perturbation]):
        directory = os.path.join(os.path.join(root, perturbation), image_name)
        helpers.plot_stable_and_unstable_prediction(image_name, model, directory, perturbation, case, labels, save = save.format(perturbation, case, i))

In the cell below you can reproduce figures on the model of those depicted in figure 5 of the paper.

In [None]:
# plot of the adversarial and SIN cases
save = '../figs/figure-5/robust_training_corruptions_{}.pdf'
cases = ['baseline', 'sin', 'augmix', "pixmix", "adv", "adv_free"]

np.random.seed(95)

for i, image_name in enumerate(images[perturbation]):
    directory = os.path.join(os.path.join(root, perturbation), image_name)
    helpers.plot_wcams(directory, cases, save = save.format(i))