# Zoom-in and accuracy: how much of the new information do models use?

This notebook contains the source code to generate the figures provided in section 4.3

In [None]:
# Libraries and imports
import sys
sys.path.append('../')

## Table generation

Retrieve the values reported in table 1.

In [None]:
import numpy as np
import json
from utils import helpers
import os

In [None]:
data_path = '../../data/spectral-attribution-outputs'
imagenet_path = '../../data/ImageNet/'
filename = 'zoom_importance.json'

results = json.load(open(os.path.join(data_path, filename)))
images = results['images']

cases = [k for k in results.keys() if not k == 'images']

# for all cases, retrieve the tables and averages

aggregate = np.zeros((5, 2 * len(cases)))
stds = np.zeros((5, 2 * len(cases)))
for i, case in enumerate(cases):
    regular = np.mean(results[case]['regular'], axis = 1)
    reg_std = np.std(results[case]['regular'], axis = 1)
    zoomed = np.mean(results[case]['zoomed'], axis = 1)
    zoom_std = np.std(results[case]['zoomed'], axis = 1)
    
    # add a 0 value for the 4th level of the regular wcam
    regular = np.append(regular, 0.)
    reg_std = np.append(reg_std, np.nan)

    aggregate[:,2*i] = regular
    aggregate[:,2*i+1] = zoomed

    stds[:,2*i] = reg_std
    stds[:,2*i+1] = zoom_std

In [None]:
print(cases)
print('Table coefficients') 

for i,(row, row_std) in enumerate(zip(aggregate, stds)):
  
    values = 'Level : {}   '.format(i)
    values_std = '            '
    for r, r_std in zip(row, row_std):
        values = values + '{:0.3f}\t'.format(r)
        if np.isnan(r_std):
            values_std = values_std + '(-)  \t'
        else :
            values_std = values_std + '({:0.3f})\t'.format(r_std)
    values_std = values_std + '\n'

    print(values)
    print(values_std)

## Illustrations

Consider an image, resize it and zoom in and resize it and consider it normal and see how the WCAM changes

In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
import os
import numpy as np
from spectral_sobol.torch_explainer import WaveletSobol
import cv2

In [None]:
# parameters, images and explanations

num_examples = 10 # number of examples to evaluate

type = 'baseline' # model type
device = 'multi' # device on which the model is sent
images_dir = '../../data/ImageNet/'
labels = helpers.format_dataframe(images_dir, images[:num_examples])

grid_size = 32 # grid size and options
opt = {'size' : grid_size}

# model and explainer
model = helpers.load_model(type, device)
wavelet_3 = WaveletSobol(model, grid_size = grid_size, nb_design = 4, batch_size = 256, levels = 3, opt = opt)
wavelet_4 = WaveletSobol(model, grid_size = grid_size, nb_design = 4, batch_size = 256, levels = 4, opt = opt)

# load the images 
baseline_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224)
])

zoomed_in = torchvision.transforms.Compose([
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(224)
])

normalize = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
    ])
# images and their label

images_baseline = [
    baseline_transforms(Image.open(os.path.join(images_dir,labels.iloc[i]['name'])).convert('RGB')) for i in range(labels.shape[0])
]

images_zoomed_in = [
    zoomed_in(Image.open(os.path.join(images_dir,labels.iloc[i]['name'])).convert('RGB')) for i in range(labels.shape[0])
]

x_baseline = torch.stack([
    normalize(im) for im in images_baseline
])

x_zoomed_in = torch.stack([
    normalize(im) for im in images_zoomed_in
])

y = labels['label'].values.astype(np.uint8)

# compute the explanations
expl_baseline = wavelet_3(x_baseline,y)
expl_zoomed_in = wavelet_4(x_zoomed_in,y)

In [None]:
size = 224

# which example to plot
index = 7

fig, ax = plt.subplots(2,2, figsize = (8,8))
plt.rcParams.update({'font.size': 15})

ax[0,0].set_title('Regular image')

ax[0,0].imshow(images_baseline[index])
ax[0,0].axis('off')

ax[0,1].set_title('Zoomed in image')

ax[0,1].imshow(images_zoomed_in[index])
ax[0,1].axis('off')

ax[1,0].set_title('Regular WCAM')


wcam_baseline = cv2.resize(expl_baseline[index], (224,224))
ax[1,0].imshow(wcam_baseline, cmap = 'jet')
ax[1,0].axis('off')
helpers.add_lines(size, 3, ax[1,0])

wcam_zoom = cv2.resize(expl_zoomed_in[index], (224,224))
ax[1,1].imshow(wcam_zoom, cmap = 'jet')
ax[1,1].axis('off')
helpers.add_lines(size, 4, ax[1,1])
ax[1,1].set_title('Zoomed-in WCAM')

fig.tight_layout()

plt.savefig('../figs/wcam_zoom_example.pdf')

plt.show()