# L-WISE Data Analysis and Figure Generation

This notebook generates (almost) all figure panels and performs all of the statistical analyses for the L-WISE paper. One exception to this is the generation of heatmaps for perturbations to example images, which is performed by imgproc_code/notebooks/visualize_heatmaps.ipynb.

A de-identified version of all experimental data is being released for this project. The variable DEIDENTIFIED_DATA is "True" when using the de-identified dataset for the analysis. Setting it to "False" allows authors who have access to the raw source data to generate demographic tables. 

In [None]:
DEIDENTIFIED_DATA = True

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import scipy.stats as stats
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import os

# Setting fonttype to the meaning of life makes text editable in exported pdfs
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

def get_df_from_xarray(data_paths, drop_columns=None):
  start_pt_idx = 0
  dfs = []
  for data_path in data_paths: 
    assert os.path.isfile(data_path), f"File {data_path} does not exist"
    ds = xr.open_dataset(data_path)
    raw_df = ds.to_dataframe().reset_index()
    df = raw_df[(raw_df['choice_slot'] == raw_df['i_choice']) | ((raw_df['i_choice'].isna()) & (raw_df['choice_slot'] == 0))]
    df["participant"] = start_pt_idx + df["participant"]
    start_pt_idx = start_pt_idx + df["participant"].nunique()
    dfs.append(df)

  combined_df = pd.concat(dfs, axis=0).reset_index(drop=True)

  if drop_columns is not None:
      combined_df = combined_df.drop(drop_columns, axis=1, errors='ignore')

  # Sort dataframe such that trials for each participant appear in order
  combined_df = combined_df.sort_values(by=['participant', 'obs'])

  # Sort columns in a logical order
  ordered_cols = ['participant', 'condition_idx', 'block', 'obs', 'trial_type', 'class', 'stimulus_image_url', 'stimulus_name', 'choice_name', 'i_correct_choice', 'i_choice', 'perf', 'reaction_time_msec', 'rel_timestamp_response', 'timestamp_start', 'monitor_width_px', 'monitor_height_px', 'stimulus_width_px', 'choice_width_px', 'stimulus_duration_msec', 'post_stimulus_delay_duration_msec', 'pre_choice_lockout_delay_duration_msec']
  other_cols = [col for col in combined_df.columns if col not in ordered_cols]
  combined_df = combined_df[ordered_cols + other_cols]

  return combined_df


# Function to calculate bootstrap confidence intervals
def bootstrap_ci(data, num_bootstrap_samples=10000, confidence_level=0.95):
    bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                for _ in range(num_bootstrap_samples)])
    return np.percentile(bootstrap_means, [(1 - confidence_level) / 2 * 100, (1 + confidence_level) / 2 * 100])


def perform_chi_square(df, condition, control=0):
    # Filter the dataframe for only the control and the specific condition
    df_filtered = df[df['condition_idx'].isin([control, condition])]
    
    # Create a contingency table for the specific condition and control
    contingency_table = pd.crosstab(df_filtered['condition_idx'], df_filtered['perf'])
    
    # Perform the Chi-square test
    chi2, p, dof, expected = stats.chi2_contingency(contingency_table)
    
    return chi2, p, dof, expected, contingency_table


def chi_square_comparisons(df_trials_test, condition_idx_ordering, condition_labels, control_condition_idx=0):

    control_group_name = condition_labels[condition_idx_ordering.index(control_condition_idx)]

    # Perform chi-square tests for each condition compared to control
    for label_idx, condition_idx in enumerate(condition_idx_ordering):
        if condition_idx != control_condition_idx:
            chi2, p, dof, expected, contingency_table = perform_chi_square(df_trials_test, condition_idx, control=control_condition_idx)
            
            print(f"\nComparing {condition_labels[label_idx]} to {control_group_name}:")
            print(f"Chi-square value: {chi2:.4f}")
            print(f"P-value: {p:.4f}")
            print(f"Degrees of freedom: {dof}")
            
            # Interpret the results
            if p < 0.05:
                print(f"There is a significant difference in performance between {condition_labels[label_idx]} and {control_group_name}.")
            else:
                print(f"There is no significant difference in performance between {condition_labels[label_idx]} and {control_group_name}.")
            
            # Display the contingency table
            print("\nContingency Table:")
            print(contingency_table)
            
            # Display expected frequencies
            print("\nExpected frequencies:")
            print(pd.DataFrame(expected, index=contingency_table.index, columns=contingency_table.columns))
            
            print("\n" + "="*50)


def print_main_stats(df_trials, condition_idx_ordering, condition_labels, chance_level=0.25, test_blocks=[8,9]):
  condition_accuracies = []
  condition_cis = []

  for cond in condition_idx_ordering:
      cond_df = df_trials[(df_trials["condition_idx"] == cond) & (df_trials["block"].isin(test_blocks))]
      mean_accuracy = cond_df['perf'].mean()
      ci = bootstrap_ci(cond_df['perf'])
      condition_accuracies.append(mean_accuracy)
      condition_cis.append(ci)

  # Create a DataFrame for plotting
  condition_acc_data = pd.DataFrame({
      'Condition': condition_labels,
      'Accuracy': condition_accuracies,
      'CI_lower': [ci[0] for ci in condition_cis],
      'CI_upper': [ci[1] for ci in condition_cis]
  })

  # Calculate the error bars
  condition_acc_data['yerr_lower'] = condition_acc_data['Accuracy'] - condition_acc_data['CI_lower']
  condition_acc_data['yerr_upper'] = condition_acc_data['CI_upper'] - condition_acc_data['Accuracy']

  print("Condition accuracies (mean with 95% CIs):")
  for condition, accuracy, ci in zip(condition_acc_data['Condition'], condition_acc_data['Accuracy'], condition_cis):
      print(f"{condition}: {accuracy:.2f} ({ci[0]:.2f}, {ci[1]:.2f})")

  # Calculate control accuracy (assumed to be the first condition)
  control_accuracy = condition_acc_data['Accuracy'].iloc[0]

  # Calculate and print percentage increase in margin above chance
  print("\nPercentage increase in margin above chance:")
  for condition, accuracy in zip(condition_acc_data['Condition'][1:], condition_acc_data['Accuracy'][1:]):  # Skip the first (control) condition
      margin_control = control_accuracy - chance_level
      margin_condition = accuracy - chance_level
      
      percentage_increase = ((margin_condition - margin_control) / margin_control) * 100

      print("Condition acc:", accuracy, "| Control acc:", control_accuracy)
      print("Margin condition:", margin_condition, "| Margin control:", margin_control)
      print(f"{condition}: {percentage_increase:.1f}%")


  # Calculate training time for each participant
  df_trials_train = df_trials[~df_trials["block"].isin(test_blocks)]
  training_times = df_trials_train.groupby('participant')['rel_timestamp_response'].max().reset_index()
  training_times = training_times.merge(df_trials_train[['participant', 'condition_idx']], on='participant', how='left')

  # Function to calculate mean with bootstrap CI
  def mean_with_ci(data, num_bootstrap_samples=10000, ci=0.95):
      bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                  for _ in range(num_bootstrap_samples)])
      mean = np.mean(data)
      ci_lower, ci_upper = np.percentile(bootstrap_means, [(1-ci)/2 * 100, (1+ci)/2 * 100])
      return mean, ci_lower, ci_upper

  # Calculate mean and CI for each condition
  training_time_results = []
  for c_idx, condition in enumerate(condition_idx_ordering):
      if condition == 0 or condition:
        condition_data = training_times[training_times['condition_idx'] == condition]['rel_timestamp_response']
        mean, ci_lower, ci_upper = mean_with_ci(condition_data)
        training_time_results.append({
            'condition': condition_labels[c_idx],
            'mean_training_time': round(mean/(1000*60), 4),
            'ci_lower':  round(ci_lower/(1000*60), 4),
            'ci_upper':  round(ci_upper/(1000*60), 4),
        })

  # Create a DataFrame with the results
  training_time_results_df = pd.DataFrame(training_time_results)
  print("Training times (minutes):")
  print(training_time_results_df)


  # Calculate completion time for each participant
  completion_times = df_trials.groupby('participant')['rel_timestamp_response'].max().reset_index()
  completion_times = completion_times.merge(df_trials[['participant', 'condition_idx']], on='participant', how='left')

  # Calculate mean and CI for each condition
  completion_time_results = []
  for c_idx, condition in enumerate(condition_idx_ordering):
      if condition == 0 or condition:
        condition_data = completion_times[completion_times['condition_idx'] == condition]['rel_timestamp_response']
        mean, ci_lower, ci_upper = mean_with_ci(condition_data)
        completion_time_results.append({
            'condition': condition_labels[c_idx],
            'mean_completion_time': round(mean/(1000*60), 4),
            'ci_lower':  round(ci_lower/(1000*60), 4),
            'ci_upper':  round(ci_upper/(1000*60), 4),
        })

  # Create a DataFrame with the results
  completion_time_results_df = pd.DataFrame(completion_time_results)
  print("Completion times (minutes):")
  print(completion_time_results_df)

  return condition_acc_data, training_time_results_df, completion_time_results_df


def assert_constant_counts(df):
    # Print unique trial types for debugging
    print("Unique trial types in the dataset:", df['trial_type'].unique())
    
    # Group by trialset_id and get value counts for trial_type
    counts = df.groupby(['experiment_id', 'trialset_id'])['trial_type'].value_counts().unstack(fill_value=0)
    
    # Get all unique trial types in the dataset
    all_trial_types = ['calibration', 'repeat_stimulus']

    for trial_type in all_trial_types:
        if trial_type in counts.columns:
            count_unique = counts[trial_type].nunique()
            if count_unique != 1:
                print(f"\nWarning: Count of {trial_type} is not constant across all trialset_ids")
                print(f"Unique counts for {trial_type}: {counts[trial_type].unique()}")
            else:
                print(f"\nCount of {trial_type} is constant ({counts[trial_type].iloc[0]}) across all trialset_ids")
        else:
            print(f"\nWarning: Trial type '{trial_type}' is not present in counts DataFrame")
            print("This might indicate an issue with data processing")

    print("\nAssertion check completed.")


def reassign_blocks(df, verbose=True):
    """
    Reassigns block values for specific participants in the dataframe based on observation numbers.
    Only affects participants who have trials with 'shuffle' in their trial_type.
    Includes verification of block sizes.
    
    Parameters:
    df (pd.DataFrame): Input dataframe containing columns 'participant', 'trial_type', 'obs', and 'block'
    
    Returns:
    pd.DataFrame: DataFrame with updated block values
    """
    # Create a copy of the dataframe to avoid modifying the original
    df_copy = df.copy()
    
    # Find participants who have 'shuffle' in any of their trial_type values
    shuffle_participants = df_copy[df_copy['trial_type'].str.contains('shuffle', na=False)]['participant'].unique()
    
    # Define the block structure
    block_structure = {
        0: 18,    # Block 0 has 18 trials
        **{i: 19 for i in range(1, 8)},    # Blocks 1-7 have 19 trials each
        **{i: 25 for i in range(8, 10)}    # Blocks 8-9 have 25 trials each
    }

    print(block_structure)
    
    # Calculate cumulative trial counts for block boundaries
    cumulative_trials = [sum(block_structure[i] for i in range(k)) for k in range(len(block_structure) + 1)]
    
    # Function to assign block based on observation number
    def get_block(obs):
        for block_num, trial_boundary in enumerate(cumulative_trials[1:]):
            if obs < trial_boundary:
                return block_num
        return len(block_structure) - 1  # Return last block number if beyond all boundaries
    
    # Process each participant who needs block reassignment
    for participant in shuffle_participants:
        # Get participant's data
        mask = df_copy['participant'] == participant
        participant_data = df_copy[mask].copy()
        
        # Sort by observation number
        participant_data = participant_data.sort_values('obs')
        
        # Assign new block values based on observation position
        df_copy.loc[mask, 'block'] = participant_data['obs'].apply(get_block)
    
    # Verify block sizes for each participant
    print("\nVerifying block sizes for each participant:")
    if verbose:
        print("------------------------------------------")
    
    all_correct = True
    for participant in shuffle_participants:
        participant_data = df_copy[df_copy['participant'] == participant]
        
        for block, expected_trials in block_structure.items():
            block_trials = len(participant_data[participant_data['block'] == block])
            
            if block_trials != expected_trials:
                print(f"WARNING: Participant {participant} has {block_trials} trials in block {block} (expected {expected_trials})")
                all_correct = False
            elif verbose:
                print(f"Participant {participant} has correct number of trials ({expected_trials}) in block {block}")
    
    if all_correct:
        print("\nVERIFICATION PASSED: All participants have the correct number of trials in each block!")
    else:
        print("\nVERIFICATION FAILED: Some participants have incorrect numbers of trials in certain blocks.")
        
    # Additional summary across all affected participants
    if verbose:
        print("\nSummary across all participants with shuffled trials:")
        print("--------------------------------------------------")
    for block, expected_trials in block_structure.items():
        total_trials = sum(len(df_copy[(df_copy['participant'] == p) & (df_copy['block'] == block)]) 
                          for p in shuffle_participants)
        num_participants = len(shuffle_participants)
        if total_trials == expected_trials * num_participants:
            if verbose:
                print(f"✓ Block {block}: All participants have exactly {expected_trials} trials")
        else:
            print(f"✗ Block {block}: Expected {expected_trials} trials per participant, "
                  f"found {total_trials/num_participants:.1f} on average")
    
    return df_copy

In [None]:
# Set the current working directory to the parent directory (which contains the "notebooks" directory among others)
changed_dir = False
if not changed_dir and os.path.exists("./make_figs.ipynb"):
  os.chdir(os.path.dirname(os.getcwd()))
  changed_dir = True
assert os.path.exists("./notebooks/make_figs.ipynb"), "Make sure your working directory starts in 'notebooks'"

os.makedirs("notebooks/fig_outputs", exist_ok=True)

In [None]:
drop_columns = ["stimulus_image_url_l", "stimulus_image_url_r", "class_l", "class_r", "mask_duration_msec", "mask_image_url", "choice_slot", "choice_image_urls", "keep_stimulus_on", "query_string", "platform", "bonus_usd_if_correct"]
if DEIDENTIFIED_DATA:
  drop_columns.extend(["assignment_id", "worker_id"])

## "idaea4" moth learning task (from iNaturalist)

In [None]:
# LOAD IDAEA4 DATASET

control_cond = "natural"
trial_type_names_idaea4 = [control_cond, "curriculum_sampling", "curriculum_sampling_shuffle", "enhancement_taper", "enhancement_taper_shuffle", "enhancement_taper_curriculum_sampling"]

if os.path.isfile("psych_data/df_idaea4.csv") and DEIDENTIFIED_DATA:
  print("Reading idaea4 dataset from saved .csv")
  df_idaea4 = pd.read_csv("psych_data/df_idaea4.csv")
else: # Load from .h5
  print("Reading idaea4 dataset from .h5 file")
  
  df_1 = get_df_from_xarray(["./results/idaea4_learn_1_PARTIAL/idaea4_learn_1_combined_dataset.h5"], drop_columns=drop_columns)
  df_2 = get_df_from_xarray(["./results/idaea4_learn_2/idaea4_learn_2_combined_dataset.h5"], drop_columns=drop_columns)

  df_1["experiment_id"] = "idaea4_learn_1"
  df_2["experiment_id"] = "idaea4_learn_2"

  df_2["participant"] = df_2["participant"] + df_1["participant"].max() + 1

  df_idaea4 = pd.concat([df_1, df_2])

  condition_idx_remap = {
    0: 0,
    1: 0, 
    2: 1,
    3: 2, 
    4: 3, 
    5: 4, 
    6: 5,
  }
  df_idaea4["condition_idx"] = df_idaea4["condition_idx"].map(condition_idx_remap).fillna(df_idaea4["condition_idx"])

  df_idaea4['perf'] = df_idaea4['perf'].fillna(0)

  df_idaea4 = df_idaea4[df_idaea4["trialset_id"]>0] 

  df_idaea4 = df_idaea4.sort_values(by=["participant", "rel_timestamp_response"])

  df_idaea4 = df_idaea4.reset_index(drop=True)

  df_idaea4 = reassign_blocks(df_idaea4, verbose=False)

  if DEIDENTIFIED_DATA:
    print("Saving de-identified version of the dataset")
    df_idaea4.to_csv("psych_data/df_idaea4.csv", index=False)

In [None]:
## FILTER OUT GUESSING PARTICIPANTS IN IDAEA4

df_idaea4_trials = df_idaea4[df_idaea4["trial_type"].isin(trial_type_names_idaea4)]

df_idaea4_calib = df_idaea4[df_idaea4["stimulus_name"].isin(["circle", "triangle"])]

calib_means = df_idaea4_calib.groupby("participant")["perf"].mean()

# Filter out participants with a mean calibration 'perf' of less than 0.9
participants_calib_above09 = calib_means[calib_means >= 0.9].index

print(f"All participants: {df_idaea4['participant'].nunique()}")

# Filter the DataFrame for these participants
df_idaea4 = df_idaea4[df_idaea4['participant'].isin(participants_calib_above09)]

print(f"Participants with calib acc of 0.9 and above: {df_idaea4['participant'].nunique()}")

df_idaea4_trials = df_idaea4_trials[df_idaea4_trials['participant'].isin(participants_calib_above09)]

df_idaea4_trials_test = df_idaea4_trials[df_idaea4_trials["block"].isin([8, 9])]

### Moth task learning performance statistics (Table 1)

In [None]:
condition_idx_ordering_idaea4 = [0, 3, 4, 1, 2, 5]
condition_labels_idaea4 = ["Control", "Enhance", "Enhance (shuffle)", "Select", "Select (shuffle)", "L-WISE"]

idaea4_accuracy_df, idaea4_training_time_df, idaea4_completion_time_df = print_main_stats(df_idaea4_trials, condition_idx_ordering_idaea4, condition_labels_idaea4, chance_level=0.25, test_blocks=[8,9])

In [None]:
chi_square_comparisons(df_idaea4_trials_test, condition_idx_ordering_idaea4, condition_labels_idaea4)

In [None]:
# We re-run the analysis with L-WISE (condition number 5) as the "control"

chi_square_comparisons(df_idaea4_trials_test, condition_idx_ordering_idaea4, condition_labels_idaea4, control_condition_idx=condition_labels_idaea4.index("L-WISE"))

In [None]:
## PLOT IDAEA4 LEARNING CURVES

from scipy.signal import savgol_filter
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


def get_difficulty(image_path, dirmap):
    image_name = image_path.split('/')[-1].split('?')[0]
    difficulty = dirmap[dirmap['im_path'].str.contains(image_name)]['difficulty'].values
    return difficulty[0] if len(difficulty) > 0 else np.nan

def get_difficulty_percentile(image_path, dirmap):
    image_name = image_path.split('/')[-1].split('?')[0]
    row = dirmap[dirmap['im_path'].str.contains(image_name)]
    
    if row.empty:
        return np.nan
    
    difficulty = row['difficulty'].values[0]
    class_value = row['class'].values[0]
    split_value = row['split'].values[0]
    
    # Filter the dirmap for the specific class and split
    subset = dirmap[(dirmap['class'] == class_value) & (dirmap['split'] == split_value)]
    
    # Calculate the percentile
    percentile = (subset['difficulty'] <= difficulty).mean() * 100
    
    return percentile

def get_enhancement(url):
    if 'natural' in url:
        return 0
    parts = url.split('-')
    for part in parts:
        if 'dot' in part:
            part = part.replace('dot', '.')
        try:
            return float(part)
        except:
            pass
    return 0

def normalize_enhancement(enhancement, max_enhancement):
    if max_enhancement == 0:
        return 0
    return enhancement / max_enhancement

def add_trial_index(df):
    #df = df.sort_values(['participant', 'block', 'trial'])
    df['trial_index'] = df.groupby('participant').cumcount() + 1
    return df

def smooth_with_forced_start(x, window_length, polyorder, force_points=1):
    smoothed = savgol_filter(x, window_length, polyorder, mode='nearest')
    smoothed[:force_points] = x[:force_points]
    return smoothed

def smooth_performance(df):
    df['smoothed_perf'] = df.groupby('participant')['perf'].transform(
        lambda x: smooth_with_forced_start(x, window_length=31, polyorder=2, force_points=1)
    )
    return df

def process_condition_data(df, condition, dirmap):
    df_cond = df[df["condition_idx"] == condition].copy()
    df_cond['difficulty'] = df_cond['stimulus_image_url'].apply(lambda x: get_difficulty_percentile(x, dirmap))
    df_cond['enhancement'] = df_cond['stimulus_image_url'].apply(get_enhancement)
    
    # Group by trial index and calculate means across participants
    grouped = df_cond.groupby('trial_index')
    avg_data = pd.DataFrame({
        'perf': grouped['smoothed_perf'].mean(),
        'perf_sem': grouped['smoothed_perf'].sem(),
        'difficulty': grouped['difficulty'].mean(),
        'enhancement': grouped['enhancement'].mean()
    })
    
    return avg_data

def create_comparison_plot(df_trials, dirmap, control_condition, condition, trial_type_names, chance_level=0.25, ylim=[0,1], max_enhance_eps=8):
    cond_dict = {idx: name for idx, name in enumerate(trial_type_names)}
    
    df_trials = add_trial_index(df_trials)
    df_trials = smooth_performance(df_trials)
    
    trial_counts = df_trials.groupby('participant')['trial_index'].max()
    print(f"Minimum trial count: {trial_counts.min()}")
    print(f"Maximum trial count: {trial_counts.max()}")
    
    control_data = process_condition_data(df_trials, control_condition, dirmap)
    condition_data = process_condition_data(df_trials, condition, dirmap)
    
    max_trial = max(control_data.index.max(), condition_data.index.max())
    
    plt.style.use('seaborn-v0_8-white')
    sns.set_palette("deep")

    # Create figure with three subplots
    fig, (ax_top, ax_middle, ax_bottom) = plt.subplots(3, 1, figsize=(9, 7), height_ratios=[0.75, 0.75, 2.75], sharex=True)
    
    # Top subplot for enhancement (epsilon)
    sns.lineplot(x=control_data.index, y=control_data['enhancement'], ax=ax_top, label='Control ϵ', color='black', linestyle=':', linewidth=1)
    sns.lineplot(x=condition_data.index, y=condition_data['enhancement'], ax=ax_top, label='L-WISE ϵ', color='red', linewidth=1)
    
    ax_top.set_ylabel('Enhance ϵ', fontsize=14, fontweight='bold')
    ax_top.set_ylim(-max_enhance_eps/40, max_enhance_eps + (max_enhance_eps/40))
    ax_top.set_yticks([0, max_enhance_eps/2, max_enhance_eps])
    ax_top.legend(loc='lower right', facecolor='white', fontsize=14)

    ax_top.spines['top'].set_visible(False)
    ax_top.spines['right'].set_visible(False)   

    # Middle subplot for difficulty
    sns.lineplot(x=control_data.index, y=control_data['difficulty'], ax=ax_middle, label='Control Difficulty', color='black', linestyle=':', linewidth=1)
    sns.lineplot(x=condition_data.index, y=condition_data['difficulty'], ax=ax_middle, label='L-WISE Difficulty', color='red', linewidth=1)
    
    ax_middle.set_ylabel('Diff. %ile', fontsize=14, fontweight='bold')
    ax_middle.set_ylim(0, 80)
    ax_middle.set_yticks([0, 40, 80])
    ax_middle.legend(loc='lower right', facecolor='white', fontsize=14)

    ax_middle.spines['top'].set_visible(False)
    ax_middle.spines['right'].set_visible(False)   

    # Bottom subplot for performance (keep this the same as before)
    sns.lineplot(x=control_data.index, y=control_data['perf'], ax=ax_bottom, label=f'Control Group', color='black', linestyle=':', linewidth=2)
    ax_bottom.fill_between(control_data.index, 
                           control_data['perf'] - control_data['perf_sem'], 
                           control_data['perf'] + control_data['perf_sem'], 
                           color='black', alpha=0.2)
    
    sns.lineplot(x=condition_data.index, y=condition_data['perf'], ax=ax_bottom, label=f'{cond_dict[condition]} Group', color='red', linewidth=2)
    ax_bottom.fill_between(condition_data.index, 
                           condition_data['perf'] - condition_data['perf_sem'], 
                           condition_data['perf'] + condition_data['perf_sem'], 
                           color='red', alpha=0.2)
    
    ax_bottom.set_xlabel('Trial', fontsize=14, fontweight='bold')
    ax_bottom.set_ylabel('Accuracy (original groundtruth)', fontsize=14, fontweight='bold')
    
    # Set x-axis range
    ax_bottom.set_xlim(0, max_trial)
    ax_bottom.set_ylim(*ylim)
    
    # Add vertical and horizontal lines
    ax_bottom.axvline(x=128, color='grey', linestyle='--', linewidth=1)
    ax_bottom.axhline(y=chance_level, color='black', linestyle='--', linewidth=1)

    ax_bottom.spines['top'].set_visible(False)
    ax_bottom.spines['right'].set_visible(False)   
    
    ax_bottom.legend(loc='upper right', facecolor='white', fontsize=14)
    
    # Adjust tick label sizes
    for ax in [ax_top, ax_middle, ax_bottom]:
        ax.tick_params(axis='both', which='major', labelsize=12)
    
    compare_title = f"{cond_dict[condition]}"
    plt.tight_layout()
    plt.savefig(f"notebooks/fig_outputs/{df_trials.iloc[0]["experiment_id"].split("_")[0]}_{compare_title}.pdf", dpi=300, format='pdf', bbox_inches='tight')

In [None]:
control_condition = 0
dirmap_idaea4 = pd.read_csv('psych_data/dataset_dirmaps/idaea4_dataset_dirmap.csv')
create_comparison_plot(df_idaea4_trials, dirmap_idaea4, control_condition, trial_type_names_idaea4.index("enhancement_taper_curriculum_sampling"), trial_type_names_idaea4, chance_level=0.25, ylim=[0.15, 0.75])

### Appendix Fig. S14A1: (relationship between human accuracy and ground truth logit in moth task)

In [None]:
def get_robust_gt_logit(image_path, dirmap):
    image_name = image_path.split('/')[-1].split('?')[0]
    robust_gt_logit = dirmap[dirmap['im_path'].str.contains(image_name)]['robust_gt_logit'].values
    return robust_gt_logit[0] if len(robust_gt_logit) > 0 else np.nan

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import cross_val_score, StratifiedKFold
from scipy.stats import pearsonr
import statsmodels.api as sm

def analyze_logit_accuracy_relationship(df_trials, dirmap, test_blocks=None, trial_type_names=None, ylim=[0,1], xlim=None, control_only=False, bins=10, print_num_data_points=False, retrieve_logits=True, fig_file_name=None, fig_size=(3,3), all_black=False):
    """Analyze relationship between robust_gt_logit and participant accuracy"""
    plt.style.use('seaborn-v0_8-white')
    plt.rcParams['xtick.bottom'] = True
    plt.rcParams['ytick.left'] = True

    # Create mask for filtering
    mask = pd.Series(True, index=df_trials.index)
    
    if test_blocks is not None:
        mask &= df_trials['block'].isin(test_blocks)
    
    if trial_type_names is not None:
        mask &= df_trials['trial_type'].isin(trial_type_names)
    
    if control_only:
        mask &= df_trials['condition_idx'] == 0
    
    # Apply filter
    df_test = df_trials[mask].copy()
    
    if retrieve_logits:
        df_test['robust_gt_logit'] = df_test['stimulus_image_url'].apply(
            lambda x: get_robust_gt_logit(x, dirmap)
        )
    
    # Remove outliers (beyond 3 std from mean)
    difficulty_mean = df_test['robust_gt_logit'].mean()
    difficulty_std = df_test['robust_gt_logit'].std()
    df_test = df_test[
        (df_test['robust_gt_logit'] >= difficulty_mean - 3*difficulty_std) & 
        (df_test['robust_gt_logit'] <= difficulty_mean + 3*difficulty_std)
    ]
    
    # Create difficulty bins
    df_test['robust_gt_logit_bin'] = pd.cut(df_test['robust_gt_logit'], bins=bins)
    
    # Calculate statistics for each bin
    binned_data = df_test.groupby('robust_gt_logit_bin').agg({
        'robust_gt_logit': ['mean', 'count', 'std'],
        'perf': ['mean', lambda x: bootstrap_ci(x)[0], lambda x: bootstrap_ci(x)[1]]
    }).reset_index()
    
    # Flatten column names
    binned_data.columns = ['robust_gt_logit_bin', 'bin_center', 'count', 'robust_gt_logit_std', 
                          'perf_mean', 'perf_ci_lower', 'perf_ci_upper']
    
    # Create figure
    fig, ax = plt.subplots(figsize=fig_size, dpi=300)
    
    # Plot error bars
    errorbar_params = {
        'fmt': 'o', 
        'capsize': 3, 
        'capthick': 1.5, 
        'elinewidth': 1.5, 
        'markersize': 4, 
        'zorder': 2
    }
    if all_black:
        errorbar_params['ecolor'] = 'black'
        errorbar_params['color'] = 'black'
    ax.errorbar(binned_data['bin_center'], 
                binned_data['perf_mean'],
                xerr=binned_data['robust_gt_logit_std'],
                yerr=[(binned_data['perf_mean'] - binned_data['perf_ci_lower']),  # lower error
                      (binned_data['perf_ci_upper'] - binned_data['perf_mean'])], # upper error
                **errorbar_params)
    
    # Add count labels if requested
    if print_num_data_points:
        for x, y, count, yerr in zip(binned_data['bin_center'], 
                                   binned_data['perf_mean'],
                                   binned_data['count'],
                                   binned_data['perf_ci_upper'] - binned_data['perf_mean']):
            # Position text slightly above the error bar
            text_y = y + yerr + 0.02  # Adjust the 0.02 offset as needed
            ax.text(x, text_y, int(count), 
                   ha='center', va='bottom',
                   fontsize=8)  # Adjust fontsize as needed
    
    # Fit logistic regression (use statsmodels to get p value)
    X = df_test['robust_gt_logit'].values.reshape(-1, 1)
    y = df_test['perf']
    
    X_sm = sm.add_constant(X)
    logit_model_sm = sm.Logit(y, X_sm).fit()
    print(logit_model_sm.summary())
    p_value = logit_model_sm.pvalues[1]

    log_reg = LogisticRegression().fit(X, y)
    
    # Generate points for the logistic curve
    X_plot = np.linspace(
        difficulty_mean - 3*difficulty_std,
        difficulty_mean + 3*difficulty_std,
        500
    ).reshape(-1, 1)
    y_plot = log_reg.predict_proba(X_plot)[:, 1]
    
    # Plot logistic regression curve
    plt.plot(X_plot, y_plot, 'k-', label='Logistic Regression Fit', zorder=1)
    
    # Customize plot
    ax.set_xlabel('Ground truth logit', fontsize=12, fontweight='bold')
    ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax.set_ylim(ylim)
    if xlim is not None:
        ax.set_xlim(xlim)
    ax.tick_params(axis='both', which='major', labelsize=12)
    
    # Calculate AUC
    auc = roc_auc_score(y, log_reg.predict_proba(X)[:, 1])
    
    # Calculate AUC with 10-fold cross-validation
    cv_auc_scores = cross_val_score(
        LogisticRegression(), 
        X, 
        y, 
        cv=StratifiedKFold(n_splits=10, shuffle=True, random_state=42),
        scoring='roc_auc'
    )
    cv_auc_mean = np.mean(cv_auc_scores)
    cv_auc_std = np.std(cv_auc_scores)
    
    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Make remaining spines thicker
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)

    # Add these lines to customize tick marks
    ax.tick_params(axis='both', which='major', length=6, width=2)
    ax.tick_params(axis='both', which='minor', length=3, width=1)
    
    plt.tight_layout()
    
    # Save figure
    if fig_file_name:
        fname = fig_file_name
    else:
        fname = f"{df_trials.iloc[0]['experiment_id'].split('_')[0]}_robust_gt_logit_accuracy.pdf"
    plt.savefig(os.path.join("notebooks/fig_outputs", fname), dpi=300, format="pdf", bbox_inches="tight")
    
    # Print statistics
    print(f"P-value: {p_value:.3e}")
    print(f"AUC score: {auc:.3f}")
    print(f"Cross-validated AUC: {cv_auc_mean:.3f} ± {cv_auc_std:.3f}")
    
    return fig, ax, {
        'p_value': p_value,
        'auc': auc,
        'binned_data': binned_data,
        'logistic_regression': log_reg
    }

In [None]:
analyze_logit_accuracy_relationship(df_idaea4_trials, dirmap_idaea4, test_blocks=[8,9], trial_type_names=trial_type_names_idaea4, ylim=[0.2,0.7], control_only=True, bins=6, print_num_data_points=True, all_black=True)

### Appendix Fig. S14A2: (relationship between human accuracy and enhancement epsilon in moth task)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

def analyze_enhancement_accuracy_relationship(df_trials, dirmap, condition_idx=None, test_blocks=[2,3,4,5], trial_type_names=None, min_acc=None, ylim=[0,1]):
    """Analyze relationship between image enhancement level and participant accuracy"""
    plt.style.use('seaborn-v0_8-white')
    plt.rcParams['xtick.bottom'] = True
    plt.rcParams['ytick.left'] = True
    
    # Filter for test blocks, specified trial types, and condition
    df_test = df_trials[
        (df_trials['block'].isin(test_blocks)) &
        (df_trials['trial_type'].isin(trial_type_names) if trial_type_names else True)
    ].copy()

    if min_acc is not None:
        # Calculate mean accuracy per participant
        participant_accs = df_test.groupby('participant')['perf'].mean()
        # Get list of participants meeting minimum accuracy criterion
        include_participants = participant_accs[participant_accs >= min_acc].index
        # Filter dataframe to keep only those participants
        df_test = df_test[df_test['participant'].isin(include_participants)]

    print(len(df_test), "total trials included in analysis")
    
    # Additional filtering for condition if specified
    if condition_idx is not None:
        df_test = df_test[df_test['condition_idx'] == condition_idx]
    
    # Get enhancement epsilon for each trial
    df_test['enhance_eps'] = df_test['stimulus_image_url'].apply(
        lambda x: get_enhancement(x)
    )

    print(list(df_test["enhance_eps"]))
    
    # Calculate statistics for each enhancement level
    grouped_data = df_test.groupby('enhance_eps').agg({
        'perf': ['mean', 'count', lambda x: bootstrap_ci(x)[0], lambda x: bootstrap_ci(x)[1]]
    }).reset_index()
    
    # Flatten column names
    grouped_data.columns = ['enhance_eps', 'perf_mean', 'count', 'perf_ci_lower', 'perf_ci_upper']
    
    # Create figure
    fig, ax = plt.subplots(figsize=(3, 3), dpi=300)
    
    # Plot error bars for each discrete enhancement level
    ax.errorbar(grouped_data['enhance_eps'], 
                grouped_data['perf_mean'],
                yerr=[(grouped_data['perf_mean'] - grouped_data['perf_ci_lower']),  # lower error
                      (grouped_data['perf_ci_upper'] - grouped_data['perf_mean'])], # upper error
                fmt='o', capsize=3, capthick=1.5, elinewidth=1.5, markersize=4, color='black', ecolor='black')
    
    # Fit logistic regression (use statsmodels to get p value)
    X = df_test['enhance_eps'].values.reshape(-1, 1)
    y = df_test['perf']
    
    log_reg = LogisticRegression().fit(X, y)

    X_sm = sm.add_constant(X)
    logit_model_sm = sm.Logit(y, X_sm).fit()
    print(logit_model_sm.summary())
    p_value = logit_model_sm.pvalues[1]
    
    # Generate points for the logistic curve
    X_plot = np.linspace(
        grouped_data['enhance_eps'].min(),
        grouped_data['enhance_eps'].max(),
        500
    ).reshape(-1, 1)
    y_plot = log_reg.predict_proba(X_plot)[:, 1]
    
    # Plot logistic regression curve
    plt.plot(X_plot, y_plot, 'k-', label='Logistic Regression Fit')
    
    # Customize plot
    ax.set_xlabel('Enhancement ϵ', fontsize=12, fontweight='bold')
    ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.tick_params(axis='both', which='major', labelsize=12)
    
    # Calculate AUC
    auc = roc_auc_score(y, log_reg.predict_proba(X)[:, 1])
    
    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Make remaining spines thicker
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)

    ax.tick_params(axis='both', which='major', length=6, width=2)
    ax.tick_params(axis='both', which='minor', length=3, width=1)
    
    plt.tight_layout()
    
    # Save figure
    save_name = df_trials.iloc[0]['experiment_id'].split('_')[0]
    if condition_idx is not None:
        save_name += f"_condition{condition_idx}"
    plt.savefig(f"notebooks/fig_outputs/{save_name}_enhancement_accuracy.pdf", 
                dpi=300, format="pdf", bbox_inches="tight")
    
    # Print statistics
    print(f"P-value: {p_value:.3e}")
    print(f"AUC score: {auc:.3f}")
    
    return fig, ax, {
        'p_value': p_value,
        'auc': auc,
        'grouped_data': grouped_data,
        'logistic_regression': log_reg
    }

In [None]:
analyze_enhancement_accuracy_relationship(df_idaea4_trials, dirmap_idaea4, test_blocks=[0, 1, 2, 3, 4, 5], trial_type_names=trial_type_names_idaea4, condition_idx=4, min_acc=None, ylim=[0.3,0.6])

### Appendix Fig. S15 (analyzing the effect of Greek name alias assignment on participant accuracy in the moth task)

In [None]:
## SUPPLEMENTARY CLASS MAPPING ANALYSIS PLOTS FOR IDAEA4

from statsmodels.stats.multicomp import pairwise_tukeyhsd

def analyze_class_mapping(df, class_name, stimuli=['Ajax', 'Eris', 'Leda', 'Tyro']):
    def get_class_stimulus(participant_data):
        mask = participant_data['class'] == class_name
        return participant_data.loc[mask, 'stimulus_name'].iloc[0]
    
    class_mappings = df.groupby('participant').apply(get_class_stimulus)
    
    def calculate_z_score(group, condition_means, condition_stds):
        condition = group['condition_idx'].iloc[0]
        avg_perf = group['perf'].mean()
        z_score = (avg_perf - condition_means[condition]) / condition_stds[condition]
        return z_score

    condition_stats = df.groupby('condition_idx')['perf'].agg(['mean', 'std'])
    condition_means, condition_stds = condition_stats['mean'].to_dict(), condition_stats['std'].to_dict()
    
    accuracy_z_scores = df.groupby('participant').apply(calculate_z_score, condition_means, condition_stds)
    
    analysis_df = pd.DataFrame({
        'stimulus': class_mappings,
        'z_score': accuracy_z_scores,
        'class': class_name
    })
    
    groups = [group['z_score'].values for name, group in analysis_df.groupby('stimulus') if name in stimuli]
    f_statistic, p_value = stats.f_oneway(*groups)
    
    df_between = len(groups) - 1
    df_within = len(analysis_df) - len(groups)
    eta_squared = (f_statistic * df_between) / (f_statistic * df_between + df_within)
    
    return analysis_df, f_statistic, p_value, eta_squared, df_between, df_within


def class_mapping_plots(df_trials_test, classes, class_names):
    # Analyze all classes

    results = {cls: analyze_class_mapping(df_trials_test, cls) for cls in classes}

    # Combine all results into a single DataFrame
    combined_df = pd.concat([df for df, _, _, _, _, _ in results.values()])
    combined_df['class_name'] = combined_df['class'].map(class_names)

    # Set up the plot style
    plt.style.use('seaborn-v0_8-whitegrid')
    sns.set_palette("deep")
    plt.figure(figsize=(12, 12))

    # Create the plot
    ax = sns.boxplot(x='class_name', y='z_score', hue='stimulus', data=combined_df,
                    width=0.7, fliersize=3, linewidth=2)

    # Customize the plot
    plt.xlabel(None)
    plt.ylabel('Condition-Normalized Accuracy (Z-Score)', fontsize=28)
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)

    # Adjust y-axis limits to make boxplots smaller
    y_min, y_max = plt.ylim()
    plt.ylim(-0.7, 0.75)

    # Add ANOVA results below each group
    for i, cls in enumerate(classes):
        df, f_stat, p_val, eta_sq, df_between, df_within = results[cls]
        stats_text = f"F({df_between},{df_within})={f_stat:.2f}\np={p_val:.4f}\nη²={eta_sq:.3f}"
        plt.text(i, plt.ylim()[1]-0.01, stats_text, ha='center', va='top', fontsize=20, 
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.8))

    # Adjust the subplot to make room for the legend
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)

    # Get the handles and labels from the main plot
    handles, labels = ax.get_legend_handles_labels()

    # Create the legend below the plot
    plt.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.08),
            ncol=4, fontsize=26, title='Alias', title_fontsize=28)

    plt.tight_layout()
    plt.savefig(f"notebooks/fig_outputs/{df_trials_test.iloc[0]["experiment_id"].split("_")[0]}_class_mapping_analysis.pdf", dpi=300, bbox_inches='tight', format='pdf')
    plt.show()

    # Print summary statistics
    for cls, (df, f_stat, p_val, eta_sq, df_between, df_within) in results.items():
        print(f"\nANOVA results for '{cls}' mapping:")
        print(f"F({df_between},{df_within}) = {f_stat:.4f}")
        print(f"p-value: {p_val:.4f}")
        print(f"Effect size (η²): {eta_sq:.4f}")
        print("\nDescriptive statistics:")
        print(df.groupby('stimulus')['z_score'].describe())

        if p_val < 0.05:
            print("\nTukey's HSD post-hoc test:")
            tukey_results = pairwise_tukeyhsd(df['z_score'], df['stimulus'])
            print(tukey_results)

In [None]:
classes = ['01233_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_aversata', '01234_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_biselata', '01239_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_seriata', '01240_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_tacturata']
class_names = {
    '01233_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_aversata': 'aversata',
    '01234_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_biselata': 'biselata',
    '01239_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_seriata': 'seriata',
    '01240_Animalia_Arthropoda_Insecta_Lepidoptera_Geometridae_Idaea_tacturata': 'tacturata'
}
class_mapping_plots(df_idaea4_trials_test, classes, class_names)

## HAM10000 dermoscopy learning task (4-class version "ham4")

In [None]:
## LOAD HAM4 DATA

control_cond = "natural"
trial_type_names_ham4 = [control_cond, "curriculum_sampling", "curriculum_sampling_shuffle", "enhancement_taper", "enhancement_taper_shuffle", "enhancement_taper_curriculum_sampling"]

if os.path.isfile("psych_data/df_ham4.csv") and DEIDENTIFIED_DATA:
  print("Reading ham4 dataset from saved .csv")
  df_ham4 = pd.read_csv("psych_data/df_ham4.csv")
else: # Load from .h5
  print("Reading ham4 dataset from .h5 file")

  df_ham4_learn_4 = get_df_from_xarray(["./results/ham4_learn_4/ham4_learn_4_combined_dataset.h5"], drop_columns=drop_columns)

  # Remap condition indices and trial types (in the first deployment of the experiment, the shuffling procedure was not performed at all due to a bug)
  condition_idx_remap = {
    2: 1, 
    4: 3,
  }
  trial_type_remap = {
    "curriculum_sampling_shuffle": "curriculum_sampling",
    "enhancement_taper_shuffle": "enhancement_taper",
  }
  df_ham4_learn_4["condition_idx"] = df_ham4_learn_4["condition_idx"].map(condition_idx_remap).fillna(df_ham4_learn_4["condition_idx"])
  df_ham4_learn_4["trial_type"] = df_ham4_learn_4["trial_type"].map(trial_type_remap).fillna(df_ham4_learn_4["trial_type"])
  df_ham4_learn_4["data_round"] = 4

  df_ham4_learn_4["experiment_id"] = "ham4_learn_4"

  df_ham4_learn_5 = get_df_from_xarray(["./results/ham4_learn_5/ham4_learn_5_combined_dataset.h5"], drop_columns=drop_columns)

  # Remap condition indices and trial types (to match conventions of ham4_learn_4)
  condition_idx_remap = {
    1: 2, 
    2: 4,
  }
  df_ham4_learn_5["condition_idx"] = df_ham4_learn_5["condition_idx"].map(condition_idx_remap).fillna(df_ham4_learn_5["condition_idx"])
  df_ham4_learn_5["participant"] = df_ham4_learn_5["participant"] + df_ham4_learn_4["participant"].max() + 1
  df_ham4_learn_5["experiment_id"] = "ham4_learn_5"
  df_ham4_learn_5["data_round"] = 5

  df_ham4 = pd.concat([df_ham4_learn_4, df_ham4_learn_5])

  df_ham4['perf'] = df_ham4['perf'].fillna(0)

  df_ham4 = df_ham4.reset_index(drop=True)


  ## Impute 4 missing condition_idx and trial_type values
  mode_values = df_ham4.groupby('trialset_id').agg({
      'condition_idx': lambda x: x.mode().iloc[0] if not x.mode().empty else None,
      'trial_type': lambda x: x.mode().iloc[0] if not x.mode().empty else None
  })
  mode_values.columns = ['condition_idx_mode', 'trial_type_mode']

  df_ham4 = df_ham4.merge(mode_values, on='trialset_id', how='left')

  df_ham4['condition_idx'] = df_ham4['condition_idx'].fillna(df_ham4['condition_idx_mode'])
  df_ham4['trial_type'] = df_ham4['trial_type'].fillna(df_ham4['trial_type_mode'])

  df_ham4 = df_ham4.drop(['condition_idx_mode', 'trial_type_mode'], axis=1)

  assert_constant_counts(df_ham4)

  df_ham4 = reassign_blocks(df_ham4, verbose=False)

  if DEIDENTIFIED_DATA:
    print("Saving de-identified version of the dataset")
    df_ham4.to_csv("psych_data/df_ham4.csv", index=False)

In [None]:
print(df_ham4.groupby('condition_idx')['participant'].nunique())

In [None]:
## FILTER OUT GUESSING PARTICIPANTS IN HAM4

df_ham4_trials = df_ham4[df_ham4["trial_type"].isin(trial_type_names_ham4)]

df_ham4_calib = df_ham4[df_ham4["stimulus_name"].isin(["circle", "triangle"])]

calib_means = df_ham4_calib.groupby("participant")["perf"].mean()

# Filter out participants with a mean calibration 'perf' of less than 0.9
participants_calib_above09 = calib_means[calib_means >= 0.9].index

print(f"All participants: {df_ham4['participant'].nunique()}")

# Filter the DataFrame for these participants
df_ham4 = df_ham4[df_ham4['participant'].isin(participants_calib_above09)]

print(f"Participants with calib acc of 0.9 and above: {df_ham4['participant'].nunique()}")

df_ham4_trials = df_ham4_trials[df_ham4_trials['participant'].isin(participants_calib_above09)]

df_ham4_trials_test = df_ham4_trials[df_ham4_trials["block"].isin([8, 9])]

### Dermoscopy task learning performance statistics (Table 1)

In [None]:
condition_idx_ordering_ham4 = [0, 3, 4, 1, 2, 5]
condition_labels_ham4 = ["Control", "Enhance", "Enhance (shuffle)", "Select", "Select (shuffle)", "L-WISE"]

ham4_accuracy_df, ham4_training_time_df, ham4_completion_time_df = print_main_stats(df_ham4_trials, condition_idx_ordering_ham4, condition_labels_ham4, chance_level=0.25, test_blocks=[8,9])

In [None]:
chi_square_comparisons(df_ham4_trials_test, condition_idx_ordering_ham4, condition_labels_ham4)

In [None]:
chi_square_comparisons(df_ham4_trials_test, condition_idx_ordering_ham4, condition_labels_ham4, control_condition_idx=condition_labels_ham4.index("L-WISE"))

### Appendix Fig. S12E-G (learning curves for dermoscopy task)

In [None]:
## PLOT HAM4 LEARNING CURVES

control_condition = 0
dirmap_ham4 = pd.read_csv('psych_data/dataset_dirmaps/ham4_dataset_dirmap.csv')
create_comparison_plot(df_ham4_trials, dirmap_ham4, control_condition, trial_type_names_ham4.index("enhancement_taper_curriculum_sampling"), trial_type_names_ham4, chance_level=0.25, ylim=[0.15, 0.85])

### Appendix Fig. S14B1: (relationship between human accuracy and ground truth logit in dermoscopy task)

In [None]:
analyze_logit_accuracy_relationship(df_ham4_trials, dirmap_ham4, test_blocks=[8,9], trial_type_names=trial_type_names_ham4, ylim=[0.1,0.75], control_only=True, bins=6, print_num_data_points=True, all_black=True)

### Appendix Fig. S14B2: (relationship between human accuracy and enhancement epsilon in dermoscopy task)

In [None]:
analyze_enhancement_accuracy_relationship(df_ham4_trials, dirmap_ham4, test_blocks=[0,1,2,3,4,5], trial_type_names=trial_type_names_ham4, condition_idx=4, min_acc=None, ylim=[0.3, 0.55])

### Appendix Fig. S16 (analyzing the effect of Greek name alias assignment on participant accuracy in the dermoscopy task)

In [None]:
## SUPPLEMENTARY CLASS MAPPING ANALYSIS PLOTS FOR HAM4

classes = ['bcc', 'bkl', 'mel', 'nv']
class_names = {
    'mel': 'Melanoma',
    'nv': 'Benign\nMole',
    'bcc': 'Basal Cell\nCarcinoma',
    'bkl': 'Benign\nKeratosis'
}
class_mapping_plots(df_ham4_trials_test, classes, class_names)

## MHIST histology learning task

In [None]:
trial_type_names_mhist = [control_cond, "enhancement_taper_curriculum_sampling"]

if os.path.isfile("psych_data/df_mhist.csv") and DEIDENTIFIED_DATA:
  print("Reading MHIST dataset from saved .csv")
  df_mhist = pd.read_csv("psych_data/df_mhist.csv")
else: # Load from .h5
  print("Reading MHIST dataset from .h5 file")

  df_mhist = get_df_from_xarray(["./results/mhist_learn_1_PARTIAL/mhist_learn_1_combined_dataset.h5"], drop_columns=drop_columns)

  df_mhist["experiment_id"] = "mhist_learn_1"

  df_mhist['perf'] = df_mhist['perf'].fillna(0)

  df_mhist = df_mhist.reset_index(drop=True)

  if DEIDENTIFIED_DATA:
    print("Saving de-identified version of the dataset")
    df_mhist.to_csv("psych_data/df_mhist.csv", index=False)

In [None]:
print(df_mhist.groupby('condition_idx')['participant'].nunique())

In [None]:
## FILTER OUT GUESSING PARTICIPANTS IN MHIST

df_mhist_trials = df_mhist[df_mhist["trial_type"].isin(trial_type_names_mhist)]

df_mhist_calib = df_mhist[df_mhist["stimulus_name"].isin(["circle", "triangle"])]

calib_means = df_mhist_calib.groupby("participant")["perf"].mean()

# Filter out participants with a mean calibration 'perf' of less than 0.9
participants_calib_above09 = calib_means[calib_means >= 0.9].index

print(f"All participants: {df_mhist['participant'].nunique()}")

# Filter the DataFrame for these participants
df_mhist = df_mhist[df_mhist['participant'].isin(participants_calib_above09)]

print(f"Participants with calib acc of 0.9 and above: {df_mhist['participant'].nunique()}")

df_mhist_trials = df_mhist_trials[df_mhist_trials['participant'].isin(participants_calib_above09)]

df_mhist_trials_test = df_mhist_trials[df_mhist_trials["block"].isin([8, 9])]

### Histology task learning performance statistics (Table 1)

In [None]:
condition_idx_ordering_mhist = [0, 1]
condition_labels_mhist = ["Control", "L-WISE"]

mhist_accuracy_df, mhist_training_time_df, mhist_completion_time_df = print_main_stats(df_mhist_trials, condition_idx_ordering_mhist, condition_labels_mhist, chance_level=0.5, test_blocks=[8,9])

In [None]:
chi_square_comparisons(df_mhist_trials_test, condition_idx_ordering_mhist, condition_labels_mhist)

### Appendix Fig. S13D-F (learning curves for histology task)

In [None]:
control_condition = 0
dirmap_mhist = pd.read_csv('psych_data/dataset_dirmaps/mhist_dataset_dirmap.csv')
create_comparison_plot(df_mhist_trials, dirmap_mhist, control_condition, trial_type_names_mhist.index("enhancement_taper_curriculum_sampling"), trial_type_names_mhist, chance_level=0.5, ylim=[0.4, 1.0])

### Appendix Fig. S14C: (relationship between human accuracy and ground truth logit in histology task)


In [None]:
analyze_logit_accuracy_relationship(df_mhist_trials, dirmap_mhist, test_blocks=[8,9], trial_type_names=trial_type_names_mhist, ylim=[0.2,1.0], control_only=True, bins=6, print_num_data_points=True, all_black=True)

### Appendix Fig. S13C (Analysis of the relationship between ground truth logit and inter-annotator agreement in the MHIST histology dataset)

In [None]:
# Create the new annotator_agreement column
dirmap_mhist['annotator_agreement'] = abs(dirmap_mhist['Number of Annotators who Selected SSA (Out of 7)'] - 3.5) / 3.5

# Calculate mean and confidence intervals for each unique annotator_agreement value
unique_agreements = dirmap_mhist['annotator_agreement'].unique()
means = []
cis = []

for agreement in unique_agreements:
    data = dirmap_mhist[dirmap_mhist['annotator_agreement'] == agreement]['robust_gt_logit']
    means.append(np.mean(data))
    cis.append(bootstrap_ci(data))

# Sort the data
sorted_indices = np.argsort(unique_agreements)
unique_agreements = unique_agreements[sorted_indices]
means = np.array(means)[sorted_indices]
cis = np.array(cis)[sorted_indices]

# Create the plot
fig, ax = plt.subplots(figsize=(2, 4))

# Plot horizontal bars
bar_color = '#005885'  # Darker blue color
ax.barh(unique_agreements, means, color=bar_color, edgecolor='black', linewidth=2, height=0.1)

# Add error bars
ax.errorbar(means, unique_agreements, xerr=np.abs(cis.T - means), fmt='none', ecolor='black', capsize=5, capthick=2, linewidth=2)

# Customize the plot
ax.set_ylabel('Annotator Agreement', fontsize=14)
ax.set_xlabel(r'$L_{\mathrm{gt}}$', fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.set_yticks(unique_agreements, labels=['4', '5', '6', '7'])

# Show the plot
plt.tight_layout()
plt.savefig(f"notebooks/fig_outputs/mhist_agreement_Lgt.pdf", dpi=300, format='pdf', bbox_inches='tight')

# Calculate correlation
correlation = dirmap_mhist['robust_gt_logit'].corr(dirmap_mhist['annotator_agreement'])
print(f"Correlation between robust_gt_logit and annotator_agreement: {correlation:.4f}")

## Combined figs with moth, dermoscopy, and histology tasks

### Fig. 4A (Plot showing L-WISE improvements in test accuracy and training speed across 3 tasks)

In [None]:
def create_combined_acc_time_plot(idaea4_data, idaea4_training, ham4_data, ham4_training, mhist_data, mhist_training):
    plt.style.use('seaborn-v0_8-white')
    fig, ax = plt.subplots(figsize=(12, 8))

    tasks = ['Natural images\n(Moths)', 'Dermoscopy\nimages', 'Histology\nimages']
    colors = ['orange', 'green', 'blue']

    for i, (accuracy_data, training_data) in enumerate([(idaea4_data, idaea4_training), 
                                                        (ham4_data, ham4_training), 
                                                        (mhist_data, mhist_training)]):
        for condition in ['Control', 'L-WISE']:
            acc = accuracy_data[accuracy_data['Condition'] == condition]
            train = training_data[training_data['condition'] == condition]
            
            xerr = [train['mean_training_time'] - train['ci_lower'], 
                    train['ci_upper'] - train['mean_training_time']]
            yerr = [acc['yerr_lower'], acc['yerr_upper']]
            
            marker = 'o'
            facecolor = 'none' if condition == 'Control' else colors[i]
            label = tasks[i] if condition == 'Control' else None
            
            ax.errorbar(train['mean_training_time'], acc['Accuracy'], 
                        xerr=xerr, yerr=yerr, 
                        fmt=marker, color=colors[i], markerfacecolor=facecolor,
                        capsize=5, label=label, markersize=10, 
                        linewidth=2, elinewidth=2, capthick=2)

        # Add arrow
        control = training_data[training_data['condition'] == 'Control']
        lwise = training_data[training_data['condition'] == 'L-WISE']
        control_acc = accuracy_data[accuracy_data['Condition'] == 'Control']
        lwise_acc = accuracy_data[accuracy_data['Condition'] == 'L-WISE']
        
        # Calculate arrow start and end points with a gap
        gap = 0.1  # Adjust this value to increase/decrease the gap
        start_x = control['mean_training_time'].values[0]
        start_y = control_acc['Accuracy'].values[0]
        end_x = lwise['mean_training_time'].values[0]
        end_y = lwise_acc['Accuracy'].values[0]
        
        dx = end_x - start_x
        dy = end_y - start_y
        arrow_length = np.sqrt(dx**2 + dy**2)
        
        start_x_adj = start_x + (dx * gap / arrow_length)
        start_y_adj = start_y + (dy * gap / arrow_length)
        end_x_adj = end_x - (dx * gap / arrow_length)
        end_y_adj = end_y - (dy * gap / arrow_length)
        
        ax.annotate('', xy=(end_x_adj, end_y_adj), xytext=(start_x_adj, start_y_adj),
                    arrowprops=dict(arrowstyle='->', color=colors[i], 
                                    linewidth=3, mutation_scale=20))

    # Customize the plot
    ax.set_xlabel('Training phase duration [min]', fontsize=24)
    ax.set_ylabel('Test accuracy\n[% correct]', fontsize=24)
    ax.tick_params(axis='both', which='major', labelsize=22)
    
    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Add legend for tasks
    handles, labels = ax.get_legend_handles_labels()
    legend1 = ax.legend(handles, labels, loc='upper right', fontsize=22, title_fontsize=22)
    ax.add_artist(legend1)
    
    # Add a second legend for the markers
    control_marker = plt.Line2D([0], [0], marker='o', color='k', markerfacecolor='none', markersize=10, label='"Naive" visual learning (control)')
    lwise_marker = plt.Line2D([0], [0], marker='o', color='k', markerfacecolor='k', markersize=10, label='L-WISE (Ours)')
    ax.legend(handles=[control_marker, lwise_marker], loc='lower left', fontsize=22, title_fontsize=22)

    plt.tight_layout()
    plt.savefig(f"notebooks/fig_outputs/combined_acc_time_plot.pdf", dpi=300, format='pdf', bbox_inches='tight')

create_combined_acc_time_plot(idaea4_accuracy_df, idaea4_training_time_df, ham4_accuracy_df, ham4_training_time_df, mhist_accuracy_df, mhist_training_time_df)

### Fig. 4B (class-wise precision and recall for L-WISE participants vs controls)

In [None]:
def calculate_class_metrics(df_trials, class_name, condition_idx, trial_type_names, test_blocks=[8,9]):
    """Calculate precision and recall for a specific class and condition"""
    # Filter for test blocks and condition
    df_test = df_trials[
        (df_trials['block'].isin(test_blocks)) & 
        (df_trials['condition_idx'] == condition_idx) & 
        (df_trials['trial_type'].isin(trial_type_names))
    ]
    
    # Extract short class name
    short_class = class_name.split('_')[-1]
    
    # Calculate participant-wise metrics
    participant_metrics = []
    for participant in df_test['participant'].unique():
        df_participant = df_test[df_test['participant'] == participant]
        
        # True positives: correct predictions for this class
        true_pos = df_participant[
            (df_participant['class'] == class_name) & 
            (df_participant['perf'] == 1)
        ].shape[0]
        
        # False positives: incorrect predictions of this class
        false_pos = df_participant[
            (df_participant['class'] != class_name) & 
            (~df_participant['i_choice'].isna()) &
            (df_participant['i_choice'] == df_participant[df_participant['class'] == class_name]['i_choice'].iloc[0])
        ].shape[0]
        
        # False negatives: incorrect predictions for examples of this class
        false_neg = df_participant[
            (df_participant['class'] == class_name) & 
            (df_participant['perf'] == 0)
        ].shape[0]
        
        # Calculate metrics (adding small epsilon to prevent division by zero)
        precision = true_pos / (true_pos + false_pos + 1e-10)
        recall = true_pos / (true_pos + false_neg + 1e-10)
        
        participant_metrics.append({
            'precision': precision,
            'recall': recall
        })
    
    # Convert to DataFrame for easy metric extraction
    metrics_df = pd.DataFrame(participant_metrics)
    
    # Calculate bootstrap CIs
    precision_ci = bootstrap_ci(metrics_df['precision'].values)
    recall_ci = bootstrap_ci(metrics_df['recall'].values)
    
    return {
        'class': short_class,
        'precision': metrics_df['precision'].mean(),
        'precision_ci': precision_ci,
        'recall': metrics_df['recall'].mean(),
        'recall_ci': recall_ci
    }

def create_precision_recall_comparison_scatter(tasks, ylim=[0, 1], plot_confidence_intervals=True):
    """
    Create a scatter plot comparing Control vs L-WISE performance across tasks.
    
    Parameters:
    tasks (list): List of dictionaries containing:
        - df: DataFrame with trial data
        - name: Display name for the task
        - color: Color for plotting
        - class_labels: Dictionary mapping class names to display labels
        - trial_type_names: List of trial type names
        - control_condition: Index for control condition
        - lwise_condition: Index for L-WISE condition
    ylim (list): Y-axis limits, should match X-axis limits
    plot_confidence_intervals (bool): Whether to plot error bars
    
    Returns:
    tuple: (figure, axis)
    """
    # Set up the style
    plt.style.use('seaborn-v0_8-white')
    plt.rcParams['xtick.bottom'] = True
    plt.rcParams['ytick.left'] = True
    
    # Create figure
    fig = plt.figure(figsize=(7.5, 5))
    
    # Create axes with specific position to maintain square aspect
    # [left, bottom, width, height]
    ax = fig.add_axes([0.15, 0.15, 0.6, 0.75])
    
    # Plot diagonal line for equal performance
    ax.plot([ylim[0], ylim[1]], [ylim[0], ylim[1]], '--', color='gray', linewidth=1.5)
    
    # Initialize handles for legends
    task_handles = []
    metric_handles = []
    
    # Plot each task
    for task in tasks:
        df = task['df']
        color = task['color']
        class_labels = task['class_labels']
        trial_type_names = task['trial_type_names']
        control_condition = task['control_condition']
        lwise_condition = task['lwise_condition']
        
        # Get unique classes
        classes = df['class'].unique()
        
        for class_name in classes:
            # Calculate metrics
            control_metrics = calculate_class_metrics(df, class_name, control_condition, trial_type_names)
            lwise_metrics = calculate_class_metrics(df, class_name, lwise_condition, trial_type_names)
            
            # Get display label
            display_label = class_labels[control_metrics['class']]
            
            # Plot precision (triangles)
            if plot_confidence_intervals:
                ax.errorbar(control_metrics['precision'], lwise_metrics['precision'],
                           xerr=[[control_metrics['precision'] - control_metrics['precision_ci'][0]],
                                [control_metrics['precision_ci'][1] - control_metrics['precision']]],
                           yerr=[[lwise_metrics['precision'] - lwise_metrics['precision_ci'][0]],
                                [lwise_metrics['precision_ci'][1] - lwise_metrics['precision']]],
                           fmt='none', color='black', capsize=5, elinewidth=1, capthick=1)
            prec = ax.scatter(control_metrics['precision'], lwise_metrics['precision'],
                            marker='^', s=60, c=color, label=display_label,
                            edgecolor='black', linewidth=1)
            
            # Plot recall (squares)
            if plot_confidence_intervals:
                ax.errorbar(control_metrics['recall'], lwise_metrics['recall'],
                           xerr=[[control_metrics['recall'] - control_metrics['recall_ci'][0]],
                                [control_metrics['recall_ci'][1] - control_metrics['recall']]],
                           yerr=[[lwise_metrics['recall'] - lwise_metrics['recall_ci'][0]],
                                [lwise_metrics['recall_ci'][1] - lwise_metrics['recall']]],
                           fmt='none', color='black', capsize=5, elinewidth=1, capthick=1)
            rec = ax.scatter(control_metrics['recall'], lwise_metrics['recall'],
                           marker='s', s=60, c=color, label=display_label,
                           edgecolor='black', linewidth=1)
        
        # Add to task handles (only once per task)
        task_handles.append(plt.scatter([], [], c=color, label=task['name'],
                                      marker='o', s=60, edgecolor='black', linewidth=1))
    
    # Add metric handles
    metric_handles.extend([
        plt.scatter([], [], c='gray', marker='^', label='Precision',
                   s=60, edgecolor='black', linewidth=1),
        plt.scatter([], [], c='gray', marker='s', label='Recall',
                   s=60, edgecolor='black', linewidth=1)
    ])
    
    # Customize plot
    ax.set_xlabel('Control', fontsize=14, fontweight='bold')
    ax.set_ylabel('L-WISE', fontsize=14, fontweight='bold')
    
    # Set equal aspect ratio and limits
    ax.set_aspect('equal')
    ax.set_xlim(ylim)
    ax.set_ylim(ylim)
    
    # Style spines and ticks
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.tick_params(axis='both', which='major', length=6, width=2, labelsize=12)
    ax.tick_params(axis='both', which='minor', length=3, width=1)
    
    # Create separate legends with adjusted positioning and smaller font
    task_legend = ax.legend(handles=task_handles, title='Tasks',
                           bbox_to_anchor=(1.05, 1), loc='upper left',
                           frameon=False, fontsize=10, title_fontsize=11)
    ax.add_artist(task_legend)
    
    metric_legend = ax.legend(handles=metric_handles, title='Metrics',
                            bbox_to_anchor=(1.05, 0.3), loc='center left',
                            frameon=False, fontsize=10, title_fontsize=11)
    
    # Adjust layout
    plt.tight_layout()

    plt.savefig("notebooks/fig_outputs/precision_recall_scatter.pdf", dpi=300, format='pdf', bbox_inches='tight')
    
    return fig, ax

In [None]:
idaea4_class_label_map = {
  "seriata": "$\\it{seriata}$",
  "tacturata": "$\\it{tacturata}$",
  "biselata": "$\\it{biselata}$",
  "aversata": "$\\it{aversata}$"
}

ham4_class_label_map = {
  "nv": "Benign mole",
  "mel": "Melanoma",
  "bcc": "Basal cell carcinoma",
  "bkl": "Benign keratosis"
}

mhist_class_label_map = {
  "ssa": "Sessile serrated adenoma (malignant)",
  "hp": "Hyperplastic polyp (benign)",
}

tasks = [
    {
        'df': df_idaea4_trials,
        'name': 'Moth photographs',
        'color': 'orange',
        'class_labels': idaea4_class_label_map,
        'trial_type_names': trial_type_names_idaea4,
        'control_condition': 0,
        'lwise_condition': trial_type_names_idaea4.index("enhancement_taper_curriculum_sampling")
    },
    {
        'df': df_ham4_trials,
        'name': 'Dermoscopy images',
        'color': 'green',
        'class_labels': ham4_class_label_map,
        'trial_type_names': trial_type_names_ham4,
        'control_condition': 0,
        'lwise_condition': trial_type_names_ham4.index("enhancement_taper_curriculum_sampling")
    },
    {
        'df': df_mhist_trials,
        'name': 'Histology images',
        'color': 'blue',
        'class_labels': mhist_class_label_map,
        'trial_type_names': trial_type_names_mhist,
        'control_condition': 0,
        'lwise_condition': 1  # Assuming this is the correct index for MHIST
    }
]

fig, ax = create_precision_recall_comparison_scatter(tasks, ylim=[0.2, 0.8], plot_confidence_intervals=True)

### Appendix Fig. S12L-M: HAM4 dermoscopy pilot study (initial enhancement epsilon was too high)

In [None]:
control_cond = "natural"
non_natural_cond = "enhanced"
trial_type_names_ham4_pilot = [control_cond, non_natural_cond]

if os.path.isfile("psych_data/df_ham4_pilot.csv") and DEIDENTIFIED_DATA:
  print("Reading ham4 pilot dataset from saved .csv")
  df_ham4_pilot = pd.read_csv("psych_data/df_ham4_pilot.csv")
else: # Load from .h5
  print("Reading ham4 pilot dataset from .h5 file")

  data_paths = [ # ham4_learn_0
    "./results/combined_dataset_37YYO3NWHDBLNHVEBAN4G8XJE1YCCC.h5",
    "./results/combined_dataset_31MCUE39BK7ARTF0K38MDWE46HH3GL.h5",
  ]

  df_ham4_pilot = get_df_from_xarray(data_paths, drop_columns=drop_columns)

  df_ham4_pilot["experiment_id"] = "ham4-pilot_learn_0"

  df_ham4_pilot['perf'] = df_ham4_pilot['perf'].fillna(0)

  if DEIDENTIFIED_DATA:
    print("Saving de-identified version of the dataset")
    df_ham4_pilot.to_csv("psych_data/df_ham4_pilot.csv", index=False)



In [None]:
## FILTER OUT GUESSING PARTICIPANTS IN HAM4

df_ham4_pilot_trials = df_ham4_pilot[df_ham4_pilot["trial_type"].isin(trial_type_names_ham4_pilot)]

df_ham4_pilot_calib = df_ham4_pilot[df_ham4_pilot["stimulus_name"].isin(["circle", "triangle"])]

calib_means = df_ham4_pilot_calib.groupby("participant")["perf"].mean()

# Filter out participants with a mean calibration 'perf' of less than 0.9
participants_calib_above09 = calib_means[calib_means >= 0.9].index

print(f"All participants: {df_ham4_pilot['participant'].nunique()}")

# Filter the DataFrame for these participants
df_ham4_pilot = df_ham4_pilot[df_ham4_pilot['participant'].isin(participants_calib_above09)]

print(f"Participants with calib acc of 0.9 and above: {df_ham4_pilot['participant'].nunique()}")

df_ham4_pilot_trials = df_ham4_pilot_trials[df_ham4_pilot_trials['participant'].isin(participants_calib_above09)]

df_ham4_pilot_trials_test = df_ham4_pilot_trials[df_ham4_pilot_trials["block"].isin([6, 7])]

In [None]:
condition_idx_ordering = [0, 1]
condition_labels = ["Control", "Enhance"]

print_main_stats(df_ham4_pilot_trials, condition_idx_ordering, condition_labels, chance_level=0.25, test_blocks=[6,7])

In [None]:
chi_square_comparisons(df_ham4_pilot_trials_test, condition_idx_ordering, condition_labels)

In [None]:
## PLOT HAM4 PILOT LEARNING CURVES

control_condition = 0
create_comparison_plot(df_ham4_pilot_trials, dirmap_ham4, control_condition, trial_type_names_ham4_pilot.index("enhanced"), trial_type_names_ham4_pilot, chance_level=0.25, ylim=[0.15, 0.85], max_enhance_eps=20)

## ImageNet 16-way animal classification experiments

In [None]:
def load_imagenet16_dataset(data_path, RUN_TESTS=True, drop_columns=None):
  # Load dataset and convert to a dataframe with 1 row per trial
  ds = xr.open_dataset(data_path)

  raw_df = ds.to_dataframe().reset_index()

  # Filter rows where choice_slot equals i_choice
  df = raw_df[raw_df['choice_slot'] == raw_df['i_choice']].copy()

  # Sort dataframe such that trials for each participant appear in order
  df = df.sort_values(by=['participant', 'obs'])

  # Sort the columns in a logical order
  ordered_cols = ['participant', 'condition_idx', 'block', 'obs', 'trial_type', 'class', 'stimulus_image_url', 'stimulus_name', 'choice_name', 'i_correct_choice', 'i_choice', 'perf', 'reaction_time_msec', 'rel_timestamp_response', 'timestamp_start', 'monitor_width_px', 'monitor_height_px', 'stimulus_width_px', 'choice_width_px', 'stimulus_duration_msec', 'post_stimulus_delay_duration_msec', 'pre_choice_lockout_delay_duration_msec']
  other_cols = [col for col in df.columns if col not in ordered_cols]
  df = df[ordered_cols + other_cols]

  # Recover info about whether each image was from train, val, etc.
  df["split"] = df.apply(lambda row: row['stimulus_image_url'].split(".s3.amazonaws.com/")[1].split("/")[0], axis=1)

  if RUN_TESTS:
    # Sanity check that will almost certainly fail if stimulus_name and choice_name are calculated incorrectly
    for _, row in df.iterrows():
      if row["i_choice"] == row["i_correct_choice"]:
        assert(row["stimulus_name"] == row["choice_name"]), "stim=" + row["stimulus_name"] + ", choice=" + row["choice_name"]
      else:
        assert(row["stimulus_name"] != row["choice_name"]), "stim=" + row["stimulus_name"] + ", choice=" + row["choice_name"]

  df.to_csv(data_path.replace(".h5", ".csv"))

  if drop_columns is not None:
    df = df.drop(drop_columns, axis=1, errors='ignore')
  
  return df

In [None]:
if DEIDENTIFIED_DATA and os.path.isfile("psych_data/df_main_i16.csv") and os.path.isfile("psych_data/df_backbone_compare_i16.csv") and os.path.isfile("psych_data/df_loss_ablation_i16.csv"):
    print("Reading imagenet datasets from saved .csv")
    df_main_i16 = pd.read_csv("psych_data/df_main_i16.csv")
    df_backbone_compare_i16 = pd.read_csv("psych_data/df_backbone_compare_i16.csv")
    df_loss_ablation_i16 = pd.read_csv("psych_data/df_loss_ablation_i16.csv")
else: # Load from raw .h5
    print("Reading imagenet datasets from .h5 files")

    df_main_i16 = load_imagenet16_dataset("./results/imagenet16_v1_mod_2/imagenet16_v1_mod_2_combined_dataset.h5", drop_columns=drop_columns)
    df_main_i16["experiment_id"] = "imagenet16_main"

    df_backbone_compare_i16 = load_imagenet16_dataset("./results/imagenet16_v1_mod_4/imagenet16_v1_mod_4_combined_dataset.h5", drop_columns=drop_columns)
    df_backbone_compare_i16['participant'] += int(df_main_i16["participant"].max()+1)
    df_backbone_compare_i16["experiment_id"] = "imagenet16_backbone_compare"

    df_loss_ablation_i16 = load_imagenet16_dataset("./results/imagenet16_v1_mod_1/imagenet16_v1_mod_1_combined_dataset.h5", drop_columns=drop_columns)
    df_loss_ablation_i16['participant'] += int(df_backbone_compare_i16["participant"].max()+1)
    df_loss_ablation_i16["experiment_id"] = "imagenet16_loss_ablation"

    # Trial type remapping for backbone comparison experiment (forgot to change trial type names in config)
    trial_type_remap = {
        # Erroneous -> Correct mappings
        "natural": "natural",  # already correct
        "enhanced_logit_5": "enhanced_vanilla_resnet50",
        "enhanced_logit_10": "enhanced_eps1_resnet50",
        "enhanced_logit_15": "enhanced_eps3_resnet50",
        "enhanced_logit_20": "enhanced_eps10_resnet50",
        "attacked_logit_10": "enhanced_cutmix_resnet50",
        "enhanced_auto_lr": "enhanced_xcit_augmented",
        "enhanced_clahe_2": "enhanced_vit_harmonized_augmented",
        "enhanced_msrcr": "enhanced_eps3_resnet50_augmented",
        
        # Correct -> Correct mappings (in case they appear)
        "enhanced_vanilla_resnet50": "enhanced_vanilla_resnet50",
        "enhanced_eps1_resnet50": "enhanced_eps1_resnet50",
        "enhanced_eps3_resnet50": "enhanced_eps3_resnet50",
        "enhanced_eps10_resnet50": "enhanced_eps10_resnet50",
        "enhanced_cutmix_resnet50": "enhanced_cutmix_resnet50",
        "enhanced_xcit_augmented": "enhanced_xcit_augmented",
        "enhanced_vit_harmonized_augmented": "enhanced_vit_harmonized_augmented",
        "enhanced_eps3_resnet50_augmented": "enhanced_eps3_resnet50_augmented"
    }

    # Map the trial types using the dictionary, keeping original values if not in mapping
    df_backbone_compare_i16['trial_type'] = df_backbone_compare_i16['trial_type'].map(lambda x: trial_type_remap.get(x, x))

    # Remove irrelevant/unused trial types
    df_main_i16 = df_main_i16[df_main_i16["trial_type"] != "attacked_logit_10"]
    df_backbone_compare_i16 = df_backbone_compare_i16[df_backbone_compare_i16["trial_type"] != "enhanced_eps3_resnet50_augmented"]
    df_backbone_compare_i16 = df_backbone_compare_i16[df_backbone_compare_i16["trial_type"] != "enhanced_vit_harmonized_augmented"]

    if DEIDENTIFIED_DATA:
        print("Saving de-identified version of the dataset")
        df_main_i16.to_csv("psych_data/df_main_i16.csv", index=False)
        df_backbone_compare_i16.to_csv("psych_data/df_backbone_compare_i16.csv", index=False)
        df_loss_ablation_i16.to_csv("psych_data/df_loss_ablation_i16.csv", index=False)

In [None]:
from lwise_psych_modules.generate_trials_helpers import replace_bucket_name_in_url

# Filter out participants with calibration accuracy below 0.9

def filter_calib_below_09(df, df_name):
  pt_count_init = df['participant'].nunique()
  df_calib = df[df["trial_type"] == "calibration"]
  calib_means = df_calib.groupby("participant")["perf"].mean()
  participants_calib_above09 = calib_means[calib_means >= 0.9].index
  df = df[df['participant'].isin(participants_calib_above09)]
  df = df[df["trial_type"] != "calibration"]
  print(f"In {df_name}, {df['participant'].nunique()} out of {pt_count_init} participants had calibration accuracy of 0.9 or above")
  return df

df_main_i16 = filter_calib_below_09(df_main_i16, "df_main_i16")
df_backbone_compare_i16 = filter_calib_below_09(df_backbone_compare_i16, "df_backbone_compare_i16")
df_loss_ablation_i16 = filter_calib_below_09(df_loss_ablation_i16, "df_loss_ablation_i16")

# Remove screening and warmup trials
df_main_i16 = df_main_i16[df_main_i16["block"] > 1]
df_backbone_compare_i16 = df_backbone_compare_i16[df_backbone_compare_i16["block"] > 1]
df_loss_ablation_i16 = df_loss_ablation_i16[df_loss_ablation_i16["block"] > 1]

# Retrieve ground truth logit values from ImageNet_eps3.pt model

def find_value_by_url(url, df, col_name):
  filtered_df = df[df['url'] == url]
  if len(filtered_df) > 1:
    raise ValueError("More than one row found for the given URL.")
  elif len(filtered_df) == 0:
    return None
  else:
    return filtered_df.iloc[0][col_name]

def get_gt_logit_of_image_version(trial_df, logit_df):
  def get_logit_from_row(row, logit_df):
    bucket_name = row["stimulus_image_url"].split("://")[1].split(".s3.amazonaws.com")[0]
    robust_logit_col_name_prefix = bucket_name.replace("dot", ".").replace("-", "_").replace("morgan_", "").replace("imagenet16_", "").replace("imagenet16", "")
    if len(robust_logit_col_name_prefix) > 0: 
      robust_logit_col_name = robust_logit_col_name_prefix + "_robust_gt_logit"
    else:
      robust_logit_col_name = "robust_gt_logit"

    return find_value_by_url(replace_bucket_name_in_url(row["stimulus_image_url"].split("?")[0], "morgan-imagenet16"), logit_df, robust_logit_col_name)

  trial_df["robust_gt_logit"] = trial_df.apply(lambda row: get_logit_from_row(row, logit_df), axis=1)
  return trial_df

logit_df = pd.read_csv("psych_data/dataset_dirmaps/imagenet_animals_dataset_dirmap.csv")

df_main_i16 = get_gt_logit_of_image_version(df_main_i16, logit_df)
df_backbone_compare_i16 = get_gt_logit_of_image_version(df_backbone_compare_i16, logit_df)
df_loss_ablation_i16 = get_gt_logit_of_image_version(df_loss_ablation_i16, logit_df)

### Fig. 1B1 (effect of image enhancement on human accuracy in ImageNet animal task)

In [None]:
NORMALIZE = False

# First, calculate the mean 'perf' for each participant on 'natural' trials
natural_means = df_main_i16[df_main_i16['trial_type'] == 'natural'].groupby('participant')['perf'].mean()

# Create a function to normalize performance
def normalize_performance(row):
    return row['perf'] / natural_means[row['participant']]

# Apply the normalization to all rows
df_main_i16['normalized_perf'] = df_main_i16.apply(normalize_performance, axis=1)
perf_metric = 'normalized_perf' if NORMALIZE else 'perf'

# Group by participant and trial_type, then calculate mean accuracy
accuracy_by_group = df_main_i16[df_main_i16['trial_type'].isin(['natural', 'enhanced_logit_5', 'enhanced_logit_10', 'enhanced_logit_15', 'enhanced_logit_20', 'enhanced_clahe_2', 'enhanced_msrcr', 'enhanced_auto_lr'])].groupby(['participant', 'trial_type'])[perf_metric].mean().reset_index()

# Function to calculate bootstrap confidence interval
def bootstrap_ci(data, num_bootstrap_samples=10000, ci=0.95):
    bootstrap_means = np.random.choice(data, (num_bootstrap_samples, len(data)), replace=True).mean(axis=1)
    return np.percentile(bootstrap_means, [(1 - ci) / 2 * 100, (1 + ci) / 2 * 100])

# Calculate overall mean accuracy and bootstrap CI for each group
mean_accuracy_by_group = accuracy_by_group.groupby('trial_type')[perf_metric].mean().reset_index()
ci_accuracy_by_group = accuracy_by_group.groupby('trial_type')[perf_metric].apply(bootstrap_ci).reset_index()
ci_accuracy_by_group[['ci_lower', 'ci_upper']] = pd.DataFrame(ci_accuracy_by_group[perf_metric].tolist(), index=ci_accuracy_by_group.index)
ci_accuracy_by_group = ci_accuracy_by_group.drop(perf_metric, axis=1)

# Merge mean and CI dataframes
result = pd.merge(mean_accuracy_by_group, ci_accuracy_by_group, on='trial_type')

# Define the order of trial types and their labels
trial_types = ['natural', 'enhanced_logit_5', 'enhanced_logit_10', 'enhanced_logit_15', 'enhanced_logit_20', 'enhanced_clahe_2', 'enhanced_msrcr', 'enhanced_auto_lr']
labels = ['ϵ = 0', 'ϵ = 5', 'ϵ = 10', 'ϵ = 15', 'ϵ = 20', 'CLAHE', 'MSRCR', 'LR']

# Set the style for a more professional look
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("deep")

# Increase the figure size and DPI for better quality, but make it more compact
fig, ax = plt.subplots(figsize=(7, 3), dpi=300)

# Prepare data for plotting
x = np.array([0, 1, 2, 3, 4, 5, 6, 7])  # Adjust x-coordinates
means = result.set_index('trial_type').loc[trial_types, perf_metric]
ci_lower = result.set_index('trial_type').loc[trial_types, 'ci_lower']
ci_upper = result.set_index('trial_type').loc[trial_types, 'ci_upper']
yerr = np.array([means - ci_lower, ci_upper - means])

# Plot points with error bars
ax.errorbar(x, means, yerr=yerr, fmt='o', capsize=5, capthick=2, elinewidth=2, markersize=8)

# Connect the first 5 points with a line
ax.plot(x[:5], means[:5], '-', linewidth=2)

# Add a red horizontal dotted line at the level of the ϵ = 0 point
ax.axhline(y=means[0], color='red', linestyle=':', linewidth=2)

# Customize the plot
ylabel = 'Normalized Mean Accuracy' if NORMALIZE else 'Mean Accuracy'
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
#ax.set_xlabel('Condition', fontsize=16, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=12, fontweight='bold')

# Increase tick label size
ax.tick_params(axis='both', which='major', labelsize=12)

# Set y-axis limits
ax.set_ylim(0.675 if not NORMALIZE else 0.8, 0.875 if not NORMALIZE else 1.2)

# Add subtle spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)

# Adjust layout and display the plot
plt.tight_layout()
#plt.show()

# Optionally, save the figure as a high-resolution image
plt.savefig('notebooks/fig_outputs/imagenet16_accuracy_plot.pdf', dpi=2400, format='pdf', bbox_inches='tight')

### Fig. 1A1 (relationship between robust ground truth logit and human accuracy)

In [None]:
df_main_i16_nat = df_main_i16[df_main_i16["trial_type"].isin(['natural', 'enhanced_clahe_2', 'enhanced_msrcr', 'enhanced_auto_lr'])]
df_backbone_compare_i16_nat = df_backbone_compare_i16[df_backbone_compare_i16["trial_type"].isin(['natural', 'enhanced_vanilla_resnet50', 'enhanced_cutmix_resnet50'])]
df_loss_ablation_i16_nat = df_loss_ablation_i16[(df_loss_ablation_i16["trial_type"] == 'natural') & (df_loss_ablation_i16["split"] == "val")]

df_combined_i16_nat = pd.concat([df_main_i16_nat, df_backbone_compare_i16_nat, df_loss_ablation_i16_nat])
print("Number of trials for plot:", len(df_combined_i16_nat))
print("Number of participants for plot:", df_combined_i16_nat["participant"].nunique())

analyze_logit_accuracy_relationship(df_combined_i16_nat, None, test_blocks=None, trial_type_names=None, ylim=[0.2,1], xlim=[-3,28], control_only=False, bins=10, print_num_data_points=False, retrieve_logits=False, fig_size=(3.5, 3), all_black=True)

### Appendix Fig. S5 (relationship between robust ground truth logit and human accuracy, strictly including only original, unmodified images (none from Retinex, vanilla models, etc))

In [None]:
df_main_i16_nat = df_main_i16[df_main_i16["trial_type"].isin(['natural'])]
df_backbone_compare_i16_nat = df_backbone_compare_i16[df_backbone_compare_i16["trial_type"].isin(['natural'])]
df_loss_ablation_i16_nat = df_loss_ablation_i16[(df_loss_ablation_i16["trial_type"] == 'natural') & (df_loss_ablation_i16["split"] == "val")]

df_combined_i16_nat = pd.concat([df_main_i16_nat, df_backbone_compare_i16_nat, df_loss_ablation_i16_nat]) # This one looks worse because less data
print("Number of trials for plot:", len(df_combined_i16_nat))
print("Number of participants for plot:", df_combined_i16_nat["participant"].nunique())

analyze_logit_accuracy_relationship(df_combined_i16_nat, None, test_blocks=None, trial_type_names=None, ylim=[0.2,1], xlim=[-3,28], control_only=False, bins=10, print_num_data_points=False, retrieve_logits=False, fig_size=(3.5, 3), fig_file_name="imagenet16_robust_gt_logit_accuracy_strict_natural.pdf", all_black=True)

### Appendix Fig. S6 (comparison of image difficulty prediction methods)

In [None]:
import json

df_combined_i16_nat_strict = df_combined_i16_nat[df_combined_i16_nat["trial_type"] == "natural"]

df_combined_i16_nat_strict["vanilla_gt_logit"] = df_combined_i16_nat_strict.apply(lambda row: find_value_by_url(replace_bucket_name_in_url(row["stimulus_image_url"].split("?")[0], "morgan-imagenet16"), logit_df, "vanilla_gt_logit"), axis=1)

with open("psych_data/imagenet_animals_image_difficulty_metrics/cscore_proxy_robust.json", "r") as file:
  cscore_robust = json.load(file)
with open("psych_data/imagenet_animals_image_difficulty_metrics/cscore_proxy_vanilla.json", "r") as file:
  cscore_vanilla = json.load(file)

with open("psych_data/imagenet_animals_image_difficulty_metrics/pred_depth_robust.json", "r") as file:
  pred_depth_robust = json.load(file)
with open("psych_data/imagenet_animals_image_difficulty_metrics/pred_depth_vanilla.json", "r") as file:
  pred_depth_vanilla = json.load(file)

with open("psych_data/imagenet_animals_image_difficulty_metrics/adv_eps_robust.json", "r") as file:
  adv_eps_robust = json.load(file)
with open("psych_data/imagenet_animals_image_difficulty_metrics/adv_eps_vanilla.json", "r") as file:
  adv_eps_vanilla = json.load(file)

df_combined_i16_nat_strict["cscore_robust"] = df_combined_i16_nat_strict.apply(lambda row: cscore_robust[row["stimulus_image_url"].split("?")[0].split(".com/")[1]], axis=1)
df_combined_i16_nat_strict["cscore_vanilla"] = df_combined_i16_nat_strict.apply(lambda row: cscore_vanilla[row["stimulus_image_url"].split("?")[0].split(".com/")[1]], axis=1)
df_combined_i16_nat_strict["pred_depth_robust"] = df_combined_i16_nat_strict.apply(lambda row: pred_depth_robust[row["stimulus_image_url"].split("?")[0].split(".com/")[1]], axis=1)
df_combined_i16_nat_strict["pred_depth_vanilla"] = df_combined_i16_nat_strict.apply(lambda row: pred_depth_vanilla[row["stimulus_image_url"].split("?")[0].split(".com/")[1]], axis=1)
df_combined_i16_nat_strict["adv_eps_robust"] = df_combined_i16_nat_strict.apply(lambda row: adv_eps_robust[row["stimulus_image_url"].split("?")[0].split(".com/")[1]], axis=1)
df_combined_i16_nat_strict["adv_eps_vanilla"] = df_combined_i16_nat_strict.apply(lambda row: adv_eps_vanilla[row["stimulus_image_url"].split("?")[0].split(".com/")[1]], axis=1)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler
import statsmodels.api as sm

def bootstrap_ci(data, num_bootstrap_samples=10000, confidence_level=0.95):
    bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                for _ in range(num_bootstrap_samples)])
    return np.percentile(bootstrap_means, [(1 - confidence_level) / 2 * 100, (1 + confidence_level) / 2 * 100])

def perform_cv_analysis(X, y, feature_name):
    # Remove any rows with NaN values
    mask = ~np.isnan(X).any(axis=1)
    X_clean = X[mask]
    y_clean = y[mask]
    
    if len(X_clean) < len(X):
        print(f"{feature_name}: {len(X) - len(X_clean)} rows removed due to NaN values")
    
    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_clean)
    
    log_reg = LogisticRegression()
    cv = StratifiedKFold(n_splits=500, shuffle=True, random_state=42)
    auc_scores = cross_val_score(log_reg, X_scaled, y_clean, cv=cv, scoring='roc_auc')
    
    mean_auc = np.mean(auc_scores)
    ci_lower, ci_upper = bootstrap_ci(auc_scores)
    
    return mean_auc, ci_lower, ci_upper, len(X_clean)

def perform_statistical_analysis(X, y, feature_names):
    # Remove any rows with NaN values
    mask = ~np.isnan(X).any(axis=1)
    X_clean = X[mask]
    y_clean = y[mask]
    
    print(f"Statistical analysis on {len(X_clean)} samples with {len(feature_names)} features")
    
    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_clean)
    
    # Store feature stats for reference
    feature_stats = pd.DataFrame({
        'Feature': feature_names,
        'Mean': scaler.mean_,
        'Std': scaler.scale_
    })
    
    # Fit logistic regression using statsmodels
    X_with_intercept = sm.add_constant(X_scaled)
    model = sm.Logit(y_clean, X_with_intercept)
    results = model.fit(disp=0)
    
    # Extract results (excluding intercept)
    coef = results.params[1:]
    p_values = results.pvalues[1:]
    
    # Create results DataFrame
    stats_df = pd.DataFrame({
        'Feature': feature_names,
        'Coefficient': coef,
        'Std Error': results.bse[1:],
        'z-value': results.tvalues[1:],
        'P-value': p_values,
        'Original Mean': scaler.mean_,
        'Original Std': scaler.scale_
    })
    
    # Sort by absolute coefficient value
    stats_df = stats_df.assign(Abs_Coef=abs(stats_df['Coefficient'])).sort_values(
        'Abs_Coef', ascending=False).drop('Abs_Coef', axis=1)
    
    return stats_df

def get_feature_sets(variant):
    """variant should be 'vanilla' or 'robust'"""
    return {
        'C-Score': [f'cscore_{variant}'],
        'Pred. Depth': [f'pred_depth_{variant}'],
        'Adv. Robustness': [f'adv_eps_{variant}'],
        'L$_{gt}$': [f'{variant}_gt_logit'],
        'Combined w/o L$_{gt}$': [
            f'cscore_{variant}',
            f'pred_depth_{variant}',
            f'adv_eps_{variant}'
        ],
        'All features': [
            f'cscore_{variant}',
            f'pred_depth_{variant}',
            f'adv_eps_{variant}',
            f'{variant}_gt_logit'
        ]
    }

# Perform analyses
results = []
statistical_results = {}

for variant in ['vanilla', 'robust']:
    print(variant.upper() + " analysis")
    feature_sets = get_feature_sets(variant)
    
    for feature_name, feature_list in feature_sets.items():
        # Get the data
        X = df_combined_i16_nat_strict[feature_list].values
        y = df_combined_i16_nat_strict['perf']
        
        print(f"\nAnalyzing {feature_name}")
        print(f"Initial shape: {X.shape}")
        
        mean_auc, ci_lower, ci_upper, n_samples = perform_cv_analysis(X, y, feature_name)
        results.append((feature_name, mean_auc, ci_lower, ci_upper))
        
        # Perform statistical analysis for all feature combinations
        if len(feature_list) > 1:  # Only perform statistical analysis for multiple features
            try:
                key = f"{feature_name}"  # Use the full feature name as the key
                stat_results = perform_statistical_analysis(X, y, feature_list)
                statistical_results[key] = stat_results
            except Exception as e:
                print(f"Error in statistical analysis for {feature_name}:")
                print(str(e))

# Create visualization
plt.style.use('seaborn-v0_8-white')
fig, ax = plt.subplots(figsize=(12, 6))

# Create hatching pattern for robust bars
hatch_pattern = '//'

# Separate vanilla and robust results
vanilla_results = results[:6]
robust_results = results[6:]

# Calculate bar positions
bar_width = 0.8
n_bars = len(vanilla_results)
vanilla_positions = np.arange(n_bars)
robust_positions = np.arange(n_bars) + n_bars + 1

# Plot vanilla results
feature_names = [r[0].replace(' (vanilla)', '') for r in vanilla_results]
mean_aucs = [r[1] for r in vanilla_results]
errors = [[r[1] - r[2], r[3] - r[1]] for r in vanilla_results]

bars1 = ax.bar(vanilla_positions, mean_aucs, bar_width, yerr=np.array(errors).T, capsize=5,
               color='lightskyblue', edgecolor='black', linewidth=2,
               error_kw=dict(lw=2, capthick=2), label='Vanilla')

# Plot robust results
mean_aucs = [r[1] for r in robust_results]
errors = [[r[1] - r[2], r[3] - r[1]] for r in robust_results]

bars2 = ax.bar(robust_positions, mean_aucs, bar_width, yerr=np.array(errors).T, capsize=5,
               color='lightcoral', edgecolor='black', linewidth=2,
               error_kw=dict(lw=2, capthick=2), label='Robust',
               hatch=hatch_pattern)

# Add vertical dotted line between groups
midpoint = (vanilla_positions[-1] + robust_positions[0]) / 2
ax.axvline(x=midpoint, color='black', linestyle=':', linewidth=2)

# Customize plot
ax.set_ylabel('AUC (predict human correct or not)', fontsize=14, fontweight='bold')
ax.set_xlabel('Feature Sets for Logistic Regression', fontsize=14, fontweight='bold')
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_ylim(0.5, 0.75)

# Set x-tick positions and labels
all_positions = np.concatenate([vanilla_positions, robust_positions])
offset = 0.2
plt.xticks(all_positions + offset, feature_names * 2, rotation=45, ha='right')

# Add legend
ax.legend(loc='upper left', frameon=True, framealpha=1.0, fontsize=12)

# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(2)
ax.spines['bottom'].set_linewidth(2)

plt.tight_layout()
plt.savefig("notebooks/fig_outputs/difficulty_prediction_systematic.pdf", dpi=300, format="pdf", bbox_inches="tight")

# Print results
print("\nCross-validation Results:")
print("-" * 50)
for feature_name, mean_auc, ci_lower, ci_upper in results:
    print(f"\n{feature_name}:")
    print(f"  Mean AUC: {mean_auc:.3f}")
    print(f"  95% CI: [{ci_lower:.3f}, {ci_upper:.3f}]")

print("\nStatistical Analysis Results:")
print("-" * 50)
for feature_name, stat_results in statistical_results.items():
    print(f"\n{feature_name} Statistical Analysis:")
    print("(Coefficients are standardized - features scaled to mean=0, std=1)")
    print(stat_results.to_string(float_format=lambda x: '{:.4f}'.format(x)))

### Appendix Fig. S7 (predicting difficulty using ground truth logits from different models)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold

def analyze_model_logits(nat_df, models, output_path='notebooks/fig_outputs/logit_prediction_model_comparison.pdf'):
    """
    Analyze and compare model logits using cross-validated logistic regression.
    
    Parameters:
    -----------
    nat_df : pandas.DataFrame
        DataFrame containing human response data
    models : list of dict
        Each dict should contain:
            - 'model_name': str, name of the model
            - 'logit_df_path': str, path to CSV containing logit values
    output_path : str
        Path where to save the resulting plot
        
    Returns:
    --------
    tuple: (DataFrame with numerical results, matplotlib figure, matplotlib axis)
    """
    
    def bootstrap_ci(data, num_bootstrap_samples=10000, confidence_level=0.95):
        bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                  for _ in range(num_bootstrap_samples)])
        return np.percentile(bootstrap_means, 
                           [(1 - confidence_level) / 2 * 100, 
                            (1 + confidence_level) / 2 * 100])
    
    results = []
    
    # Analyze each model
    for model_info in models:
        model_name = model_info['model_name']
        print(f"\nAnalyzing {model_name}...")
        
        # Load logit dataframe
        logit_df = pd.read_csv(model_info['logit_df_path'])
        
        # Create temporary dataframe for analysis
        temp_df = nat_df.copy()
        
        # Map logit values
        def find_logit_value_and_type_by_url(url, df):
            filtered_df = df[df['url'] == url]
            if len(filtered_df) > 1:
                raise ValueError(f"Multiple rows found for URL in {model_name}")
            elif len(filtered_df) == 0:
                return None, None
            else:
                row = filtered_df.iloc[0]
                # Check which logit type is available
                if 'robust_gt_logit' in row and not pd.isna(row['robust_gt_logit']):
                    return row['robust_gt_logit'], 'robust'
                elif 'vanilla_gt_logit' in row and not pd.isna(row['vanilla_gt_logit']):
                    return row['vanilla_gt_logit'], 'vanilla'
                else:
                    raise ValueError(f"No valid logit found for {model_name}")

        # Add logit values to temporary dataframe
        logit_values = []
        logit_types = []
        
        for _, row in temp_df.iterrows():
            value, logit_type = find_logit_value_and_type_by_url(
                replace_bucket_name_in_url(row["stimulus_image_url"].split("?")[0], 'morgan-imagenet16'), 
                logit_df
            )
            logit_values.append(value)
            logit_types.append(logit_type)
        
        temp_df['logit'] = logit_values
        # Store the most common logit type used for this model
        logit_type_used = max(set(filter(None, logit_types)), key=logit_types.count)
        print(f"{model_name}: Using {logit_type_used} logits")
        
        # Remove rows with missing values
        mask = ~np.isnan(temp_df['logit'])
        X = temp_df[mask]['logit'].values.reshape(-1, 1)
        y = temp_df[mask]['perf'].values
        
        # Print class distribution
        class_counts = np.bincount(y.astype(int))
        print(f"Class distribution - {class_counts[0]} incorrect, {class_counts[1]} correct")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Perform cross-validation
        cv = StratifiedKFold(n_splits=500, shuffle=True, random_state=42)
        log_reg = LogisticRegression()
        auc_scores = cross_val_score(log_reg, X_scaled, y, cv=cv, scoring='roc_auc')
        
        # Calculate statistics
        mean_auc = np.mean(auc_scores)
        ci_lower, ci_upper = bootstrap_ci(auc_scores)
        
        results.append({
            'model': model_name,
            'mean_auc': mean_auc,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'n_samples': len(X),
            'logit_type': logit_type_used
        })
    
    # Create results DataFrame
    results_df = pd.DataFrame(results)
    
    # Create visualization
    plt.style.use('seaborn-v0_8-white')
    plt.rcParams['ytick.left'] = True
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Calculate bar positions and width
    x = np.arange(len(results))
    width = 0.8
    
    # Create bars with different styles based on logit type
    vanilla_bar = None
    robust_bar = None
    
    for i, row in results_df.iterrows():
        color = 'lightcoral' if row['logit_type'] == 'robust' else 'lightskyblue'
        hatch = '//' if row['logit_type'] == 'robust' else ''
        
        bar = ax.bar(i, row['mean_auc'], width,
                    yerr=[[row['mean_auc'] - row['ci_lower']], [row['ci_upper'] - row['mean_auc']]],
                    capsize=5,
                    color=color,
                    edgecolor='black',
                    linewidth=2,
                    error_kw=dict(lw=2, capthick=2),
                    hatch=hatch)
        
        # Store reference to bars for legend
        if row['logit_type'] == 'vanilla' and vanilla_bar is None:
            vanilla_bar = bar
        elif row['logit_type'] == 'robust' and robust_bar is None:
            robust_bar = bar
    
    # Create legend handles
    legend_elements = []
    if vanilla_bar is not None:
        legend_elements.append(vanilla_bar)
    if robust_bar is not None:
        legend_elements.append(robust_bar)
    
    # Add legend if we have both types
    if len(legend_elements) > 1:
        ax.legend(legend_elements, ['Non-adversarially trained models', 'Adversarially trained models'],
                 loc='upper left', frameon=False, framealpha=1.0, fontsize=14)
    
    # Customize plot
    ax.set_ylabel('AUC (predict human correct or not)', 
                 fontsize=16, fontweight='bold')
    ax.set_xlabel('Models', fontsize=16, fontweight='bold')
    
    # Set tick parameters
    ax.tick_params(axis='both',which='major', labelsize=14)
    ax.tick_params(axis='y', which='major', length=6, width=2)
    
    ax.set_ylim([0.5, 0.75])
    
    # Set x-tick positions and labels
    plt.xticks(x, results_df['model'], rotation=45, ha='right')
    
    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, format="pdf", bbox_inches="tight")
    
    # Print numerical results
    print("\nNumerical Results:")
    print("-" * 50)
    for _, row in results_df.iterrows():
        print(f"\n{row['model']} ({row['logit_type']} logits):")
        print(f"  Mean AUC: {row['mean_auc']:.3f}")
        print(f"  95% CI: [{row['ci_lower']:.3f}, {row['ci_upper']:.3f}]")
        print(f"  N samples: {row['n_samples']}")
    
    return results_df, fig, ax

In [None]:
models = [
  {"model_name": "Vanilla RN50", "logit_df_path": "psych_data/dataset_dirmaps/imagenet_animals_different_models/dataset_dirmap_logits_resnet50.csv"},
  {"model_name": "CutMix RN50", "logit_df_path": "psych_data/dataset_dirmaps/imagenet_animals_different_models/dataset_dirmap_logits_resnet50_cutmix.csv"},
  {"model_name": "ϵ = 1 RN50", "logit_df_path": "psych_data/dataset_dirmaps/imagenet_animals_different_models/dataset_dirmap_logits_resnet50_eps1.csv"},
  {"model_name": "ϵ = 3 RN50", "logit_df_path": "psych_data/dataset_dirmaps/imagenet_animals_different_models/dataset_dirmap_logits_resnet50_eps3.csv"},
  {"model_name": "ϵ = 10 RN50", "logit_df_path": "psych_data/dataset_dirmaps/imagenet_animals_different_models/dataset_dirmap_logits_resnet50_eps10.csv"},
  {"model_name": "ϵ = 4 XCiT", "logit_df_path": "psych_data/dataset_dirmaps/imagenet_animals_different_models/dataset_dirmap_logits_xcit_large.csv"},
]

analyze_model_logits(df_combined_i16_nat_strict, models, output_path='notebooks/fig_outputs/logit_prediction_model_comparison.pdf')

### Appendix Fig. S8: analysis of image difficulty predictions from epoch-wise robust model checkpoints

In [None]:
def analyze_epoch_logits(nat_df, dirmap_path, accuracy_csv_path, output_path='notebooks/fig_outputs/epoch_logit_prediction.pdf'):
    """
    Analyze how model logits from different training epochs predict human performance.
    
    Parameters:
    -----------
    nat_df : pandas.DataFrame
        DataFrame containing human response data
    dirmap_path : str
        Path to CSV containing epoch logit values (epoch0_gt_logit, epoch1_gt_logit, etc.)
    accuracy_csv_path : str
        Path to CSV containing model training and validation accuracy per epoch
    output_path : str
        Path where to save the resulting plot
        
    Returns:
    --------
    tuple: (DataFrame with numerical results, matplotlib figure)
    """
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import cross_val_score, StratifiedKFold
    from sklearn.preprocessing import StandardScaler
    
    # Bootstrap CI calculation function (same as in original)
    def bootstrap_ci(data, num_bootstrap_samples=10000, confidence_level=0.95):
        bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                  for _ in range(num_bootstrap_samples)])
        return np.percentile(bootstrap_means, 
                           [(1 - confidence_level) / 2 * 100, 
                            (1 + confidence_level) / 2 * 100])
    
    # Load dataframes
    logit_df = pd.read_csv(dirmap_path)
    accuracy_df = pd.read_csv(accuracy_csv_path)
    
    # Ensure accuracy dataframe is sorted by epoch
    accuracy_df = accuracy_df.sort_values('epoch').reset_index(drop=True)
    
    # Identify epoch columns in the dirmap file
    epoch_columns = [col for col in logit_df.columns if col.startswith('epoch') and col.endswith('_gt_logit')]
    
    # Check if we found any epoch columns
    if not epoch_columns:
        raise ValueError("No epoch_*_gt_logit columns found in the dirmap CSV file.")
    
    # Check if accuracy_df has the required columns
    required_columns = ['epoch', 'train_acc', 'val_acc']
    missing_columns = [col for col in required_columns if col not in accuracy_df.columns]
    if missing_columns:
        raise ValueError(f"Accuracy CSV missing required columns: {missing_columns}")
    
    # Sort the epoch columns by epoch number
    epoch_columns.sort(key=lambda x: int(x.replace('epoch', '').replace('_gt_logit', '')))
    
    # Store results for each epoch
    results = []
    
    # Process each epoch's logits
    for epoch_col in epoch_columns:
        epoch_num = int(epoch_col.replace('epoch', '').replace('_gt_logit', ''))
        print(f"\nAnalyzing epoch {epoch_num}...")
        
        # Create temporary dataframe for analysis
        temp_df = nat_df.copy()
        
        # Map logit values for this epoch
        logit_values = []
        
        for _, row in temp_df.iterrows():
            # Find matching URL in logit_df
            url = replace_bucket_name_in_url(row["stimulus_image_url"].split("?")[0], 'morgan-imagenet16')
            matching_row = logit_df[logit_df['url'] == url]
            
            if len(matching_row) > 0:
                logit_value = matching_row[epoch_col].values[0]
            else:
                logit_value = np.nan
                
            logit_values.append(logit_value)
        
        temp_df['logit'] = logit_values
        
        # Remove rows with missing values
        mask = ~np.isnan(temp_df['logit'])
        X = temp_df[mask]['logit'].values.reshape(-1, 1)
        y = temp_df[mask]['perf'].values
        
        # Print class distribution
        class_counts = np.bincount(y.astype(int))
        print(f"Class distribution - {class_counts[0]} incorrect, {class_counts[1]} correct")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Perform cross-validation
        cv = StratifiedKFold(n_splits=500, shuffle=True, random_state=42)
        log_reg = LogisticRegression(max_iter=1000)
        auc_scores = cross_val_score(log_reg, X_scaled, y, cv=cv, scoring='roc_auc')
        
        # Calculate statistics
        mean_auc = np.mean(auc_scores)
        ci_lower, ci_upper = bootstrap_ci(auc_scores)
        
        # Store results
        results.append({
            'epoch': epoch_num,
            'mean_auc': mean_auc,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper,
            'n_samples': len(X)
        })
    
    # Create results DataFrame
    results_df = pd.DataFrame(results)
    
    # Create visualization with two y-axes
    plt.style.use('seaborn-v0_8-white')
    
    # Enable tick marks on all axes
    plt.rcParams['ytick.left'] = True
    plt.rcParams['ytick.right'] = True
    plt.rcParams['xtick.bottom'] = True
    
    fig, ax1 = plt.subplots(figsize=(12, 6))
    
    # Primary y-axis: AUC for human performance prediction
    ax1.set_xlabel('Training epoch', fontsize=18, fontweight='bold')
    ax1.set_ylabel('AUC (predict human error rate)', fontsize=18, fontweight='bold')
    
    # Plot AUC values with confidence intervals
    line1, = ax1.plot(results_df['epoch'], results_df['mean_auc'], 'o-', color='navy', 
             linewidth=2, markersize=5, label='Human error rate prediction (AUC)')
    ax1.fill_between(results_df['epoch'], 
                     results_df['ci_lower'], 
                     results_df['ci_upper'], 
                     alpha=0.2, color='navy')
    
    # Set y-axis range for AUC
    ax1.set_ylim([0.5, 0.75])
    
    # Secondary y-axis: Model training/validation accuracy
    ax2 = ax1.twinx()
    ax2.set_ylabel('ImageNet classification accuracy', fontsize=18, fontweight='bold')
    
    # Plot training and validation accuracy with updated labels
    line2, = ax2.plot(accuracy_df['epoch'], accuracy_df['train_acc'], 's-', color='forestgreen', 
             linewidth=2, markersize=5, label='ImageNet training accuracy')
    line3, = ax2.plot(accuracy_df['epoch'], accuracy_df['val_acc'], '^-', color='crimson', 
             linewidth=2, markersize=6, label='ImageNet validation accuracy')
    
    # Set y-axis range for model accuracy
    ax2.set_ylim([0, 1.0])
    
    # Combine legends from both axes
    ax1.legend([line1, line2, line3], 
           ['Human error rate prediction (AUC)', 
            'ImageNet training accuracy', 
            'ImageNet validation accuracy'], 
           loc='lower right', frameon=True, fontsize=16)  # Larger font and bottom right position
    
    # Increase tick label font sizes
    ax1.tick_params(axis='both', which='major', labelsize=16)
    ax2.tick_params(axis='both', which='major', labelsize=16)
    ax1.tick_params(axis='both', which='both', length=6, width=1.5)
    ax2.tick_params(axis='both', which='both', length=6, width=1.5)
    
    # Style customization
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(True)
    ax1.spines['left'].set_linewidth(2)
    ax1.spines['bottom'].set_linewidth(2)
    ax2.spines['right'].set_linewidth(2)
    
    # Add grid for easier reading
    ax1.grid(True, axis='y', linestyle='--', alpha=0.3)
    
    # Adjust x-axis to show integer ticks at appropriate intervals
    max_epoch = max(results_df['epoch'])
    if max_epoch <= 20:
        tick_spacing = 1
    elif max_epoch <= 50:
        tick_spacing = 5
    else:
        tick_spacing = 10
        
    ax1.set_xticks(range(0, max_epoch + 1, tick_spacing))

    plt.rcParams['xtick.direction'] = 'out'
    plt.rcParams['ytick.direction'] = 'out'
    plt.rcParams['xtick.major.size'] = 6
    plt.rcParams['ytick.major.size'] = 6
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    
    # Print numerical results
    print("\nNumerical Results:")
    print("-" * 50)
    for _, row in results_df.iterrows():
        print(f"\nEpoch {row['epoch']}:")
        print(f"  Mean AUC: {row['mean_auc']:.3f}")
        print(f"  95% CI: [{row['ci_lower']:.3f}, {row['ci_upper']:.3f}]")
        print(f"  N samples: {row['n_samples']}")
    
    return results_df, fig

In [None]:
analyze_epoch_logits(
    nat_df=df_combined_i16_nat_strict, 
    dirmap_path="psych_data/imagenet_animals_model_epochs_analysis/imagenet_animals_90epochs_logits_dirmap.csv", 
    accuracy_csv_path="psych_data/imagenet_animals_model_epochs_analysis/epoch_wise_acc.csv", 
    output_path='notebooks/fig_outputs/epoch_logit_prediction.pdf'
)

### Appendix Fig. S9 (comparing image enhancement with different guide models)

In [None]:
from matplotlib.ticker import AutoMinorLocator

def create_model_comparison_plot(df, normalize=False, output_path=None):
    """
    Create a bar plot comparing model performances with confidence intervals.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        DataFrame containing columns: 'participant', 'trial_type', 'perf'
    normalize : bool, default=False
        Whether to normalize performance against natural trial performance
    output_path : str, optional
        Path to save the figure. If None, figure is only displayed
    """
    # Calculate normalized performance if requested
    if normalize:
        natural_means = df[df['trial_type'] == 'natural'].groupby('participant')['perf'].mean()
        df = df.copy()  # Create copy to avoid modifying original
        df['normalized_perf'] = df.apply(
            lambda row: row['perf'] / natural_means[row['participant']], 
            axis=1
        )
        perf_metric = 'normalized_perf'
    else:
        perf_metric = 'perf'

    # Define trial types to include
    trial_types = [
        'natural',
        'enhanced_vanilla_resnet50',
        'enhanced_cutmix_resnet50',
        'enhanced_eps1_resnet50',
        'enhanced_eps3_resnet50',
        'enhanced_eps10_resnet50',
        'enhanced_xcit_augmented',
    ]

    # Filter and calculate group statistics
    mask = df['trial_type'].isin(trial_types)
    accuracy_by_group = df[mask].groupby(['participant', 'trial_type'])[perf_metric].mean()

    # Calculate means and confidence intervals
    def bootstrap_ci(data, num_samples=10000, ci=0.95):
        if len(data) < 2:  # Check for insufficient data
            return np.nan, np.nan
        bootstrap_means = np.random.choice(data, (num_samples, len(data)), replace=True).mean(axis=1)
        return np.percentile(bootstrap_means, [(1-ci)/2*100, (1+ci)/2*100])

    results = []
    for trial_type in trial_types:
        data = accuracy_by_group.xs(trial_type, level='trial_type')
        mean_val = data.mean()
        ci_lower, ci_upper = bootstrap_ci(data)
        results.append({
            'trial_type': trial_type,
            'mean': mean_val,
            'ci_lower': ci_lower,
            'ci_upper': ci_upper
        })
    
    result_df = pd.DataFrame(results)

    # Plotting
    plt.style.use('seaborn-v0_8-white')
    sns.set_palette("deep")
    plt.rcParams['ytick.left'] = True
    
    # Create figure with higher DPI
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)

    # Define x positions and width
    x = np.arange(len(trial_types))
    width = 0.6

    # Calculate error bar heights
    means = result_df['mean']
    yerr = np.array([
        means - result_df['ci_lower'],
        result_df['ci_upper'] - means
    ])

    # Create bars
    bars = ax.bar(x, means, width, 
                 yerr=yerr, 
                 capsize=10,
                 alpha=0.8, 
                 edgecolor='black', 
                 linewidth=2,
                 error_kw={'elinewidth': 2, 'capthick': 2})

    # Customize labels and formatting
    ylabel = 'Normalized rate humans choose\nground truth [%]' if normalize else 'Rate humans choose\nground truth [%]'
    ax.set_ylabel(ylabel, fontsize=16, fontweight='bold')
    ax.set_xlabel("Guide Models", fontsize=16, fontweight='bold')

    # Define and set x-tick labels
    xtick_labels = {
        'natural': 'Original',
        'enhanced_vanilla_resnet50': 'Vanilla RN50',
        'enhanced_cutmix_resnet50': 'CutMix RN50',
        'enhanced_eps1_resnet50': 'ϵ = 1 RN50',
        'enhanced_eps3_resnet50': 'ϵ = 3 RN50',
        'enhanced_eps10_resnet50': 'ϵ = 10 RN50',
        'enhanced_xcit_augmented': 'ϵ = 4 XCiT'
    }

    ax.set_xticks(x)
    ax.set_xticklabels(
        [xtick_labels[t] for t in trial_types],
        fontsize=14,
        rotation=45,
        ha='right'
    )

    # Set axis limits and format
    if normalize:
        ax.set_ylim(0.8, 1.2)
    else:
        ax.set_ylim(0.7, 0.95)

    # Add minor ticks
    ax.yaxis.set_minor_locator(AutoMinorLocator())

    # Add reference line for natural condition
    natural_mean = means[result_df['trial_type'] == 'natural'].iloc[0]
    ax.axhline(
        y=natural_mean,
        color='red',
        linestyle=':',
        linewidth=2,
        label='Mean acc. on unmodified images'
    )

    # Customize spines and ticks
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.tick_params(axis='y', which='major', length=6, width=2, labelsize=14)
    ax.tick_params(axis='y', which='both', right=False)

    # Add legend
    ax.legend(fontsize=14, frameon=False, edgecolor='black', loc='upper left')

    # Adjust layout
    plt.tight_layout()
    
    # Save if path provided
    if output_path:
        # Ensure the figure is rendered before saving
        fig.canvas.draw()
        plt.savefig(output_path, dpi=300, format='pdf', bbox_inches='tight')
    else:
        plt.show()

    return fig, ax

create_model_comparison_plot(df_backbone_compare_i16, normalize=False, output_path='notebooks/fig_outputs/guide_model_enhancement_comparison.pdf')

### Supplementary Fig. S11 (ablation study on image enhancement with ImageNet animal images)

In [None]:
plt.style.use('seaborn-v0_8-white')
sns.set_palette("deep")

NORMALIZE = False
df_s8 = df_loss_ablation_i16.copy()

# First, calculate the mean 'perf' for each participant on 'natural' trials
natural_means = df_s8[df_s8['trial_type'] == 'natural'].groupby('participant')['perf'].mean()

# Create a function to normalize performance
def normalize_performance(row):
    return row['perf'] / natural_means[row['participant']]

# Apply the normalization to all rows
df_s8['normalized_perf'] = df_s8.apply(normalize_performance, axis=1)
perf_metric = 'normalized_perf' if NORMALIZE else 'perf'

# Step 1: Filter out 'calibration' and 'new_stimulus' trial types
filtered_df_s8 = df_s8[~df_s8['trial_type'].isin(['calibration', 'new_stimulus'])]

trial_types_s8 = np.flip(filtered_df_s8['trial_type'].unique())

# Step 2: Group by trial_type and split, and calculate mean accuracy and bootstrap CI for each group
def calculate_bootstrap_for_group(group):
    mean_val = group[perf_metric].mean()
    ci_lower, ci_upper = bootstrap_ci(group[perf_metric])
    return pd.Series({'mean': mean_val, 'ci_lower': ci_lower, 'ci_upper': ci_upper})

# Apply the function to each group
grouped_results = filtered_df_s8.groupby(['trial_type', 'split']).apply(calculate_bootstrap_for_group).reset_index()

pivot_result = grouped_results.pivot(index='split', columns='trial_type', values=['mean', 'ci_lower', 'ci_upper'])

# Increase the figure size and DPI for better quality
fig, ax = plt.subplots(figsize=(12, 5), dpi=300)

trial_type_mapping = {
    'natural': 'Natural',
    'enhanced_logit': 'Enhanced (Logit)',
    'enhanced_cross_entropy': 'Enhanced (Cross-entropy)',
}
trial_types_disp = [trial_type_mapping.get(t, t) for t in trial_types_s8]

splits = ["val", "train"]

split_label_mapping = {
    'val': 'ImageNet Validation Set',
    'train': 'ImageNet Training Set',
}

# Define hatch patterns for texture coding
hatch_patterns = ['', '///', '...']

# Plot bars for each trial type, grouped by split
x = np.arange(len(splits))
width = 0.25  # Adjust bar width
for i, trial_type in enumerate(trial_types_s8):
    means = pivot_result['mean'][trial_type].values
    ci_lower = pivot_result['ci_lower'][trial_type].values
    ci_upper = pivot_result['ci_upper'][trial_type].values
    yerr = np.array([means - ci_lower, ci_upper - means])
    ax.bar(x + (i - 1)*width, means, width, yerr=yerr, capsize=5, 
           label=trial_types_disp[i], alpha=0.8, edgecolor='black', linewidth=2, 
           error_kw={'elinewidth': 2, 'capthick': 2}, hatch=hatch_patterns[i])

# Customize the plot
if NORMALIZE:
    ylabel = 'Normalized Mean Accuracy'
else:
    ylabel = 'Mean Accuracy'
ax.set_ylabel(ylabel, fontsize=20, fontweight='bold')
ax.legend(fontsize=16, frameon=True, edgecolor='black', loc='upper left')
ax.grid(axis='y', linestyle='--', alpha=0.7, color='gray')

# Set x-ticks and labels
ax.set_xticks(x)
ax.set_xticklabels([split_label_mapping[split] for split in splits], fontsize=20, fontweight='bold')

# Increase tick label size
ax.tick_params(axis='both', which='major', labelsize=20)

# Add a horizontal line at y=1 to represent the baseline (natural trial type)
if NORMALIZE:
    ax.axhline(y=1, color='red', linestyle='--', linewidth=2, label='Natural Trial Baseline')

# Set y-axis limits
if NORMALIZE:
    ax.set_ylim(0.8, 1.2)
else:
    ax.set_ylim(0.675, 0.875)

# Add subtle spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)

# Adjust layout and display the plot
plt.tight_layout()
#plt.show()

# Optionally, save the figure as a high-resolution image
plt.savefig('notebooks/fig_outputs/accuracy_by_split_plot.pdf', dpi=300, format='pdf', bbox_inches='tight')

In [None]:
enhance_type_for_interaction_plot = 'enhanced_logit_20'

# Select only 'natural' and a enhanced trial types with specific enhancement level
df_selected_s8b = df_main_i16[df_main_i16['trial_type'].isin(['natural', enhance_type_for_interaction_plot])]

# Replace with the robust_gt_logit for the ORIGINAL image for this analysis (not for the enhanced image)
df_selected_s8b["robust_gt_logit"] = df_selected_s8b.apply(lambda row: find_value_by_url(replace_bucket_name_in_url(row["stimulus_image_url"].split("?")[0], "morgan-imagenet16"), logit_df, "robust_gt_logit"), axis=1)

# Function to calculate difference in performance and its CI
def calc_perf_diff_with_ci(natural_data, enhanced_data):
    natural_mean = np.mean(natural_data['perf'])
    enhanced_mean = np.mean(enhanced_data['perf'])
    diff = enhanced_mean - natural_mean
    
    # Bootstrap for confidence interval
    diff_samples = []
    for _ in range(10000):
        natural_sample = natural_data['perf'].sample(n=len(natural_data), replace=True)
        enhanced_sample = enhanced_data['perf'].sample(n=len(enhanced_data), replace=True)
        diff_samples.append(np.mean(enhanced_sample) - np.mean(natural_sample))
    
    ci_lower, ci_upper = np.percentile(diff_samples, [2.5, 97.5])
    return diff, ci_lower, ci_upper

# Calculate quartiles for robust_gt_logit
quartiles = df_selected_s8b['robust_gt_logit'].quantile([0.25, 0.5, 0.75])

print("quartiles:")
print(quartiles)

# Initialize lists to store results
quartile_labels = ['Q1', 'Q2', 'Q3', 'Q4']
diff_means = []
diff_ci_lower = []
diff_ci_upper = []

# Calculate differences for each quartile
for i in range(4):
    if i == 0:
        mask = df_selected_s8b['robust_gt_logit'] <= quartiles.iloc[0]
    elif i == 3:
        mask = df_selected_s8b['robust_gt_logit'] > quartiles.iloc[2]
    else:
        mask = (df_selected_s8b['robust_gt_logit'] > quartiles.iloc[i-1]) & (df_selected_s8b['robust_gt_logit'] <= quartiles.iloc[i])
    
    natural_data = df_selected_s8b[(df_selected_s8b['trial_type'] == 'natural') & mask]
    enhanced_data = df_selected_s8b[(df_selected_s8b['trial_type'] == enhance_type_for_interaction_plot) & mask]
    
    diff, ci_lower, ci_upper = calc_perf_diff_with_ci(natural_data, enhanced_data)
    diff_means.append(diff)
    diff_ci_lower.append(ci_lower)
    diff_ci_upper.append(ci_upper)

# Plotting
fig, ax = plt.subplots(figsize=(8, 5), dpi=300)
orange_color = sns.color_palette('deep')[1]  # Get the orange color

# Create the bar plot with edge color
bars = ax.bar(quartile_labels, diff_means, color=orange_color, width=0.9, 
               edgecolor='black', linewidth=2, hatch='///')

# Add error bars with increased linewidth
ax.errorbar(quartile_labels, diff_means, 
             yerr=[np.array(diff_means) - np.array(diff_ci_lower), 
                   np.array(diff_ci_upper) - np.array(diff_means)],
             fmt='none', color='black', capsize=5, capthick=2, elinewidth=2)

# Customize the plot
plt.xlabel('Quartiles of Groundtruth Logit', fontsize=20, fontweight='bold')
plt.ylabel('Δ Accuracy ϵ = 0 → ϵ = 20', fontsize=20, fontweight='bold')
plt.ylim([0, 0.30])
plt.tick_params(axis='both', which='major', labelsize=16)
ax.grid(axis='y', linestyle='--', alpha=0.7, color='gray')

# Add subtle spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)

plt.tight_layout()
plt.savefig("notebooks/fig_outputs/quartile_difference_plot.pdf", format="pdf", dpi=300, bbox_inches="tight")
plt.show()

# Print out the numerical results
for i, label in enumerate(quartile_labels):
    print(f"{label}: Difference = {diff_means[i]:.3f}, 95% CI: [{diff_ci_lower[i]:.3f}, {diff_ci_upper[i]:.3f}]")

## Dropout probability analysis (Appendix Section S9)

### Dropout analysis for idaea4 experiment

Total dropouts: 9

Dropouts in control condition: 6

Probability of being assigned to control: 2/7 ~= 0.2857

In [None]:
!python scripts/dropout_statistics.py --total 9 --condition 6 --n_conditions 6 --p_condition 0.2857 -v

### Dropout analysis for ham4 experiment

Total dropouts: 13

Dropouts in control condition: 6

(Note: 4 additional dropouts had been assigned to the shuffled enhancement taper condition)

Probability of being assigned to control: (227/317)(1/6)  +  (90/317)(1/3) = 0.214

In [None]:
!python scripts/dropout_statistics.py --total 13 --condition 6 --n_conditions 6 --p_condition 0.214 -v

## Demographics table generation (Appendix Section S11)

In [None]:
# Helper functions for demographics calculations

def clean_age(age, participant_id=None):
    """
    Clean age values, handling invalid entries.
    
    Parameters:
    age: Age value that could be string or numeric
    participant_id: Identifier for the participant (for warning messages)
    
    Returns:
    float or None: Cleaned age value or None if invalid
    """
    if pd.isna(age):
        if participant_id is not None:
            warnings.warn(f"Missing age value for participant {participant_id}")
        return None
    
    try:
        # Convert to string first to handle any numeric types
        age_str = str(age).strip()
        # Convert to float and check if it's a reasonable age
        age_float = float(age_str)
        # Basic validation: age should be between 0 and 120
        if 0 <= age_float <= 120:
            return age_float
        else:
            if participant_id is not None:
                warnings.warn(f"Age value {age_float} outside reasonable range (0-120) for participant {participant_id}")
            return None
    except (ValueError, TypeError):
        if participant_id is not None:
            warnings.warn(f"Invalid age value '{age}' for participant {participant_id}")
        return None


def demographic_analysis(trial_data, demog_data, trial_key='worker_id', demog_key='Participant id'):
    """
    Analyze demographic characteristics of study participants with additional diagnostics for missing data.
    """
    # Get unique worker IDs from trial data
    unique_workers = trial_data[trial_key].unique()
    
    # Diagnostic information
    print(f"\nDiagnostic Information:")
    print(f"Total unique worker_ids in trial_data: {len(unique_workers)}")
    print(f"Total rows in demographic data: {len(demog_data)}")
    
    # Check which worker_ids don't have matching demographic data
    missing_demographics = set(unique_workers) - set(demog_data[demog_key])
    print(f"Number of workers missing from demographic data: {len(missing_demographics)}")
    
    if len(missing_demographics) > 0:
        print("\nSample of missing worker_ids (up to 5):")
        for worker_id in list(missing_demographics)[:5]:
            print(f"  {worker_id}")
            
    # Check for any demographic data that doesn't match trial participants
    extra_demographics = set(demog_data['Participant id']) - set(unique_workers)
    if len(extra_demographics) > 0:
        print(f"\nNumber of demographic entries not in trial_data: {len(extra_demographics)}")
    
    # Original analysis continues...
    study_demog = demog_data[demog_data[demog_key].isin(unique_workers)].copy()
    
    # Clean age data
    study_demog['Age_clean'] = study_demog.apply(
        lambda row: clean_age(row['Age'], row[demog_key]), 
        axis=1
    )
    
    # Clean Sex and Ethnicity data
    study_demog['Sex'] = study_demog['Sex'].replace({
        'Prefer not to say': 'Not specified',
        'DATA_EXPIRED': 'Not specified'
    })
    
    study_demog['Ethnicity simplified'] = study_demog['Ethnicity simplified'].replace({
        'Prefer not to say': 'Not specified',
        'DATA_EXPIRED': 'Not specified'
    })
    
    results = {
        'n_participants': len(unique_workers),
        'n_with_demographics': len(study_demog),
        'age': {},
        'sex': {},
        'ethnicity': {}
    }
    
    # Age analysis (using cleaned age data)
    valid_ages = study_demog['Age_clean'].dropna()
    if len(valid_ages) > 0:
        results['age'] = {
            'n': len(valid_ages),
            'mean': np.round(valid_ages.mean(), 1),
            'std': np.round(valid_ages.std(), 1),
            'range': f"{int(valid_ages.min())}-{int(valid_ages.max())}",
            'missing': len(study_demog) - len(valid_ages)
        }
    else:
        results['age'] = {
            'n': 0,
            'mean': None,
            'std': None,
            'range': None,
            'missing': len(study_demog)
        }
    
    # Sex distribution
    sex_counts = study_demog['Sex'].value_counts()
    sex_percentages = (sex_counts / len(study_demog) * 100).round(1)
    results['sex'] = {
        category: f"{count} ({percentage}%)"
        for category, count, percentage in zip(
            sex_counts.index,
            sex_counts.values,
            sex_percentages.values
        )
    }
    
    # Ethnicity distribution
    ethnicity_counts = study_demog['Ethnicity simplified'].value_counts()
    ethnicity_percentages = (ethnicity_counts / len(study_demog) * 100).round(1)
    results['ethnicity'] = {
        category: f"{count} ({percentage}%)"
        for category, count, percentage in zip(
            ethnicity_counts.index,
            ethnicity_counts.values,
            ethnicity_percentages.values
        )
    }
    
    # Generate formatted output
    output = f"""Demographic Characteristics (Total N = {results['n_participants']})
Available Demographic Data for {results['n_with_demographics']} participants ({round(results['n_with_demographics']/results['n_participants']*100, 1)}% of sample)

Age (n = {results['age']['n']})"""
    
    if results['age']['mean'] is not None:
        output += f"""
  Mean (SD): {results['age']['mean']} ({results['age']['std']}) years
  Range: {results['age']['range']} years"""
    
    if results['age']['missing'] > 0:
        output += f"\n  Missing: {results['age']['missing']}"
    
    output += f"""

Sex
{chr(10).join(f"  {category}: {count}" for category, count in results['sex'].items())}

Ethnicity
{chr(10).join(f"  {category}: {count}" for category, count in results['ethnicity'].items())}"""
    
    return output, results


def demog_results_to_latex(results):
    """
    Convert demographic analysis results to a LaTeX table with specific formatting requirements.
    
    Parameters:
    results (dict): Results dictionary from demographic_analysis function
    
    Returns:
    str: LaTeX formatted table
    """
    def clean_value(value):
        """Helper function to clean values"""
        value = str(value)
        value = value.replace("%", "\\%") # Ensure all percentages use escaped character
        return value
    
    # Calculate percentage of sample with demographics
    n_with_demos = results['n_with_demographics']
    total_n = results['n_participants']
    demo_percent = round(n_with_demos/total_n*100, 1)
    
    latex = r"""
\begin{table}[htbp]
\centering
\begin{tabular}{ll}"""
    
    # Total participants and demographic coverage
    latex += f"\nTotal participants & {total_n} \\\\"
    latex += f"\nPts. w/ demographic data & {n_with_demos} ({demo_percent}\\%) \\\\"
    
    # Age section
    if results['age']['mean'] is not None:
        latex += "\nAge & \\\\"
        latex += f"\quad Mean (SD) & {results['age']['mean']} ({results['age']['std']}) years \\\\"
        latex += f"\quad Range & {results['age']['range']} years \\\\"
        if results['age']['missing'] > 0:
            missing_percent = round(results['age']['missing']/n_with_demos*100, 1)
            latex += f"\quad Missing & {results['age']['missing']} ({missing_percent}\\%) \\\\"
    
    # Sex section
    latex += "\nSex & \\\\"
    for category, value in results['sex'].items():
        category = clean_value(category)
        value = clean_value(value)
        latex += f"\quad {category} & {value} \\\\"
    
    # Ethnicity section
    latex += "\nEthnicity & \\\\"
    for category, value in results['ethnicity'].items():
        category = clean_value(category)
        value = clean_value(value)
        latex += f"\quad {category} & {value} \\\\"
    
    # Close the table
    latex += r"""
\end{tabular}
\caption{Demographic Characteristics of Study Participants}
\label{tab:demographics}
\end{table}
"""
    
    return latex

In [None]:
# Demographics for all of the experiments combined
# You can generate an experiment-specific table with (for example) "demog_output, demog_results = demographic_analysis(df_idaea4, demog_idaea4)"

if DEIDENTIFIED_DATA:
  print("Demographic tables cannot be re-generated from the de-identified dataset. Please see the demographics table in the paper.")
else:
  # Load ImageNet16 demographics
  demog_imagenet16 = pd.read_csv("results/demographics_imagenet16_v1_mod_2.csv")
  demog_imagenet16 = demog_imagenet16[demog_imagenet16["Status"] == "APPROVED"]

  demog_imagenet16_loss_ablation = pd.read_csv("results/demographics_imagenet16_v1_mod_1.csv")
  demog_imagenet16_loss_ablation = demog_imagenet16_loss_ablation[demog_imagenet16_loss_ablation["Status"] == "APPROVED"]

  demog_imagenet16_guide_comparison = pd.read_csv("results/demographics_imagenet16_v1_mod_4.csv")
  demog_imagenet16_guide_comparison = demog_imagenet16_guide_comparison[demog_imagenet16_guide_comparison["Status"] == "APPROVED"]

  # Load idaea4 demographics
  demog_idaea4_learn_1 = pd.read_csv("results/demographics_idaea4_learn_1.csv")
  demog_idaea4_learn_2 = pd.read_csv("results/demographics_idaea4_learn_2.csv")
  demog_idaea4 = pd.concat([demog_idaea4_learn_1, demog_idaea4_learn_2])
  demog_idaea4 = demog_idaea4[demog_idaea4["Status"] == "APPROVED"]
  assert demog_idaea4["Participant id"].is_unique, "There are duplicate Participant id (worker_id) values in the dataframe"

  # Load HAM4 demographics
  demog_ham4_learn_4 = pd.read_csv("results/demographics_ham4_learn_4.csv")
  demog_ham4_learn_5 = pd.read_csv("results/demographics_ham4_learn_5.csv")
  demog_ham4 = pd.concat([demog_ham4_learn_4, demog_ham4_learn_5])
  demog_ham4 = demog_ham4[demog_ham4["Status"] == "APPROVED"]
  assert demog_ham4["Participant id"].is_unique, "There are duplicate Participant id (worker_id) values in the dataframe"

  # Load MHIST demographics
  demog_mhist = pd.read_csv("results/demographics_mhist_learn_1.csv")
  demog_mhist = demog_mhist[demog_mhist["Status"] == "APPROVED"]
  assert demog_mhist["Participant id"].is_unique, "There are duplicate Participant id (worker_id) values in the dataframe"

  # Combine demographics across experiments
  demog_all = pd.concat([demog_imagenet16, demog_imagenet16_loss_ablation, demog_imagenet16_guide_comparison, demog_idaea4, demog_ham4, demog_mhist])

  demog_unique = demog_all.drop_duplicates(subset=['Participant id'], keep='first')

  worker_ids_all = pd.concat([df_main_i16[["worker_id"]], df_loss_ablation_i16[["worker_id"]], df_backbone_compare_i16[["worker_id"]], df_idaea4[["worker_id"]], df_ham4[["worker_id"]], df_mhist[["worker_id"]]])

  demog_output, demog_results = demographic_analysis(worker_ids_all, demog_unique)

  print(demog_output)
  print("----------")
  print(demog_results_to_latex(demog_results))