This notebook allows to perform a counterfactual simulation on the exclusion criteria thresholds.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import datetime
import json
import os
import importlib

%cd ../..
from tools.mturk.mturk import MTurkHIT
from tools.mturk.spawn_experiment import get_verify_task_callback
%cd tools/data_analysis
from utils import utils_data

In [None]:
folder = "data/experiment_202303/resnet50_natural_20230325"

In [None]:
stored_results = utils_data.load_results(folder)

In [None]:
df_results = utils_data.parse_results(stored_results, use_raw_data=False)
df_checks = utils_data.parse_check_results(stored_results)

In [None]:
structure = utils_data.load_and_parse_trial_structure(os.path.join(
    folder, "structure.json"))
df_results = utils_data.append_trial_structure_to_results(df_results, structure)

In [None]:
# Allows counter-factual simulation on the rejected criteria

for i, row in df_checks.iterrows():
    if not row["passed_checks"]:
        if not row["instruction_time_result"] and row["instruction_time_details"]["total_time"] > 15:
            df_checks.loc[i, "instruction_time_result"] = True
            v = True
            for k in ("catch_trials_result", "row_variability_result",
                      "row_variability_result", "instruction_time_result"):
                v = v and df_checks.loc[i, k]
            df_checks.loc[i, "passed_checks"] = v
            print("Changing overall check status from False to", df_checks.loc[i, "passed_checks"])

In [None]:
df_results = utils_data.append_checks_to_results(df_results, df_checks)

In [None]:
df_results_passed = df_results.copy(deep=True)[df_results["passed_checks"]]
df_results_rejected = df_results.copy(deep=True)[~df_results["passed_checks"]]

In [None]:
if len(df_results_passed) == len(df_results_rejected):
    print("WARNING: Number of rejected trials equals that of passed trials; this could be a bug.")

print("only using passed responses:", df_results_passed["correct"].mean())
print("only using rejected responses:", df_results_rejected["correct"].mean())

In [None]:
catch_trial_ratio_threshold = 0.8
min_total_response_time = 135
max_total_response_time = 2500
min_instruction_time = 15
max_instruction_time = 180
row_variability_threshold = 5
max_demo_trials_attempts = 3
verify_task_callback = get_verify_task_callback(
    "2afc",
    catch_trial_ratio_threshold,
    min_total_response_time,
    max_total_response_time,
    min_instruction_time,
    max_instruction_time,
    row_variability_threshold,
    max_demo_trials_attempts,
)

# Dummy HIT
hit = MTurkHIT(
    "1",
    "1",
    "1",
    "1",
    "1",
    1,
    datetime.datetime.now(),
    datetime.datetime.now(),
    1,
    "2afc",
)

raw_response = stored_results[0].raw_responses[0]
verify_task_callback(hit, raw_response)

In [None]:
df_results_main = df_results[~df_results["catch_trial"] & ~df_results["is_demo"]]
df_results_catch = df_results[~df_results["catch_trial"] & ~df_results["is_demo"]]
df_results_demo = df_results[df_results["is_demo"]]

df_results_passed_main = df_results_passed[~df_results_passed["catch_trial"] & ~df_results_passed["is_demo"]]
df_results_passed_catch = df_results_passed[~df_results_passed["catch_trial"] & ~df_results_passed["is_demo"]]
df_results_passed_demo = df_results_passed[df_results_passed["is_demo"]]

df_results_rejected_main = df_results_rejected[~df_results_rejected["catch_trial"] & ~df_results_rejected["is_demo"]]
df_results_rejected_catch = df_results_rejected[df_results_rejected["catch_trial"] & ~df_results_rejected["is_demo"]]
df_results_rejected_demo = df_results_rejected[df_results_rejected["is_demo"]]

In [None]:
print(df_results_passed_main.shape, df_results_rejected_main.shape)
print(df_results_passed_main["correct"].mean(), df_results_rejected_main["correct"].mean())

In [None]:
from utils import utils_analysis
utils_analysis.apply_all_checks(df_checks)

In [None]:
# df_demo = df_results_passed_demo[["participant_id", "correct"]].groupby("participant_id").mean()
# df_demo = df_demo.rename(columns={'correct':'correct_demo'})

df_demo = df_checks[["worker_id", "demo_trials_details_extracted"]].rename(
    columns={"worker_id": "participant_id", "demo_trials_details_extracted": "demo_repetitions"})
df_demo = df_demo.set_index("participant_id")

df_main = df_results_passed[["participant_id", "correct"]].groupby("participant_id").mean()
df_main = df_main.rename(columns={'correct':'correct_main'})
df_merged = pd.concat((df_main, df_demo), axis=1)

plt.scatter(df_merged["demo_repetitions"], df_merged["correct_main"])
plt.xlabel("#demo repetitions")
plt.ylabel("main performance")
plt.show()

In [None]:
keys = ("catch_trials_result", "row_variability_result",
        "row_variability_result", "instruction_time_result",
        "demo_trials_result")
fig, axs = plt.subplots(1, len(keys), figsize=(2.5*len(keys), 3))
axs = axs.flatten()
for k, ax in zip(keys, axs):
    df_checks[k].value_counts().plot(kind="bar", ax=ax)
    ax.set_title(k)
    ax.set_xlabel("Passed Exclusion Criteria")
    ax.set_ylabel("Count")
plt.tight_layout()

In [None]:
keys = ('instruction_time_details_extracted',
       'total_response_time_details_extracted',
       'row_variability_details_details_upper_extracted',
       'row_variability_details_details_lower_extracted',
       'catch_trials_details_ratio_extracted',
       'demo_trials_details_extracted')
fig, axs = plt.subplots(int(np.ceil(len(keys) / 3)), 3, figsize=(8, 5))
axs = axs.flatten()
for ax in axs:
    ax.axis("off")
for k, ax in zip(keys, axs):
    ax.axis("on")
    ax.hist(df_checks[k], bins=20)
    ax.set_title(k.replace("_extracted", ""))
    ax.set_xlabel("Passed Exclusion Criteria")
    ax.set_ylabel("Count")
plt.tight_layout()

### Correlation b/w mean/min/max RT and Accuracy per Unit

In [None]:
def f(df):
    accuracy = df["correct"].mean()
    min_rt = df["rt"].min()
    max_rt = df["rt"].max()
    mean_rt = df["rt"].mean()
    median_rt = df["rt"].median()
    return pd.Series([accuracy, min_rt, max_rt, mean_rt, median_rt],
                     index=["accuracy", "min_rt", "max_rt", "mean_rt", "median_rt"])
pdf_results = df_results.groupby("participant_id").apply(f)

plt.scatter(pdf_results["min_rt"] / 1000, pdf_results["accuracy"], label="min")
plt.scatter(pdf_results["mean_rt"] / 1000, pdf_results["accuracy"], label="mean")
# plt.scatter(pdf_results["median_rt"] / 1000, pdf_results["accuracy"], label="median")
#plt.scatter(pdf_results["max_rt"] / 1000, pdf_results["accuracy"], label="max")
plt.ylabel("Accuracy")
plt.xlabel("Reaction Time [s]")
plt.legend()

### Correlation b/w RT and Correctness of Individual Responses

In [None]:
plt.scatter(df_results["rt"] / 1000, df_results["correct"])
plt.xlabel("Reaction Time [s]")
plt.ylabel("Trial Correctly Solved?")