In [55]:
import pandas as pd
import glob
import os
import sys
sys.path.append('../src')
from sharedcontrolpaper.force_sensitive_stopping_task_utils import get_subject_label
from sharedcontrolpaper.simple_stop_utils import compute_SSRT, preprocess_stop_data, analyze_violations

In [56]:
parent_directory = os.path.dirname(os.getcwd())
data_path = os.path.join(parent_directory, 'data', 'experiment')
exp_stage = "final/"
task = "simple_stop"

pattern = os.path.join(data_path, exp_stage, '*', task, '*.csv')

data_files = glob.glob(pattern)

In [57]:
task_metrics = {}
all_violations_df = pd.DataFrame()
for file in data_files:
    subject_label = get_subject_label(file)
    
    df = pd.read_csv(file)

    if subject_label == 's019':
        df = df[df["Block"] != 2]
    # Preprocess data
    df = preprocess_stop_data(df)
    
    # Analyze violations and aggregate results
    violations_df = analyze_violations(df)
    all_violations_df = pd.concat([all_violations_df, violations_df], ignore_index=True)

    # Calculate other task metrics
    go_rt = df.query("trialType == 'go'").rt.mean()
    stop_fail_rt = df.query("trialType == 'stop'").rt.mean()
    go_acc = df.query("trialType == 'go'").go_acc.mean()
    stop_fail_acc = df.query("trialType == 'stop'").stop_failure_acc.mean()
    stop_success = df.query("trialType == 'stop'").stop_acc.mean()
    stop_fail_rate = 1 - stop_success
    avg_ssd = df.query("trialType == 'stop'").ssd.mean()
    min_ssd = df.query("trialType == 'stop'").ssd.min()
    max_ssd = df.query("trialType == 'stop'").ssd.max()
    min_ssd_count = (df.query("trialType == 'stop'").ssd == 0).sum()
    max_ssd_count = (df.query("trialType == 'stop'").ssd == 0.75).sum()

    ssrt = compute_SSRT(df)
    ssrt_without_short_ssd_trials = compute_SSRT(df, without_short_ssd_trials = True)
    if avg_ssd < 200:
        ssrt_without_short_ssd_subs = None
    else:
        ssrt_without_short_ssd_subs = ssrt

    
    task_metrics[subject_label] = {
        'go_rt': go_rt,
        'stop_fail_rt': stop_fail_rt,
        'go_acc': go_acc,
        'stop_fail_acc': stop_fail_acc,
        'stop_success': stop_success,
        'stop_fail_rate': stop_fail_rate,
        'avg_ssd': avg_ssd,
        'min_ssd': min_ssd,
        'max_ssd': max_ssd,
        'min_ssd_count': min_ssd_count,
        'max_ssd_count': max_ssd_count,
        'ssrt': ssrt,
        'ssrt_without_short_ssd_subs': ssrt_without_short_ssd_subs,
        'ssrt_without_short_ssd_trials': ssrt_without_short_ssd_trials,
    }
    
final_aggregated_results = all_violations_df.groupby('ssd').agg(
    count=('difference', 'count'),
    avg_difference=('difference', 'mean'),
    all_differences=('difference', lambda x: list(x))
    ).reset_index()
%store final_aggregated_results

Stored 'final_aggregated_results' (DataFrame)


In [58]:
metrics = pd.DataFrame(task_metrics).T
metrics = metrics.sort_index()
mean_row = metrics.mean(skipna=True)
sd_row = metrics.std(skipna=True)
metrics = pd.concat([metrics, mean_row.rename('mean').to_frame().T, sd_row.rename('sd').to_frame().T])
simple_stop_ssrt = metrics
%store simple_stop_ssrt

Stored 'simple_stop_ssrt' (DataFrame)
