In [69]:
import pandas as pd
import numpy as np
import glob
import os
import re

In [70]:
pd.set_option('display.max_columns', None)

In [71]:
data_path = "/Users/jahrios/Documents/Stanford/PoldrackLab/Projects/SharedControl/Data/raw/"
exp_stage = "final/"
task = "simple_stop"

pattern = os.path.join(data_path, exp_stage, '*', task, '*.csv')

data_files = glob.glob(pattern)

In [72]:
def get_subject_label(file):
    
    match = re.search(r'/sub-(s\d{3})/', file)
    
    if match:
        subject_label = match.group(1)
        print("Subject label:", subject_label)
        return subject_label
    else:
        print("No subject label found.")
        return None

In [73]:
def compute_SSRT(df, max_go_rt = 2, violation_flag = False):
    
    df = df.query('Phase == "test"')
    
    go_trials = df.loc[df.trialType == 'go']
    stop_df = df.loc[df.trialType == 'stop']

    go_replacement_df = go_trials.where(~go_trials['rt'].isna(), max_go_rt)
    sorted_go = go_replacement_df.rt.sort_values(ascending = True, ignore_index=True)
    stop_failure = stop_df.loc[stop_df['rt'].notna()]

    p_respond = len(stop_failure)/len(stop_df)
    avg_SSD = stop_df.ssd.mean()

    nth_index = int(np.rint(p_respond*len(sorted_go))) - 1

    if nth_index < 0:
        nth_RT = sorted_go[0]
    elif nth_index >= len(sorted_go):
        nth_RT = sorted_go[-1]
    else:
        nth_RT = sorted_go[nth_index]

    SSRT = nth_RT - avg_SSD

    return SSRT

In [74]:
def preprocess_stop_data(df):
    df = df.query(f'Phase == "test"')
    df = df[['Block', 'Phase', 'trialType', 'goStim', 'correctResponse', 'ssd', 'goResp_test.keys', 'goResp_test.corr','goResp_test.rt']]
    
    df.loc[:, 'stop_acc'] = np.where(df['trialType'] == 'stop', 
                                 np.where(df['goResp_test.keys'] == 'None', 1, 0), 
                                 np.nan)

    df.loc[:, 'go_acc'] = np.where(df['trialType'] == 'go', 
                                   np.where(df['goResp_test.keys'] == df['correctResponse'], 1, 0), 
                                   np.nan)

    df.loc[:, 'stop_failure_acc'] = np.where(
        (df['trialType'] == 'stop') & (df['goResp_test.rt'].notna()),
        np.where(df['goResp_test.keys'] == df['correctResponse'], 1, 0),
        np.nan)
    
    df.rename(columns={'goResp_test.keys': 'response', 'goResp_test.corr': 'correct', 'goResp_test.rt': 'rt'}, inplace=True)
    
    return df

In [75]:
task_metrics = {}
for file in data_files:
    
    subject_label = get_subject_label(file)
    
    df = pd.read_csv(file)
    
    df = preprocess_stop_data(df)
    
    
    go_rt = df.query("trialType == 'go'").rt.mean()
    stop_fail_rt = df.query("trialType == 'stop'").rt.mean()
    go_acc = df.query("trialType == 'go'").go_acc.mean()
    stop_failure_acc = df.query("trialType == 'stop'").stop_failure_acc.mean()
    stop_success = df.query("trialType == 'stop'").stop_acc.mean()
    avg_ssd = df.query("trialType == 'stop'").ssd.mean()
    min_ssd = df.query("trialType == 'stop'").ssd.min()
    max_ssd = df.query("trialType == 'stop'").ssd.max()
    min_ssd_count = (df.query("trialType == 'stop'").ssd == min_ssd).sum()
    max_ssd_count = (df.query("trialType == 'stop'").ssd == max_ssd).sum()
    ssrt = compute_SSRT(df)
    
    
    
    task_metrics[subject_label] = {
            'go_rt': go_rt,
            'stop_fail_rt': stop_fail_rt,
            'go_acc': go_acc,
            'stop_failure_acc': stop_failure_acc,
            'stop_success': stop_success,
            'avg_ssd': avg_ssd,
            'min_ssd': min_ssd,
            'max_ssd': max_ssd,
            'min_ssd_count': min_ssd_count,
            'max_ssd_count': max_ssd_count,
            'ssrt': ssrt
        }
    

Subject label: s016
Subject label: s011
Subject label: s018
Subject label: s019
Subject label: s010
Subject label: s017
Subject label: s004
Subject label: s005
Subject label: s012
Subject label: s015
Subject label: s014
Subject label: s013
Subject label: s009
Subject label: s007
Subject label: s006
Subject label: s008


In [76]:
metrics = pd.DataFrame(task_metrics).T
metrics = metrics.sort_index()

In [78]:
metrics.to_csv('output/simple_stop_metrics.csv')