### This notebook aims to inspect the image discrepency between the seed and its corresponding perturbed samples under different perturbations and severity levels.
We calculate and demonstrate image similarity using the 3 commonly-used similarity metrics:
- Mean square error (MSE),
- Structure Similarity Index Measure (SSIM)
- Pearson Correlation Coefficient (PCC)
<!-- - Normalized Phase-Cross-Correlation Coefficient (NCC). -->

The visualizations are in the format of line chart. The summary of average image discrepency and the visaulizations are saved in the `results/eval/image_discrepency/` folder.

In [None]:
import numpy as np
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

os.chdir(os.path.dirname(os.getcwd()))
print("Current working directory: ", os.getcwd())
from utils.test_utils import cal_image_sim
from utils.dataloader import load_dataset
from utils.attackers import build_attackers
from utils.eval import *
from utils.visualize import *

# Load configs: benchmarks, model variants, OoD datasets and save directory.
with open('config.yaml', 'r') as f:
    configs = yaml.safe_load(f)
    
score_functions = configs["score_functions"]
perturb_functions = configs["perturb_functions"]
rand_seed = configs["rand_seed"]
batch_size = configs["batch_size"]
n_seeds = configs["n_seeds"] // 10
n_sampling = configs["n_sampling"] // 10

# Define the order of perturbation functions and model variants in visualizations.
perturb_function_sorter = configs["perturb_functions"] + ["average"]
variant_sorter = ["NT", "DA", "AT", "PAT"]

sns.set_style("whitegrid")

if configs["device"] == "cuda" and torch.cuda.is_available():
    device = torch.device(configs["device"])
else:
    device = torch.device("cpu")
print("Device:", device)


In [None]:
save_dir = os.path.join("results", "eval", "image_discrepency")
os.makedirs(save_dir, exist_ok=True)

for benchmark in configs["benchmark"]:
    print("--------------------")
    print("Benchmark:", benchmark)
    img_size = configs["benchmark"][benchmark]["img_size"]
    # Build attackers for each severity level.
    attackers = []
    for severity_level in range(1, 6):
        attackers_ = build_attackers(perturb_functions, severity_level, img_size=img_size)
        attackers.append(attackers_)

    for dataset in [benchmark]+configs["benchmark"][benchmark]["ood_datasets"]:
        print("--------------------")
        print(f"Dataset: {dataset}")
        # Load ID & OoD test dataset
        data_set, data_loader = load_dataset("dataset/", dataset, img_size=img_size, benchmark=benchmark, 
                                             batch_size=batch_size, split="test", 
                                             shuffle=False)
        idx = torch.randperm(len(data_set))[:n_seeds]
        data_set_ = [data_set[i] for i in idx]

        df_sim_all = pd.DataFrame()
        
        for idx_temp, (x_temp, y_temp) in zip(idx, data_set_):
            df_sim = pd.DataFrame()
            x_temp = x_temp.to(device)
            for i, attacker_dict in enumerate(attackers):
                severity_level = i+1
                for perb_func, attacker in attacker_dict.items():
                    # Generate perturbed samples
                    with torch.no_grad():
                        x_perbs = attacker.random_perturb(x_temp.unsqueeze(0), n_repeat=n_sampling, 
                                                         device=device)
                    # Calculate the average local similarity between each seed and its 
                    # perturbed samples.
                    for x_perb in x_perbs:
                        
                        sim_dict = cal_image_sim(x_temp.permute(1, 2, 0), x_perb.permute(1, 2, 0))
                        # Record the local image similarity metrics.
                        sim_dict["perturb_function"] = perb_func
                        sim_dict["severity"] = severity_level
                        df_sim = pd.concat([df_sim, pd.DataFrame([sim_dict])], axis=0)
                
            df_sim_avg = df_sim.groupby(["perturb_function", "severity"]).mean().reset_index().copy()
            df_sim_avg_severity = df_sim.groupby(["perturb_function"]).mean().reset_index().copy()
            df_sim_avg_severity["severity"] = "average"
            df_sim_avg_perturb = df_sim.groupby(["severity"]).mean().reset_index().copy()
            df_sim_avg_perturb["perturb_function"] = "average"

            df_sim_avg = pd.concat([df_sim_avg, df_sim_avg_severity, df_sim_avg_perturb], 
                                ignore_index=True, axis=0).copy()
            df_sim_avg["idx"] = int(idx_temp)
            df_sim_all = pd.concat([df_sim_all, df_sim_avg], axis=0).copy()

        # Calculate the average image discrepency under different perturbations and severity levels.
        df_sim_all["1-ssim"] = 1 - df_sim_all["ssim"]
        df_sim_all["1-pcc"] = 1 - df_sim_all["pcc"]
        df_sim_all.drop(columns=["ssim", "pcc"], inplace=True)
        df_sim_all = df_sim_all.melt(id_vars=["idx", "perturb_function", "severity"],
                                    value_vars=["mse", "1-ssim", "1-pcc"],
                                    var_name="metric", value_name="distance")
        
        df_sim_all["perturb_function"] = df_sim_all["perturb_function"].astype("category")
        df_sim_all["perturb_function"] = df_sim_all["perturb_function"].cat.set_categories(
            perturb_function_sorter, ordered=True)
        df_sim_all.sort_values(["perturb_function", "severity"], inplace=True)

        # Visualize the average image discrepency under different perturbations (considering 
        # all severity levels).
        ncols = 1
        nrows = 1
        fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols*6, nrows*4), 
                                layout="constrained", sharey=True)
        data = df_sim_all[(df_sim_all["severity"]=="average")].copy()
        sns.lineplot(data=data, marker="o", errorbar=None,
                        x="perturb_function", y="distance", hue="metric", ax=ax)
        ax.set_title(f"Average Image Discrepency on {dataset} dataset ({'ID' if dataset==benchmark else 'OoD'}).\nBenchmark={benchmark}")
        ax.xaxis.set_tick_params(rotation=45)
        fig.savefig(os.path.join(save_dir, f"{benchmark}_{dataset}_image_discrepency_perturbation.png"))
        plt.close("all")

        # Visualize the average image discrepency under different perturbations with 
        # increasing levels of severity.
        ncols = 3
        nrows = int(np.ceil(len(perturb_functions) / float(ncols)))
        fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols*6, nrows*4), 
                                layout="constrained", sharey=True)                       
        axes = axes.flatten()
        for i, ax in enumerate(axes):
            if i < len(perturb_functions):
                perb_func = perturb_functions[i]
                data = df_sim_all[(df_sim_all["perturb_function"]==perb_func) & 
                                (df_sim_all["severity"]!="average")].copy()
                sns.lineplot(data=data, marker="o", errorbar=None,
                            x="severity", y="distance", hue="metric", ax=ax)
                ax.set_title(f"Perturbation={perturb_functions[i]}")
            else:
                ax.axis("off")
        fig.suptitle(f"Average Image Discrepency on {dataset} dataset ({'ID' if dataset==benchmark else 'OoD'}).\nBenchmark={benchmark}")
        fig.savefig(os.path.join(save_dir, f"{benchmark}_{dataset}_image_discrepency_severity.png"))
        plt.close("all")

