# Benchmark

Notebook to reproduce table 2 using pretrained models on Google images.

The models all have a ResNet-50 backbone and are the following:

- Standard: no particular augmentation beyond crops, resizes and rotations,
- AugMix: data augmentations using AugMix
- Blurring and noising: applying random noise and blur to improve the robustness against these perturbations.
- Target: we train a model on IGN to evaluate the oracle accuracy


Our methods are the following:
- Blurring : surrogate for the scale removal, requiring only to apply a well-calibrated Gaussian blur on the image
- Scale removal: based on the wavelet transform of the image

In [None]:
# Libraries
import os
import pandas as pd
import numpy as np
import json
from PIL import Image, ImageEnhance
import matplotlib.pyplot as plt
from src import bdappv, utils, helpers
import torchvision
from torch.utils.data import DataLoader
import tqdm
import torch
import random
from functools import reduce 
import copy
from spectral_sobol.torch_explainer import WaveletSobol
from torch.nn import functional as F
import os
import pywt
from src import reconstruction
import cv2

In [None]:
# helpers

def rgb_to_wavelet_array(image, wavelet='haar', level=3):
    # Convert PIL image to NumPy array
    img_array = np.array(image.convert('L'))

    # Compute wavelet transform for each channel
    c = pywt.wavedec2(img_array, wavelet, level=level)     
    # normalize each coefficient array independently for better visibility
    c[0] /= np.abs(c[0]).max()
    for detail_level in range(level):
        c[detail_level + 1] = [d/np.abs(d).max() for d in c[detail_level + 1]]
    arr, _ = pywt.coeffs_to_array(c)

    
    return arr


def plot_wcam(ax, image, wcam, levels, vmin = None, vmax = None):
    """
    plts the wcam
    """

    def logplot(x):
        return np.log(1 + x)
    
    size = image.size[0]
    # compute the wavelet transform
    wt = rgb_to_wavelet_array(image,level = levels)
    
    # plots
    ax.imshow(wt, cmap = 'gray')

    vmin = logplot(vmin) if vmin is not None else vmin
    vmax = logplot(vmax) if vmax is not None else vmax

    #im = ax.imshow(1 + logplot(wcam), cmap = "hot", alpha = 0.5, vmin = vmin, vmax = vmax)

    minlog = np.min(logplot(wcam))
    im = ax.imshow(minlog  + logplot(wcam), cmap = "hot", alpha = 0.7, vmin = vmin, vmax = vmax)

    ax.axis('off')
    utils.add_lines(size, levels, ax)

    #cbar = plt.colorbar(im, ax = ax)
    #cbar.ax.tick_params(labelsize=10)

    return None

In [None]:
# Models
device = "cuda"
models_dir = 'models-google'

models_name = [m.split('_')[1][:-4] for m in os.listdir(models_dir) if not m == "weights"]
models = {
    m : torch.load(os.path.join(models_dir, "model_{}.pth".format(m))) for m in ["standard", "oracle"]
}

# set up the classification thresholds
thresholds = {}

results_dir = "results-google"

for file in os.listdir(results_dir):
    f = json.load(open(os.path.join(results_dir, file)))
    name = file.split('_')[1][:-5]
    thresholds[name] = f['best_threshold']

print(models.keys())

## Evaluation of the models (Table 2)

### Overall results

In [None]:
# Generate the test dataset: IGN images

images_list = json.load(open("data/images_lists.json"))

dataset_dir = "../../data/bdappv"
batch_size = 512

# baseline transforms: no corruptions
baseline = torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)),
])

datasets = {}

for case in ['google', 'ign']:

    path = os.path.join(dataset_dir, case)

    if case == "google":

        dataset = bdappv.BDAPPVClassification(path, size = 200, transform=baseline, images_list=images_list["test"], random = False, downsample=200)
    else:
        dataset = bdappv.BDAPPVClassification(path, size = 200, images_list=images_list["test"], random = False, transform = baseline)

    database = DataLoader(dataset, batch_size=batch_size)
    datasets[case] = database

In [None]:
# plot some examples

fig, ax = plt.subplots(2,4, figsize = (16,12))
indices = [10, 30, 128, 42]

def NormalizeData(data):
    """helper to normalize in [0,1] for the plots"""
    return (data - np.min(data)) / (np.max(data) - np.min(data))

for i, key in enumerate(datasets.keys()):
    # display an example from each dataset

    images, labels, _ = next(iter(datasets[key]))

    for j, index in enumerate(indices):
        example = images[index,:,:,:].numpy().swapaxes(0,2)
        ax[i,j].imshow(NormalizeData(example))
        ax[i,j].axis('off')

fig.tight_layout()
# plt.savefig('figures/examples.pdf')
plt.show()

In [None]:
# Evaluation

results = {n : {} for n in models.keys()}

for name in models.keys():

    model = models[name]
    threshold = thresholds[name]

    print('Evaluating {} .... '.format(name))

    for case in datasets.keys():
    
        results[name][case] = utils.evaluate(model, datasets[case], device, threshold)

In [None]:
with open('results-updated.json', 'w') as f:
    json.dump(results, f, cls=helpers.NpEncoder)

In [None]:
# load the results if the file already exists
results = json.load(open('results-updated.json'))

case = 'ign'

labels = {
    'ign' :    {},
    'google' : {}
}

probs = {
    "ign"    : {},
    'google' : {}
}

print('---- Main matter table - case : {} ----'.format(case))

for name in results.keys():

    f1, tp, fp, tn, fn, predictions, probabilities, truth = results[name][case]
    s = '{} & {:0.2f} & {} & {} & {} & {}'.format(name, f1, tp, tn, fp, fn) 

    labels[case][name] = predictions
    probs [case][name] = (probabilities, truth)
    
    print(s)

case = 'google'

print('\n ---- Appendix table - case : {} ----'.format(case))

for name in results.keys():

    f1, tp, fp, tn, fn, predictions, probabilities, truth = results[name][case]
    s = '{} & {:0.2f} & {} & {} & {} & {}'.format(name, f1, tp, tn, fp, fn) 

    labels[case][name] = predictions
    probs[case][name] = (probabilities, truth)
    
    print(s)

## Augmentations

Plots of the different augmentation strategies benchmarked in the paper. 

In [None]:
from src import transforms
policy = torchvision.transforms.AutoAugmentPolicy.IMAGENET


from src import transforms
dataset_dir = "../../data/bdappv"
indices = [42, 95, 20]
# Load your image (replace 'input_image.jpg' with your image path)
input_images = [
    Image.open(os.path.join(dataset_dir + '/google/img', os.listdir(os.path.join(dataset_dir, 'google/img'))[index])).convert('RGB') for index in indices
    ]

transforms_set = [torchvision.transforms.RandAugment(), torchvision.transforms.AutoAugment(policy), transforms.AugMix()] # , transforms.Spectral(), transforms.NoiseAndBlur(sigma_b = 2., sigma_n = 0., deterministic = True)]

transforms_set = [transforms.NoiseAndBlur(sigma_b = 2., sigma_n = 0., deterministic = True), transforms.NoiseAndBlur(sigma_b = 2., sigma_n = 40., deterministic = False), transforms.Spectral()]

imgs = {}

for name, trans in zip(["Blurring", "Blur and Noise", "Blurring + WP"], transforms_set):#, "Blurring + WP", "Blurring"], transforms_set):

    tr = torchvision.transforms.Compose([
        torchvision.transforms.CenterCrop(224),
        trans
        ])

    imgs[name] = [tr(input_image) for input_image in input_images]

fig, ax = plt.subplots(4,3, figsize = (12,16))
plt.rcParams.update({'font.size': 18})

def NormalizeData(data):
    """helper to normalize in [0,1] for the plots"""
    return (data - np.min(data)) / (np.max(data) - np.min(data))

crop = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(224)
])


for i, name in enumerate(list(imgs.keys())):
    ax[0,i].imshow(crop(input_images[i]))
    ax[0,i].set_title('Original image')
    ax[0,i].axis('off')

    images = imgs[name]

    for j, img in enumerate(images):

        ax[i+1,j].imshow(img)
        ax[i+1,j].set_title(name)
        ax[i+1,j].axis('off')


fig.tight_layout()
plt.savefig('examples_augmentations_ours.pdf')
plt.show()