In [22]:
import os

interface_dir = os.environ["DATA"] + "webinterfaces/exp02/"

tasks_dir = os.path.join(interface_dir, "res", "tasks")
results_dir = os.path.join(interface_dir, "results")
protocols_dir = os.path.join(interface_dir, "protocols")
prolific_matching_dir = os.path.join(interface_dir, "prolific")

protocol_paths_d = {
    "H": os.path.join(protocols_dir, "H_0.json"),
    "H+AI": os.path.join(protocols_dir, "AI_0.json"),
    "H+AI+CF": os.path.join(protocols_dir, "XAI_CF_0.json"),
    "H+AI+SHAP": os.path.join(protocols_dir, "XAI_SHAP_0.json"),
    "H+AI+LLM": os.path.join(protocols_dir, "XAI_LLM_0.json"),
    "H+AI+GRADCAM": os.path.join(protocols_dir, "XAI_GRADCAM_0.json"),
}

COMPREHENSION_THRESHOLD = 0.8

COMPREHENSION_TASKS = ["xeasy1_find_pattern_rot"]
TRAINING_TASKS = ["med3_find_pattern_rot"]
EASY_TASKS = ["easy1_find_pattern_rot", "easy3_find_pattern_rot"]
DIFFICULT_TASKS = ["hard1_find_pattern_rot", "hard3_find_pattern_rot"]

MILD_PRESSURE_TASKS = ["easy1_find_pattern_rot", "hard1_find_pattern_rot"]
STRONG_PRESSURE_TASKS = ["easy3_find_pattern_rot", "hard3_find_pattern_rot"]

TASK_PROTOCOL_KEYS = {
    "easy1_find_pattern_rot": "mainexp_easy_mild_patrot_task",
    "easy3_find_pattern_rot": "mainexp_easy_strong_patrot_task",

    "hard1_find_pattern_rot": "mainexp_hard_mild_patrot_task",
    "hard3_find_pattern_rot": "mainexp_hard_strong_patrot_task",

    "xeasy1_find_pattern_rot": "intro_comprehension_task",
    "med3_find_pattern_rot": "intro_training_1_task"
}


In [23]:
import sys

sys.path.append("/home/jleguy/Documents/postdoc/git_repos/WebXAII/")

In [24]:
import json
import csv
import numpy as np


def load_json(path):
    with open(path) as json_file:
        return json.load(json_file)


def load_task_csv_file(path):
    y_true, y_pred = [], []
    with open(path) as csv_data:
        reader = csv.DictReader(csv_data)
        for row in reader:
            y_true.append(int(row["target"]))
            y_pred.append(int(row["pred"]))

    return np.array(y_true), np.array(y_pred)


In [25]:
import json


def data_matching(protocols_paths_d, prolific_matching_files):
    results_filenames_d = {k: [] for k in protocol_paths_d.keys()}

    for prolific_matching_file in prolific_matching_files:

        with open(prolific_matching_file) as json_data:
            d = json.load(json_data)

            for prolific_id, prot_dict in d.items():
                condition_split = prot_dict["protocol"].split("_")
                filename = prolific_id + ".json"

                if condition_split[0] == "H":
                    results_filenames_d["H"].append(filename)
                elif condition_split[0] == "AI":
                    results_filenames_d["H+AI"].append(filename)
                elif condition_split[0] == "XAI" and condition_split[1] == "SHAP":
                    results_filenames_d["H+AI+SHAP"].append(filename)
                elif condition_split[0] == "XAI" and condition_split[1] == "CF":
                    results_filenames_d["H+AI+CF"].append(filename)
                elif condition_split[0] == "XAI" and condition_split[1] == "LLM":
                    results_filenames_d["H+AI+LLM"].append(filename)
                elif condition_split[0] == "XAI" and condition_split[1] == "GRADCAM":
                    results_filenames_d["H+AI+GRADCAM"].append(filename)

    return results_filenames_d


In [26]:
from pywebxaii.resretrieval import extract_p_questionnaire_results, get_protocol_entry_from_key


def extract_quest_results(results_dir, results_filenames_d, protocol_paths_d, quest_keys):
    output_res_d = {}

    # Iterating over all groups
    for group_key, filenames_list in results_filenames_d.items():

        output_res_d[group_key] = {"raw": {}, "values": {}, "times": {}}

        # Iterating on all results files for the current group
        for filename in filenames_list:
            curr_res_path = os.path.join(results_dir, filename)
            curr_res_d = load_json(curr_res_path)
            data_issue = False
            if not curr_res_d["is_completed"]:
                data_issue = True

            # Iterating over all questionnaires keys
            for quest_key in quest_keys:

                curr_protocol_d = load_json(protocol_paths_d[group_key])
                try:
                    get_protocol_entry_from_key(curr_protocol_d, quest_key)
                    answers_raw, answers_values, quest_times = extract_p_questionnaire_results(curr_res_d,
                                                                                               quest_key,
                                                                                               protocol_d=curr_protocol_d)
                except KeyError:
                    data_issue = True

                if quest_key not in output_res_d[group_key]["raw"]:
                    output_res_d[group_key]["raw"][quest_key] = []
                    output_res_d[group_key]["values"][quest_key] = []
                    output_res_d[group_key]["times"][quest_key] = []

                if data_issue:
                    output_res_d[group_key]["raw"][quest_key].append(None)
                    output_res_d[group_key]["values"][quest_key].append(None)
                    output_res_d[group_key]["times"][quest_key].append(None)
                else:
                    output_res_d[group_key]["raw"][quest_key].append(answers_raw)
                    output_res_d[group_key]["values"][quest_key].append(answers_values)
                    output_res_d[group_key]["times"][quest_key].append(quest_times)

    return output_res_d

In [27]:
from pywebxaii.resretrieval import extract_p_task_results


def compute_scores(results_dir, results_filenames_d, protocol_paths_d, tasks_dir, tasks_names, task_protocol_keys):
    output_res_scores_d = {}
    output_res_reliance_d = {}
    output_res_overreliance_d = {}
    output_res_underreliance_d = {}
    output_res_appropriate_reliance_d = {}
    output_res_task_true_d = {}
    output_res_ai_pred_d = {}
    output_res_user_decision_d = {}
    output_res_quest_order_d = {}
    output_res_answer_times_d = {}

    # Iterating over all groups
    for group_key, filenames_list in results_filenames_d.items():

        output_res_scores_d[group_key] = []
        output_res_reliance_d[group_key] = []
        output_res_overreliance_d[group_key] = []
        output_res_underreliance_d[group_key] = []
        output_res_appropriate_reliance_d[group_key] = []
        output_res_answer_times_d[group_key] = []
        output_res_task_true_d[group_key] = []
        output_res_ai_pred_d[group_key] = []
        output_res_user_decision_d[group_key] = []
        output_res_quest_order_d[group_key] = []

        # Iterating on all results files for the current group
        for filename in filenames_list:
            curr_res_path = os.path.join(results_dir, filename)
            curr_res_d = load_json(curr_res_path)
            if not curr_res_d["is_completed"]:
                continue

            nb_questions = 0
            nb_quest_wrong_predictions = 0
            nb_quest_right_predictions = 0
            nb_correct = 0
            nb_reliance = 0
            nb_overreliance = 0
            nb_underreliance = 0
            answer_times = []
            early_break = False
            task_true_l = []
            ai_pred_l = []
            user_decision_l = []
            quest_order_l = []
            # Iterating over all tasks
            for task_idx, task_name in enumerate(tasks_names):

                task_true, ai_pred = load_task_csv_file(os.path.join(tasks_dir, task_name + "_content.csv"))

                answers_idx_vect, answers_text_vect, quest_order_vect, time_vect, _, _ = \
                    extract_p_task_results(curr_res_d,
                                           task_protocol_keys[tasks_names[task_idx]],
                                           protocol_d=load_json(protocol_paths_d[group_key]))

                nb_questions += len(answers_idx_vect)
                nb_quest_wrong_predictions += np.sum(task_true != ai_pred)
                nb_quest_right_predictions += np.sum(task_true == ai_pred)

                try:
                    nb_correct += np.sum(answers_idx_vect == np.logical_not(task_true))
                    nb_reliance += np.sum(answers_idx_vect == np.logical_not(ai_pred))
                    nb_overreliance += np.sum(np.logical_and(
                        answers_idx_vect == np.logical_not(ai_pred),
                        ai_pred != task_true
                    ))
                    nb_underreliance += np.sum(np.logical_and(
                        answers_idx_vect != np.logical_not(ai_pred),
                        ai_pred == task_true
                    ))
                    answer_times.extend((np.array(time_vect) / 1000).tolist())

                    task_true_l.append(task_true)
                    ai_pred_l.append(ai_pred)
                    user_decision_l.append(np.logical_not(answers_idx_vect))
                    quest_order_l.append(quest_order_vect)

                    #
                    # if np.isnan(np.sum(answers_idx_vect)):
                    #     print(f"answers {answers_idx_vect}")
                    #     print(f"true {np.logical_not(task_true)}")
                    #     print(f"ai pred {np.logical_not(ai_pred)}")
                    #     print(f"correct extracted {np.sum(answers_idx_vect == np.logical_not(task_true))}")
                    #     print(f"reliance extracted {np.sum(answers_idx_vect == np.logical_not(ai_pred))}")

                # Happens if the results file is not complete
                except ValueError:
                    print("ValueError exception")
                    output_res_scores_d[group_key].append(None)
                    output_res_reliance_d[group_key].append(None)
                    output_res_overreliance_d[group_key].append(None)
                    output_res_underreliance_d[group_key].append(None)
                    output_res_answer_times_d[group_key].append(None)
                    output_res_task_true_d[group_key].append(None)
                    output_res_ai_pred_d[group_key].append(None)
                    output_res_user_decision_d[group_key].append(None)
                    output_res_quest_order_d[group_key].append(None)
                    early_break = True
                    break

            if not early_break:
                output_res_scores_d[group_key].append(nb_correct / nb_questions)
                output_res_reliance_d[group_key].append(nb_reliance / nb_questions)
                output_res_overreliance_d[group_key].append(nb_overreliance / nb_quest_wrong_predictions)
                output_res_underreliance_d[group_key].append(nb_underreliance / nb_quest_right_predictions)
                output_res_answer_times_d[group_key].extend(answer_times)
                output_res_task_true_d[group_key].append(task_true_l)
                output_res_ai_pred_d[group_key].extend(ai_pred_l)
                output_res_user_decision_d[group_key].extend(user_decision_l)
                output_res_quest_order_d[group_key].extend(quest_order_l)

    return (output_res_scores_d, output_res_reliance_d, output_res_overreliance_d, output_res_underreliance_d, output_res_answer_times_d,
            output_res_task_true_d, output_res_ai_pred_d, output_res_user_decision_d, output_res_quest_order_d)


In [28]:
results_filenames_d = data_matching(protocol_paths_d, [os.path.join(prolific_matching_dir, "prolific.json"),
                                                       os.path.join(prolific_matching_dir, "prolific_21-1.json"),
                                                       os.path.join(prolific_matching_dir, "prolific_21-2.json"),
                                                       os.path.join(prolific_matching_dir, "prolific_22-1.json"),
                                                       os.path.join(prolific_matching_dir, "prolific_27-1.json"),
                                                       os.path.join(prolific_matching_dir, "prolific_28-1.json"),
                                                       ])

# results_filenames_d = data_matching(protocol_paths_d, [os.path.join(prolific_matching_dir, "prolific_28-1.json")])

In [29]:
total = 0
for k, v in results_filenames_d.items():
    total += len(v)
    print(f"{k}: {len(v)}")
print(f"Total: {total}")

H: 121
H+AI: 108
H+AI+CF: 115
H+AI+SHAP: 113
H+AI+LLM: 111
H+AI+GRADCAM: 110
Total: 678


In [30]:
def filter_not_completed(results_filenames_d):
    filtered_results_filenames_d = {}

    # Iterating over all groups
    for group_key, filenames_list in results_filenames_d.items():

        filtered_results_filenames_d[group_key] = []

        # Iterating on all results files for the current group
        for filename in filenames_list:
            curr_res_path = os.path.join(results_dir, filename)
            curr_res_d = load_json(curr_res_path)
            if not curr_res_d["is_completed"]:
                continue
            filtered_results_filenames_d[group_key].append(filename)
    return filtered_results_filenames_d

In [31]:
results_filenames_d = filter_not_completed(results_filenames_d)

In [32]:
results_filenames_before_filtering = dict(results_filenames_d)

In [33]:
total = 0
for k, v in results_filenames_d.items():
    total += len(v)
    print(f"{k}: {len(v)}")
print(f"Total: {total}")

H: 99
H+AI: 99
H+AI+CF: 100
H+AI+SHAP: 97
H+AI+LLM: 99
H+AI+GRADCAM: 95
Total: 589


In [34]:
def filter_on_attention_tests(results_filenames_d):
    res = extract_quest_results(results_dir, results_filenames_d, protocol_paths_d,
                                ["attentioncheck_1", "attentioncheck_2"])
    filtered_results_filenames_d = {}

    for k, v in res.items():
        filtered_results_filenames_d[k] = []
        for i in range(len(results_filenames_d[k])):
            if v["raw"]["attentioncheck_1"][i] is None:
                passes1 = False
            else:
                passes1 = v["raw"]["attentioncheck_1"][i][0] == 2 and v["raw"]["attentioncheck_1"][i][1] == 0

            if v["raw"]["attentioncheck_2"][i] is None:
                passes2 = False
            else:
                passes2 = v["raw"]["attentioncheck_2"][i][0] == 6 and v["raw"]["attentioncheck_2"][i][1] == 0

            if passes1 and passes2:
                filtered_results_filenames_d[k].append(results_filenames_d[k][i])

    return filtered_results_filenames_d

In [35]:
results_filenames_d = filter_on_attention_tests(results_filenames_d)

In [36]:
total = 0
for k, v in results_filenames_d.items():
    total += len(v)
    print(f"{k}: {len(v)}")
print(f"Total: {total}")

H: 88
H+AI: 94
H+AI+CF: 89
H+AI+SHAP: 91
H+AI+LLM: 95
H+AI+GRADCAM: 90
Total: 547


In [37]:
print("Success rate at attention tests :")
for k, v in results_filenames_d.items():
    total += len(v)
    print(f"{k}: {len(v)/len(results_filenames_before_filtering[k])*100:.2f}%")

Success rate at attention tests :
H: 88.89%
H+AI: 94.95%
H+AI+CF: 89.00%
H+AI+SHAP: 93.81%
H+AI+LLM: 95.96%
H+AI+GRADCAM: 94.74%


In [38]:
def filter_comprehension_score(results_filenames_d):
    comprehension_score_d, _, _, _, _, _, _, _, _= compute_scores(results_dir, results_filenames_d, protocol_paths_d, tasks_dir,
                                                       COMPREHENSION_TASKS, TASK_PROTOCOL_KEYS)
    filtered_results_filenames_d = {}

    for k, v in comprehension_score_d.items():
        filtered_results_filenames_d[k] = []
        for i in range(len(results_filenames_d[k])):
            if comprehension_score_d[k][i] >= COMPREHENSION_THRESHOLD:
                filtered_results_filenames_d[k].append(results_filenames_d[k][i])
            else:
                print(f"Rejecting sample due to comprehension score of {comprehension_score_d[k][i]}")

    return filtered_results_filenames_d


In [39]:
results_filenames_d = filter_comprehension_score(results_filenames_d)

Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.4
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to comprehension score of 0.6
Rejecting sample due to compreh

In [40]:
total = 0
for k, v in results_filenames_d.items():
    total += len(v)
    print(f"{k}: {len(v)}")
print(f"Total: {total}")

H: 85
H+AI: 89
H+AI+CF: 80
H+AI+SHAP: 86
H+AI+LLM: 85
H+AI+GRADCAM: 85
Total: 510


In [41]:
print(
    f"Total passing filters among complete files: {np.sum([len(v) for v in results_filenames_d.values()])}/{np.sum([len(v) for v in results_filenames_before_filtering.values()])}")

Total passing filters among complete files: 510/589


In [42]:
all_scores, all_reliance, all_overreliance, all_underreliance, all_times, \
    all_task_true, all_ai_pred, all_user_decision, all_quest_order = compute_scores(results_dir,
                                                                                          results_filenames_d,
                                                                                          protocol_paths_d,
                                                                                          tasks_dir,
                                                                                          EASY_TASKS + DIFFICULT_TASKS,
                                                                                          TASK_PROTOCOL_KEYS)

In [45]:
scores_easy_mild, reliance_easy_mild, overreliance_easy_mild, underreliance_easy_mild, answertimes_easy_mild, \
    task_true_easy_mild, ai_pred_easy_mild, user_decision_easy_mild, quest_order_easy_mild = compute_scores(
    results_dir, results_filenames_d, protocol_paths_d, tasks_dir, ["easy1_find_pattern_rot"], TASK_PROTOCOL_KEYS)
scores_easy_strong, reliance_easy_strong, overreliance_easy_strong, underreliance_easy_strong, answertimes_easy_strong, \
    task_true_easy_strong, ai_pred_easy_strong, user_decision_easy_strong, quest_order_easy_strong = compute_scores(
    results_dir, results_filenames_d, protocol_paths_d, tasks_dir, ["easy3_find_pattern_rot"], TASK_PROTOCOL_KEYS)
scores_hard_mild, reliance_hard_mild, overreliance_hard_mild, underreliance_hard_mild, answertimes_hard_mild, \
    task_true_hard_mild, ai_pred_hard_mild, user_decision_hard_mild, quest_order_hard_mild = compute_scores(
    results_dir, results_filenames_d, protocol_paths_d, tasks_dir, ["hard1_find_pattern_rot"], TASK_PROTOCOL_KEYS)
scores_hard_strong, reliance_hard_strong, overreliance_hard_strong, underreliance_hard_strong, answertimes_hard_strong, \
    task_true_hard_strong, ai_pred_hard_strong, user_decision_hard_strong, quest_order_hard_strong = compute_scores(
    results_dir, results_filenames_d, protocol_paths_d, tasks_dir, ["hard3_find_pattern_rot"], TASK_PROTOCOL_KEYS)

comprehension_score, _, _, _, _, _, _, _, _ = compute_scores(results_dir, results_filenames_d, protocol_paths_d, tasks_dir,
                                                 COMPREHENSION_TASKS, TASK_PROTOCOL_KEYS)
