In [1]:
import numpy as np
import pandas as pd
from stopsignalmetrics import StopData, SSRTmodel, PostStopSlow, Violations, StopSummary


import matplotlib.pyplot as plt
import seaborn as sns

# MTURK DATA

In [2]:
variable_dict = {
   "columns": {
      "ID": "worker_id",
      "block": "current_block",
      "condition": "SS_trial_type",
      "SSD": "SS_delay",
      "goRT": "rt",
      "stopRT": "rt",
      "response": "key_press",
      "correct_response": "correct_response",
      "choice_accuracy": "choice_accuracy"
   },
   "key_codes": {
      "go": "go",
      "stop": "stop",
      "correct": 1,
      "incorrect": 0
   }
}
stop_data = StopData(var_dict=variable_dict)

# 1 Subject

In [3]:
subj_file = '../stopsignalmetrics/data/stop_signal_single_task_network_A3QAHF4UUBM7ZO.csv'
subj_df = pd.read_csv(subj_file, index_col=0)
cleaned_subj = stop_data.fit_transform(subj_df)

In [4]:
ssrt_model = SSRTmodel(model='all')
metrics = ssrt_model.fit_transform(cleaned_subj)
metrics

{'SSRT': {'mean': 301.53228449688623,
  'integration': 279.6296296296296,
  'omission': 282.6296296296296,
  'replacement': 282.6296296296296},
 'mean_SSD': 295.3703703703704,
 'p_respond': 0.5,
 'max_RT': 1675.0,
 'mean_go_RT': 596.9026548672566,
 'mean_stopfail_RT': 520.8888888888889,
 'omission_count': 1,
 'omission_rate': 0.008771929824561403,
 'go_acc': 0.911504424778761,
 'stopfail_acc': 0.0}

# Group 1

In [5]:
group_file = '../stopsignalmetrics/data//stop_signal_single_task_network.csv'
group_df = pd.read_csv(group_file, index_col=0)
group_df = group_df.reset_index()

cleaned_group = stop_data.fit_transform(group_df)
cleaned_group

Unnamed: 0,SSD,SS_duration,SS_stimulus,condition,att_check_percent,block_duration,correct,correct_response,original_correcttrial,block,...,stop_acc,stop_signal_condition,time_elapsed,timing_post_trial,trial_id,ID,correct_trial,goRT,stopRT,choice_accuracy
0,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,90.0,1.0,0.0,...,,go,98287,0.0,test_trial,A3QAHF4UUBM7ZO,1,622.0,,1
1,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,77.0,1.0,0.0,...,,go,100791,0.0,test_trial,A3QAHF4UUBM7ZO,1,519.0,,1
2,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,77.0,1.0,0.0,...,,go,103296,0.0,test_trial,A3QAHF4UUBM7ZO,1,491.0,,1
3,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,77.0,1.0,0.0,...,,go,105801,0.0,test_trial,A3QAHF4UUBM7ZO,1,389.0,,1
4,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,77.0,1.0,0.0,...,,go,108307,0.0,test_trial,A3QAHF4UUBM7ZO,1,314.0,,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4747,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,90.0,1.0,2.0,...,,go,683864,0.0,test_trial,A2581F7TDPAMBQ,1,747.0,,1
4748,400.0,500.0,<img class = center src='/static/experiments/s...,stop,,2000.0,,-1.0,0.0,2.0,...,0.0,stop,686479,0.0,test_trial,A2581F7TDPAMBQ,0,,654.0,0
4749,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,90.0,1.0,2.0,...,,go,689000,0.0,test_trial,A2581F7TDPAMBQ,1,680.0,,1
4750,,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,,77.0,1.0,2.0,...,,go,691512,0.0,test_trial,A2581F7TDPAMBQ,1,586.0,,1


In [10]:
ind_test = Violations().fit_transform(cleaned_group, level='group')

In [11]:
ind_test

Unnamed: 0,ID,SSD,n_go_stopfail_pairs,mean_violation,mean_stopFailureRT,mean_precedingGoRT
0,A1DS5O8MSI3ZH0,400.0,3,-65.000000,479.333333,544.333333
1,A1DS5O8MSI3ZH0,450.0,5,51.200000,596.800000,545.600000
2,A1DS5O8MSI3ZH0,500.0,5,-6.800000,632.800000,639.600000
3,A1DS5O8MSI3ZH0,550.0,2,-14.500000,737.000000,751.500000
4,A1L1SQ488YCCFJ,300.0,6,-113.500000,452.166667,565.666667
...,...,...,...,...,...,...
108,AVMIXXCHPD291,750.0,3,-132.000000,889.333333,1021.333333
109,AY7WPVKHVNBLG,200.0,4,-5.750000,376.750000,382.500000
110,AY7WPVKHVNBLG,250.0,7,-128.571429,425.428571,554.000000
111,AY7WPVKHVNBLG,300.0,5,-41.200000,446.200000,487.400000


In [None]:
all_ssdvals = g_test.index.levels[1]

In [None]:
all_ssdvals

In [None]:
g_test.loc[(slice(None) , 300.0), ]

In [None]:
ssrt_model = SSRTmodel()
cleaned_group.groupby('ID').apply(ssrt_model.fit_transform).apply(pd.Series)

In [None]:
StopSummary().fit_transform(cleaned_group, level='group')

In [None]:
Violations().fit_transform(cleaned_group, level='group')

In [None]:
ind_test = ind_test.set_index('SSD')

# Group 2 - Recreating Figure 1b

In [None]:
# load in data
group2_file = 'example_data/DataFixedSSDs2.xlsx'
group2_df = pd.read_excel(group2_file)
group2_df = group2_df.replace(r'^\s*$', np.nan, regex=True).replace('?', np.nan)

#build up "correct" responses for stop trials
addon_file = 'example_data/FixedSSD2StopTrialChoiceAccuracyInput.xlsx'
addon_df = pd.read_excel(addon_file)

group2_df['StopTrialCorrectResponse'] = np.nan

# circle response
for shape in ['Circle', 'Rhombus', 'Square', 'Triangle']:
    group2_df.loc[(group2_df['TrialType']=='stop') & (addon_df['Unnamed: 5']==f'{shape.lower()}.bmp'), 'StopTrialCorrectResponse'] = \
    addon_df.loc[(group2_df['TrialType']=='stop') & (addon_df['Unnamed: 5']==f'{shape.lower()}.bmp'), f'{shape}Response']

#combine go and stop into single columns
group2_df['response'] = np.where(group2_df['GoTrialResponse'].isnull(), group2_df['StopTrialResponse'], group2_df['GoTrialResponse'])
group2_df['CorrectResponse'] = np.where(group2_df['GoTrialCorrectResponse'].isnull(), group2_df['StopTrialCorrectResponse'], group2_df['GoTrialCorrectResponse'])
group2_df['CorrectResponse'] = group2_df['CorrectResponse'].str.lower()

# Preprocess data

In [None]:
variable_dict = {
   "columns": {
      "ID": "Subject",
      "block": "Block",
      "condition": "TrialType",
      "SSD": "StopSignalDelay",
      "goRT": "GoRT",
      "stopRT": "StopFailureRT",
      "response": "response",
      "correct_response": "CorrectResponse",
      "choice_accuracy": "choice_accuracy"
   },
   "key_codes": {
      "go": "go",
      "stop": "stop",
      "correct": 1,
      "incorrect": 0
   }
}

stop_data = StopData(var_dict=variable_dict)
cleaned_df = stop_data.fit_transform(group2_df)

# Get Group Violations - 1 line!

In [None]:
va_df = Violations().fit_transform(cleaned_df, level='group')

# Make a pretty & compelling plot

In [None]:
# Pivot & Plot
pivot_df = va_df.pivot_table(values='mean_violation', index=['SSD'],
                        columns=['ID'])


fig = plt.figure(figsize=(10,8))
plt.plot(pivot_df, linewidth=0, color='c', marker='o', alpha=.3, ) #plot individuals
ax = sns.lineplot(x='SSD', y="mean_violation", data=va_df, color='c')
ax.axis([0,500,-300,250])
ax.plot([0,pivot_df.index.max()],[0,0],color='k',linestyle=':',linewidth=3)

xticks=np.arange(0,pivot_df.index.max()+50,50)
xticks = [int(i) for i in xticks]
if len(xticks)> 16:
    xticks = [i for i in xticks if i%100==0]
ax.set_xticks(xticks)
ax.set_xticklabels(xticks,fontsize=18) 
yticks=np.arange(-300,250+50,50)
ax.set_yticks(yticks)
ax.set_yticklabels(yticks,fontsize=18)

ax.set_xlabel('Stop-signal delay (ms)',fontsize=24)
_ = ax.set_ylabel('Stop failure RT - No-stop RT',fontsize=24) 
plt.show()
plt.close()