### This notebook aims to study of the influence of natural perturbations under increasing severity levels.
We first demonstrate some perturbed image samples under different perturbation functions and severity levels. Then we study the influence on model/detector performance and robustness metrics of natural perturbations under increasing severity levels.

In [None]:
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

os.chdir(os.path.dirname(os.getcwd()))
print("Current working directory: ", os.getcwd())

from utils.attackers import build_attacker
from utils.test_utils import setup_seed
from utils.dataloader import load_dataset
from utils.visualize import *

# Load configs: benchmarks, model variants, OoD datasets and save directory.
with open('config.yaml', 'r') as f:
    configs = yaml.safe_load(f)
    
score_functions = configs["score_functions"]
perturb_functions = configs["perturb_functions"]
rand_seed = configs["rand_seed"]


# Define the order of perturbation functions and model variants in visualizations.
perturb_function_sorter = ["rotation", "translation", "scale", "hue", "saturation", "bright_contrast", "blur", "Linf", "L2", "average"]
variant_sorter = ["NT", "DA", "AT", "PAT"]

sns.set_theme(style="whitegrid")


#### 1. Model / detector robustness under increasing levels of perturbation severity.
For each perturbation, we demonstrate the MAE/DAE rate with increasing levels of severity.

The visualization is in the format of line charts. All the statistics can be inspected in the tables generated in `results/eval/severity_levels/robustness/` folders. The generated figures are saved in `results/eval/severity_levels/robustness/figures/` folder.

- The severity level is 1-5 as demonstrated in the Cell 3-4.
- We consider different OoD detectors and various perturbations.
- We compare among different model variants (NT, DA, AT, PAT).

In [None]:
robustness_dir = os.path.join("results", "eval", "severity_levels", "robustness")
save_dir = os.path.join("results", "eval", "severity_levels", "robustness", "figures")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(os.path.join(save_dir, "perturbations"), exist_ok=True)
os.makedirs(os.path.join(save_dir, "detectors"), exist_ok=True)

for benchmark in configs["benchmark"]:
    
    # ID dataset
    # Load the MAE/DAE rate table on ID dataset.
    df_id = None
    file_path = os.path.join(robustness_dir, f"{benchmark.lower()}_id_local_mae_dae_rate_mean_std.csv")
    if os.path.exists(file_path):
        df_id = pd.read_csv(file_path).copy()
        
        df_id["perturb_function"] = df_id["perturb_function"].astype("category")
        df_id["perturb_function"] = df_id["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)
        df_id["variant"] = df_id["variant"].astype("category")
        df_id["variant"] = df_id["variant"].cat.set_categories(variant_sorter, ordered=True)
    
    # OOD dataset
    # Load the DAE rate table on OoD dataset.
    df_ood = None
    file_path = os.path.join(robustness_dir, f"{benchmark.lower()}_ood_local_dae_rate_mean_std.csv")
    if os.path.exists(file_path):
        df_ood = pd.read_csv(file_path).copy()
    
        df_ood["perturb_function"] = df_ood["perturb_function"].astype("category")
        df_ood["perturb_function"] = df_ood["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)
        df_ood["variant"] = df_ood["variant"].astype("category")
        df_ood["variant"] = df_ood["variant"].cat.set_categories(variant_sorter, ordered=True)

    if (df_id is not None) and (df_ood is not None):

        for model_name in configs["benchmark"][benchmark]["model"]:

            df_ood_ = df_ood[(df_ood["model"]==model_name) & (df_ood["dataset"]=="average")].copy()
            df_ood_["dataset"] = "OoD"
            df_id_ = df_id[(df_id["model"]==model_name)].copy()
            df_id_["dataset"] = "ID"

            data = pd.concat([df_id_, df_ood_], axis=0).copy()
            data = data.sort_values(by=["perturb_function", "variant", "severity"]).copy()
            data = data[data["severity"]!="average"].copy()
            if len(data) == 0:
                continue
            perturb_functions = data["perturb_function"].unique()
            
            # MAE rate - severity level plot
            # ax=perturb_function
            n_cols = 3
            n_rows = int(np.ceil(len(perturb_functions) / float(n_cols)))
            fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
            axes = axes.flatten()
            for ax_i, ax in enumerate(axes):
                if ax_i < len(perturb_functions):
                    perturb_func = perturb_functions[ax_i]
                    data_i = data[(data["perturb_function"]==perturb_func) & (data["dataset"]=="ID")].copy()
                    sns.lineplot(data=data_i, x="severity", y="mae_mean", hue="variant", ax=ax, style="variant", marker="o")
                    
                    # # Demonstrate the standard deviation of DAE rate.
                    # for v in data_i["variant"].unique():
                    #     data_v = data_i[data_i["variant"]==v].copy()
                    #     ax.fill_between(data_v["severity"].to_numpy(), 
                    #                     data_v["mae_mean"]+data_v["mae_std"].to_numpy(), 
                    #                     data_v["mae_mean"]-data_v["mae_std"].to_numpy(), alpha=0.1)
                    
                    ax.set_xlim(0, 4.2)
                    ax.set_title(f"Perturbation={perturb_func}")
                    ax.set_ylabel("MAE rate (%)")
                    ax.set_xlabel("Severity level")
                    ax.legend(loc="upper left", ncols=2)

            plt.suptitle(f"Model robustness - MAE rate with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}")
            plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_mae_severity.png"))
            plt.close("all")

            # ID & OoD DAE rate - severity level plot under different perturbations.
            for score_func in score_functions:
                # ax=perturb_function
                n_cols = 3
                n_rows = int(np.ceil(len(perturb_functions) / float(n_cols)))
                fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
                axes = axes.flatten()
                for ax_i, ax in enumerate(axes):
                    if ax_i < len(perturb_functions):
                        perturb_func = perturb_functions[ax_i]
                        data_i = data[data["perturb_function"]==perturb_func].copy()
                        sns.lineplot(data=data_i, x="severity", y=f"dae_{score_func}_mean", hue="dataset", style="variant", ax=ax, marker="o")
                        # # Demonstrate the standard deviation of DAE rate.
                        # for v in data_i["variant"].unique():
                        #     for dataset in data_i["dataset"].unique():
                        #         data_v = data_i[(data_i["variant"]==v) & (data_i["dataset"]==dataset)].copy()
                        #         ax.fill_between(data_v["severity"].to_numpy(), 
                        #                         data_v[f"dae_{score_func}_mean"]+data_v[f"dae_{score_func}_std"].to_numpy(), 
                        #                         data_v[f"dae_{score_func}_mean"]-data_v[f"dae_{score_func}_std"].to_numpy(), alpha=0.1)
                        ax.set_title(f"Perturbation={perturb_func}")
                        ax.set_ylabel("DAE rate (%)")
                        ax.set_xlabel("Severity level")
                        ax.set_xlim(0, 4.2)
                        ax.legend(loc="upper left", ncols=2)
                    else:
                        ax.axis("off")

                plt.suptitle(f"OoD Detector robustness - DAE rate with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}, OoD detector={score_func}")
                plt.savefig(os.path.join(save_dir, "perturbations", f"{benchmark.lower()}_{model_name}_{score_func}_dae_severity.png"))
                plt.close("all")
            
            # ID & OoD DAE rate - severity level plot for different OoD detectors.
            for perb_func in perturb_functions:
                n_cols = 3
                n_rows = int(np.ceil(len(score_functions) / float(n_cols)))
                fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
                axes = axes.flatten()
                data_ = data[data["perturb_function"]==perb_func].copy()
                for ax_i, ax in enumerate(axes):
                    if ax_i < len(score_functions):
                        score_func = score_functions[ax_i]
                        sns.lineplot(data=data_, x="severity", y=f"dae_{score_func}_mean", hue="dataset", style="variant", ax=ax, marker="o")
                        ax.set_title(f"Detector={score_func}")
                        ax.set_ylabel("DAE rate (%)")
                        ax.set_xlabel("Severity level")
                        ax.set_xlim(0, 4.2)
                        ax.legend(loc="upper left", ncols=2)
                    else:
                        ax.axis("off")

                plt.suptitle(f"OoD Detector robustness - DAE rate with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}, Perturbation={perb_func}")
                plt.savefig(os.path.join(save_dir, "detectors", f"{benchmark.lower()}_{model_name}_{perb_func}_dae_severity.png"))
                plt.close("all")


#### 2. Model / detector performance under incresing levels of perturbation severity.
For each perturbation, we demonstrate the model accuracy / detector FPR95/AUROC/AUPR with increasing levels of severity.

The visualization is in the format of line charts. All the statistics can be inspected in the tables generated in `results/eval/severity_levels/performance/` folder. The generated figures are saved in `results/eval/severity_levels/performance/figures/` folder.

- The severity level is 1-5 as demonstrated in the Cell 3-4.
- We consider different OoD detectors and various perturbations.
- We compare among different model variants (NT, DA, AT, PAT).

In [None]:
save_dir = os.path.join("results", "eval", "severity_levels", "performance", "figures")
os.makedirs(save_dir, exist_ok=True)

# Load DNN model and OoD detectors' performance metrics.
performance_dir = os.path.join("results", "eval", "severity_levels", "performance")
df_model_perf_seed, df_model_perf_perb, df_detector_perf_seed, df_detector_perf_perb = None, None, None, None
file_path = os.path.join(performance_dir, "model_performance_seed.csv")
if os.path.exists(file_path):
    df_model_perf_seed = pd.read_csv(file_path).copy()
    df_model_perf_seed = df_model_perf_seed.set_index(["benchmark", "model"])
file_path = os.path.join(performance_dir, "model_performance_perb.csv")
if os.path.exists(file_path):
    df_model_perf_perb = pd.read_csv(file_path).copy()
    df_model_perf_perb = df_model_perf_perb.set_index(["benchmark", "model"])

file_path = os.path.join(performance_dir, "detector_performance_seed.csv")
if os.path.exists(file_path):
    df_detector_perf_seed = pd.read_csv(file_path).copy()
    df_detector_perf_seed = df_detector_perf_seed.set_index(["benchmark", "model", "score_function"])
file_path = os.path.join(performance_dir, "detector_performance_perb.csv")
if os.path.exists(file_path):
    df_detector_perf_perb = pd.read_csv(file_path).copy()
    df_detector_perf_perb = df_detector_perf_perb.set_index(["benchmark", "model", "score_function"])

for benchmark in configs["benchmark"]:
    
    if (df_model_perf_seed is not None) and (df_model_perf_perb is not None):
        
        for model_name in configs["benchmark"][benchmark]["model"]:
            
            # Model performance - severity level plot
            if (benchmark, model_name) in df_model_perf_seed.index and (benchmark, model_name) in df_model_perf_perb.index:
                df_model_perf_seed_ = df_model_perf_seed.loc[[(benchmark, model_name)]].copy().reset_index(drop=True)
                df_model_perf_perb_ = df_model_perf_perb.loc[[(benchmark, model_name)]].copy().reset_index(drop=True)
                df_model_perf_seed_["severity"] = "0 - original"

                data = df_model_perf_perb_.copy()
                
                perturb_functions = data["perturb_function"].unique()
                print(perturb_functions)
                for perb_func in perturb_functions:
                    df_model_perf_seed__ = df_model_perf_seed_.copy()
                    df_model_perf_seed__["perturb_function"] = perb_func
                    data = pd.concat([data, df_model_perf_seed__], axis=0).copy()
                
                data["perturb_function"] = data["perturb_function"].astype("category")
                data["perturb_function"] = data["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)
                data["variant"] = data["variant"].astype("category")
                data["variant"] = data["variant"].cat.set_categories(variant_sorter, ordered=True)
            
                data = data[data["severity"]!="average"].copy()
                data = data.sort_values(by=["perturb_function", "variant", "severity"]).copy()                
                
                if len(data) > 0:
                    # ax=perturb_function
                    n_cols = 3
                    n_rows = int(np.ceil(len(perturb_functions) / float(n_cols)))
                    fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
                    axes = axes.flatten()
                    for ax_i, ax in enumerate(axes):
                        if ax_i < len(perturb_functions):
                            perturb_func = perturb_functions[ax_i]
                            data_i = data[(data["perturb_function"]==perturb_func)].copy()
                            sns.lineplot(data=data_i, x="severity", y="accuracy", hue="variant", ax=ax, style="variant", 
                                         marker="o", errorbar=None)
                            
                            ax.set_title(f"Perturbation={perturb_func}")
                            ax.set_ylabel("Model accuracy (%)")
                            ax.set_xlabel("Severity level")
                            ax.set_xlim(0, 5.2)
                            ax.legend(loc="lower left", ncols=4)
                        else:
                            ax.axis("off")

                    plt.suptitle(f"Model performance - Accuracy with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}")
                    plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_accuracy_severity.png"))
                    plt.close("all")

            for score_func in score_functions:
                # FPR95 - severity level plot
                if (benchmark, model_name, score_func) in df_detector_perf_seed.index and (benchmark, model_name, score_func) in df_detector_perf_perb.index:
                    df_detector_perf_seed_ = df_detector_perf_seed.loc[[(benchmark, model_name, score_func)]].copy().reset_index(drop=True)
                    df_detector_perf_perb_ = df_detector_perf_perb.loc[[(benchmark, model_name, score_func)]].copy().reset_index(drop=True)
                    df_detector_perf_seed_["severity"] = "0 - original"

                    data = df_detector_perf_perb_.copy()
                    perturb_functions = data["perturb_function"].unique()

                    for perb_func in perturb_functions:
                        df_detector_perf_seed__ = df_detector_perf_seed_.copy()
                        df_detector_perf_seed__["perturb_function"] = perb_func
                        data = pd.concat([data, df_detector_perf_seed__], axis=0).copy()
                    
                    data["perturb_function"] = data["perturb_function"].astype("category")
                    data["perturb_function"] = data["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)
                    data["variant"] = data["variant"].astype("category")
                    data["variant"] = data["variant"].cat.set_categories(variant_sorter, ordered=True)
    
                    data = data[(data["severity"]!="average") & (data["dataset"]!="average")].copy()
                    data = data.sort_values(by=["perturb_function", "variant", "severity"]).copy()                
                    
                    if len(data) > 0:
                        # ax=perturb_function
                        n_cols = 3
                        n_rows = int(np.ceil(len(perturb_functions) / float(n_cols)))
                        fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
                        axes = axes.flatten()
                        for ax_i, ax in enumerate(axes):
                            if ax_i < len(perturb_functions):
                                perturb_func = perturb_functions[ax_i]
                                data_i = data[(data["perturb_function"]==perturb_func)].copy()
                                sns.lineplot(data=data_i, x="severity", y="FPR95", hue="variant", ax=ax, style="variant", 
                                             marker="o", errorbar=None)
                                
                                ax.set_title(f"Perturbation={perturb_func}")
                                ax.set_ylabel("FPR95 (%)")
                                ax.set_xlabel("Severity level")
                                ax.set_xlim(0, 5.2)
                                ax.legend(loc="lower left", ncols=4)
                            else:
                                ax.axis("off")

                        plt.suptitle(f"OoD Detector performance - FPR95 with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}, OoD detector={score_func}")
                        plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{score_func}_fpr95_severity.png"))
                        plt.close("all")

                        fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
                        axes = axes.flatten()
                        for ax_i, ax in enumerate(axes):
                            if ax_i < len(perturb_functions):
                                perturb_func = perturb_functions[ax_i]
                                data_i = data[(data["perturb_function"]==perturb_func)].copy()
                                sns.lineplot(data=data_i, x="severity", y="AUROC", hue="variant", ax=ax, style="variant", 
                                             marker="o", errorbar=None)
                                
                                ax.set_title(f"Perturbation={perturb_func}")
                                ax.set_ylabel("AUROC")
                                ax.set_xlabel("Severity level")
                                ax.set_xlim(0, 5.2)
                                ax.legend(loc="lower left", ncols=4)
                            else:
                                ax.axis("off")

                        plt.suptitle(f"OoD Detector performance - AUROC with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}, OoD detector={score_func}")
                        plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{score_func}_auroc_severity.png"))
                        plt.close("all")

                        fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*6, n_rows*6), layout="constrained", sharey=True)
                        axes = axes.flatten()
                        for ax_i, ax in enumerate(axes):
                            if ax_i < len(perturb_functions):
                                perturb_func = perturb_functions[ax_i]
                                data_i = data[(data["perturb_function"]==perturb_func)].copy()
                                sns.lineplot(data=data_i, x="severity", y="AUPR_IN", hue="variant", ax=ax, style="variant", 
                                             marker="o", errorbar=None)
                                
                                ax.set_title(f"Perturbation={perturb_func}")
                                ax.set_ylabel("AUPR")
                                ax.set_xlabel("Severity level")
                                ax.set_xlim(0, 4.2)
                                ax.legend(loc="lower left", ncols=4)
                            else:
                                ax.axis("off")

                        plt.suptitle(f"OoD Detector performance - AUPR with increasing severity levels of perturbations.\nBenchmark={benchmark}, Model={model_name}, OoD detector={score_func}")
                        plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{score_func}_aupr_severity.png"))
                        plt.close("all")
