### This notebook aims at studing interesting cases of high-confidence detector adversarial examples (HDAE) by visualization of the statistics at different confidence levels.

- ID dataset:
    - Level 1 - DAE sample: g(x) <= theta, g(x') > theta
    - Level 2 - HDAE sample: g(x) <= theta, g(x') > theta, s(x') > 1-alpha
    - Level 3
        - Case 1 - HDAE-HCS sample: g(x) <= theta, g(x') > theta, s(x') > 1-alpha, s(x) > 1-alpha
        - Case 2 - HDAE-LCS sample: g(x) <= theta, g(x') > theta, s(x') > 1-alpha, s(x) <= 1-alpha
- OoD dataset: 
    - Level 1 - DAE sample: g(x) > theta, g(x') <= theta
    - Level 2 - HDAE sample: g(x) > theta, g(x') <= theta, s(x') > 1-alpha
    - Level 3
        - Case 1 - HDAE-HCS sample: g(x) > theta, g(x') <= theta, s(x') > 1-alpha, s(x) > 1-alpha
        - Case 2 - HDAE-LCS sample: g(x) > theta, g(x') <= theta, s(x') > 1-alpha, s(x) <= 1-alpha

In [None]:
import numpy as np
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
os.chdir(os.path.dirname(os.getcwd()))
print("Current working directory: ", os.getcwd())
from utils.visualize import *

# Load configs: benchmarks, model variants, OoD datasets and save directory.
with open('config.yaml', 'r') as f:
    configs = yaml.safe_load(f)
    
score_functions = configs["score_functions"]
perturb_functions = configs["perturb_functions"]
rand_seed = configs["rand_seed"]

# Define the order of perturbation functions and model variants in visualizations.
perturb_function_sorter = configs["perturb_functions"]
variant_sorter = ["NT", "DA", "AT", "PAT"]

sns.set_style("whitegrid")

#### 1. Calculate and assembly the local HDAE rate for each benchmark, model variant and OoD detector.


In [None]:
# Load and assembly local DAE rate, search for HDAE samples.
def get_local_hdae_rate(benchmark, ood_datasets, model_name, weights, alphas, refresh=False):
    """
    This function calculates the local MAE/DAE rate of each ID/OoD seed, for each OoD 
    detector under different perturbations. The outputs are saved in csv files in 
    "results/eval/intermediate_results/" folder.

    Args:
        benchmark (str): The benchmark for the evaluation. Can be selected from ["CIFAR10", "Imagenet100"].
        ood_datasets (list): List of OoD datasets.
        model_name (str): The name of the model. Can be selected from ["wrn_40_2", "resnet50"].
        weights (list): List of weight variants. Can be selected from ["NT", "DA", "AT", "PAT"].
        refresh (bool, optional): Whether to refresh the output files. Defaults to False.
        
    Returns:
        None
    
    """
    print("\n=========================")
    print("Calculating local MAE/DAE rates.")
    print("=========================")
    print("Benchmark:", benchmark)
    print("Model_name:", model_name)
    for weight_variant in weights:
        print("-------------------------")
        print("weight_variant:", weight_variant)

        rlt_dir = "results/eval/intermediate_results/" + f"{benchmark.lower()}_{model_name}_{weight_variant}"
        if not os.path.exists(rlt_dir):
            print("  "+rlt_dir+" not exists!")
            continue
        
        print("> Calculating the local HMAE/HDAE rate on ID dataset.")
        save_path = os.path.join(rlt_dir, f"id_local_hmae_hdae_rate.csv")
        if (not os.path.exists(save_path)) or refresh: 
            file_path = os.path.join(rlt_dir, f"id_mae_dae_record.csv")
            if os.path.exists(file_path):
                df_hdae = pd.DataFrame()
                for alpha in alphas:
                    df = pd.read_csv(file_path).copy()
                    for col in df.columns:
                        if "dae_" in col or "mae" in col:
                            df[col] = (df[col] & (df["conf_perb"]>(1-alpha))).copy()

                    df = df.groupby(["idx", "perturb_function"]).mean(numeric_only=False).copy().reset_index()
                    df["alpha"] = alpha
                    df_hdae = pd.concat([df_hdae, df], axis=0).copy()
                
                df_hdae.to_csv(save_path, index=False)
                print("ID local HMAE/HDAE rate saved to: "+save_path)
            else:
                print("   "+file_path+" not exists!")
        
        print("> Calculating the local HDAE rate on OoD datasets.")
        for dataset in ood_datasets:
            save_path = os.path.join(rlt_dir, f"ood_local_hdae_rate_{dataset}.csv")
            if (not os.path.exists(save_path)) or refresh: 
                print(" - Dataset:", dataset)
                file_path = os.path.join(rlt_dir, f"ood_dae_record_{dataset}.csv")
                if not os.path.exists(file_path):
                    print("   "+file_path+" not exists!")
                    continue
                
                df_hdae = pd.DataFrame()
                for alpha in alphas:
                    df = pd.read_csv(file_path).copy()
                    for col in df.columns:
                        if "dae_" in col:
                            df[col] = (df[col] & (df["conf_perb"]>(1-alpha))).copy()

                    df = df.groupby(["idx", "perturb_function"]).mean(numeric_only=False).copy().reset_index()
                    df["alpha"] = alpha
                    df_hdae = pd.concat([df_hdae, df], axis=0).copy()
                df_hdae.to_csv(save_path, index=False)
                print("OoD local HDAE rate saved to: "+save_path)

# Calculate the local HDAE rate for each benchmark, model variant and OoD detector.
alphas = [0.0, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,1.0]
for benchmark in configs["benchmark"]:
    ood_datasets = configs["benchmark"][benchmark]["ood_datasets"]
    for model_name in configs["benchmark"][benchmark]["model"]:
        weights = list(configs["benchmark"][benchmark]["model"][model_name].keys())
        get_local_hdae_rate(benchmark, ood_datasets, model_name, weights, alphas, refresh=False)


#### 2. Make the HDAE rate - (1-alpha) plot on ID & OoD datasets.
- We study 12 OoD detectors under 9 perturbation functions.
- We campare among different model variants (NT, DA, AT, PAT).
- For each model variant, we compare the HDAE rate among different OoD datasets and under various perturbation functions.

#### 3. Make the HDAE-HCS/LCS - (1-alpha) plot on ID & OoD datasets.
- We study the percentage of HDAEs which originate from high-confidence seeds (HCS) and low-confidence seeds (LCS), respectively.

In [None]:
def hdae_alpha_plot(df_hdae, hue, score_functions, suptitle="", savepath="hdae_alpha_plot.png"):
    save_dir = os.path.dirname(savepath)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    ncols = 3
    nrows = int(np.ceil(len(score_functions)//float(ncols)))
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(8*ncols, 4*nrows), layout="constrained")
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        if i < len(score_functions):
            score_func = score_functions[i]
            if f"dae_{score_func}" not in df_hdae.columns:
                ax.axis("off")
                continue
            sns.lineplot(data=df_hdae, x="1-alpha", y=f"dae_{score_func}", ax=ax, hue=hue, 
                        style=hue, palette="tab10", markers=True)
            ax.set_xlabel("1-alpha")
            ax.set_ylabel("HDAE rate (%)")
            ax.set_title(f"Detector={score_func}")
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax.axis("off")
    plt.suptitle(suptitle)
    plt.savefig(savepath)

def hdae_hcs_lcs_alpha_plot(data, score_functions, suptitle="", savepath="hdae_hcs_lcs_alpha_plot.png"):
    save_dir = os.path.dirname(savepath)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    alphas = data["alpha"].unique()
    hdae_hcs = pd.DataFrame()
    hdae_lcs = pd.DataFrame()
    for alpha in alphas:
        df_hdae = data[data["1-alpha"]==(1-alpha)].copy()
        df_hdae.drop(["alpha", "1-alpha", "idx", "conf_seed", "conf_perb"], axis=1, inplace=True)
        df_hdae_hcs = df_hdae[df_hdae["HCS"]].drop("HCS", axis=1).copy()
        df_hdae_lcs = df_hdae[df_hdae["HCS"]==False].drop("HCS", axis=1).copy()
        hdae_hcs_ = df_hdae_hcs.sum(numeric_only=True) / len(df_hdae)
        hdae_lcs_ = df_hdae_lcs.sum(numeric_only=True) / len(df_hdae)
        hdae_hcs_["1-alpha"] = 1-alpha
        hdae_lcs_["1-alpha"] = 1-alpha
        hdae_hcs = pd.concat([hdae_hcs, hdae_hcs_], axis=1).copy()
        hdae_lcs = pd.concat([hdae_lcs, hdae_lcs_], axis=1).copy()

    hdae_hcs = hdae_hcs.T
    hdae_lcs = hdae_lcs.T

    ncols = 3
    nrows = int(np.ceil(len(score_functions)//float(ncols)))
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(8*ncols, 4*nrows), layout="constrained")
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        if i < len(score_functions):
            score_func = score_functions[i]
            if f"dae_{score_func}" not in data.columns:
                ax.axis("off")
                continue
            sns.lineplot(data=data, x="1-alpha", y=f"dae_{score_func}", label="HDAE", errorbar=None, marker="o", alpha=0.7, ax=ax)
            sns.lineplot(data=hdae_hcs, x="1-alpha", y=f"dae_{score_func}", label="HDAE_HCS", marker="^", alpha=0.7, ax=ax)  
            sns.lineplot(data=hdae_lcs, x="1-alpha", y=f"dae_{score_func}", label="HDAE_LCS", marker="v", alpha=0.7, ax=ax)
            ax.set_xlabel("1-alpha")
            ax.set_ylabel("HDAE rate (%)")
            ax.set_title(f"Detector={score_func}")
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax.axis("off")
            
    plt.suptitle(suptitle)
    plt.savefig(savepath)


save_dir = os.path.join("results/eval/robustness/figures", "hdae_alpha_plot")
os.makedirs(save_dir, exist_ok=True)

# ID HDAE rate - alpha plot
for benchmark in configs["benchmark"]:
    print(f"\n=========================")
    print(f"Benchmark: {benchmark}")
    for model_name in configs["benchmark"][benchmark]["model"]:
        df_hdae_id = pd.DataFrame()
        print("-------------------------")
        print(f"Model: {model_name}")

        print("> Plotting the HDAE rate - alpha relation on ID dataset.")
        for weight_name in configs["benchmark"][benchmark]["model"][model_name]:
            print(f"> Variant: {weight_name}")
            filepath = os.path.join("results/eval/intermediate_results", f"{benchmark.lower()}_{model_name}_{weight_name}", 
                                    "id_local_hmae_hdae_rate.csv")
            if not os.path.exists(filepath):
                print("File "+filepath+" does not exist!")
                continue
            df = pd.read_csv(filepath).copy()
            df["variant"] = weight_name
            df_hdae_id = pd.concat([df_hdae_id, df], axis=0).copy()
            
        df_hdae_id["1-alpha"] = 1-df_hdae_id["alpha"]
        cols = [col for col in df_hdae_id.columns if "dae" in col]
        df_hdae_id[cols] = df_hdae_id[cols].apply(lambda x: x*100, axis=1).copy()

        suptitle = f"HDAE rate - (1-alpha) plot on ID dataset.\nBenchmark={benchmark}, Model={model_name}"
        savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_id_hdae_alpha_plot.png")
        hdae_alpha_plot(df_hdae_id, "variant", score_functions, suptitle, savepath)

        # Ablation on perturbation functions
        for weight_name in configs["benchmark"][benchmark]["model"][model_name]:
            data = df_hdae_id[df_hdae_id["variant"]==weight_name].copy()
            suptitle = f"HDAE rate - (1-alpha) plot on ID dataset under different perturbation functions.\nBenchmark={benchmark}, Model={model_name}, Variant={weight_name}"
            savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{weight_name}_id_hdae_alpha_plot_perturb_function.png")
            hdae_alpha_plot(data, "perturb_function", score_functions, suptitle, savepath)
        
        # We study the percentage of HDAEs which originate from high-confidence/low-confidence seeds.
        df_hdae_id["HCS"] = df_hdae_id["conf_seed"] > df_hdae_id["1-alpha"]
        for weight_name in configs["benchmark"][benchmark]["model"][model_name]:
            data = df_hdae_id[df_hdae_id["variant"]==weight_name].copy()
            suptitle = f"HDAE_HCS and HDAE_LCS rate - (1-alpha) plot on ID dataset.\nBenchmark={benchmark}, Model={model_name}, Variant={weight_name}"
            savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{weight_name}_id_hdae_hcs_lcs_alpha_plot.png")
            hdae_hcs_lcs_alpha_plot(data, score_functions, suptitle, savepath)

        
        print("> Plotting the HDAE rate - alpha relation on OoD datasets.")
        df_hdae_ood = pd.DataFrame()
        for weight_name in configs["benchmark"][benchmark]["model"][model_name]:
            print(f"> Variant: {weight_name}")
            for dataset in configs["benchmark"][benchmark]["ood_datasets"]:
                print(f" - Dataset: {dataset}")
                filepath = os.path.join("results/eval/intermediate_results", f"{benchmark.lower()}_{model_name}_{weight_name}", 
                                        f"ood_local_hdae_rate_{dataset}.csv")
                if not os.path.exists(filepath):
                    print("File "+filepath+" does not exist!")
                    continue
                df = pd.read_csv(filepath)
                df["variant"] = weight_name
                df["dataset"] = dataset
                df["1-alpha"] = 1-df["alpha"]
                df_hdae_ood = pd.concat([df_hdae_ood, df], axis=0).copy()
        
        cols = [col for col in df_hdae_ood.columns if "dae" in col]
        df_hdae_ood[cols] = df_hdae_ood[cols].apply(lambda x: x*100, axis=1).copy()

        suptitle = f"HDAE rate - (1-alpha) plot on OoD datasets.\nBenchmark={benchmark}, Model={model_name}"
        savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_ood_hdae_alpha_plot.png")
        hdae_alpha_plot(df_hdae_ood, "variant", score_functions, suptitle, savepath)

        # Ablation on perturbation functions, OoD datasets
        for weight_name in configs["benchmark"][benchmark]["model"][model_name]:
            data = df_hdae_ood[df_hdae_ood["variant"]==weight_name].copy()
            suptitle = f"HDAE rate - (1-alpha) plot on OoD datasets under different perturbation functions.\nBenchmark={benchmark}, Model={model_name}, Variant={weight_name}"
            savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{weight_name}_ood_hdae_alpha_plot_perturb_function.png")
            hdae_alpha_plot(data, "perturb_function", score_functions, suptitle, savepath)

            suptitle = f"HDAE rate - (1-alpha) plot on OoD data for different datasets.\nBenchmark={benchmark}, Model={model_name}, Variant={weight_name}"
            savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{weight_name}_ood_hdae_alpha_plot_dataset.png")
            hdae_alpha_plot(data, "dataset", score_functions, suptitle, savepath)

        # We study the percentage of HDAEs which originate from high-confidence/low-confidence seeds.
        df_hdae_ood["HCS"] = df_hdae_ood["conf_seed"] > df_hdae_ood["1-alpha"]
        for weight_name in configs["benchmark"][benchmark]["model"][model_name]:
            data = df_hdae_ood[df_hdae_ood["variant"]==weight_name].copy()
            suptitle = f"HDAE_HCS and HDAE_LCS rate - (1-alpha) plot on OoD dataset.\nBenchmark={benchmark}, Model={model_name}, Variant={weight_name}"
            savepath = os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{weight_name}_ood_hdae_hcs_lcs_alpha_plot.png")
            hdae_hcs_lcs_alpha_plot(data, score_functions, suptitle, savepath)
            