# Zoom-in and accuracy: how much of the new information do models use?

This notebook contains the source code to generate the figures provided in section 4.3

In [1]:
# Libraries and imports
import sys
sys.path.append('../')

## Table generation

Retrieve the values reported in table 1.

In [2]:
import numpy as np
import json
from utils import helpers
import os

2023-05-04 10:23:46.410898: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
data_path = 'path/to/data/outputs'
imagenet_path = 'path/to/imagenet'
filename = 'zoom_importance.json'

results = json.load(open(os.path.join(data_path, filename)))
images = results['images']

cases = [k for k in results.keys() if not k == 'images']

# for all cases, retrieve the tables and averages

aggregate = np.zeros((5, 2 * len(cases)))
stds = np.zeros((5, 2 * len(cases)))
for i, case in enumerate(cases):
    regular = np.mean(results[case]['regular'], axis = 1)
    reg_std = np.std(results[case]['regular'], axis = 1)
    zoomed = np.mean(results[case]['zoomed'], axis = 1)
    zoom_std = np.std(results[case]['zoomed'], axis = 1)
    
    # add a 0 value for the 4th level of the regular wcam
    regular = np.append(regular, 0.)
    reg_std = np.append(reg_std, np.nan)

    aggregate[:,2*i] = regular
    aggregate[:,2*i+1] = zoomed

    stds[:,2*i] = reg_std
    stds[:,2*i+1] = zoom_std

In [21]:
cases = [k for k in results.keys() if not k == 'images']

mean_cases = np.zeros((5, 4))
mean_stds = np.zeros((5, 4))


for case in ['sin', 'augmix', 'pixmix']:
    regular = np.mean(results[case]['regular'], axis = 1)
    reg_std = np.std(results[case]['regular'], axis = 1)
    zoomed = np.mean(results[case]['zoomed'], axis = 1)
    zoom_std = np.std(results[case]['zoomed'], axis = 1)

    # add a 0 value for the 4th level of the regular wcam
    regular = np.append(regular, 0.)
    reg_std = np.append(reg_std, np.nan)

    mean_cases[:,0] += regular
    mean_stds[:,0] += reg_std

    mean_cases[:,1] += zoomed
    mean_stds[:,1] += zoom_std


for case in ['adv', 'adv_free', 'fast_adv']:
    regular = np.mean(results[case]['regular'], axis = 1)
    reg_std = np.std(results[case]['regular'], axis = 1)
    zoomed = np.mean(results[case]['zoomed'], axis = 1)
    zoom_std = np.std(results[case]['zoomed'], axis = 1)

    # add a 0 value for the 4th level of the regular wcam
    regular = np.append(regular, 0.)
    reg_std = np.append(reg_std, np.nan)

    mean_cases[:,2] += regular
    mean_stds[:,2] += reg_std

    mean_cases[:,3] += zoomed
    mean_stds[:,3] += zoom_std

# average
mean_cases /= 3
mean_stds /= 3

# for all cases, retrieve the tables and averages

aggregate = np.zeros((5, 8))
stds = np.zeros((5, 8))

for case in ['baseline']:
    regular = np.mean(results[case]['regular'], axis = 1)
    reg_std = np.std(results[case]['regular'], axis = 1)
    zoomed = np.mean(results[case]['zoomed'], axis = 1)
    zoom_std = np.std(results[case]['zoomed'], axis = 1)
    
    # add a 0 value for the 4th level of the regular wcam
    regular = np.append(regular, 0.)
    reg_std = np.append(reg_std, np.nan)

    aggregate[:,0] = regular
    aggregate[:,1] = zoomed

    stds[:,0] = reg_std
    stds[:,1] = zoom_std

for case in ['vit']:
    regular = np.mean(results[case]['regular'], axis = 1)
    reg_std = np.std(results[case]['regular'], axis = 1)
    zoomed = np.mean(results[case]['zoomed'], axis = 1)
    zoom_std = np.std(results[case]['zoomed'], axis = 1)
    
    # add a 0 value for the 4th level of the regular wcam
    regular = np.append(regular, 0.)
    reg_std = np.append(reg_std, np.nan)

    aggregate[:,-2] = regular
    aggregate[:,-1] = zoomed

    stds[:,-2] = reg_std
    stds[:,-1] = zoom_std

aggregate[:,2:4] = mean_cases[:,:2]
aggregate[:,4:6] = mean_cases[:,2:]
stds[:,2:4] = mean_stds[:,:2]
stds[:,4:6] = mean_stds[:,2:]

In [31]:
print('Table coefficients') 

for i,(row, row_std) in enumerate(zip(aggregate, stds)):
  
    values = ' {}  '.format(i)
    values_std = ''
    for r, r_std in zip(row, row_std):
        values = values + '&{:0.3f}\t '.format(r)
        if np.isnan(r_std):
            values_std = values_std + '&(-) '
        else :
            values_std = values_std + '&({:0.3f})\t '.format(r_std)
    values_std = values_std + '\n'

    print(values)
    print(values_std)

Table coefficients
 0  &0.837	 &0.752	 &0.895	 &0.869	 &0.954	 &0.906	 &0.830	 &0.611	 
&(0.064)	 &(0.137)	 &(0.082)	 &(0.101)	 &(0.033)	 &(0.080)	 &(0.075)	 &(0.164)	 

 1  &0.130	 &0.190	 &0.077	 &0.095	 &0.039	 &0.080	 &0.137	 &0.295	 
&(0.053)	 &(0.107)	 &(0.063)	 &(0.077)	 &(0.029)	 &(0.071)	 &(0.063)	 &(0.132)	 

 2  &0.028	 &0.047	 &0.023	 &0.030	 &0.006	 &0.013	 &0.029	 &0.078	 
&(0.012)	 &(0.030)	 &(0.024)	 &(0.033)	 &(0.005)	 &(0.011)	 &(0.018)	 &(0.043)	 

 3  &0.005	 &0.010	 &0.004	 &0.005	 &0.001	 &0.001	 &0.004	 &0.014	 
&(0.002)	 &(0.006)	 &(0.005)	 &(0.005)	 &(0.001)	 &(0.001)	 &(0.002)	 &(0.008)	 

 4  &0.000	 &0.001	 &0.000	 &0.001	 &0.000	 &0.000	 &0.000	 &0.002	 
&(-) &(0.001)	 &(-) &(0.001)	 &(-) &(0.000)	 &(-) &(0.001)	 



In [4]:
print(cases)
print('Table coefficients') 

for i,(row, row_std) in enumerate(zip(aggregate, stds)):
  
    values = 'Level : {}   '.format(i)
    values_std = '            '
    for r, r_std in zip(row, row_std):
        values = values + '{:0.3f}\t'.format(r)
        if np.isnan(r_std):
            values_std = values_std + '(-)  \t'
        else :
            values_std = values_std + '({:0.3f})\t'.format(r_std)
    values_std = values_std + '\n'

    print(values)
    print(values_std)

['baseline', 'sin', 'adv', 'augmix', 'vit', 'pixmix', 'adv_free', 'fast_adv']
Table coefficients
Level : 0   0.837	0.752	0.890	0.873	0.953	0.903	0.892	0.873	0.830	0.611	0.904	0.863	0.951	0.906	0.959	0.909	
            (0.064)	(0.137)	(0.082)	(0.102)	(0.036)	(0.077)	(0.088)	(0.088)	(0.075)	(0.164)	(0.074)	(0.114)	(0.033)	(0.081)	(0.030)	(0.081)	

Level : 1   0.130	0.190	0.082	0.091	0.040	0.082	0.080	0.093	0.137	0.295	0.069	0.100	0.042	0.079	0.035	0.078	
            (0.053)	(0.107)	(0.063)	(0.075)	(0.031)	(0.069)	(0.070)	(0.070)	(0.063)	(0.132)	(0.055)	(0.086)	(0.029)	(0.073)	(0.027)	(0.072)	

Level : 2   0.028	0.047	0.024	0.030	0.006	0.013	0.024	0.029	0.029	0.078	0.022	0.032	0.006	0.013	0.005	0.012	
            (0.012)	(0.030)	(0.025)	(0.032)	(0.006)	(0.010)	(0.023)	(0.030)	(0.018)	(0.043)	(0.023)	(0.036)	(0.006)	(0.011)	(0.005)	(0.011)	

Level : 3   0.005	0.010	0.004	0.005	0.001	0.001	0.004	0.005	0.004	0.014	0.004	0.005	0.001	0.002	0.001	0.001	
            (0.002)	(0.006)	(0.006)	(0.00

## Illustrations

Consider an image, resize it and zoom in and resize it and consider it normal and see how the WCAM changes

In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
import os
import numpy as np
from spectral_sobol.torch_explainer import WaveletSobol
import cv2

In [None]:
# parameters, images and explanations

num_examples = 10 # number of examples to evaluate

type = 'baseline' # model type
device = 'multi' # device on which the model is sent
images_dir = '../../data/ImageNet/'
labels = helpers.format_dataframe(images_dir, images[:num_examples])

grid_size = 32 # grid size and options
opt = {'size' : grid_size}

# model and explainer
model = helpers.load_model(type, device)
wavelet_3 = WaveletSobol(model, grid_size = grid_size, nb_design = 4, batch_size = 256, levels = 3, opt = opt)
wavelet_4 = WaveletSobol(model, grid_size = grid_size, nb_design = 4, batch_size = 256, levels = 4, opt = opt)

# load the images 
baseline_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224)
])

zoomed_in = torchvision.transforms.Compose([
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(224)
])

normalize = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
    ])
# images and their label

images_baseline = [
    baseline_transforms(Image.open(os.path.join(images_dir,labels.iloc[i]['name'])).convert('RGB')) for i in range(labels.shape[0])
]

images_zoomed_in = [
    zoomed_in(Image.open(os.path.join(images_dir,labels.iloc[i]['name'])).convert('RGB')) for i in range(labels.shape[0])
]

x_baseline = torch.stack([
    normalize(im) for im in images_baseline
])

x_zoomed_in = torch.stack([
    normalize(im) for im in images_zoomed_in
])

y = labels['label'].values.astype(np.uint8)

# compute the explanations
expl_baseline = wavelet_3(x_baseline,y)
expl_zoomed_in = wavelet_4(x_zoomed_in,y)

In [None]:
size = 224

# which example to plot
index = 7

fig, ax = plt.subplots(2,2, figsize = (8,8))
plt.rcParams.update({'font.size': 15})

ax[0,0].set_title('Regular image')

ax[0,0].imshow(images_baseline[index])
ax[0,0].axis('off')

ax[0,1].set_title('Zoomed in image')

ax[0,1].imshow(images_zoomed_in[index])
ax[0,1].axis('off')

ax[1,0].set_title('Regular WCAM')


wcam_baseline = cv2.resize(expl_baseline[index], (224,224))
ax[1,0].imshow(wcam_baseline, cmap = 'jet')
ax[1,0].axis('off')
helpers.add_lines(size, 3, ax[1,0])

wcam_zoom = cv2.resize(expl_zoomed_in[index], (224,224))
ax[1,1].imshow(wcam_zoom, cmap = 'jet')
ax[1,1].axis('off')
helpers.add_lines(size, 4, ax[1,1])
ax[1,1].set_title('Zoomed-in WCAM')

fig.tight_layout()

plt.savefig('../figs/wcam_zoom_example.pdf')

plt.show()