In [23]:
import pandas as pd
import numpy as np
import glob
import os
import sys
sys.path.append('../src')
from sharedcontrolpaper.utils import get_subject_label, string_to_numbers, plot_trial_pressure_individual

## Generate Table 1 - Simple Stop

In [24]:
parent_directory = os.path.dirname(os.getcwd())
data_path = os.path.join(parent_directory, 'data', 'experiment')
exp_stage = "final/"
task = "simple_stop"

pattern = os.path.join(data_path, exp_stage, '*', task, '*.csv')

data_files = glob.glob(pattern)

In [25]:
def compute_SSRT(df, max_go_rt = 2):
    """
    Compute Stop Signal Reaction Time (SSRT) for the simple stop task.

    Parameters:
    - df: DataFrame containing trial data.
    - max_go_rt: Maximum reaction time for go trials to handle missing values.

    Returns:
    - SSRT: The computed Stop Signal Reaction Time.
    """

    avg_SSD = None
    
    df = df.query('Phase == "test"')
    
    go_trials = df.loc[df.trialType == 'go']
    stop_df = df.loc[df.trialType == 'stop']

    go_replacement_df = go_trials.where(~go_trials['rt'].isna(), max_go_rt)
    sorted_go = go_replacement_df.rt.sort_values(ascending = True, ignore_index=True)
    stop_failure = stop_df.loc[stop_df['rt'].notna()]
    if len(stop_df) > 0:
        p_respond = len(stop_failure)/len(stop_df)
        avg_SSD = stop_df.ssd.mean()

    nth_index = int(np.rint(p_respond*len(sorted_go))) - 1

    if nth_index < 0:
        nth_RT = sorted_go[0]
    elif nth_index >= len(sorted_go):
        nth_RT = sorted_go[-1]
    else:
        nth_RT = sorted_go[nth_index]
    
    if avg_SSD:
        SSRT = nth_RT - avg_SSD
    else:
        SSRT = None

    return SSRT

In [26]:
def preprocess_stop_data(df):
    df = df.query(f'Phase == "test"')
    df = df[['Block', 'Phase', 'trialType', 'goStim', 'correctResponse', 'ssd', 'goResp_test.keys', 'goResp_test.corr','goResp_test.rt']]
    
    df.loc[:, 'stop_acc'] = np.where(df['trialType'] == 'stop', 
                                 np.where(df['goResp_test.keys'].isnull(), 1, 0),  
                                 np.nan)

    df.loc[:, 'go_acc'] = np.where(df['trialType'] == 'go', 
                                   np.where(df['goResp_test.keys'] == df['correctResponse'], 1, 0), 
                                   np.nan)

    df.loc[:, 'stop_failure_acc'] = np.where(
        (df['trialType'] == 'stop') & (df['goResp_test.rt'].notna()),
        np.where(df['goResp_test.keys'] == df['correctResponse'], 1, 0),
        np.nan)
    
    df.rename(columns={'goResp_test.keys': 'response', 'goResp_test.corr': 'correct', 'goResp_test.rt': 'rt'}, inplace=True)
    
    return df

In [27]:
task_metrics = {}
for file in data_files:
    subject_label = get_subject_label(file)
    
    df = pd.read_csv(file)

    if subject_label == 's019':
        df = df[df["stopBlocks.thisN"] != 2] # Excluded block 2 of s019

    df = preprocess_stop_data(df)

    go_rt = df.query("trialType == 'go'").rt.mean()
    stop_fail_rt = df.query("trialType == 'stop'").rt.mean()
    go_acc = df.query("trialType == 'go'").go_acc.mean()
    stop_failure_acc = df.query("trialType == 'stop'").stop_failure_acc.mean()
    stop_success = df.query("trialType == 'stop'").stop_acc.mean()
    avg_ssd = df.query("trialType == 'stop'").ssd.mean()
    min_ssd = df.query("trialType == 'stop'").ssd.min()
    max_ssd = df.query("trialType == 'stop'").ssd.max()
    min_ssd_count = (df.query("trialType == 'stop'").ssd == 0).sum()
    max_ssd_count = (df.query("trialType == 'stop'").ssd == 0.75).sum()
    ssrt = compute_SSRT(df)

    
    task_metrics[subject_label] = {
        'go_rt': go_rt,
        'stop_fail_rt': stop_fail_rt,
        'go_acc': go_acc,
        'stop_failure_acc': stop_failure_acc,
        'stop_success': stop_success,
        'avg_ssd': avg_ssd,
        'min_ssd': min_ssd,
        'max_ssd': max_ssd,
        'min_ssd_count': min_ssd_count,
        'max_ssd_count': max_ssd_count,
        'ssrt': ssrt,
    }

In [28]:
metrics = pd.DataFrame(task_metrics).T
metrics = metrics.sort_index()

mean_row = metrics.mean()
sd_row = metrics.std()
metrics = pd.concat([metrics, mean_row.rename('mean').to_frame().T, sd_row.rename('sd').to_frame().T])
simple_stop_ssrt = metrics
%store simple_stop_ssrt

Stored 'simple_stop_ssrt' (DataFrame)


In [29]:
metrics = metrics.iloc[:-2] #remove the mean and sd rows at the end of 'metrics'
metrics.loc[:, 'go_rt'] = metrics.loc[:, 'go_rt'] * 1000
metrics.loc[:, 'stop_fail_rt'] = metrics.loc[:, 'stop_fail_rt'] * 1000
metrics.loc[:, 'avg_ssd'] = metrics.loc[:, 'avg_ssd'] * 1000
metrics.loc[:, 'ssrt'] = metrics.loc[:, 'ssrt'] * 1000
# Calculate the mean and standard deviation for the specified metrics
mean_values = metrics[['go_rt', 'go_acc', 'stop_fail_rt', 'stop_success', 'avg_ssd', 'ssrt']].mean().to_frame().T
sd_values = metrics[['go_rt', 'go_acc', 'stop_fail_rt', 'stop_success', 'avg_ssd', 'ssrt']].std().to_frame().T

# Combine mean and standard deviation into a single DataFrame
metrics_mean_sd = pd.concat([mean_values, sd_values], ignore_index=True)
metrics_mean_sd.index = ['Mean', 'SD']

new_column_names = {
    'go_rt': 'Go Task Reaction Time (ms)',
    'go_acc': 'Go Task Accuracy',
    'stop_fail_rt': 'Stop Failure Reaction Time (ms)',
    'stop_success': 'Stop Success Rate',
    'avg_ssd': 'Average Stop Signal Delay (ms)',
    'ssrt': 'Stop Signal Reaction Time (ms)'
}


metrics_mean_sd = metrics_mean_sd.rename(columns=new_column_names)
metrics_mean_sd = metrics_mean_sd.round(2)
metrics_mean_sd = metrics_mean_sd.T

metrics_mean_sd.to_csv(f'{parent_directory}/tables/table1.csv')
metrics_mean_sd


Unnamed: 0,Mean,SD
Go Task Reaction Time (ms),557.62,132.88
Go Task Accuracy,0.97,0.03
Stop Failure Reaction Time (ms),481.66,101.76
Stop Success Rate,0.51,0.04
Average Stop Signal Delay (ms),327.71,121.41
Stop Signal Reaction Time (ms),215.55,31.15
