In [1]:
import pandas as pd
import glob
import os
import json
from sharedcontrolpaper.force_sensitive_stopping_task_utils import get_subject_label, rename_index_column
from sharedcontrolpaper.simple_stop_utils import compute_SSRT, preprocess_stop_data, analyze_violations

In [2]:
parent_directory = os.path.dirname(os.getcwd())
data_path = os.path.join(parent_directory, 'data', 'experiment')
exp_stage = "final/"
task = "simple_stop"

pattern = os.path.join(data_path, exp_stage, '*', task, '*.csv')

data_files = glob.glob(pattern)

In [3]:
def process_all_data(data_files):
    """Processes all data files efficiently using pandas."""
    all_violations_df = pd.DataFrame()
    task_metrics = {}

    for file in data_files:
        subject_label = get_subject_label(file)
        df = pd.read_csv(file)

        if subject_label == 's019':
            df = df[df["Block"] != 2]

        df = preprocess_stop_data(df)
        violations_df = analyze_violations(df)

        #Efficiently filter SSDs
        ssd_counts = violations_df['ssd'].value_counts()
        violations_df = violations_df[violations_df['ssd'].isin(ssd_counts[ssd_counts >= 2].index)]
        all_violations_df = pd.concat([all_violations_df, violations_df], ignore_index=True)

        # Efficiently calculate metrics using pandas boolean indexing and groupby
        go_data = df[df['trialType'] == 'go']
        stop_data = df[df['trialType'] == 'stop']

        metrics = {
            'go_rt': go_data['rt'].mean(),
            'stop_fail_rt': stop_data['rt'].mean(),
            'go_acc': go_data['go_acc'].mean(),
            'stop_fail_acc': stop_data['stop_failure_acc'].mean(),
            'stop_success': stop_data['stop_acc'].mean(),
            'stop_fail_rate': 1 - stop_data['stop_acc'].mean(),
            'avg_ssd': stop_data['ssd'].mean(),
            'min_ssd': stop_data['ssd'].min(),
            'max_ssd': stop_data['ssd'].max(),
            'min_ssd_count': (stop_data['ssd'] == stop_data['ssd'].min()).sum(),
            'max_ssd_count': (stop_data['ssd'] == stop_data['ssd'].max()).sum(),
            'ssrt': compute_SSRT(df),
            'ssrt_without_short_ssd_trials': compute_SSRT(df, without_short_ssd_trials=True),
            'ssrt_without_short_ssd_subs': compute_SSRT(df) if stop_data['ssd'].mean() >= 200 else None,
        }
        task_metrics[subject_label] = metrics

    final_aggregated_results = all_violations_df.groupby('ssd').agg(
        count=('difference', 'count'),
        avg_difference=('difference', 'mean'),
        all_differences=('difference', list)
    ).reset_index()

    return task_metrics, final_aggregated_results
task_metrics, final_aggregated_results = process_all_data(data_files)

In [4]:
simple_stop_metrics = pd.DataFrame(task_metrics).T
simple_stop_metrics.sort_index(inplace=True)
mean_row = simple_stop_metrics.mean(skipna=True)
sd_row = simple_stop_metrics.std(skipna=True)
simple_stop_metrics = pd.concat([simple_stop_metrics, mean_row.rename('mean').to_frame().T, sd_row.rename('sd').to_frame().T])
rename_index_column(simple_stop_metrics)

In [5]:
data_to_save = {
    'final_aggregated_results': final_aggregated_results.to_dict(),
    'simple_stop_metrics': simple_stop_metrics.to_dict()
}

with open('simple_stop_data.json', 'w') as f:
    json.dump(data_to_save, f, indent=4)