In [1]:
%load_ext autoreload
%autoreload 2

In [48]:
from evaluation.main import load_xai_results
from utils import load_json_file
import pandas as pd
import pickle
from tqdm import tqdm
import ast

base_config_path = "/home/hjall/work/qai/xai/xai-nlp-benchmark/artifacts/xai-nlp-benchmark-2024-04-23-21-20-02/configs"
config_paths = ["gender_no_sub_samp_project_config.json", "gender_project_config.json", "sentiment_project_config.json"]
configs = [(load_json_file(f"{base_config_path}/{config_path}"), config_path.replace(".json", "")) for config_path in config_paths]

def prepare_data(config: dict, key: str):
    xai_results = load_xai_results(config)
    df = pd.DataFrame(xai_results)

    pred_diffs = []
    attribution_diffs = []
    attribution_diffs_gt = []
    attribution_diffs_not_gt = []

    group_columns = [
        'model_name',
        'model_version',
        'model_repetition_number',
        'dataset_type',
        'attribution_method',
        'sentence_idx',
    ]
    for keys, group in tqdm(df.groupby(group_columns)):
        info = {key: value for key, value in zip(group_columns, keys)}

        female = group[group["target"] == 0].iloc[0]
        male = group[group['target'] == 1].iloc[0]

        if female["pred_probabilities"] is not None:
            pred_diff = female["pred_probabilities"][0] - male["pred_probabilities"][0]
            pred_diffs.append({**info, "pred_diff": pred_diff})

        for female_word, male_word, female_attribution, male_attribution, gt in zip(
            ast.literal_eval(female["sentence"]),
            ast.literal_eval(male["sentence"]),
            female["attribution"],
            male["attribution"],
            female["ground_truth"],
        ):
            attribution_diff = female_attribution - male_attribution

            diff_obj = {
                **info,
                "female_word": female_word.lower(),
                "male_word": male_word.lower(),
                "attribution_diff": attribution_diff,
            }

            attribution_diffs.append(diff_obj)

            if gt:
                attribution_diffs_gt.append(diff_obj)
            else:
                attribution_diffs_not_gt.append(diff_obj)

    pred_diffs_df = pd.DataFrame(pred_diffs)
    attribution_diffs_df = pd.DataFrame(attribution_diffs)
    attribution_diffs_gt_df = pd.DataFrame(attribution_diffs_gt)
    attribution_diffs_not_gt_df = pd.DataFrame(attribution_diffs_not_gt)

    with open(f"{key}_diffs.pkl", "wb") as f:
        pickle.dump({
            "pred_diffs_df": pred_diffs_df,
            "attribution_diffs_df": attribution_diffs_df,
            "attribution_diffs_gt_df": attribution_diffs_gt_df,
            "attribution_diffs_not_gt_df": attribution_diffs_not_gt_df,
        }, f)

for config, key in configs:
    prepare_data(config, key)

Loading XAI results: 100%|██████████| 100/100 [02:00<00:00,  1.21s/it]
100%|██████████| 322000/322000 [06:06<00:00, 877.92it/s]
Loading XAI results: 100%|██████████| 100/100 [01:58<00:00,  1.18s/it]
100%|██████████| 322000/322000 [05:58<00:00, 897.63it/s]
Loading XAI results: 100%|██████████| 100/100 [01:55<00:00,  1.16s/it]
100%|██████████| 322000/322000 [06:00<00:00, 894.09it/s]


In [64]:
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
import matplotlib.cm as cm

model_name_mapping = {

    'bert_only_classification': "BERTC",
    'bert_only_embedding_classification': "BERTCEf",
    'bert_randomly_init_embedding_classification': "BERTCE",
    'bert_all': "BERTAll",
    'one_layer_attention_classification': "OneLayerAtt"
}

# Load two arrays of samples

def get_cutoff_cmap(alpha: float = 0.05, cmap_name: str = "magma"):
	vird = cm.get_cmap(cmap_name, 256)
	new_colors = vird(np.concatenate([np.linspace(0, 0.1, int(np.ceil(alpha * (256)))), np.linspace(0.6, 0.7, int(np.ceil((1 - alpha) * 256)))]))
	return ListedColormap(new_colors, name='cutoff')

def apply_prediction_test(pred_diffs_df: pd.DataFrame, test: str = "ttest"):
	group_by = [
		'model_name',
		'model_version',
		'dataset_type',
		'attribution_method',
		'model_repetition_number',
	]

	# Predictions are model based and therefore same for all attribution methods
	# so we can just use the first one
	attributions_methods = pred_diffs_df["attribution_method"].unique()
	pred_diffs_df = pred_diffs_df[pred_diffs_df["attribution_method"] == attributions_methods[0]]

	results = []
	for keys, group in pred_diffs_df.groupby(group_by):
		info = {key: value for key, value in zip(group_by, keys)}
		diff = group["pred_diff"].values

		alpha = 0.05
		if test == "ttest":
			mu = 0
			t_stat, p_value = stats.ttest_1samp(diff, mu)

			results.append({
				**info,
				"t_stat": t_stat,
				"p_value": p_value,
				"reject": p_value < alpha,
			})
		elif test == "wilcoxon":
			w_stat, p_value = stats.wilcoxon(diff)

			results.append({
				**info,
				"w_stat": w_stat,
				"p_value": p_value,
				"reject": p_value < alpha,
			})

	df = pd.DataFrame(results)
	df["model_name"] = df["model_name"].map(model_name_mapping)

	return df

def plot_prediction_heatmap(df: pd.DataFrame, run_name: str, test: str = "ttest"):
	cmap = get_cutoff_cmap()
	gender_all = df[(df["model_version"] == "best") & (df["dataset_type"] == "gender_all")].pivot(
		index="model_name", columns="model_repetition_number", values="p_value"
	)

	gender_subj = df[(df["model_version"] == "best") & (df["dataset_type"] == "gender_subj")].pivot(
		index="model_name", columns="model_repetition_number", values="p_value"
	)

	max_value = max(gender_all.values.max(), gender_subj.values.max())

	fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=True, gridspec_kw={'width_ratios': [1, 1.2]})
	sns.heatmap(gender_all, annot=True, fmt=".3f", ax=axs[0], vmax=max_value, cbar=False, cmap=cmap)
	axs[0].set_title("$D_A$")
	axs[0].set_ylabel("Model")
	axs[0].set_xlabel("Repetition")

	sns.heatmap(gender_subj, annot=True, fmt=".3f", ax=axs[1], vmax=max_value, cbar=True, cmap=cmap)
	axs[1].set_title("$D_S$")
	axs[1].set_ylabel("")
	axs[1].set_xlabel("Repetition")

	fig.suptitle(f"Prediction Difference {run_name.replace('_config', '')} with {test}")
	plt.savefig(f"prediction_diff_{run_name}_{test}.png")

In [65]:

def apply_attribution_test(cur_df: pd.DataFrame, test: str, include_repetitions: bool = False):
    results = []

    group_by = [
        'model_name',
        'model_version',
        'dataset_type',
        'attribution_method',
    ]

    if include_repetitions:
        group_by += ["model_repetition_number"]

    for keys, group in cur_df.groupby(group_by):
        info = {key: value for key, value in zip(group_by, keys)}
        diff = group["attribution_diff"].values

        alpha = 0.05

        if test == "ttest":
            mu = 0
            t_stat, p_value = stats.ttest_1samp(diff, mu)
            results.append(
                {
                    **info,
                    "t_stat": t_stat,
                    "p_value": p_value,
                    "reject": p_value < alpha,
                }
            )
        elif test == "wilcoxon":
            w_stat, p_value = stats.wilcoxon(diff)
            results.append(
                {
                    **info,
                    "w_stat": w_stat,
                    "p_value": p_value,
                    "reject": p_value < alpha,
                }
            )

    results_df = pd.DataFrame(results)
    results_df = results_df[results_df["model_version"] == "best"]
    results_df["model_name"] = results_df["model_name"].map(model_name_mapping)

    return results_df

def plot_attribution_heatmap_row(df: pd.DataFrame, axs: list[plt.Axes], max_value: float = 1) -> None:
    cmap = get_cutoff_cmap()

    # Filter out "Correlation" attribution method
    df = df[df["attribution_method"] != "Correlation"]

    gender_all = df[df["dataset_type"] == "gender_all"].pivot(
        index="model_name", columns="attribution_method", values="p_value"
    )

    gender_subj = df[df["dataset_type"] == "gender_subj"].pivot(
        index="model_name", columns="attribution_method", values="p_value"
    )

    sns.heatmap(gender_all, annot=True, fmt=".2f",  ax=axs[0], vmax=max_value, cbar=False, cmap=cmap)
    axs[0].set_title("$D_A$")
    axs[0].set_ylabel("Model")
    axs[0].set_xlabel("Attribution Method")

    sns.heatmap(gender_subj, annot=True, fmt=".2f", ax=axs[1], vmax=max_value, cbar=True, cmap=cmap)
    axs[1].set_title("$D_S$")
    axs[1].set_ylabel("")
    axs[1].set_xlabel("Attribution Method")


def plot_attribution_heatmap(results_df: pd.DataFrame, run_name: str, test: str, title_version: str = None) -> None:
    max_value = results_df["p_value"].values.max()

    fig, axs = plt.subplots(1, 2, figsize=(11, 6), sharey=True, gridspec_kw={'width_ratios': [1, 1.2]})
    fig.suptitle(f"P-values of the Attribution Methods for {title_version} for run {run_name.replace('_config', '')} with {test}")
    plot_attribution_heatmap_row(results_df, axs, max_value=max_value)

    fig.tight_layout()

    plt.savefig(f"./temp_results/attribution_{run_name}_{test}_{title_version}.png")
    plt.close()

def plot_attribution_heatmap_with_rep(results_df: pd.DataFrame, run_name: str, test: str, title_version: str = None) -> None:
    
        max_value = results_df["p_value"].values.max()

        model_repetitions = results_df["model_repetition_number"].unique()
        # Sort model repetitions
        model_repetitions = sorted(model_repetitions)

        fig, axs = plt.subplots(len(model_repetitions), 2, figsize=(11, 6 * len(model_repetitions)), sharey=True, gridspec_kw={'width_ratios': [1, 1.2]})
        fig.suptitle(f"P-values of the Attribution Methods for {title_version} for run {run_name.replace('_config', '')} with {test}")
        
        for i, model_repetition in enumerate(model_repetitions):
            cur_results_df = results_df[results_df["model_repetition_number"] == model_repetition]
            plot_attribution_heatmap_row(cur_results_df, axs[i], max_value=max_value)
            axs[i][0].set_title(f"Repetition {model_repetition}")
    
        fig.tight_layout()
    
        plt.savefig(f"./temp_results/attribution_rep_{run_name}_{test}_{title_version}.png")
        plt.close()



In [62]:
def test_and_plot(key: str):
    tests = ["ttest", "wilcoxon"]

    with open(f"{key}_diffs.pkl", "rb") as f:
        data = pickle.load(f)

        pred_diffs_df =  data["pred_diffs_df"]
        attribution_diffs_df = data["attribution_diffs_df"]
        attribution_diffs_gt_df = data["attribution_diffs_gt_df"]
        attribution_diffs_not_gt_df = data["attribution_diffs_not_gt_df"]
    
    for test in tests:
        print(pred_diffs_df)
        if len(pred_diffs_df) > 0:
            pred_results = apply_prediction_test(pred_diffs_df, test)
            plot_prediction_heatmap(pred_results, key, test)
        
        results_all = apply_attribution_test(attribution_diffs_df, test)
        results_gt = apply_attribution_test(attribution_diffs_gt_df, test)
        results_not_gt = apply_attribution_test(attribution_diffs_not_gt_df, test)

        results_all_rep = apply_attribution_test(attribution_diffs_df, test, include_repetitions=True)
        results_gt_rep = apply_attribution_test(attribution_diffs_gt_df, test, include_repetitions=True)
        results_not_gt_rep = apply_attribution_test(attribution_diffs_not_gt_df, test, include_repetitions=True)

        plot_attribution_heatmap(results_all, key, test, "All")
        plot_attribution_heatmap(results_gt, key, test, "GT")
        plot_attribution_heatmap(results_not_gt, key, test, "Not GT")

        plot_attribution_heatmap_with_rep(results_all_rep, key, test, "All")
        plot_attribution_heatmap_with_rep(results_gt_rep, key, test, "GT")
        plot_attribution_heatmap_with_rep(results_not_gt_rep, key, test, "Not GT")
        

In [63]:
for _, key in configs[2:]:
    test_and_plot(key)

                                model_name model_version  \
0                                 bert_all          best   
1                                 bert_all          best   
2                                 bert_all          best   
3                                 bert_all          best   
4                                 bert_all          best   
...                                    ...           ...   
321995  one_layer_attention_classification          last   
321996  one_layer_attention_classification          last   
321997  one_layer_attention_classification          last   
321998  one_layer_attention_classification          last   
321999  one_layer_attention_classification          last   

        model_repetition_number dataset_type attribution_method  sentence_idx  \
0                             0   gender_all        Correlation             0   
1                             0   gender_all        Correlation             1   
2                             0   ge

  vird = cm.get_cmap(cmap_name, 256)


ValueError: Index contains duplicate entries, cannot reshape