# ImageNet-R

In this notebook, we evaluate the model on samples from ImageNet-A. This dataset contains 200 labels, and for each label a set of "natural" adversarial images. 

In [None]:
import sys
sys.path.append('../')

import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from utils import helpers
from spectral_sobol.torch_explainer import WaveletSobol
import pandas as pd
import json
import torchvision
import torch

In [None]:
device = 'cuda:1'
model_name = "baseline"
imagenet_a_dir = "../../data/ImageNet-R/"
imagenet_dir = "../../data/ImageNet/"

# imagenet labels
classes = json.load(open(os.path.join(imagenet_dir,'classes-imagenet.json')))


# parameters
index = 9 # index of the image to evaluate
batch_size = 128

# load the labels
labels = pd.DataFrame.from_dict(json.load(open(os.path.join(imagenet_a_dir, "labels.json"))), orient = "index").reset_index()
labels.columns = [['name', 'label']]
labels.head() 

# clean the folders
# for i in range(labels.shape[0]):
#     directory = os.path.join(imagenet_a_dir, labels.loc[i]["name"])
#     names = os.listdir(directory)
#     for name in names:
#         new_name = name.replace(' ', '')
#         os.rename(os.path.join(directory, name), os.path.join(directory, new_name))

resize_and_crop = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224)
])

directory = os.path.join(imagenet_a_dir, labels.iloc[index]['name'])

images = [
    resize_and_crop(Image.open(os.path.join(directory, name)).convert('RGB')) for name in os.listdir(directory)
]

## an imagenet example
complete_label = [v for v in classes.values() if labels.loc[index]['label'] in v]
in_class = [int(k) for k in classes.keys() if classes[k] in complete_label]
in_labels = helpers.format_dataframe(imagenet_dir, filter = None)

in_samples = in_labels[in_labels['label'].isin(in_class)]['name'].values

in_examples = [
    resize_and_crop(Image.open(os.path.join(imagenet_dir, im)).convert('RGB')) for im in in_samples
]

In [None]:
# evaluate the model on the images

model = helpers.load_model(model_name, device)

# normalisation for inference
normalize = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
])

x = torch.stack([
    normalize(im) for im in images
])

preds = helpers.evaluate_model_on_samples(x,model,batch_size)

In [None]:
# pick some images and compute their cams

num_samples = 3
grid_size = 28
nb_design = 4
opt = {'approximation' : False}

np.random.seed(42)
indices = np.random.choice([*range(len(preds))], num_samples)

candidates = [images[i] for i in indices]
candidates.append(in_examples[0])

x_wavelet = torch.stack([
    normalize(candidate) for candidate in candidates
])

y = np.array([
    preds[i] for i in indices
])
y = np.append(y, int(in_class[0]))

wavelet = WaveletSobol(model, grid_size = grid_size, nb_design = nb_design, batch_size = batch_size, opt = opt)
explanations = wavelet(x_wavelet, y.astype(int))

In [None]:
fig, ax = plt.subplots(2, (num_samples + 1), figsize = (4 * (num_samples + 1), 8))
plt.rcParams.update({'font.size': 17})

for idx, k in zip(indices, range(num_samples)):

    prediction = classes[str(preds[idx].astype(int))].split(',')[0]

    ax[0,k + 1].set_title('Pred : {}'.format(prediction))
    ax[0,k + 1] .imshow(images[idx])
    ax[0,k+ 1 ].imshow(wavelet.spatial_cam[k], alpha = 0.5, cmap = 'jet')
    ax[1,k +1 ].set_title('WCAM')
    ax[1,k +1 ].imshow(explanations[k], cmap = "hot")
    helpers.add_lines(224, 3, ax[1,k+1])
    
    ax[1,k+1].axis('off')
    ax[0,k+1].axis('off')

ax[0,0].set_title('ImageNet example')
ax[0,0].imshow(in_examples[0])
ax[0,0].axis('off')

ax[1,0].set_title('WCAM')
ax[0,0].imshow(wavelet.spatial_cam[-1], cmap = 'jet', alpha = 0.5)
ax[1,0].imshow(explanations[-1], cmap = 'hot')
helpers.add_lines(224, 3, ax[1,0])
ax[1,0].axis('off')

plt.suptitle('True class : {}'.format(labels.loc[index]['label']))
fig.tight_layout()

plt.savefig('../figs/rendition_pred_{}_{}.pdf'.format(model_name, index))

plt.show()