# Robustification effects

In this notebook, we measure how the standard robustness enhancement techniques improve the robustness of the models. We also show that the properties highlighted for baseline models also holds for robust models. 

We consider three settings 
* ST (standard training) which has been studied in the first notebook
* AT (adversarial training) we study how adversarial training affects the model reliance on frequencies
* RT (robust training) we study how robust training affects the model reliance on frequencies

Punchline: difference in degrees, not in nature. The remaining (yet less numerous) errors behave the same as for the ST model.

In [None]:
import sys
sys.path.append('../')

# library imports
import numpy as np
import os
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
from utils import helpers, corruptions
import torch
import cv2
from torchvision.models import resnet50
from spectral_sobol.torch_explainer import FourierSobol, WaveletSobol
from robustness.datasets import ImageNet
from robustness.model_utils import make_and_restore_model

In [None]:
# inputs by the user
n_samples = 100
data_dir = "../../data/ImageNet"
device = 'cuda:2'

In [None]:
# set ups

# data a set of samples from IN validation set
samples = [s for s in os.listdir(data_dir) if s[-5:] == '.JPEG']
np.random.seed(42)
samples = np.random.choice(samples, n_samples) # restrict the list of samples

# load the labels dataframe
labels = helpers.format_dataframe(data_dir, filter = samples)

# transforms
preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224)
])

normalize = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
])

# load the images and generate the labels 
images = [preprocess(Image.open(os.path.join(data_dir, s)).convert('RGB')) for s in samples]
y = np.array([ # generate the labels
    labels[labels['name'] == s]['label'].values.item() for s in samples
])

# images passed to the model
x = torch.stack([
    normalize(im) for im in images
])

In [None]:
# model set up
# model (and case)
# model zoo features : 
#               - RT : 'augmix', 'pixmix', 'sin' (highest accuracy on ImageNet-C), 
#               - AT : 'adv_free, fast_adv and adv,
#               - ST : standard training ('baseline')

models_dir = '../../models/spectral-attribution-baselines'
cases = ['baseline', 'augmix', 'pixmix', 'sin', 'adv_free', 'fast_adv', 'adv']

# load the model

models = []
for case in cases:
    
    if case == 'baseline':
        model = resnet50(pretrained = True).to(device).eval()
    elif case in ['augmix', 'pixmix', 'sin']:
        model = resnet50(pretrained = False) # model backbone #torch.load(os.path.join(models_dir, '{}.pth'.format(case))).eval()
        weights = torch.load(os.path.join(models_dir, '{}.pth.tar').format(case))
        model.load_state_dict(weights['state_dict'], strict = False)
        model.eval()
    elif case in ['adv_free', 'fast_adv', 'adv']:
        model = resnet50(pretrained = False) # model backbone #torch.load(os.path.join(models_dir, '{}.pth'.format(case))).eval()
        weights = torch.load(os.path.join(models_dir, "{}.pth".format(case)))
        model.load_state_dict(weights)
        model.eval()

    # we will loop over the models
    models.append(model)

## Frequency importance

We leverage the Fourier-CAM to measure the average frequency importance over the selected batch of images. 

### Importance of the frequency components using the circular masks

In [None]:
# perturbation 
perturbation = 'circle' # other values are 'square' and 'grid' for the replication of 
                        # former paper's results 
grid_size = 14

results = {}

for case, model in zip(cases, models):
    print('Case ........... {}'.format(case))
    fourier = FourierSobol(model, grid_size = grid_size, nb_design = 16, batch_size=128, perturbation = perturbation)
    _ = fourier(x,y)
    results[case] = fourier.components


In [None]:
# plots
# represent the importance of each frequency component in each case-
plt.rcParams.update({'font.size': 12})

# average the contributions per training mode
average_baseline = np.sum(results['baseline'], axis = 0) / n_samples
average_robust = np.sum(
    [
    np.sum(results[case], axis = 0) / n_samples for case in ['augmix', 'pixmix', 'sin']
    ], axis = 0) / 3

average_adversarial = np.sum(
    [
    np.sum(results[case], axis = 0) / n_samples for case in ['adv_free', 'fast_adv', "adv"]
    ], axis = 0) / 3

# average the contributions for each case
average_contributions = {
    'ST' : average_baseline,
    'RT' : average_robust,
    'AT' : average_adversarial
    }

# plots
for i, case in enumerate(average_contributions.keys()):

    average_contribution = average_contributions[case]

    offset = 0.33 * i

    plt.bar(np.array([*range(len(average_contribution))]) - 0.33 + offset , average_contribution, label = case, width = 0.2)


plt.legend()

# plot labels
plt.xlabel('Frequency component\n (The higher the higher the frequency)')
plt.ylabel('Importance (STI)')
plt.title('Importance of the frequency components (circular case)')
plt.xlim(-1,(grid_size // 2) + (grid_size // 4)) # resize as many circular components have no contribution
plt.xticks([*range(0, (grid_size // 2) + (grid_size // 4))], rotation = 60)

# plt.savefig("../figs/frequency-components-importance-complete.pdf")
plt.show()

### WCAM frequency importance

We show that we have similar results as in the previous plot with the WCAM

In [None]:
grid_size = 28
levels = 5
opt = {
    'approximation' : True,
    'size' : grid_size
}

results = {}

for case, model in zip(cases, models):
    print('Case ............... {}'.format(case))
    wavelet = WaveletSobol(model, grid_size = grid_size, nb_design = 8, batch_size=128, opt = opt, levels = levels)
    explanations = wavelet(x,y)
    results[case] = explanations

In [None]:
# consider the means
levels = 5

means = np.zeros((len(results.keys()), grid_size, grid_size))

for i, case in enumerate(list(results.keys())):

    if case == "baseline":
        wcams = results[case][1:]
    else:
        wcams = results[case]

    for wcam in wcams:
        means[i] += cv2.resize(wcam, (grid_size, grid_size))

    means[i] /= n_samples # average the values
    means[0,0] = 0 # remove the 0th component (approximation coefficient)

# work on the explanations to recover the frequencies 
averaged_coefficients = [helpers.reshape_wcam(means[i], grid_size, levels = levels) for i in range(means.shape[0])]

# Reshape the contributions per ST, RT, AT
baseline_coefficients = averaged_coefficients[0]
robust_coefficients = np.mean(averaged_coefficients[1:4], axis = 0)
adversarial_coefficients = np.mean(averaged_coefficients[-3:], axis = 0)

# group again the coefficients
averaged_coefficients = [baseline_coefficients, robust_coefficients, adversarial_coefficients]

# plots
plt.rcParams.update({'font.size': 12})

group_cases = ['ST', 'RT', 'AT']

for i, (average_contribution, case) in enumerate(zip(averaged_coefficients, group_cases)):
    offset = 0.33 * i
    plt.bar(np.array([*range(len(average_contribution) - 1)]) - 0.33 + offset , average_contribution[1:] / sum(average_contribution[1:]), label = case, width = 0.2)

# labels = [['a  \n 0']]

labels = []

for level in range(levels):
    labels.append(['h', 'd \n{}'.format(level + 1), 'v'])
labels = list(sum(labels, []))

plt.xticks([*range(len(labels))], labels = labels)
plt.title('Importance of the frequency components (WCAM case) \n (without approximation coefficient)')
plt.xlabel('Frequency/level')
plt.ylabel('Importance (normalized STIs)')
plt.legend()
plt.savefig('../figs/wcam-fcam-consistency-complete-no-approx.pdf')
plt.show()

## Instance-based behavior is the same across models

We partially replicate the results from the first notebook to show that even for robust models, sensitivity to shifts translates into associating an image with the wrong spectral profile. 

In [None]:
# common parameters for this section

index = 2
model = models[index]
case = cases[index]
n = 3 # number of images that will be considered


In [None]:
# generate corrupted images for a set of instances
np.random.seed(42)
image_names = np.random.choice(samples, n)

corrupted_images = {
    image_name : [preprocess(im) for im in corruptions.generate_corruptions(os.path.join(data_dir, image_name))] for image_name in image_names
}

# evaluate the model on the corrupted images
batch_size = 128
labels = helpers.format_dataframe(data_dir)

model = models[2]

results = {}
for image_name in corrupted_images.keys():

    label = labels[labels['name'] == image_name]['label'].values.item()
    sequence = corrupted_images[image_name]

    x = torch.stack(
        [normalize(im) for im in sequence]
    )

    preds = helpers.evaluate_model_on_samples(x, model, batch_size)
    results[image_name] = np.array(preds)

In [None]:
# compute spectral profiles on the fly

classes = {}
n_items = 10
grid_size = 28

opt = {
    'approximation' : False,
    'size' : grid_size
}

# explainer
wavelet = WaveletSobol(model, grid_size = grid_size, nb_design = 4, opt = opt)

# parameters passed to the comoutation of the spectral profile
params = {
    'source_dir' : data_dir,
    'preprocessing' : preprocess,
    'normalize' : normalize,
    "explainer" : wavelet

}

for image in results.keys():
    
    label = preds[0]
    # get images that have the same label

    preds = results[image]
    shifted_labels = np.unique(preds[np.where(preds != label)[0]])

    # compute the WCAM of the predicted label
    if label not in classes.keys():
        np.random.seed(42)
        item_names = np.random.choice( # choose the items of interest
            [l for l in labels[labels['label'] == label]['name'].values if not l == image], 
                                      size = n_items)

        classes[label] = helpers.compute_spectral_profile(item_names, label, params)

    # compute the WCAM of the shifted labels
    for p in shifted_labels:
        if p not in classes.keys():
            np.random.seed(42)
            item_names = np.random.choice( # choose the items of interest
            [l for l in labels[labels['label'] == label]['name'].values if not l == image], 
                                      size = n_items)
            classes[p] = helpers.compute_spectral_profile(item_names, p, params)

In [None]:
# compute the wcams for the images
wcams = {}

for image_name in corrupted_images.keys():
    images = corrupted_images[image_name]
    x = torch.stack([
        normalize(im) for im in images
    ])

    y = results[image_name].astype(np.uint8)

    # compute the wcams
    wcams[image_name] = wavelet(x,y)

In [None]:
# plot these profiles and the image-based wcams
plt.rcParams.update({'font.size': 17})
input_shape = (224,224)

images_list = list(wcams.keys())
index = 2

image_name = images_list[index]


# get the predictions 
preds = results[images_list[index]]
img_wcams = wcams[image_name]

source_images = corrupted_images[image_name]

print(preds)


# get the original image

fig, ax = plt.subplots(2,3, figsize = (15,8))
label = preds[0]

ax[0,0].set_title('Original image')
ax[0,0].imshow(source_images[0])
ax[0,0].axis('off')

# true class, wcam and class wcam
ax[0,1].set_title('Image WCAM - class : {}'.format(preds[0].astype(int)))
ax[0,1].imshow(cv2.resize(img_wcams[0], input_shape), cmap = 'jet')
ax[0,2].set_title('Class average WCAM')
ax[0,2].imshow(cv2.resize(classes[label], input_shape), cmap = 'jet')
ax[0,1].axis('off')
ax[0,2].axis('off')
helpers.add_lines(224, 3, ax[0,2])
helpers.add_lines(224, 3, ax[0,1])

# true class, wcam and class wcam
k = 9

label = preds[k]

ax[1,0].set_title('Corrupted image')
ax[1,0].imshow(source_images[k])
ax[1,0].axis('off')

ax[1,1].set_title('Image WCAM - class : {}'.format(preds[k].astype(int)))
ax[1,1].imshow(cv2.resize(img_wcams[k], input_shape), cmap = 'jet')
ax[1,2].set_title('Class average WCAM')
ax[1,2].imshow(cv2.resize(classes[label], input_shape), cmap = 'jet')
helpers.add_lines(224, 3, ax[1,2])
helpers.add_lines(224, 3, ax[1,1])
ax[1,1].axis('off')
ax[1,2].axis('off')

plt.suptitle('WCAM for the original and corrupted image. Training setup : {}'.format(case))
plt.savefig('../figs/class-swap-example-{}.pdf'.format(case))
fig.tight_layout()
plt.show()

## Appendices

### Square masks

Replication of the frequency components importance plot with square masks [(Zhang et al, 2022)](https://www.sciencedirect.com/science/article/abs/pii/S0925231222001084)

In [None]:
# perturbation 
perturbation = 'square' # other values are 'square' and 'grid' for the replication of 
                        # former paper's results 
grid_size = 14
results = {}

for case, model in zip(cases, models):
    print('Case ........... {}'.format(case))
    fourier = FourierSobol(model, grid_size = grid_size, nb_design = 16, batch_size=128, perturbation = perturbation)
    _ = fourier(x,y)
    results[case] = fourier.components

# plots
# represent the importance of each frequency component in each case-
plt.rcParams.update({'font.size': 12})

# average the contributions per training mode
average_baseline = np.sum(results['baseline'], axis = 0) / n_samples
average_robust = np.sum(
    [
    np.sum(results[case], axis = 0) / n_samples for case in ['augmix', 'pixmix', 'sin']
    ], axis = 0) / 3

average_adversarial = np.sum(
    [
    np.sum(results[case], axis = 0) / n_samples for case in ['adv_free', 'fast_adv', "adv"]
    ], axis = 0) / 3

# average the contributions for each case
average_contributions = {
    'ST' : average_baseline,
    'RT' : average_robust,
    'AT' : average_adversarial
    }

# plots
for i, case in enumerate(average_contributions.keys()):
    average_contribution = average_contributions[case]

    offset = 0.33 * i

    plt.bar(np.array([*range(len(average_contribution))]) - 0.33 + offset , average_contribution, label = case, width = 0.25)

plt.legend()

# plot labels
plt.xlabel('Frequency component\n (The higher the higher the frequency)')
plt.ylabel('Importance (STI)')
plt.title('Importance of the frequency components (square case)')
plt.xticks([*range(0, grid_size)], rotation = 60)

plt.savefig("../figs/frequency-components-importance-squares-complete.pdf")

plt.show()

### Visualization of the importance in the Nyquist square

Replication of the same results but with grid masks [(Chen et al. 2022)](https://openreview.net/forum?id=rQ1cNbi07Vq)

In [None]:
perturbation = 'grid'
grid_size = 17
results = {}

for case, model in zip(cases, models):
    print('Case ............. {}'.format(case))
    
    fourier = FourierSobol(model, grid_size = grid_size, nb_design = 8, batch_size=128, perturbation = perturbation)
    explanations = fourier(x,y)
    results[case] = explanations

In [None]:
fig, ax = plt.subplots(1,7, figsize = (28, 4))
plt.rcParams.update({'font.size': 15})

for i, case in enumerate(list(results.keys())):

    mean = np.zeros((grid_size, grid_size))

    count = 0
    for map in results[case]:
        if not np.isnan(map).any():

            mean += cv2.resize(map, (grid_size, grid_size))
            count += 1
            
    mean /= count
    
    ax[i].set_title('{}'.format(case))
    im = ax[i].imshow(np.log(1 + mean), cmap = 'jet')

    ax[i].set_xticks(([*range(grid_size)]))
    ax[i].set_xticklabels(np.array([*range(grid_size)]) - (grid_size // 2))

    ax[i].set_yticks(([*range(grid_size)]))
    ax[i].set_yticklabels(np.array([*range(grid_size)]) - (grid_size // 2))

    fig.colorbar(im, ax = ax[i], shrink = 0.8)
plt.suptitle('Frequency importance in the Nyquist square for baseline, robust and adversarial training')
fig.tight_layout()

#plt.savefig('../figs/frequency-grid-complete.pdf')
plt.show()

# TODO. With 7 plots (1 for each type of training)
