In [1]:
import numpy as np
import pandas as pd
from stopsignalmetrics import StopData, SSRTmodel, PostStopSlow, Violations, StopSummary


import matplotlib.pyplot as plt
import seaborn as sns

# MTURK DATA

In [7]:
variable_dict = {
   "columns": {
      "ID": "worker_id",
      "block": "current_block",
      "condition": "SS_trial_type",
      "SSD": "SS_delay",
      "goRT": "rt",
      "stopRT": "rt",
      "response": "key_press",
      "correct_response": "correct_response",
      "choice_accuracy": "choice_accuracy"
   },
   "key_codes": {
      "go": "go",
      "stop": "stop",
      "correct": 1,
      "incorrect": 0,
      "noResponse": -1
   }
}
stop_data = StopData(var_dict=variable_dict)

In [None]:
STANDARDS_FILE = '../stopsignalmetrics/data/standards.json'
with open(STANDARDS_FILE) as json_file:
            jstring = json_file.read().replace('\n', '').replace('null', '"np.nan"')
            print(jstring)

# 1 Subject

In [8]:
subj_file = '../stopsignalmetrics/data/stop_signal_single_task_network_A3QAHF4UUBM7ZO.csv'
subj_df = pd.read_csv(subj_file, index_col=0)
subj_df = subj_df.fillna(-1.)
cleaned_subj = stop_data.fit_transform(subj_df)

In [18]:
cleaned_subj['stopRT'].notnull().index

Int64Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
            ...
            363, 364, 365, 366, 367, 368, 369, 370, 371, 372],
           dtype='int64', length=373)

In [None]:
mapping = {'go': 'go', 'stop': 'stop', 1: 1, 0: 0, -1: np.nan}
cleaned_subj.query("condition=='stop' ")['stopRT'].map(lambda x: mapping.get(x,x))

In [15]:
cleaned_subj.query("condition=='stop' and stopRT == stopRT", engine='python')

Unnamed: 0,SSD,SS_duration,SS_stimulus,condition,att_check_percent,block_duration,correct,correct_response,correct_trial,block,...,text,time_elapsed,timing_post_trial,trial_id,trial_index,trial_type,view_history,goRT,stopRT,choice_accuracy
6,350.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,23084,0.0,practice_trial,6,stop-signal,-1,,580.0,0
27,300.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,44131,0.0,practice_trial,27,stop-signal,-1,,427.0,0
39,250.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,56158,0.0,practice_trial,39,stop-signal,-1,,515.0,0
45,250.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,62171,0.0,practice_trial,45,stop-signal,-1,,470.0,0
91,250.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,113315,0.0,test_trial,91,stop-signal,-1,,460.0,0
117,350.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,145876,0.0,test_trial,117,stop-signal,-1,,695.0,0
129,350.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,160903,0.0,test_trial,129,stop-signal,-1,,464.0,0
137,350.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,170918,0.0,test_trial,137,stop-signal,-1,,566.0,0
139,300.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,173424,0.0,test_trial,139,stop-signal,-1,,497.0,0
157,300.0,500.0,<img class = center src='/static/experiments/s...,stop,-1.0,2000.0,-1,-1.0,0.0,0.0,...,-1,195964,0.0,test_trial,157,stop-signal,-1,,429.0,0


In [None]:
ssrt_model = SSRTmodel(model='all')
metrics = ssrt_model.fit_transform(cleaned_subj)
metrics

# Group 1

In [None]:
group_file = '../stopsignalmetrics/data//stop_signal_single_task_network.csv'
group_df = pd.read_csv(group_file, index_col=0)
group_df = group_df.reset_index()

cleaned_group = stop_data.fit_transform(group_df)
cleaned_group

In [None]:
ind_test = Violations().fit_transform(cleaned_group, level='group')

In [None]:
ind_test

In [None]:
all_ssdvals = g_test.index.levels[1]

In [None]:
all_ssdvals

In [None]:
g_test.loc[(slice(None) , 300.0), ]

In [None]:
ssrt_model = SSRTmodel()
cleaned_group.groupby('ID').apply(ssrt_model.fit_transform).apply(pd.Series)

In [None]:
StopSummary().fit_transform(cleaned_group, level='group')

In [None]:
Violations().fit_transform(cleaned_group, level='group')

In [None]:
ind_test = ind_test.set_index('SSD')

# Group 2 - Recreating Figure 1b

In [None]:
# load in data
group2_file = 'example_data/DataFixedSSDs2.xlsx'
group2_df = pd.read_excel(group2_file)
group2_df = group2_df.replace(r'^\s*$', np.nan, regex=True).replace('?', np.nan)

#build up "correct" responses for stop trials
addon_file = 'example_data/FixedSSD2StopTrialChoiceAccuracyInput.xlsx'
addon_df = pd.read_excel(addon_file)

group2_df['StopTrialCorrectResponse'] = np.nan

# circle response
for shape in ['Circle', 'Rhombus', 'Square', 'Triangle']:
    group2_df.loc[(group2_df['TrialType']=='stop') & (addon_df['Unnamed: 5']==f'{shape.lower()}.bmp'), 'StopTrialCorrectResponse'] = \
    addon_df.loc[(group2_df['TrialType']=='stop') & (addon_df['Unnamed: 5']==f'{shape.lower()}.bmp'), f'{shape}Response']

#combine go and stop into single columns
group2_df['response'] = np.where(group2_df['GoTrialResponse'].isnull(), group2_df['StopTrialResponse'], group2_df['GoTrialResponse'])
group2_df['CorrectResponse'] = np.where(group2_df['GoTrialCorrectResponse'].isnull(), group2_df['StopTrialCorrectResponse'], group2_df['GoTrialCorrectResponse'])
group2_df['CorrectResponse'] = group2_df['CorrectResponse'].str.lower()

# Preprocess data

In [None]:
variable_dict = {
   "columns": {
      "ID": "Subject",
      "block": "Block",
      "condition": "TrialType",
      "SSD": "StopSignalDelay",
      "goRT": "GoRT",
      "stopRT": "StopFailureRT",
      "response": "response",
      "correct_response": "CorrectResponse",
      "choice_accuracy": "choice_accuracy"
   },
   "key_codes": {
      "go": "go",
      "stop": "stop",
      "correct": 1,
      "incorrect": 0
   }
}

stop_data = StopData(var_dict=variable_dict)
cleaned_df = stop_data.fit_transform(group2_df)

# Get Group Violations - 1 line!

In [None]:
va_df = Violations().fit_transform(cleaned_df, level='group')

# Make a pretty & compelling plot

In [None]:
# Pivot & Plot
pivot_df = va_df.pivot_table(values='mean_violation', index=['SSD'],
                        columns=['ID'])


fig = plt.figure(figsize=(10,8))
plt.plot(pivot_df, linewidth=0, color='c', marker='o', alpha=.3, ) #plot individuals
ax = sns.lineplot(x='SSD', y="mean_violation", data=va_df, color='c')
ax.axis([0,500,-300,250])
ax.plot([0,pivot_df.index.max()],[0,0],color='k',linestyle=':',linewidth=3)

xticks=np.arange(0,pivot_df.index.max()+50,50)
xticks = [int(i) for i in xticks]
if len(xticks)> 16:
    xticks = [i for i in xticks if i%100==0]
ax.set_xticks(xticks)
ax.set_xticklabels(xticks,fontsize=18) 
yticks=np.arange(-300,250+50,50)
ax.set_yticks(yticks)
ax.set_yticklabels(yticks,fontsize=18)

ax.set_xlabel('Stop-signal delay (ms)',fontsize=24)
_ = ax.set_ylabel('Stop failure RT - No-stop RT',fontsize=24) 
plt.show()
plt.close()

In [None]:
import json

In [None]:
with open("sample.json", "w") as outfile:  
    json.dump({'check': np.nan}, outfile)

