### This notebook aims to study of the influence of natural perturbations under increasing severity levels.
We first demonstrate some perturbed image samples under different perturbation functions and severity levels. Then we study the influence on model/detector performance and robustness metrics of natural perturbations under increasing severity levels.

In [None]:
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm.notebook import tqdm
from torchvision.utils import save_image

os.chdir(os.path.dirname(os.getcwd()))
print("Current working directory: ", os.getcwd())

from utils.attackers import build_attacker
from utils.test_utils import setup_seed
from utils.dataloader import load_dataset
from utils.visualize import *

# Load configs: benchmarks, model variants, OoD datasets and save directory.
with open('config.yaml', 'r') as f:
    configs = yaml.safe_load(f)
    
score_functions = configs["score_functions"]
perturb_functions = configs["perturb_functions"]
batch_size = configs["batch_size"]
device = configs["device"]
rand_seed = configs["rand_seed"]
data_dir = configs["datadir"]
severity = configs["severity"]

# Define the order of perturbation functions and model variants in visualizations.
perturb_function_sorter = ["rotation", "translation", "scale", "hue", "saturation", "bright_contrast", "blur", "Linf", "L2", "average"]
variant_sorter = ["NT", "DA", "AT", "PAT"]

sns.set_theme(style="whitegrid")
print("device:", device)
setup_seed(rand_seed)


#### 1. Demonstration of perturbed samples under increasing levels of perturbation severity.
We randomly select 5 ID samples from each benchmark (CIFAR10, ImageNet100) and demonstrate the corresponding perturbed samples under increasing levels of perturbation severity. The example images are saved in `results/eval/severity_levels/perturbed_samples_demo/` folder.

- The severity level is defined as 1-5.
- We consider 9 functional perturbations in 4 categories as follows:

<div align="center">

<table>
    <tr>
        <th>Category</th>
        <th>Perturbation</th>
        <th>Parameters</th>
    </tr>
    <tr>
        <td rowspan="3">Geometric Transformation</td>
        <td>rotation</td>
        <td>±[6°, 12°, 18°, 24°, 30°]</td>
    </tr>
    <tr>
        <td>translation</td>
        <td>±[6°, 12°, 18°, 24°, 30°]</td>
    </tr>
    </tr>
    <tr>
        <td>scale</td>
        <td>±[6%, 12%, 18%, 24%, 30%]</td>
    </tr>
    <tr>
        <td rowspan="3">Colour-Shifted Function based on HSB</td>
        <td>hue</td>
        <td>±[0, 0.06, 0.09, 0.12, 0.15, 0.18]π</td>
    </tr>
    <tr>
        <td>saturation</td>
        <td>1±[0, 0.16, 0.32, 0.48, 0.64, 0.80]</td>
    </tr>
    <tr>
        <td>bright_contrast</td>
        <td>bright=±[0.06, 0.12, 0.18, 0.24, 0.30]<br>
        cont=±[0.06, 0.12, 0.18, 0.24, 0.30]
        </td>
    </tr>
    <tr>
        <td>Gaussian Blur</td>
        <td>blur</td>
        <td>CIFAR10: σ=[0.4, 0.6, 0.7, 0.8, 1.0]<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;kernel_size=[5, 5, 7, 7, 9]<br>
        ImageNet: σ=[1, 2, 3, 4, 6]<br>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;kernel_size=[9, 17, 25, 33, 49]
        </td>
    </tr>
    <tr>
        <td rowspan="2">Additive perturbation bounded by $L_p$ norm</td>
        <td>L<sub>inf</sub> noise</td>
        <td>CIFAR10: ε=[0.016, 0.032, 0.048, 0.064, 0.080]<br>
        ImageNet: ε=[0.04,0.08,0.12,0.16,0.20]
        </td>
    </tr>
    <tr>
        <td>L<sub>2</sub> noise</td>
        <td>CIFAR10: ε=[0.016, 0.032, 0.048, 0.064, 0.080]<br>
        ImageNet: ε=[0.04, 0.08, 0.12, 0.16, 0.20]
        </td>
    </tr>
</table>

</div>


In [None]:
# Demonstrate the perturbed sammples under 1-5 severity levels
save_dir = os.path.join("results", "eval", "severity_levels","perturbed_samples_demo")
os.makedirs(save_dir, exist_ok=True)

for benchmark in configs["benchmark"]:
    img_size = configs["benchmark"][benchmark]["img_size"]
    data_set, data_loader = load_dataset("dataset/", benchmark, img_size=img_size, benchmark=benchmark, 
                                            split="test", batch_size=1)

    x = torch.stack([data_set[i][0] for i in torch.randint(0, len(data_set), (5,))], dim=0)

    for perb_func in perturb_functions:
        
        fig, axes = plt.subplots(nrows=6, ncols=5, figsize=(10, 15), facecolor="white")
        axes_o = axes[0]
        for ax, img in zip(axes_o, x):
            img = np.transpose(img, (1,2,0))
            if img.shape[-1] == 1:
                img.squeeze()
            ax.imshow(img, cmap='viridis', vmax=1.0, vmin=0.0)
            ax.axis("off")
            if ax is axes_o[0]:
                ax.set_title("Original image")

        for severity_level in range(1, 6):

            attacker = build_attacker(perb_func, severity_level, benchmark=benchmark)
            x_perb = attacker.random_perturb(x, n_repeat=1, device=device)

            axes_s = axes[severity_level]
            for ax, img in zip(axes_s, x_perb):
                img = np.transpose(img, (1,2,0))
                if img.shape[-1] == 1:
                    img.squeeze()
                ax.imshow(img, cmap='viridis', vmax=1.0, vmin=0.0)
                ax.axis("off")
                if ax is axes_s[0]:
                    params = {k: v for k, v in attacker.__dict__.items() if v != 0}
                    ax.set_title(f"Severity level {severity_level}, \nparameters: {params}",
                                 loc="left")

        plt.suptitle(f"Perturbation: {perb_func}. Benchmark: {benchmark}")
        save_path = os.path.join(save_dir, f"{benchmark}_{perb_func}.png")
        plt.savefig(save_path)
        plt.close()

#### 2. Generate and save perturbed samples as corrupted dataset.
We add one of the above-mentioned natural perturbation under predefined severity (1-5 or average) to each image in the original dataset. The corrupted images are saved in `dataset/perturbed/` folder.

In [None]:
# Generate and save perturbed samples as corrupted dataset.
for benchmark in configs["benchmark"]:
    img_size = configs["benchmark"][benchmark]["img_size"]
    ood_datasets = configs["benchmark"][benchmark]["ood_datasets"]
    for dataset in  [benchmark] + ood_datasets:
        print("Dataset: ", dataset)
        for attacker_name in configs["perturb_functions"]:
            print("perturb_function: ", attacker_name)
            attacker = build_attacker(attacker_name, severity_level=severity, img_size=img_size)
            # Try your own dataset here by replacing load_dataset()
            data_set, data_loader = load_dataset(data_dir, dataset_name=dataset, img_size=img_size,
                                                 split="test", benchmark=benchmark, batch_size=batch_size)

            # Generate n_sampling perturbed samples from each seed
            save_dir = os.path.join(data_dir, "perturbed", f"{dataset}_{attacker_name}_{severity.replace('all', 'avg')}")
            os.makedirs(save_dir, exist_ok=True)
            idx = 0
            for i, (x, y) in tqdm(enumerate(data_loader)):
                x = x.to(device)
                y = y.to(device)
                x_perb = attacker.random_perturb(x, n_repeat=1, seed=rand_seed, device=device)
                if i == 0:
                    save_image(x_perb[:8], os.path.join(save_dir, "demo.png"))

                for x_p, y_p in zip(x_perb, y):
                    cls_save_dir = os.path.join(save_dir, str(y_p.item()))
                    os.makedirs(cls_save_dir, exist_ok=True)
                    img_save_path = os.path.join(cls_save_dir, f"{idx}.png")
                    if not os.path.exists(img_save_path):
                        save_image(x_p, img_save_path)
                    idx += 1
        