### This notebook aims to study the perturbation-induced details, including OoD scores distribution shift and local error rate statistics.

In [1]:
import numpy as np
import yaml
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

os.chdir(os.path.dirname(os.getcwd()))
print("Current working directory: ", os.getcwd())

from utils.eval import get_ood_measures
from utils.visualize import *

# Load configs: benchmarks, model variants, OoD datasets and save directory.
with open('config.yaml', 'r') as f:
    configs = yaml.safe_load(f)
    
score_functions = configs["score_functions"]
perturb_functions = configs["perturb_functions"]
rand_seed = configs["rand_seed"]

# Define the order of perturbation functions and model variants in visualizations.
perturb_function_sorter = configs["perturb_functions"]
variant_sorter = ["NT", "DA", "AT", "PAT"]


Current working directory:  /home/dingw/work/ood_robustness_eva


: 

#### 1. OoD scores distribution shift
We show the distribution of ID & OoD datasets' OoD scores for both seeds and perturbed samples. A good post-hoc OoD detection method should have a good separation of OoD scores between ID & OoD datasets. 

We consider 2 benchmarks (CIFAR10, ImageNet100), 2 model architectures (WRN-40-2, ResNet50), and 4 model variants (NT, DA, AT, PAT) for each architecture. We show the OoD scores distrubution shift under 9 perturbation functions and for 12 OoD detectors.

The visualizations are in the format of histogram and box plot. The generated figures are saved in `results/eval/performance/figures/statistics` folder.

In [None]:
# OoD scores distribution for ID & OoD seeds and perturbed samples.
save_dir = os.path.join("results", "eval", "performance", "figures", "statistics")
os.makedirs(save_dir, exist_ok=True)

for benchmark in configs["benchmark"]:
    for model_name in configs["benchmark"][benchmark]["model"]:
        for variant in configs["benchmark"][benchmark]["model"][model_name]:
            scores_dir = f"results/{benchmark.lower()}/{rand_seed}/{model_name}/{variant}"
            
            if not os.path.exists(os.path.join(scores_dir, benchmark, "scores","temp_scores.csv")):
                print("File", os.path.join(scores_dir, benchmark, "scores","temp_scores.csv"), "does not exist.")
                continue
            id_scores_seed = pd.read_csv(os.path.join(scores_dir, benchmark, "scores","temp_scores.csv")).copy()
            id_scores_seed["dataset"] = "ID"
            for dataset in configs["benchmark"][benchmark]["ood_datasets"]:
                if not os.path.exists(os.path.join(scores_dir, dataset, "scores","temp_scores.csv")):
                    print("File", os.path.join(scores_dir, dataset, "scores","temp_scores.csv"), "does not exist.")
                    continue
                ood_scores_seed = pd.read_csv(os.path.join(scores_dir, dataset, "scores","temp_scores.csv")).copy()
                ood_scores_seed["dataset"] = "OoD"

                scores_seed = pd.concat([id_scores_seed, ood_scores_seed], axis=0).copy()

                # Plot score distributions for seeds and perturbed samples under all perturbations.
                ncols = len(score_functions)
                nrows = 2
                fig, axes = plt.subplots(nrows, ncols, figsize=(8*ncols, 4*nrows), layout="constrained")
                axes_seed = axes[0]
                axes_perb = axes[1]
                for ax_seed, ax_perb, score_func in zip(axes_seed, axes_perb, score_functions):
                    scores_perb = pd.DataFrame()
                    for perb_func in perturb_functions:

                        id_scores_perb = pd.read_csv(os.path.join(scores_dir, benchmark, "scores", f"perb_{perb_func}_scores.csv"))[[f"{score_func}_score"]].copy()
                        id_scores_perb["dataset"] = "ID"

                        ood_scores_perb = pd.read_csv(os.path.join(scores_dir, dataset, "scores", f"perb_{perb_func}_scores.csv"))[[f"{score_func}_score"]].copy()
                        ood_scores_perb["dataset"] = "OoD"

                        scores_perb = pd.concat([scores_perb, id_scores_perb, ood_scores_perb], axis=0).copy()

                    id_scores_seed_ = scores_seed[scores_seed["dataset"]=="ID"][f"{score_func}_score"].to_numpy()
                    ood_scores_seed_ = scores_seed[scores_seed["dataset"]=="OoD"][f"{score_func}_score"].to_numpy()
                    id_scores_perb_ = scores_perb[scores_perb["dataset"]=="ID"][f"{score_func}_score"].to_numpy()
                    ood_scores_perb_ = scores_perb[scores_perb["dataset"]=="OoD"][f"{score_func}_score"].to_numpy()
                    auroc_seed, aupr_in_seed, aupr_out_seed, fpr_seed, threshold_seed, _, _ = \
                        get_ood_measures(id_scores_seed_, ood_scores_seed_)
                    auroc_perb, aupr_in_perb, aupr_out_perb, fpr_perb, threshold_perb, _, _ = \
                        get_ood_measures(id_scores_seed_, ood_scores_perb_)
                
                    sns.histplot(data=scores_seed, x=f"{score_func}_score", hue="dataset", stat="percent", 
                                    bins=30, common_norm=False, ax=ax_seed)
                    sns.histplot(data=scores_perb, x=f"{score_func}_score", hue="dataset", stat="percent",
                                    bins=30, common_norm=False, ax=ax_perb)
                    ax_seed.axvline(x=-threshold_seed, color="red", linestyle="--")
                    ax_perb.axvline(x=-threshold_perb, color="red", linestyle="--")

                    min_ = min(scores_seed[f"{score_func}_score"].min(), scores_perb[f"{score_func}_score"].min())
                    max_ = max(scores_seed[f"{score_func}_score"].max(), scores_perb[f"{score_func}_score"].max())
                    ax_seed.set_xlim([min_, max_])
                    ax_perb.set_xlim([min_, max_])
                    _, ymax0 = ax_seed.set_ylim()
                    _, ymax1 = ax_perb.set_ylim()
                    ymax = max(ymax0, ymax1)
                    ax_seed.set_ylim([0, ymax])
                    ax_perb.set_ylim([0, ymax])
        
                    ax_seed.set_xlabel(f"{score_func} score")
                    ax_perb.set_xlabel(f"{score_func} score")
                    ax_seed.set_title(f"{score_func} score distribution\nSeeds: FPR95={fpr_seed*100:.2f}%, AUROC={auroc_seed:.2f}, AUPR={aupr_in_seed:.2f}", loc="left")
                    ax_perb.set_title(f"Perturbed samples: FPR95={fpr_perb*100:.2f}%, AUROC={auroc_perb:.2f}, AUPR={aupr_in_perb:.2f}", loc="left")

                plt.suptitle(f"OoD scores distribution for seeds and perturbed samples\nBenchmark={benchmark}, Model={model_name}, Variant={variant}, OoD dataset={dataset}, perturbation=all")
                plt.savefig(os.path.join(save_dir, f"ood_scores_{benchmark}_{model_name}_{variant}_{dataset}_hist.png"))
                
                # Plot score distributions for seeds and perturbed samples under each perturbation.
                for perb_func in perturb_functions:

                    id_scores_perb = pd.read_csv(os.path.join(scores_dir, benchmark, "scores", f"perb_{perb_func}_scores.csv")).copy().drop(["idx","y_true", "y_pred", "idx_suffix"], axis=1)
                    id_scores_perb["dataset"] = "ID"

                    ood_scores_perb = pd.read_csv(os.path.join(scores_dir, dataset, "scores", f"perb_{perb_func}_scores.csv")).copy().drop(["idx","y_true", "y_pred", "idx_suffix"], axis=1)
                    ood_scores_perb["dataset"] = "OoD"

                    scores_perb = pd.concat([id_scores_perb, ood_scores_perb], axis=0).copy()

                    ncols = len(score_functions)
                    nrows = 2
                    fig, axes = plt.subplots(nrows, ncols, figsize=(8*ncols, 4*nrows), layout="constrained")
                    axes_seed = axes[0]
                    axes_perb = axes[1]
                    for ax_seed, ax_perb, score_func in zip(axes_seed, axes_perb, score_functions):
                        
                        id_scores_seed_ = scores_seed[scores_seed["dataset"]=="ID"][f"{score_func}_score"].to_numpy()
                        ood_scores_seed_ = scores_seed[scores_seed["dataset"]=="OoD"][f"{score_func}_score"].to_numpy()
                        id_scores_perb_ = scores_perb[scores_perb["dataset"]=="ID"][f"{score_func}_score"].to_numpy()
                        ood_scores_perb_ = scores_perb[scores_perb["dataset"]=="OoD"][f"{score_func}_score"].to_numpy()
                        
                        auroc_seed, aupr_in_seed, aupr_out_seed, fpr_seed, threshold_seed, _, _ = \
                            get_ood_measures(id_scores_seed_, ood_scores_seed_)
                        auroc_perb, aupr_in_perb, aupr_out_perb, fpr_perb, threshold_perb, _, _ = \
                            get_ood_measures(id_scores_seed_, ood_scores_perb_)

                        sns.histplot(data=scores_seed, x=f"{score_func}_score", hue="dataset", stat="percent", 
                                        bins=30, common_norm=False, ax=ax_seed)
                        sns.histplot(data=scores_perb, x=f"{score_func}_score", hue="dataset", stat="percent",
                                        bins=30, common_norm=False, ax=ax_perb)
                        
                        ax_seed.axvline(x=-threshold_seed, color="red", linestyle="--")
                        ax_perb.axvline(x=-threshold_perb, color="red", linestyle="--")

                        min_ = min(scores_seed[f"{score_func}_score"].min(), scores_perb[f"{score_func}_score"].min())
                        max_ = max(scores_seed[f"{score_func}_score"].max(), scores_perb[f"{score_func}_score"].max())
                        ax_seed.set_xlim(min_, max_)
                        ax_perb.set_xlim(min_, max_)
                        _, ymax0 = ax_seed.set_ylim()
                        _, ymax1 = ax_perb.set_ylim()
                        ymax = max(ymax0, ymax1)
                        ax_seed.set_ylim(0, ymax)
                        ax_perb.set_ylim(0, ymax)
        
                        ax_seed.set_xlabel(f"{score_func} score")
                        ax_perb.set_xlabel(f"{score_func} score")

                        ax_seed.set_title(f"{score_func} score distribution\nSeeds: FPR95={fpr_seed*100:.2f}%, AUROC={auroc_seed:.2f}, AUPR={aupr_in_seed:.2f}", loc="left")
                        ax_perb.set_title(f"Perturbed samples: FPR95={fpr_perb*100:.2f}%, AUROC={auroc_perb:.2f}, AUPR={aupr_in_perb:.2f}", loc="left")

                    plt.suptitle(f"OoD scores distribution for seeds and perturbed samples\nBenchmark={benchmark}, Model={model_name}, Variant={variant}, OoD dataset={dataset}, perturbation={perb_func}")
                    plt.savefig(os.path.join(save_dir, f"ood_scores_{benchmark}_{model_name}_{variant}_{dataset}_{perb_func}_hist.png"))
    

  fig, axes = plt.subplots(nrows, ncols, figsize=(8*ncols, 4*nrows), layout="constrained")


In [None]:
# Box plot of OoD scores for ID & OoD seeds and perturbed samples
save_dir = os.path.join("results", "eval", "performance", "figures", "statistics")
os.makedirs(save_dir, exist_ok=True)

for benchmark in configs["benchmark"]:
    dataset_sorter = [benchmark] + configs["benchmark"][benchmark]["ood_datasets"]
    for model_name in configs["benchmark"][benchmark]["model"]:
        for variant in configs["benchmark"][benchmark]["model"][model_name]:
            scores_dir = f"results/{benchmark.lower()}/{rand_seed}/{model_name}/{variant}"
            
            if not os.path.exists(os.path.join(scores_dir, benchmark, "scores","temp_scores.csv")):
                print("File", os.path.join(scores_dir, benchmark, "scores","temp_scores.csv"), "does not exist.")
                continue
            
            for dataset in configs["benchmark"][benchmark]["ood_datasets"]:
                if not os.path.exists(os.path.join(scores_dir, dataset, "scores","temp_scores.csv")):
                    print("File", os.path.join(scores_dir, dataset, "scores","temp_scores.csv"), "does not exist.")
                    continue
                
                # Box plot of OoD scores for seeds and perturbed samples under different perturbations.
                ncols = 2
                nrows = len(score_functions)
                fig, axes = plt.subplots(nrows, ncols, figsize=(10*ncols, 4*nrows), layout="constrained")
                axes = axes.flatten()

                for i, score_func in enumerate(score_functions):

                    ax_id = axes[2*i]
                    ax_ood = axes[2*i+1]

                    id_scores_seed = pd.read_csv(os.path.join(scores_dir, benchmark, "scores","temp_scores.csv"))[[f"{score_func}_score"]].copy()
                    id_scores_seed["distribution"] = "ID"
                    id_scores_seed["dataset"] = benchmark
                    id_scores_seed["perturb_function"] = "original"

                    ood_scores_seed = pd.read_csv(os.path.join(scores_dir, dataset, "scores","temp_scores.csv"))[[f"{score_func}_score"]].copy()
                    ood_scores_seed["distribution"] = "OoD"
                    ood_scores_seed["dataset"] = dataset
                    ood_scores_seed["perturb_function"] = "original"

                    auroc_seed, aupr_in_seed, aupr_out_seed, fpr_seed, threshold_seed, _, _ = \
                        get_ood_measures(id_scores_seed[f"{score_func}_score"].to_numpy(), ood_scores_seed[f"{score_func}_score"].to_numpy())

                    id_scores = id_scores_seed.copy()
                    for perb_func in perturb_functions:

                        id_scores_perb = pd.read_csv(os.path.join(scores_dir, benchmark, "scores", f"perb_{perb_func}_scores.csv"))[[f"{score_func}_score"]].copy()
                        id_scores_perb["dataset"] = benchmark
                        id_scores_perb["perturb_function"] = perb_func

                        id_scores = pd.concat([id_scores, id_scores_perb], axis=0).copy()
                    
                    sns.boxplot(data=id_scores, x="perturb_function", y=f"{score_func}_score",
                                ax=ax_id, width=.1, fill=False, fliersize=5, color='k')
                    sns.pointplot(data=id_scores, x="perturb_function", y=f"{score_func}_score",
                                  errorbar="sd", capsize=.2, color="r", linewidth=2, 
                                  linestyle="none", markers="_", markersize=40, ax=ax_id)
                
                    del id_scores
                    
                    ax_id.axhline(-threshold_seed, color="red", linestyle="--")
                    ax_id.set_xlabel(f"Perturbation")
                    ax_id.set_title(f"Detector={score_func}, Dataset=ID", loc="left")
                    ymin0, ymax0 = ax_id.set_ylim()

                    ood_scores = ood_scores_seed.copy()
                    for perb_func in perturb_functions:

                        ood_scores_perb = pd.read_csv(os.path.join(scores_dir, dataset, "scores", f"perb_{perb_func}_scores.csv"))[[f"{score_func}_score"]].copy()
                        ood_scores_perb["dataset"] = dataset
                        ood_scores_perb["perturb_function"] = perb_func

                        ood_scores = pd.concat([ood_scores, ood_scores_perb], axis=0).copy()

                    sns.boxplot(data=ood_scores, x="perturb_function", y=f"{score_func}_score",
                                ax=ax_ood, width=.1, fill=False, fliersize=5, color='k')
                    sns.pointplot(data=ood_scores, x="perturb_function", y=f"{score_func}_score",
                                  errorbar="sd", capsize=.2, color="b", linewidth=2, 
                                  linestyle="none", markers="_", markersize=40, ax=ax_ood,
                                  )
                    del ood_scores

                    ax_ood.axhline(-threshold_seed, color="red", linestyle="--")
                    ax_ood.set_xlabel(f"Perturbation")
                    ax_ood.set_title(f"Dataset=OoD", loc="left")
                    ymin1, ymax1 = ax_ood.set_ylim()

                    ymin = min(ymin0, ymin1)
                    ymax = max(ymax0, ymax1)

                    ax_id.set_ylim(ymin, ymax)
                    ax_ood.set_ylim(ymin, ymax)


                plt.suptitle(f"OoD scores distribution for seeds and perturbed samples\nBenchmark={benchmark}, Model={model_name}, Variant={variant}, OoD dataset={dataset}")
                plt.savefig(os.path.join(save_dir, f"ood_scores_{benchmark}_{model_name}_{variant}_{dataset}_box.png"))
                plt.show()
                

#### 2. Local DAE rate statistics
We show the distribution of local DAE rate under various functional perturbations and for different OoD datasets in the format of histogram. The average DAE rate with 95% confidence interval under each perturbation for each OoD dataset are shown in the format of bar plot.

The generated figures are saved in `results/eval/robustness/figures/statistics/` folder.

In [None]:
# Histogram of local DAE rate for ID & OoD seeds.
dae_record_dir = os.path.join("results", "eval", "intermediate_results")
save_dir = os.path.join("results", "eval", "robustness", "figures", "statistics")
os.makedirs(save_dir, exist_ok=True)

for benchmark in configs["benchmark"]:
    dataset_sorter = [benchmark] + configs["benchmark"][benchmark]["ood_datasets"]
    for model_name in configs["benchmark"][benchmark]["model"]:
        for variant in configs["benchmark"][benchmark]["model"][model_name]:
            
            # Histogram of local DAE rate for ID seeds under different perturbations.
            filepath = os.path.join(dae_record_dir, f"{benchmark.lower()}_{model_name}_{variant}", 
                                               "id_local_mae_dae_rate.csv")
            if not os.path.exists(filepath):
                print("File", filepath, "does not exist.")
            else:
                df_id_dae = pd.read_csv(filepath).copy()
                df_id_dae["perturb_function"] = df_id_dae["perturb_function"].astype("category")
                df_id_dae["perturb_function"] = df_id_dae["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)

                ncols = 1
                nrows = len(score_functions)
                fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4*nrows), layout="constrained")
                axes = axes.flatten()
                for ax, score_func in zip(axes, score_functions):
                    if f"dae_{score_func}" not in df_id_dae.columns:
                        continue
                    
                    df_id_dae[f"dae_{score_func}"] = df_id_dae[f"dae_{score_func}"] * 100
                    
                    sns.histplot(data=df_id_dae, x=f"dae_{score_func}", hue="perturb_function", 
                                 stat="percent", ax=ax, legend=True, shrink=0.8,
                                 bins=30, common_norm=False, multiple="dodge", palette="tab10")
                    
                    ax.set_xlabel("DAE rate (%)")
                    ax.set_ylabel("Number of seeds")
                    ax.set_title(f"Detector={score_func}")
                    ax.set_xlim(0, 100)

                plt.suptitle(f"Local DAE rate distribution for ID seeds.\nBenchmark={benchmark}, Model={model_name}, Variant={variant}")
                plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{variant}_id_local_dae_rate_perturb_hist.png"))
            
            # Histogram of local DAE rate for OoD seeds under different perturbations.
            df_ood_dae = pd.DataFrame()
            for dataset in configs["benchmark"][benchmark]["ood_datasets"]:
                filepath = os.path.join(dae_record_dir, f"{benchmark.lower()}_{model_name}_{variant}", 
                                               f"ood_local_dae_rate_{dataset}.csv")
                if not os.path.exists(filepath):
                    print("File", filepath, "does not exist.")
                else:
                    df = pd.read_csv(filepath).copy()
                    df["dataset"] = dataset
                    df_ood_dae = pd.concat([df_ood_dae, df], ignore_index=True)

            df_ood_dae["perturb_function"] = df_ood_dae["perturb_function"].astype("category")
            df_ood_dae["perturb_function"] = df_ood_dae["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)

            ncols = 1
            nrows = len(score_functions)
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4*nrows), layout="constrained" )
            axes = axes.flatten()
            for ax, score_func in zip(axes, score_functions):
                if f"dae_{score_func}" not in df_ood_dae.columns:
                    continue
                
                df_ood_dae[f"dae_{score_func}"] = df_ood_dae[f"dae_{score_func}"] * 100
                
                sns.histplot(data=df_ood_dae, x=f"dae_{score_func}", hue="perturb_function", 
                            stat="percent", ax=ax, legend=True, shrink=0.8,
                            bins=30, common_norm=False, multiple="dodge", palette="tab10")
                
                ax.set_xlabel("DAE rate (%)")
                ax.set_ylabel("Number of seeds")
                ax.set_title(f"Detector={score_func}")
                ax.set_xlim(0, 100)

            plt.suptitle(f"Local DAE rate distribution for OoD seeds.\nBenchmark={benchmark}, Model={model_name}, Variant={variant}",)
            plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{variant}_ood_local_dae_rate_perturb_hist.png"))

            # Histogram of local DAE rate for OoD seeds from different datasets.
            ncols = 1
            nrows = len(score_functions)
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4*nrows), layout="constrained")
            axes = axes.flatten()
            for ax, score_func in zip(axes, score_functions):
                if f"dae_{score_func}" not in df_ood_dae.columns:
                    continue
                                
                sns.histplot(data=df_ood_dae, x=f"dae_{score_func}", hue="dataset", 
                            stat="percent", ax=ax, legend=True, shrink=0.8,
                            bins=30, common_norm=False, multiple="dodge", palette="tab10")
                
                ax.set_xlabel("DAE rate (%)")
                ax.set_ylabel("Number of seeds")
                ax.set_title(f"Detector={score_func}")
                ax.set_xlim(0, 100)

            plt.suptitle(f"Local DAE rate distribution for OoD seeds.\nBenchmark={benchmark}, Model={model_name}, Variant={variant}",
                        y=1.02)
            plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{variant}_ood_local_dae_rate_dataset_hist.png"))
            plt.close("all")
        

In [None]:
# Histogram of local MAE rate for ID seeds.
for benchmark in configs["benchmark"]:
    dataset_sorter = [benchmark] + configs["benchmark"][benchmark]["ood_datasets"]
    for model_name in configs["benchmark"][benchmark]["model"]:
        for variant in configs["benchmark"][benchmark]["model"][model_name]:
            
            # Histogram of local DAE rate for ID seeds under different perturbations.
            filepath = os.path.join(dae_record_dir, f"{benchmark.lower()}_{model_name}_{variant}", 
                                               "id_local_mae_dae_rate.csv")
            if not os.path.exists(filepath):
                print("File", filepath, "does not exist.")
            else:
                df_id_dae = pd.read_csv(filepath).copy()
                df_id_dae["perturb_function"] = df_id_dae["perturb_function"].astype("category")
                df_id_dae["perturb_function"] = df_id_dae["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)

                ncols = 1
                nrows = 1
                fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4*nrows), layout="constrained")
                if f"mae" not in df_id_dae.columns:
                    continue
                
                df_id_dae[f"mae"] = df_id_dae[f"mae"] * 100
                
                sns.histplot(data=df_id_dae, x=f"mae", hue="perturb_function", 
                                stat="percent", ax=ax, legend=True, shrink=0.8,
                                bins=30, common_norm=False, multiple="dodge", palette="tab10")
                
                ax.set_xlabel("MAE rate (%)")
                ax.set_ylabel("Number of seeds")
                ax.set_xlim(0, 100)
                ax.set_title(f"Local MAE rate distribution for ID seeds.\nBenchmark={benchmark}, Model={model_name}, Variant={variant}")
            
            plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{variant}_id_local_mae_rate_perturb_hist.png"))
            plt.close("all")

In [None]:
# Average DAE rate bar plot for different OoD datasets & perturbation functions.
dae_record_dir = os.path.join("results", "eval", "intermediate_results")
save_dir = os.path.join("results", "eval", "robustness", "figures", "statistics")
os.makedirs(save_dir, exist_ok=True)

for benchmark in configs["benchmark"]:
    dataset_sorter = [benchmark] + configs["benchmark"][benchmark]["ood_datasets"]
    for model_name in configs["benchmark"][benchmark]["model"]:
        for variant in configs["benchmark"][benchmark]["model"][model_name]:
            
            df_ood_dae = pd.DataFrame()
            for dataset in configs["benchmark"][benchmark]["ood_datasets"]:
                filepath = os.path.join(dae_record_dir, f"{benchmark.lower()}_{model_name}_{variant}", 
                                               f"ood_local_dae_rate_{dataset}.csv")
                if not os.path.exists(filepath):
                    print("File", filepath, "does not exist.")
                else:
                    df = pd.read_csv(filepath).copy()
                    df["dataset"] = dataset
                    df_ood_dae = pd.concat([df_ood_dae, df], ignore_index=True)

            df_ood_dae["perturb_function"] = df_ood_dae["perturb_function"].astype("category")
            df_ood_dae["perturb_function"] = df_ood_dae["perturb_function"].cat.set_categories(perturb_function_sorter, ordered=True)

            ncols = 1
            nrows = len(score_functions)
            fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 4*nrows), layout="constrained")
            axes = axes.flatten()
            for ax, score_func in zip(axes, score_functions):
                if f"dae_{score_func}" not in df_ood_dae.columns:
                    continue
                
                df_ood_dae[f"dae_{score_func}"] = df_ood_dae[f"dae_{score_func}"] * 100
                
                sns.barplot(data=df_ood_dae, y=f"dae_{score_func}", x="perturb_function",
                             hue="dataset", ax=ax, # errorbar="sd", 
                             )
                
                ax.set_xlabel("DAE rate (%)")
                ax.set_ylabel("Number of seeds")
                ax.set_title(f"Detector={score_func}")
                
            plt.suptitle(f"Average DAE rate for OoD seeds from different datasets.\nBenchmark={benchmark}, Model={model_name}, Variant={variant}")
            plt.savefig(os.path.join(save_dir, f"{benchmark.lower()}_{model_name}_{variant}_ood_dae_rate_perturb_dataset_bar.png"))
            plt.close("all")

