In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import importlib
import matplotlib
import scikit_posthocs

%cd ../..
from tools.mturk.mturk import MTurkHIT
from tools.mturk.spawn_experiment import get_verify_task_callback
from stimuli_generation import sg_utils as utils_stimuli_generation
%cd tools/data_analysis
from utils import utils_data
from utils import utils_analysis

In [None]:
from matplotlib import rcParams

rcParams["font.family"] = "sans-serif"
rcParams["font.sans-serif"] = ["DejaVu Sans"]

# output text as text and not paths
rcParams["svg.fonttype"] = "none"
rcParams["pdf.fonttype"] = "truetype"

colors = {
    "synthetic": [71 / 255, 120 / 255, 158 / 255], 
    "natural": [255 / 255, 172 / 255, 116 / 255],

    "natural easy": [255 / 255, 172 / 255, 116 / 255],
    "natural medium": [197 / 255, 135 / 255, 96 / 255],
    "natural hard": [150 / 255, 105 / 255, 75 / 255],
    "natural very hard": [109 / 255, 75 / 255, 52 / 255],

    "synthetic easy": [71 / 255, 120 / 255, 158 / 255], 
    "synthetic medium": [52 / 255, 81 / 255, 105 / 255],
    "synthetic very hard": [39 / 255, 61 / 255, 79 / 255], 

    "natural low": [255 / 255, 172 / 255, 116 / 255],
    "natural high": [146 / 255, 100 / 255, 71 / 255],

    "synthetic low": [71 / 255, 120 / 255, 158 / 255], 
    "synthetic high": [39 / 255, 61 / 255, 79 / 255], 

    "c0": [245 / 255, 181 / 255, 121 / 255],
    "c1": [244 / 255, 170 / 255, 113 / 255],
    "c2": [203 / 255, 147 / 255, 94 / 255],
    "c3": [196 / 255, 134 / 255, 91 / 255],
    "c4": [150 / 255, 108 / 255, 68 / 255],
    "c5": [142 / 255, 98 / 255, 67 / 255],
    "c6": [116 / 255, 72 / 255, 44 / 255],
    "c7": [103 / 255, 76 / 255, 54 / 255],
    "c8":  [88 / 255, 73 / 255, 59 / 255],
}

# Load Data

In [None]:
experiments = {
    "GoogLeNet (natural)": "data/experiment_202303/googlenet_natural_20230312",
    "GoogLeNet (synthetic)": "data/experiment_202303/googlenet_optimized_20230312",

    "DenseNet (natural)": "data/experiment_202303/densenet_201_natural_20230412",
    "DenseNet (synthetic)": "data/experiment_202303/densenet_201_optimized_20230503",

    "ResNet (natural)": "data/experiment_202303/resnet50_natural_20230325",
    "ResNet (synthetic)": "data/experiment_202303/resnet50_optimized_20230325",

    "Hard85 ResNet (natural)": "data/experiment_202303/resnet50_hard85_natural_20230505",
    "Hard85 ResNet (synthetic)": None,

    "Hard95 ResNet (natural)": "data/experiment_202303/resnet50_hard95_natural_20230510",
    "Hard95 ResNet (synthetic)": "data/experiment_202303/resnet50_hard95_optimized_20230514",

    "Hard99 ResNet (natural)": "data/experiment_202303/resnet50_hard99_natural_20230803",
    "Hard99 ResNet (synthetic)": None,

    "Robust ResNet (natural)": "data/experiment_202303/resnet50-l2_natural_20230314",
    "Robust ResNet (synthetic)": "data/experiment_202303/resnet50-l2_optimized_20230314",

    "Clip ResNet (natural)": "data/experiment_202303/clip-resnet50_natural_20230412",
    "Clip ResNet (synthetic)": "data/experiment_202303/clip-resnet50_optimized_20230502",

    "Hard85 Clip ResNet (natural)": "data/experiment_202303/clip-resnet50_hard85_natural_20230505",
    "Hard85 Clip ResNet (synthetic)": None,

    "Hard95 Clip ResNet (natural)": "data/experiment_202303/clip-resnet50_hard95_natural_20230510",
    "Hard95 Clip ResNet (synthetic)": None,

    "Hard99 Clip ResNet (natural)": "data/experiment_202303/clip-resnet50_hard99_natural_20230803",
    "Hard99 Clip ResNet (synthetic)": None,

    "WideResNet (natural)": "data/experiment_202303/wide_resnet50_natural_20230412",
    "WideResNet (synthetic)": "data/experiment_202303/wide_resnet50_optimized_20230502",

    "ViT (natural)": "data/experiment_202303/in1k-vit_b32_natural_20230501",
    "ViT (synthetic)": None,

    "Clip ViT (natural)": "data/experiment_202303/clip-vit_b32_natural_20230501",
    "Clip ViT (synthetic)": None,

    "ConvNeXT (natural)": "data/experiment_202303/convnext_b_natural_20230429",
    "ConvNeXT (synthetic)": "data/experiment_202303/convnext_b_optimized_20230509",
}

In [None]:
dfs_results = {}
dfs_checks = {}

In [None]:
for k in experiments:
    if (dfs_checks.get(k, None) is not None and dfs_results.get(k, None) is not None) and (len(dfs_checks.get(k, [])) > 0 and len(dfs_results.get(k, [])) > 0):
        continue

    if experiments[k] is None:
        dfs_results[k] = None
        dfs_checks[k] = None
        continue

    print(k)

    results = utils_data.load_results(experiments[k])
    structure = utils_data.load_and_parse_trial_structure(os.path.join(experiments[k], "structure.json"))
    df_results = utils_data.parse_results(results, use_raw_data=False)
    df_results = utils_data.append_trial_structure_to_results(df_results, structure)
    df_checks = utils_analysis.apply_all_checks(utils_data.parse_check_results(results))
    dfs_checks[k] = df_checks

    df_results = utils_data.append_checks_to_results(df_results, df_checks)
    dfs_results[k] = df_results

In [None]:
reference_key = [k for k in experiments if k is not None][0]
reference_df_results = dfs_results[reference_key]
reference_df_checks = dfs_checks[reference_key]
for k in experiments:
    if experiments[k] is None:
        dfs_results[k] = reference_df_results.copy().head(0)
        dfs_checks[k] = reference_df_checks.copy().head(0)

In [None]:
# Ensure that the correct structures have been used and no two experiments use the same model & mode
unique_modes_and_models = {k: (set(dfs_results[k]["mode"].unique()), set(dfs_results[k]["model"].unique())) for k in dfs_results}
for k1 in unique_modes_and_models:
    for k2 in unique_modes_and_models:
        if k1 == k2:
            continue

        if len(unique_modes_and_models[k1][0]) == 0 or len(unique_modes_and_models[k2][0]) == 0:
            continue

        if unique_modes_and_models[k1] == unique_modes_and_models[k2]:
            print(f"WARNING: {k1} and {k2} use the same model and mode!")

### Load & apply list of flawed responses through multi participation

In [None]:
import pickle
flawed_tasks_and_participants = pickle.load(open("data/experiment_202303_neurips_submission/flawed_tasks_and_participants.pd.pkl", "rb"))

In [None]:
for k in dfs_results:
    dfs_results[k]["multi_participation"] = False

for _, row in flawed_tasks_and_participants.iterrows():
    task_id = row["task_id"]
    worker_id = row["worker_id"]

    # Find relevant responses
    for k in dfs_results:
        df_results = dfs_results[k]
        df_results.loc[(df_results["task_id"] == task_id) & (df_results["worker_id"] == worker_id), "multi_participation"] = True

In [None]:
dfs_results_mp = {k: df[df["multi_participation"]] for k, df in dfs_results.items()}
dfs_results_no_mp = {k: df[~df["multi_participation"]] for k, df in dfs_results.items()}

dfs_results_main = {k: df[~df["catch_trial"] & ~df["is_demo"]] for k, df in dfs_results_no_mp.items()}
dfs_results_catch = {k: df[df["catch_trial"] & ~df["is_demo"]] for k, df in dfs_results_no_mp.items()}
dfs_results_demo = {k: df[df["is_demo"]] for k, df in dfs_results_no_mp.items()}

dfs_results_no_mp_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_no_mp.items()} 

dfs_results_main_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_main.items()}
dfs_results_catch_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_catch.items()}
dfs_results_demo_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_demo.items()}

dfs_results_main_rejected = {k: df[~df["passed_checks"]] for k, df in dfs_results_main.items()}
dfs_results_catch_rejected = {k: df[~df["passed_checks"]] for k, df in dfs_results_catch.items()}
dfs_results_demo_rejected = {k: df[~df["passed_checks"]] for k, df in dfs_results_demo.items()}

dfs_results_mp_main = {k: df[~df["catch_trial"] & ~df["is_demo"]] for k, df in dfs_results_mp.items()}
dfs_results_mp_catch = {k: df[df["catch_trial"] & ~df["is_demo"]] for k, df in dfs_results_mp.items()}
dfs_results_mp_demo = {k: df[df["is_demo"] == True] for k, df in dfs_results_mp.items()}

dfs_results_mp_main_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_mp_main.items()}
dfs_results_mp_catch_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_mp_catch.items()}
dfs_results_mp_demo_passed = {k: df[df["passed_checks"]] for k, df in dfs_results_mp_demo.items()}

dfs_results_mp_main_rejected = {k: df[~df["passed_checks"]] for k, df in dfs_results_mp_main.items()}
dfs_results_mp_catch_rejected = {k: df[~df["passed_checks"]] for k, df in dfs_results_mp_catch.items()}
dfs_results_mp_demo_rejected = {k: df[~df["passed_checks"]] for k, df in dfs_results_mp_demo.items()}

dfs_results_mp_or_rejected = {k: df[
    ~df["passed_checks"] | df["multi_participation"]] for k, df in dfs_results.items()}
dfs_results_main_mp_or_rejected = {k: df[~df["catch_trial"] & ~df["is_demo"] & (
    ~df["passed_checks"] | df["multi_participation"])] for k, df in dfs_results.items()}
dfs_results_catch_mp_or_rejected = {k: df[df["catch_trial"] & (
    ~df["passed_checks"] | df["multi_participation"])] for k, df in dfs_results.items()}
dfs_results_demo_mp_or_rejected = {k: df[df["is_demo"] & (
    ~df["passed_checks"] | df["multi_participation"])] for k, df in dfs_results.items()}

In [None]:
print("Number of all responses:", sum([len(df) for df in dfs_results.values()]))
print("Number of all clean responses:", sum([len(df) for df in dfs_results_no_mp.values()]))
print("Number of all flawed responses:", sum([len(df) for df in dfs_results_mp.values()]))
print()

print("Number of all clean responses:", sum([len(df) for df in dfs_results_no_mp.values()]))
print("Number of all clean main responses:", sum([len(df) for df in dfs_results_main.values()]))
print("Number of all clean passed main responses:", sum([len(df) for df in dfs_results_main_passed.values()]))
print()

print("Number of all flawed responses:", sum([len(df) for df in dfs_results_mp.values()]))
print("Number of all flawed main responses:", sum([len(df) for df in dfs_results_mp_main.values()]))
print("Number of all flawed passed main responses:", sum([len(df) for df in dfs_results_mp_main_passed.values()]))
print()

print("Number of all passed responses:", sum([len(df) for df in dfs_results_no_mp_passed.values()]))
print("Number of all flawed/rejected responses:", sum([len(df) for df in dfs_results_mp_or_rejected.values()]))
print("Number of all flawed/rejected main responses:", sum([len(df) for df in dfs_results_main_mp_or_rejected.values()]))
print()

print("Number of all participants:", sum([len(df["worker_id"].unique()) for df in dfs_results.values()]))
print("Number of clean participants:", sum([len(df["worker_id"].unique()) for df in dfs_results_no_mp.values()]))
print("Number of flawed/rejected participants:", sum([len(df["worker_id"].unique()) for df in dfs_results_mp_or_rejected.values()]))

In [None]:
all_units = dict()
n_units = dict()
for k in dfs_results:
    if len(dfs_results[k]) == 0:
        continue
    units = set(dfs_results[k].fillna('nan').apply(lambda x: f"{x['model'].replace('_hard99', '').replace('_hard98', '').replace('_hard95', '').replace('_hard90', '').replace('_hard85', '')}_{x['layer']}_{x['channel']}", axis=1).to_list())
    all_units[k] = units

    n_units[k] = len(dfs_results[k].groupby(["layer", "channel"]).mean(numeric_only=True))

for model, n_units in zip(*np.unique([it.split("_")[0] for it in list(set([it for k in all_units for it in all_units[k]]))], return_counts=True)):
    print(f"Model {model}: {n_units} units")

print("Total number of unique units:", len(set([it for k in all_units for it in all_units[k]])))

## Analyze how many responses have been collected/accepted per unit/model

In [None]:
dfs = dfs_results_main_passed
meta_datas = []
for k in dfs:
    if len(dfs[k]) == 0:
        continue
    n_units = dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True).shape[0]
    n_responses_counts = dfs[k].groupby(["layer", "channel"])["correct"].apply(lambda df: len(df)).value_counts().to_dict()
    meta_datas.append({
        "model_condition": k,
        "units": n_units,
        "responses_per_unit": sorted(list(n_responses_counts.keys())),
        "completed": (n_responses_counts.get(31, 0) == n_responses_counts.get(32, 0) == 40 if n_units == 80 else
            n_responses_counts.get(30, 0) == 84 if n_units == 84 else np.nan)
        })
meta_datas = pd.DataFrame(meta_datas)
meta_datas = meta_datas.set_index("model_condition")

print("All Experiments")
display(meta_datas)

print("Incomplete Experiments")
display(meta_datas[meta_datas["completed"] == False])

# Visualizations

In [None]:
# Set this to True to use data of participants that participated in multiple experiments and
# data of participants that did not pass the checks.
use_flawed_data = False

if use_flawed_data:
    results_dir = "results_rejected"
else:
    results_dir = "results"

os.makedirs(results_dir, exist_ok=True)

### Calculate Confidence Intervals/Stds/SEMs

In [None]:
import collections

if use_flawed_data:
    dfs = {k: dfs_results_main_mp_or_rejected[k].copy() for k in dfs_results_main_passed}
else:
    dfs = {k: dfs_results_main_passed[k].copy() for k in dfs_results_main_passed}


dummy_test_results = collections.namedtuple("DummyTestResults", ("confidence_interval",))((np.nan, np.nan))

import scipy
confidences = {k: scipy.stats.bootstrap(
    dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True)["correct"].values.reshape(  # .groupby("participant_id").mean(numeric_only=True)
    (1, -1)), statistic=np.mean, n_resamples=1_0_000) if experiments[k] is not None else dummy_test_results for k in dfs}

means = {k: dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True).mean(numeric_only=True) for k in experiments}
stds = {k: dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True).std(numeric_only=True) for k in experiments}
sems = {k: stds[k] / np.sqrt(len(dfs[k])) if len(dfs[k]) > 0 else stds[k] for k in experiments}

In [None]:
def plot_performance(means, confidences, relevant_experiments, labels=("Natural", "Synthetic"), color_names=("natural", "synthetic"), legend: bool = True,
                     chance_level: bool = True, xticks: bool = True, yticks: bool = True, xlabel="Model", ylabel="Proportion Correct",
                     bar_width: float = 1, grid: bool = True, rotate_ticks: bool = False, y_max: float = 0.9, legend_cols: int = 2,
                     show_errorbars: bool = True, ax=None):
    x = np.arange(len(relevant_experiments) // len(labels) * (len(labels) + 1), dtype=float)

    relevant_means = [means[k]["correct"] for k in relevant_experiments]
    error_values = np.array([
        (max(0, -confidences[k].confidence_interval[0] + means[k]["correct"]),
        max(0, confidences[k].confidence_interval[1] - means[k]["correct"]))
        for k in relevant_experiments]).T
    #error_values = np.array([2 * sems[k]["correct"] for k in experiments])
    #error_values = np.array([stds[k]["correct"] for k in experiments])
    missing_indices = [i for i, it in enumerate(relevant_means) if np.isnan(it)]
    #relevant_means = [it for i, it in enumerate(relevant_means) if i not in missing_indices]
    #error_values = error_values[:, [i for i in range(len(relevant_means)) if i not in missing_indices]]
    for i in missing_indices:
        group_idx = i // len(labels)
        element_idx = i % len(labels)
        j = group_idx * (len(labels) + 1) + element_idx
        x[j + 1:] = x[j:-1]

    
    for i in range(1, len(relevant_experiments) // len(labels)):
        x[i * (len(labels) + 1):] -= 0.5

    x = np.vstack([x[i::len(labels) + 1] for i in range(len(labels))]).reshape((-1,), order='F')

    x *= bar_width

    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(1 + bar_width*0.6*(len(relevant_means) - len(missing_indices)), 4.8))
    else:
        print("Reusing existing axis. Ensure that the axis has the right size:", (1 + bar_width*0.6*(len(relevant_means) - len(missing_indices)), 4.8))
    ax.bar(
        x,
        [means[k]["correct"] for k in relevant_experiments],
        color=[colors[color_names[i % len(labels)].lower()] for i in range(len(relevant_experiments))],
        width=bar_width
        )
    if show_errorbars:
        ax.errorbar(x, relevant_means, 
                    error_values, fmt=".", color="k")

    if xticks:
        single_bar_offset = bar_width * np.array([0.5 if np.isnan(means[k]["correct"]) else 0 for k in relevant_experiments]).reshape(-1, len(labels)).sum(-1)
        x_ticks = (x[:-1:len(labels)] + x[len(labels) - 1::len(labels)]) / 2 - single_bar_offset
        ax.set_xticks(x_ticks, [k.split("(")[0].replace(" ", "\n") for k in relevant_experiments][::len(labels)], rotation=45 if rotate_ticks else 0)
    else:
        ax.set_xticks([])

    if not yticks:
        ax.set_yticks([])

    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel, labelpad=-10)
    ax.set_ylim((0.475, y_max))
    
    x_max = max([x[i] for i, k in enumerate(relevant_experiments) if not np.isnan(means[k]["correct"])])

    if legend:
        hdls = [ax.scatter([], [], color=colors[k.lower()]) for k in color_names]
        ax.legend(hdls, labels, frameon=False, ncol=legend_cols)

    if chance_level:
        ax.text(x[1] + 0.6 if len(x) > 2 else x[0] - bar_width, 0.505, "Chance")
        ax.hlines(0.5, x.min() - 1, x.max() + 1, color="k", ls="-.", lw=1)

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(-1.0 * bar_width, x_max + 0.75 * bar_width)
    ax.spines["bottom"].set_bounds((-0.75 * bar_width, x_max + 0.75 * bar_width))

    if grid:
        ax.set_axisbelow(True)
        ax.grid(linestyle="dashed", axis="y")

    plt.tight_layout()

### Main Visualization

In [None]:
relevant_experiments = [k for k in experiments if "hard" not in k.lower()]
plot_performance(means, confidences, relevant_experiments, bar_width=0.6)
plt.xlabel("Model (Ordered by Scale)")
plt.savefig(os.path.join(results_dir, "model_comparison.pdf"), bbox_inches="tight")
plt.show()

In [None]:
relevant_experiments = [k for k in experiments if "hard" not in k.lower() and "natural" in k.lower()]
print(len(relevant_experiments), len(means))
plot_performance(means, confidences, relevant_experiments, labels=[""] * len(relevant_experiments),
                 color_names=[f"c{i}" for i in range(len(relevant_experiments))],
                 legend=False, chance_level=False,
                 xticks=False, yticks=True, xlabel="Model & Dataset Size", ylabel="Interpretability Score", bar_width=0.4, grid=False,
                 show_errorbars=False)
plt.gcf().set_size_inches(2, 2)
plt.savefig(os.path.join(results_dir, "comparison_figure_1.pdf"), bbox_inches="tight")

In [None]:
class DummyInterval:
    confidence_interval = (0, 0)
relevant_experiments = [k for k in experiments if "hard" not in k.lower()]

selector_fn_0 = lambda k: False
selector_fn_1 = lambda k: "GoogLeNet (natural)" == k
selector_fn_2 = lambda k: "natural" in k

for i, selector_fn in enumerate((selector_fn_0, selector_fn_1, selector_fn_2)):
    plot_performance(
        {k: means[k] if selector_fn(k) else means[k] * 0 for k in means},
        {k: confidences[k] if selector_fn(k) else DummyInterval() for k in confidences},
        relevant_experiments, bar_width=0.6, legend=i>1)
    plt.xlabel("Model (Ordered by Scale)")
    plt.savefig(os.path.join(results_dir, f"model_comparison_slides_variation_{i + 1}.pdf"), bbox_inches="tight")
    plt.show()

### Hard vs. Easy Condition

In [None]:
relevant_experiments = [
    "ResNet (natural)",
    "Hard99 ResNet (natural)",
    "Hard95 ResNet (natural)",
    "Hard85 ResNet (natural)",
    "Clip ResNet (natural)",
    "Hard95 Clip ResNet (natural)",
    "Hard99 Clip ResNet (natural)",
    "Hard85 Clip ResNet (natural)",
]

fig, axs = plt.subplots(1, 2, figsize=(3.88+1.72, 4.8), sharey=True, width_ratios=[3.88, 1.12])
plot_performance(means, confidences, relevant_experiments, ("Easy", "Medium", "Hard", "Very Hard"),
                 color_names=("natural easy", "natural medium", "natural hard", "natural very hard"),  bar_width=0.6,
                 ax=axs[0])

# Uncomment to get plots separately.
# plt.tight_layout()
# plt.savefig("results/model_comparison_rn_easy_vs_hard_natural.pdf", bbox_inches="tight")
# plt.show()

relevant_experiments = [
    "ResNet (synthetic)",
    "Hard95 ResNet (synthetic)",
]
plot_performance(means, confidences, relevant_experiments, ("Easy", "Hard"),
                 color_names=("synthetic easy", "synthetic medium"), bar_width=0.6, legend_cols=1,
                 ax=axs[1])
axs[1].set_ylabel("")

plt.tight_layout()
# Uncomment to get plots separately.
#plt.savefig(os.path.join(results_dir, "model_comparison_rn_easy_vs_hard_optimized.pdf"), bbox_inches="tight")

plt.savefig(os.path.join(results_dir, "model_comparison_rn_easy_vs_hard_natural_and_optimized.pdf"), bbox_inches="tight")
plt.show()

### Per Unit

In [None]:
relevant_experiments = [k for k in experiments if "hard" not in k.lower()]

if use_flawed_data:
    dfs = {k: dfs_results_main_mp_or_rejected[k].copy() for k in relevant_experiments}
else:
    dfs = {k: dfs_results_main_passed[k].copy() for k in relevant_experiments}

unit_mean_dfs = {k: dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True).reset_index() for k in relevant_experiments}

model_names = {
    "ResNet": "resnet50",
    "Robust ResNet": "resnet50-l2",
    "GoogLeNet": "googlenet",
    "Clip ResNet": "clip-resnet50",
    "WideResNet": "wide_resnet50",
    "DenseNet": "densenet_201",
    "ConvNeXT": "convnext_b",
    "Clip ViT": "clip-vit_b32",
    "ViT": "in1k-vit_b32",
    "Hard99 Clip ResNet": "clip-resnet50",
    "Hard99 ResNet": "resnet50",
    "Hard95 Clip ResNet": "clip-resnet50",
    "Hard95 ResNet": "resnet50",
    "Hard85 Clip ResNet": "clip-resnet50",
    "Hard85 ResNet": "resnet50",
}

for i, k in enumerate(list(unit_mean_dfs.keys())):
    model_name = model_names[k.split(" (")[0]]
    model = utils_stimuli_generation.load_model(model_name)
    network_layers = utils_stimuli_generation.get_relevant_layers(model, model_name)

    if model_name == "clip-vit_b32":
        network_layers = ["visual_" + it for it in network_layers]

    unit_mean_dfs[k]["layer_index"] = unit_mean_dfs[k]["layer"].map(lambda l: network_layers.index(l))

    unit_mean_dfs[k] = unit_mean_dfs[k].sort_values("layer_index")

In [None]:
fig, ax = plt.subplots(1, 1)

cmap = matplotlib.cm.get_cmap('tab10')

relevant_experiments = [k for k in experiments if "natural" in k.lower() and "hard" not in k.lower()]

color_values = [cmap(i / len(relevant_experiments)) for i in range(len(relevant_experiments))]

relevant_unit_mean_dfs = {k: unit_mean_dfs[k] for k in relevant_experiments}

for i, k in enumerate(relevant_experiments):
    print(k, np.sum(relevant_unit_mean_dfs[k]["correct"] > 0.95), len(relevant_unit_mean_dfs[k]["correct"]))

for i, k in enumerate(relevant_experiments):
    label = k.split(" (")[0]

    # Pool units in the same layer and compute mean proportion correct.
    grouped_df = relevant_unit_mean_dfs[k].groupby("layer_index").apply(lambda df: df[["layer_index", "correct"]].mean())

    x = grouped_df["layer_index"] / grouped_df["layer_index"].max()
    y = grouped_df["correct"]

    ax.scatter(x, y,# + np.random.uniform(-0.01, 0.01, size=len(unit_mean_dfs[k])),
               s=6, label=label, color=color_values[i])

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_ylabel("Proportion Correct")
ax.set_xlabel("Relative Layer Index")
ax.legend(ncol=3)
plt.show()

fig, axs = plt.subplots(int(np.ceil(len(relevant_unit_mean_dfs) / 3)), 3, sharex=False, sharey=False)
fig.set_size_inches(6, 5)

axs_f = axs.flatten()
for ax in axs_f:
    ax.axis("off")

for i, (k, ax) in enumerate(zip(relevant_unit_mean_dfs, axs_f)):
    label = k.split(" (")[0]

    # Pool units in the same layer and compute mean proportion correct.
    grouped_df = relevant_unit_mean_dfs[k].groupby("layer_index").apply(lambda df: df[["layer_index", "correct"]].mean())

    x = grouped_df["layer_index"] / grouped_df["layer_index"].max()
    y = grouped_df["correct"]
    ax.scatter(
        x, y,
        s=7,
        color=color_values[i],
        alpha=0.8,
        linewidth=0,
        clip_on=False
    )

    print(f"{k}\t", scipy.stats.spearmanr(x, y))
    ax.set_title(label)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_bounds((0, 1))
    ax.spines["left"].set_bounds((0.25, 1))
    ax.set_ylim((0.25, 1))
    ax.axis("on")

    ax.tick_params(axis='both', which='both', labelsize=11)

axs[1, 0].set_ylabel("Proportion Correct (Natural)", fontsize=12)
axs[-1, 1].set_xlabel("Relative Layer Position", fontsize=12)

for i, ax in enumerate(axs_f):
    if i == 0 or i == 3 or i == 6:
        ax.set_yticks([0.25, 0.5, 0.75, 1.0], ["0.25", "0.50", "0.75", "1.0"])
    else:
        ax.set_yticks([0.25, 0.5, 0.75, 1.0], ["", "", "", ""])
    if i == 6 or i == 7 or i == 8:
        ax.set_xticks([0.0, 0.5, 1.0], ["0.0", "0.5", "1.0"])
    else:
        ax.set_xticks([0.0, 0.5, 1.0], ["", "", ""])

plt.tight_layout()

plt.savefig(os.path.join(results_dir, "model_comparison_unit_performance.pdf"), bbox_inches="tight")

In [None]:
cmap = matplotlib.cm.get_cmap('tab10')

relevant_synthetic_experiments = [k for k in experiments if "synthetic" in k.lower() and not "hard" in k.lower() if len(unit_mean_dfs[k]) > 0]
relevant_natural_experiments = [k.replace("synthetic", "natural") for k in relevant_synthetic_experiments]
relevant_natural_experiments = [k for k in relevant_natural_experiments if k in experiments]

#color_values = [cmap(i / (len(relevant_synthetic_experiments) - 1)) for i in range(len(relevant_synthetic_experiments))]

fig, axs = plt.subplots(int(np.ceil(len(relevant_synthetic_experiments) / 3)), 3, sharex=False, sharey=False)
fig.set_size_inches(6, 5)

axs_f = axs.flatten()
for ax in axs_f:
    ax.axis("off")

if len(relevant_natural_experiments) == 7:
    relevant_natural_experiments = relevant_natural_experiments[:6] + [None, relevant_natural_experiments[-1]]
    relevant_synthetic_experiments = relevant_synthetic_experiments[:6] + [None, relevant_synthetic_experiments[-1]]

i_correction = 0
for i, (k_natural, k_synthetic, ax) in enumerate(zip(relevant_natural_experiments, relevant_synthetic_experiments, axs_f)):
    if k_natural is None:
        i_correction += 1
        continue
    label = k_natural.split(" (")[0]

    grouped_natural_df = unit_mean_dfs[k_natural].set_index(["layer", "channel"])
    grouped_synthetic_df = unit_mean_dfs[k_synthetic].set_index(["layer", "channel"])

    x = grouped_natural_df["correct"]
    y = grouped_synthetic_df["correct"]

    print(f"{label}\t{x.mean()} {y.mean()}")

    ax.scatter(
        x, y,
        s=7,
        color=color_values[i + i_correction],
        alpha=0.8,
        linewidth=0,
        clip_on=False
    )

    # ax.scatter([x.mean()], [y.mean()], color="red")

    print(f"{label}\t", scipy.stats.spearmanr(x, y))
    ax.set_title(label)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_bounds((0.40, 1))
    ax.spines["left"].set_bounds((0.00, 1))
    ax.set_ylim((0.00, 1))
    ax.set_xlim((0.35, 1))
    ax.axis("on")

    ax.tick_params(axis='both', which='both', labelsize=11)

for i, ax in enumerate(axs_f):
    if not (i == 0 or i == 3):
        ax.set_yticks([0, 0.5, 1.0], ["", "", ""])
    else:
        ax.set_yticks([0, 0.5, 1.0], ["0.0", "0.5", "1.0"])
    if i == 3 or i == 5 or i == 7:
        ax.set_xticks([0.4, 0.7, 1.0], ["0.4", "0.7", "1.0"])
    else:
        ax.set_xticks([0.4, 0.7, 1.0], ["", "", ""])


axs[1, 0].set_ylabel("Proportion Correct (Synthetic)", fontsize=12)
axs[-1, 1].set_xlabel("Proportion Correct (Natural)", fontsize=12)

plt.tight_layout()

axs_f[-2].set_autoscale_on(False)
# Adding yticklabels for last plot manually
ticks = [0.0, 0.5, 1.0]
for tick in ticks:
    axs_f[-2].text(
        x=-0.225,
        y=tick - 0.05, # small offset for some reason
        s=str(tick),
        transform=axs_f[-2].transAxes,
        fontsize=11
    )

plt.savefig(os.path.join(results_dir, "model_comparison_unit_performance_natural_vs_synthetic.pdf"), bbox_inches="tight")

In [None]:
# Pool units in the same layer and compute mean proportion correct.
k = "Clip ResNet (natural)"
grouped_df = unit_mean_dfs[k].groupby("layer_index").apply(lambda df: df[["layer_index", "correct"]].mean())

grouped_df = grouped_df[grouped_df["correct"] <0.90]

x = grouped_df["layer_index"] / grouped_df["layer_index"].max()
y = grouped_df["correct"]
print(f"{k}\t", scipy.stats.spearmanr(x, y))

In [None]:
relevant_experiments = list(experiments.keys())

if use_flawed_data:
    dfs = {k: dfs_results_main_mp_or_rejected[k].copy() for k in relevant_experiments}
else:
    dfs = {k: dfs_results_main_passed[k].copy() for k in relevant_experiments}
unit_mean_dfs = {k: dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True).reset_index() for k in relevant_experiments}

joined_easy_hard_dfs = {}

easy_hard_condition_names = {
    #"ResNet50 (natural)": ("Hard85 ResNet (natural)", "Hard95 ResNet (natural)", "Hard99 ResNet (natural)", "ResNet (natural)"),
    #"Clip ResNet50 (natural)": ("Hard85 Clip ResNet (natural)", "Hard95 Clip ResNet (natural)", "Hard99 Clip ResNet (natural)", "Clip ResNet (natural)")
    "ResNet50 (natural)": ("Hard95 ResNet (natural)", "Hard99 ResNet (natural)", "ResNet (natural)"),
    "Clip ResNet50 (natural)": ("Hard95 Clip ResNet (natural)", "Hard99 Clip ResNet (natural)", "Clip ResNet (natural)")
}

fig, axs = plt.subplots(1, len(easy_hard_condition_names), sharey=True)
fig.set_size_inches(2+4*len(easy_hard_condition_names), 4)

for ax, mk in zip(axs, easy_hard_condition_names):
    restricted_unit_mean_dfs = {}
    for k in easy_hard_condition_names[mk]:
        restricted_unit_mean_dfs[k] = unit_mean_dfs[k].copy()[["layer", "channel", "correct"]]
        restricted_unit_mean_dfs[k]["layer_channel"] = restricted_unit_mean_dfs[k].apply(lambda row: f"{row['layer']}:{row['channel']}", axis=1)
        restricted_unit_mean_dfs[k] = restricted_unit_mean_dfs[k].sort_values("layer_channel")

    shared_units = set.intersection(*[set(restricted_unit_mean_dfs[k]["layer_channel"].to_list()) for k in easy_hard_condition_names[mk]])
    for k in easy_hard_condition_names[mk]:
        model_name = model_names[k.split(" (")[0]]
        model = utils_stimuli_generation.load_model(model_name)
        network_layers = utils_stimuli_generation.get_relevant_layers(model, model_name)
        restricted_unit_mean_dfs[k] = restricted_unit_mean_dfs[k][restricted_unit_mean_dfs[k]["layer_channel"].map(lambda lc: lc in shared_units)]
        restricted_unit_mean_dfs[k]["layer_index"] = restricted_unit_mean_dfs[k]["layer"].map(lambda l: network_layers.index(l))
        restricted_unit_mean_dfs[k]["layer_unit_index"] = np.argsort(restricted_unit_mean_dfs[k].apply(lambda row: row["layer_index"] + int(row["channel"]) / 10000, axis=1))
        restricted_unit_mean_dfs[k]["relative_layer_unit_index"] = restricted_unit_mean_dfs[k]["layer_unit_index"] / restricted_unit_mean_dfs[k]["layer_unit_index"].max()

        restricted_unit_mean_dfs[k] = restricted_unit_mean_dfs[k].set_index(["layer", "channel"])

        restricted_unit_mean_dfs[k].columns = [f"{'hard95' if 'Hard95' in k else 'hard85' if 'Hard85' in k else 'hard99' if 'Hard99' in k else 'easy'}_{c}" for c in restricted_unit_mean_dfs[k].columns]

    joined_easy_hard_dfs[mk] = pd.concat(restricted_unit_mean_dfs.values(), ignore_index=False, axis=1)

    from matplotlib.collections import LineCollection
    from matplotlib.colors import LinearSegmentedColormap
    if len(easy_hard_condition_names[mk]) == 4:
        column_names = ("hard85_correct", "hard95_correct", "hard99_correct", "easy_correct")
    elif len(easy_hard_condition_names[mk]) == 3:
        column_names = ("hard95_correct", "hard99_correct", "easy_correct")
    else:
        raise ValueError(f"Invalid number of easy/hard conditions. Found {len(easy_hard_condition_names[mk])} but expected 3 or 4.")

    if "hard85_relative_layer_unit_index" in joined_easy_hard_dfs[mk]:
        x = joined_easy_hard_dfs[mk]["hard85_relative_layer_unit_index"]
    elif "hard95_relative_layer_unit_index" in joined_easy_hard_dfs[mk]:
        x = joined_easy_hard_dfs[mk]["hard95_relative_layer_unit_index"]
    elif "hard99_relative_layer_unit_index" in joined_easy_hard_dfs[mk]:
        x = joined_easy_hard_dfs[mk]["hard95_relative_layer_unit_index"]
    else:
        x = joined_easy_hard_dfs[mk]["easy_relative_layer_unit_index"]
    min_y = joined_easy_hard_dfs[mk].apply(lambda row: min([row[k] for k in column_names]), axis=1)
    max_y = joined_easy_hard_dfs[mk].apply(lambda row: max([row[k] for k in column_names]), axis=1)

    [
        (247/255, 125/255, 40/255),
        (51/255, 120/255, 177/255),
        (44/255, 160/255, 44/255),
        (0, 0, 0),
    ]
    colors_raw = [
      {
        "hard85": (247/255, 125/255, 40/255),
        "hard95": (51/255, 120/255, 177/255),
        "hard99": (44/255, 160/255, 44/255),
        "easy": (0, 0, 0)
    }[k.split("_")[0]] for k in column_names]

    for i in range(len(min_y)):
        ys = np.linspace(min_y[i], max_y[i], 100)
        zs = np.linspace(min_y[i], max_y[i], 100)
        points = np.array([[(x[i], ys[j]) for j in range(len(ys))]]).reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        
        y_raw = [joined_easy_hard_dfs[mk][k][i] for k in column_names]
        sorted_idxs = np.argsort(y_raw)
        y_sorted = [y_raw[i] for i in sorted_idxs]
        colors_sorted = [colors_raw[i] for i in sorted_idxs]

        cmap_list = [(0, (0, 0, 0))] + [(y_sorted[j], colors_sorted[j]) for j in range(3)] + [(1, (0, 0, 0))]
        cmap = LinearSegmentedColormap.from_list(f"interpolation-{mk}-{i}", colors=cmap_list)
        lc = LineCollection(segments, array=zs, cmap=cmap, norm=plt.Normalize(0, 1), alpha=0.5)
        #lc.set_array(zs)
        ax.add_collection(lc)

    if len(colors_raw) == 4:
        legend_items = [("black", "v", "Easy"), ("C2", "o", "Medium"), ("C0", "x", "Hard"), ("C1", "^", "Very Hard")]
        scatter_items = [("black", "v", "easy_correct"), ("C2", "o", "hard99_correct"), ("C0", "x", "hard95_correct"), ("C1", "^", "hard85_correct")]
    else:
        legend_items = [("black", "v", "Easy"), ("C0", "x", "Medium"), ("C1", "^", "Hard")]
        scatter_items = [("black", "v", "easy_correct"), ("C0", "x", "hard99_correct"), ("C1", "^", "hard95_correct")]

    for color, m, k in scatter_items:
        ax.scatter(joined_easy_hard_dfs[mk][k.replace("correct", "relative_layer_unit_index")], joined_easy_hard_dfs[mk][k],
                s=8, color=color, zorder=2, marker=m)

    ax.set_ylabel("Proportion Correct")
    ax.set_xlabel("Relative Layer Position")

    ax.set_title(label)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    ax.spines["bottom"].set_bounds(0, 1)
    ax.spines["left"].set_bounds(0.3, 1)

    ax.set_title(mk)

    
    ax.legend([ax.scatter([], [], c=c, s=10, marker=m) for (c, m, _) in legend_items], [it[-1] for it in legend_items], frameon=False, ncol=len(legend_items))

plt.tight_layout()
plt.savefig("results/model_comparison_units_rn_easy_vs_hard.pdf", bbox_inches="tight")
#for k in joined_easy_hard_dfs:
#    print(k)
#    display(joined_easy_hard_dfs[k].sort_values("correct_gap")[["correct_gap"]])

In [None]:
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["DejaVu Sans"]

relevant_experiments = list(experiments.keys())

if use_flawed_data:
    dfs = {k: dfs_results_main_mp_or_rejected[k].copy() for k in relevant_experiments}
else:
    dfs = {k: dfs_results_main_passed[k].copy() for k in relevant_experiments}
unit_mean_dfs = {k: dfs[k].groupby(["layer", "channel"]).mean(numeric_only=True).reset_index() for k in relevant_experiments}

joined_easy_hard_dfs = {}

easy_hard_condition_names = {
    "ResNet": ("Hard85 ResNet (natural)", "Hard95 ResNet (natural)", "ResNet (natural)"),
    "Clip ResNet": ("Hard85 Clip ResNet (natural)", "Hard95 Clip ResNet (natural)", "Clip ResNet (natural)")
}

fig, axss = plt.subplots(len(easy_hard_condition_names), 4, sharex=False, sharey=True)
fig.set_size_inches(6*2, 2*len(easy_hard_condition_names))

for i, (axs, mk) in enumerate(zip(axss, easy_hard_condition_names)):
    restricted_unit_mean_dfs = {}
    for k in easy_hard_condition_names[mk]:
        restricted_unit_mean_dfs[k] = unit_mean_dfs[k].copy()[["layer", "channel", "correct"]]
        restricted_unit_mean_dfs[k]["layer_channel"] = restricted_unit_mean_dfs[k].apply(lambda row: f"{row['layer']}:{row['channel']}", axis=1)
        restricted_unit_mean_dfs[k] = restricted_unit_mean_dfs[k].sort_values("layer_channel")

    shared_units = set.intersection(*[set(restricted_unit_mean_dfs[k]["layer_channel"].to_list()) for k in easy_hard_condition_names[mk]])
    for k in easy_hard_condition_names[mk]:
        model_name = model_names[k.split(" (")[0]]
        model = utils_stimuli_generation.load_model(model_name)
        network_layers = utils_stimuli_generation.get_relevant_layers(model, model_name)
        restricted_unit_mean_dfs[k] = restricted_unit_mean_dfs[k][restricted_unit_mean_dfs[k]["layer_channel"].map(lambda lc: lc in shared_units)]
        restricted_unit_mean_dfs[k]["layer_index"] = restricted_unit_mean_dfs[k]["layer"].map(lambda l: network_layers.index(l))
        restricted_unit_mean_dfs[k]["layer_unit_index"] = np.argsort(restricted_unit_mean_dfs[k].apply(lambda row: row["layer_index"] + int(row["channel"]) / 10000, axis=1))
        restricted_unit_mean_dfs[k]["relative_layer_unit_index"] = restricted_unit_mean_dfs[k]["layer_unit_index"] / restricted_unit_mean_dfs[k]["layer_unit_index"].max()

        restricted_unit_mean_dfs[k] = restricted_unit_mean_dfs[k].set_index(["layer", "channel"])

        restricted_unit_mean_dfs[k].columns = [f"{'hard95' if 'Hard95' in k else 'hard85' if 'Hard85' in k else 'easy'}_{c}" for c in restricted_unit_mean_dfs[k].columns]

    joined_easy_hard_dfs[mk] = pd.concat(restricted_unit_mean_dfs.values(), ignore_index=False, axis=1)

    joined_easy_hard_dfs[mk]["relative_easy_correct"] = joined_easy_hard_dfs[mk]["easy_correct"] - joined_easy_hard_dfs[mk]["easy_correct"].min()
    joined_easy_hard_dfs[mk]["relative_easy_correct"] = joined_easy_hard_dfs[mk]["relative_easy_correct"] / joined_easy_hard_dfs[mk]["relative_easy_correct"].max()

    joined_easy_hard_dfs[mk]["correct_gap_85"] = joined_easy_hard_dfs[mk]["easy_correct"] - joined_easy_hard_dfs[mk]["hard85_correct"]
    joined_easy_hard_dfs[mk]["correct_gap_95"] = joined_easy_hard_dfs[mk]["easy_correct"] - joined_easy_hard_dfs[mk]["hard95_correct"]

    df = joined_easy_hard_dfs[mk].sort_values("hard85_relative_layer_unit_index")
    axs[0].plot(df["hard85_relative_layer_unit_index"], df["correct_gap_85"], ls=":", alpha=0.5)
    axs[0].scatter(df["hard85_relative_layer_unit_index"], df["correct_gap_85"], s=5)
    df = joined_easy_hard_dfs[mk].sort_values("hard95_relative_layer_unit_index")
    axs[2].plot(df["hard95_relative_layer_unit_index"], df["correct_gap_95"], ls=":", alpha=0.5)
    axs[2].scatter(df["hard95_relative_layer_unit_index"], df["correct_gap_95"], s=5)
    df = joined_easy_hard_dfs[mk].sort_values("easy_correct")
    axs[1].plot(df["easy_correct"], df["correct_gap_85"], ls=":", alpha=0.5)
    axs[3].plot(df["easy_correct"], df["correct_gap_95"], ls=":", alpha=0.5)
    axs[1].scatter(df["easy_correct"], df["correct_gap_85"], s=5)
    axs[3].scatter(df["easy_correct"], df["correct_gap_95"], s=5)

    for ax in axs:
        ax.tick_params(axis="both", which="both", labelsize=11)

    if i == len(axss) - 1:
        axs[0].set_xlabel("Relative Layer Position", fontsize=12)
        axs[2].set_xlabel("Relative Layer Position", fontsize=12)

        axs[1].set_xlabel("Proportion Correct (Easy)", fontsize=12)
        axs[3].set_xlabel("Proportion Correct (Easy)", fontsize=12)

    if i != len(axss) - 1:
        for ax in axs:
            ax.set_xticklabels([])

    for j in (1, 3):
        axs[j].set_xlim(0.43, 1.05)
        axs[j].spines["bottom"].set_bounds(0.48, 1)
        #ax.set_yticklabels([])

    axs[0].set_ylabel("$\quad$")
    axs[2].set_ylabel("$\quad$")

    for j in (0, 2):
        axs[j].set_xlim(-0.05, 1.05)
        axs[j].spines["bottom"].set_bounds(0.0, 1)

    #ax.set_title(label)
    for ax in axs:
        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)

    fig.text(-0.0085, 1.0 - (i+0.5)/len(axss) + (0.1 if i == 1 else 0), mk, va='center', rotation='vertical', size=12)

fig.text(0.015, 0.55, 'Proportion Correct Gap (Easy → Medium)', va='center', rotation='vertical', size=12)
fig.text(0.52, 0.55, 'Proportion Correct Gap (Easy → Hard)', va='center', rotation='vertical', size=12)

plt.tight_layout()
plt.savefig(os.path.join(results_dir, "model_comparison_units_rn_easy_vs_hard_gaps.pdf"), bbox_inches="tight")
#for k in joined_easy_hard_dfs:
#    print(k)
#    display(joined_easy_hard_dfs[k].sort_values("correct_gap")[["correct_gap"]])

### Confidence

In [None]:
for mode in ("natural", "synthetic"):
    relevant_experiments = [k for k in experiments if mode in k.lower() and "hard" not in k.lower() and len(dfs_results_main_passed[k]) > 0]

    dfs = dfs_results_main_passed

    means_conditioned_on_confidence = {}
    confidences_conditioned_on_confidence = {}
    for k in relevant_experiments:
        for c in (1, 3):
            confidences_conditioned_on_confidence[f"{k}_{c}"] = scipy.stats.bootstrap(
                dfs[k][dfs[k]["confidence"] == c].groupby(["layer", "channel"]).mean(numeric_only=True)["correct"].values.reshape(
                (1, -1)), statistic=np.mean, n_resamples=10_000) if experiments[k] is not None else dummy_test_results

            means_conditioned_on_confidence[f"{k}_{c}"] = dfs[k][dfs[k]["confidence"] == c].groupby(["layer", "channel"]).mean(numeric_only=True).mean(numeric_only=True)
            
    plot_performance(means_conditioned_on_confidence, confidences_conditioned_on_confidence, list(means_conditioned_on_confidence.keys()),
                     ("Low Confidence", "High Confidence"), (f"{mode.capitalize()} Low", f"{mode.capitalize()} High"), bar_width=0.6, y_max=0.925,
                     legend_cols=1)
    plt.xlabel("Model (Ordered by Scale)")
    plt.savefig(os.path.join(results_dir, f"model_comparison_{mode}_confidence.pdf"), bbox_inches="tight")

In [None]:
def plot_stacked_confidence(means, relevant_experiments, labels=("Low", "Medium", "High", "Low", "Medium", "High"),
                            legend_labels=("Confidence (Natural)", "Confidence (Synthetic)", "Low", "Low", "Medium", "Medium", "High", "High"),
                            color_names=("natural easy", "natural medium", "natural hard", "synthetic easy", "synthetic medium", "synthetic hard"), legend: bool = True,
                            legend_color_names=(None, None, "natural easy", "synthetic easy", "natural medium", "synthetic medium", "natural hard", "synthetic hard"),

                     xticks: bool = True, yticks: bool = True, xlabel="Model", ylabel="Relative Number of Responses",
                     bar_width: float = 1, grid: bool = True, rotate_ticks: bool = False, y_max: float = 0.9, legend_cols: int = 2):
    # x = np.arange(len(relevant_experiments) // len(labels) * 2, dtype=float)

    relevant_means = [means[k]["count"] for k in relevant_experiments]
    missing_indices = [i for i, it in enumerate(relevant_means) if np.isnan(it)]
    missing_indices = sorted(missing_indices, reverse=True)

    x = []
    color_values = []
    idx_correction = 0
    for i in range(0, len(relevant_experiments) // 3):
        # for gaps between pairs of pairs of bars
        if i * 3 in missing_indices:
            idx_correction -= 1
        else:
            x += [i + idx_correction]
            for j in range(3):
                color_values += [colors[color_names[3 * (i % 2) + j].lower()]]
        if i % 2 == 1:
            idx_correction += 0.5
        
    x = np.array(x, dtype=float)
    x *= bar_width

    fig, ax = plt.subplots(1, 1, figsize=(1 + bar_width*0.6*(len(relevant_means) // 3), 4.8))

    #x = np.arange(len(bottom))
    bottom = None
    for ci in range(3):
        y = np.array([means[k]["count"] for k in relevant_experiments])
        y = np.array([y[i] for i in range(len(y)) if not i in missing_indices])

        y = y[ci::3]

        if bottom is None:
            bottom = np.zeros_like(y)

        ax.bar(
            x,
            y,
            color=color_values[ci::3],
            width=bar_width,
            bottom=bottom
            )
        bottom += y

    if xticks:
        #single_bar_offset = bar_width * np.array([0.5 if np.isnan(means[k]["correct"]) else 0 for k in relevant_experiments]).reshape(-1, len(labels)).sum(-1)
        x_ticks = []
        x_tick_labels = []
        for i in range(0, len(relevant_experiments) // 3):
            if i * 3 not in missing_indices:
                x_tick_labels += [relevant_experiments[3 * i].split("_1")[0].split("(")[0].replace(" ", "\n")]

        x_ticks = [v for v in x]
        for j in range(len(x_tick_labels) - 1, 0, -1):
            if x_tick_labels[j] == x_tick_labels[j - 1]:
                del x_tick_labels[j]
                x_ticks[j - 1] = (x_ticks[j - 1] + x_ticks[j]) / 2
                del x_ticks[j]

        print(len(x_tick_labels))
        #x_ticks = (x[:-1:2] + x[2 - 1::2]) / 2# - single_bar_offset
        ax.set_xticks(x_ticks, x_tick_labels, rotation=45 if rotate_ticks else 0)
    else:
        ax.set_xticks([])

    if not yticks:
        ax.set_yticks([])

    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel, labelpad=-10)
    ax.set_ylim((0, y_max))
    
    x_max = max(x)

    if legend:
        hdls = [ax.scatter([], [], color=colors[k.lower()] if k is not None else "white", alpha=0.0 if k is None else 1.0) for k in legend_color_names]
        ax.legend(hdls, legend_labels, frameon=False, ncol=legend_cols, handletextpad=0.4, columnspacing=1)

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(-1.0 * bar_width, x_max + 0.75 * bar_width)
    ax.spines["bottom"].set_bounds((-0.75 * bar_width, x_max + 0.75 * bar_width))

    if grid:
        ax.set_axisbelow(True)
        ax.grid(linestyle="dashed", axis="y")

    plt.tight_layout()


confidence_means = {}
relevant_experiments = [k for k in experiments if "hard" not in k.lower()]
def tmp_f(df):
    df["count"] = len(df)
    return df.mean(numeric_only=True)
for re in relevant_experiments:
    if use_flawed_data:
        tmp = dfs_results_main_mp_or_rejected[re].groupby(["confidence"], group_keys=False).apply(tmp_f)
    else:
        tmp = dfs_results_main_passed[re].groupby(["confidence"], group_keys=False).apply(tmp_f)
    if len(tmp) > 0:
        tmp["count"] /= tmp["count"].sum()
    for c in (1, 2, 3):
        if len(tmp) == 0:
            confidence_means[f"{re}_{c}"] = tmp.copy().mean()
            confidence_means[f"{re}_{c}"]["count"] = np.nan
        else:
            confidence_means[f"{re}_{c}"] = tmp.loc[c]
plot_stacked_confidence(confidence_means, list(confidence_means.keys()), bar_width=0.6, legend_cols=4, y_max=1.15)
plt.xlabel("Model (Ordered by Scale)")
plt.savefig(os.path.join(results_dir, f"model_comparison_confidence_distribution.pdf"), bbox_inches="tight")

### Correlation of performance with model properties

In [None]:
relevant_experiments = [k for k in experiments if "hard" not in k.lower()]

accuracies = {
    "ResNet": 0.7613,
    "Robust ResNet": 0.6238,
    "GoogLeNet": 0.6915,
    "Clip ResNet": 0.7430,  # The first number is linear probing, 0-shot equals: 0.5983
    "WideResNet": 0.8160,  # Using 232 instead of 256 resize before central crop
    "DenseNet": 0.7689,
    "ConvNeXT": 0.838,
    "Clip ViT": 0.666,
    "ViT": 0.74904
}

mean_ranks = {
    "ResNet": 7 + 1/3,
    "Robust ResNet": 5 + 1/3,
    "GoogLeNet": 9,
    "Clip ResNet": 3,  # The first number is linear probing, 0-shot equals: 0.5983
    "WideResNet": 5 + 2/3,  # Using 232 instead of 256 resize before central crop
    "DenseNet": 4 + 2/3,
    "ConvNeXT": 4,
    "Clip ViT": 1,
    "ViT": 5
}

data = []
for k in relevant_experiments:
    if len(dfs_results_main_passed[k]) == 0:
        continue
    model = k.replace(" (natural)", "").replace(" (synthetic)", "")
    mode = "natural" if "natural" in k else "synthetic"
    np.array([
    (max(0, -confidences[k].confidence_interval[0] + means[k]["correct"]),
     max(0, confidences[k].confidence_interval[1] - means[k]["correct"]))
     for k in experiments]).T
    data.append({
        "model": model,
        "mean_proportion_correct": means[k]["correct"],
        "confidence_interval_proportion_correct": (
        max(0, -confidences[k].confidence_interval[0] + means[k]["correct"]),
        max(0, confidences[k].confidence_interval[1] - means[k]["correct"])
        ),
        "accuracy": accuracies[model],
        "mean_rank": mean_ranks[model],
        "mode": mode
        })
data = pd.DataFrame(data).set_index("model")

#### With accuracy

In [None]:
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(4, 3)
for mode in ("natural", "synthetic"):
    ax.errorbar(data[data["mode"] == mode]["accuracy"], data[data["mode"] == mode]["mean_proportion_correct"],
                yerr=np.stack(data[data["mode"] == mode]["confidence_interval_proportion_correct"], 1), linestyle="None", capsize=5, color=colors[mode])
    ax.scatter(data[data["mode"] == mode]["accuracy"], data[data["mode"] == mode]["mean_proportion_correct"], color=colors[mode])

ax.set_ylim((0.55, 0.95))
ax.set_xlim(0.59, 0.85)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_bounds((0.6, 0.85))
ax.spines["left"].set_bounds((0.55, 0.95))

ax.set_axisbelow(True)
ax.grid(linestyle="dashed", axis="y")

ax.set_ylabel("Proportion Correct")
ax.set_xlabel("ImageNet Validation Top-1 Accuracy")

ax.legend((plt.scatter([],[], c=colors["natural"]), plt.scatter([],[], c=colors["synthetic"])), ("Natural", "Synthetic"), frameon=False, ncol=2, bbox_to_anchor=(0.5, 0.935), loc="center")

plt.tight_layout()

plt.savefig(os.path.join(results_dir, "model_accuracy_vs_performance.pdf"), bbox_inches="tight")

#### With human-likeness

In [None]:
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(4, 3)
for mode in ("natural", "synthetic"):
    ax.errorbar(data[data["mode"] == mode]["mean_rank"], data[data["mode"] == mode]["mean_proportion_correct"],
                yerr=np.stack(data[data["mode"] == mode]["confidence_interval_proportion_correct"], 1), linestyle="None", capsize=5, color=colors[mode])
    ax.scatter(data[data["mode"] == mode]["mean_rank"], data[data["mode"] == mode]["mean_proportion_correct"], color=colors[mode])

ax.set_ylim((0.55, 0.95))
ax.set_xlim(0.5, 9.5)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_bounds((1, 9))
ax.spines["left"].set_bounds((0.55, 0.95))

ax.set_axisbelow(True)
ax.grid(linestyle="dashed", axis="y")

ax.set_ylabel("Proportion Correct")
ax.set_xlabel("Inverse Human-Likeness (Mean Model Rank)")

ax.set_xticks([1, 3, 5, 7, 9])

ax.legend((plt.scatter([],[], c=colors["natural"]), plt.scatter([],[], c=colors["synthetic"])), ("Natural", "Synthetic"), frameon=False, ncol=2, bbox_to_anchor=(0.5, 0.935), loc="center")

plt.tight_layout()

plt.savefig(os.path.join(results_dir, "model_human_likeness_vs_performance.pdf"), bbox_inches="tight")

## Significance Tests

In [None]:
relevant_experiments = [k for k in experiments if "hard" not in k.lower()]

for mode in ("natural", "synthetic"):
    if use_flawed_data:
        relevant_dfs_results_main_passed = [dfs_results_main_mp_or_rejected[k] for k in relevant_experiments if mode in k]
    else:
        relevant_dfs_results_main_passed = [dfs_results_main_passed[k] for k in relevant_experiments if mode in k]
    relevant_dfs_results_main_passed = [it for it in relevant_dfs_results_main_passed if len(it) > 0]
    print("Mode:", mode, "Number of trials:", sum(len(it) for it in relevant_dfs_results_main_passed))
    utils_analysis.run_kruskal_wallis(relevant_dfs_results_main_passed)

In [None]:
def safe_mean(df):
    """Compute a mean over a (grouped) dataset while gracefully treating non-numeric columns."""
    new_df = {}
    for col in df.columns:
        if df[col].dtype == object:
            try:
                unique_values = df[col].unique()

                if len(unique_values) == 1:
                    value = unique_values[0]
                else:
                    value = np.nan
            except:
                value = np.nan
        else:
            value = df[col].mean()
        new_df[col] = value
        
    return pd.Series(new_df)

In [None]:
def merge_dfs(dfs):
    tmp_dfs = []
    for k in dfs:
        if len(dfs[k]) == 0:
            continue

        tmp_df = dfs[k].groupby(["layer", "channel"]).apply(lambda x: safe_mean(x))
        tmp_df["model_condition"] = k
        tmp_df["unit"] = tmp_df["layer"] + ":" + tmp_df["channel"]
        tmp_dfs.append(tmp_df)
    return pd.concat(tmp_dfs)

df_results_main_passed_units_merged = merge_dfs({k: dfs_results_main_passed[k] for k in experiments if "hard" not in k.lower()})
df_results_main_mp_or_rejected_units_merged = merge_dfs({k: dfs_results_main_mp_or_rejected[k] for k in experiments if "hard" not in k.lower()})

In [None]:
model_names = {
    "ResNet": "resnet50",
    "Rob. ResNet": "resnet50-l2",
    "GoogLeNet": "googlenet",
    "Clip ResNet": "clip-resnet50",
    "WideResNet": "wide_resnet50",
    "DenseNet": "densenet_201",
    "ConvNeXT": "convnext_b",
    "Clip ViT": "clip-vit_b32",
    "ViT": "in1k-vit_b32",
}

# Format: diagonal, non-significant, p<0.001, p<0.01, p<0.05
cmap = ['1', '#fb6a4a',  '#08306b',  '#4292c6', '#c6dbef']

for mode in ("natural", "optimized"):
    print("Mode:", mode)
    pc = scikit_posthocs.posthoc_conover(
        df_results_main_mp_or_rejected_units_merged[df_results_main_mp_or_rejected_units_merged["mode"] == mode] if use_flawed_data else df_results_main_passed_units_merged[df_results_main_passed_units_merged["mode"] == mode],
        val_col="correct",
        group_col="model",
        p_adjust="holm")
    
    pc = pc.rename(index={model_names[k]: k for k in model_names})
    pc = pc.rename(columns={model_names[k]: k for k in model_names})

    display(pc)
    hax,_ = scikit_posthocs.sign_plot(pc, cmap=cmap, linewidth=1, square=True, cbar_ax_bbox=[-0.3, -0.25, 0.035, 0.25])
    #hax.set_xticklabels(hax.get_yticklabels(), rotation = 0)
    #hax.set_yticklabels(hax.get_xticklabels(), rotation = 0)
    hax.figure.set_size_inches(2, 2)
    plt.savefig(os.path.join(results_dir, f"model_comparison_significance_{mode}.pdf"), bbox_inches="tight")
    plt.show()

## Exclusion Criteria

In [None]:
keys = ("catch_trials_result", "row_variability_result",
        "total_response_time_result", "instruction_time_result",
        "demo_trials_result")
fig, axs = plt.subplots(1, len(keys), figsize=(1.75*len(keys), 2.4), sharey=True)
axs = axs.flatten()
df_checks = pd.concat(dfs_checks, axis=0).reset_index(drop=True)
for k, ax in zip(keys, axs):
    df_checks[k].value_counts().plot(kind="bar", ax=ax)

    value_name = k.replace("_extracted", "")
    value_name = value_name.replace("_result", "").replace("demo", "practice")
    value_name = " ".join([w.capitalize() for w in value_name.split("_")])

    ax.set_title(value_name)

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

axs[2].set_xlabel("Passed Exclusion Criteria")
axs[0].set_ylabel("Count")


del df_checks
plt.tight_layout()
plt.savefig(os.path.join(results_dir, "exclusion_criteria_decision_distribution.pdf"), bbox_inches="tight")

In [None]:
keys = ('instruction_time_details_extracted',
       'total_response_time_details_extracted',
       'row_variability_details_details_upper_extracted',
       'row_variability_details_details_lower_extracted',
       'catch_trials_details_ratio_extracted',
       'demo_trials_details_extracted')
thresholds = (
    (15, None),
    (135, 2500),
    (5, 40),
    (5, 40),
    (0.8, None),
    (None, 3)
)
df_checks = pd.concat(dfs_checks, axis=0).reset_index(drop=True)
fig, axs = plt.subplots(int(np.ceil(len(keys) / 3)), 3, figsize=(8, 5))
axs = axs.flatten()
for ax in axs:
    ax.axis("off")
for k, ax, ths in zip(keys, axs, thresholds):
    ax.axis("on")
    ax.hist(df_checks[k], bins=20)
    value_name = k.replace("_extracted", "")
    value_name = value_name.replace("_details", "").replace("demo", "practice")
    value_name = " ".join([w.capitalize() for w in value_name.split("_")])
    ax.set_xlabel(value_name)

    if ths[0] is not None:
        ax.vlines(ths[0], 0, ax.get_ylim()[1], color="red", linestyle="dashed")
    if ths[1] is not None:
        ax.vlines(ths[1], 0, ax.get_ylim()[1], color="black", linestyle="dashed")

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

axs[0].legend((axs[0].plot([], [], c="red", ls="dashed")[0], axs[0].plot([], [], c="black", ls="dashed")[0]), ("Min. Required", "Max. Allowed"), frameon=False, ncol=1)

axs[0].set_ylabel("Count")
axs[3].set_ylabel("Count")
del df_checks
plt.tight_layout()
plt.savefig(os.path.join(results_dir, "exclusion_criteria_value_distribution.pdf"), bbox_inches="tight")

In [None]:
for k in dfs_checks:
    if len(dfs_checks[k]) == 0:
        continue
    print(k, dfs_checks[k]["catch_trials_details_correctly_answered_extracted"].map(lambda x: sum(x) / len(x)).unique())

# Export Data

In [None]:
for k in experiments:
    assert k in dfs_results, f"Missing results for {k}"

joined_df = pd.concat(dfs_results.values(), axis=0).reset_index(drop=True)

print("#Responses:", len(joined_df))
print("#Unique Participants:", len(joined_df["worker_id"].unique()))

unique_participant_id_map = {k: f"P{i + 1:04}" for i, k in enumerate(joined_df["participant_id"].unique())}

joined_df["min_query"] = "anonymized"
joined_df["max_query"] = "anonymized"
joined_df["min_references"] = "anonymized"
joined_df["max_references"] = "anonymized"
joined_df["participant_id"] = joined_df["participant_id"].map(lambda pid: unique_participant_id_map[pid])
del joined_df["worker_id"]
del joined_df["result_file_name"]
del joined_df["result_creation_time"]
del joined_df["query_path"]

# Transform int columns containing NaNs to Int64Dtype to properly represent
# NaNs after serialization
for c in ["batch", "channel"]:
    vs = joined_df[c].copy()
    vs = vs.astype(np.float32)
    vs[vs == np.nan] = pd.NA
    joined_df[c] = vs.astype(pd.Int64Dtype())

joined_df_passed_and_no_mp = joined_df[~joined_df["multi_participation"] & joined_df["passed_checks"]].reset_index(drop=True)
joined_df_rejected_or_mp = joined_df[joined_df["multi_participation"] | ~joined_df["passed_checks"]].reset_index(drop=True)

joined_df.to_csv("results/responses_all.csv", index=False)
joined_df.to_pickle("results/responses_all.pd.pkl")

joined_df_passed_and_no_mp.to_csv("results/responses_main.csv", index=False)
joined_df_passed_and_no_mp.to_pickle("results/responses_main.pd.pkl")

joined_df_rejected_or_mp.to_csv("results/responses_lower_quality.csv", index=False)
joined_df_rejected_or_mp.to_pickle("results/responses_lower_quality.pd.pkl")

print("Responses:", len(joined_df))
print("Responses (Passed Checks and No Multi-Participation):", len(joined_df_passed_and_no_mp))
print("Responses (Rejected or Multi-Participation):", len(joined_df_rejected_or_mp))

assert len(joined_df) == len(joined_df_passed_and_no_mp) + len(joined_df_rejected_or_mp)