In [1]:
import pandas as pd
import numpy as np
from numpy import mean
from glob import glob
from os import path
from scipy import stats
import pingouin as pg


import matplotlib.pyplot as plt
import seaborn as sns
from statannot import add_stat_annotation
from IPython.display import display

from dual_data_utils import make_clean_concat_data
from stopsignalmetrics import StopData, SSRTmodel, Violations, PostStopSlow
%matplotlib inline

  **kwargs
  **kwargs


In [2]:
# violation ploting imports
from cycler import cycler

from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri

pandas2ri.activate()
%load_ext rpy2.ipython

try:
    tidyverse = importr('tidyverse')
except:
    utils = importr('utils')
    utils.install_packages('tidyverse', repo="http://cran.rstudio.com/") 
    tidyverse = importr('tidyverse')
    
try:
    lme4 = importr('lme4')
except:
    utils = importr('utils')
    utils.install_packages('lme4', repo="http://cran.rstudio.com/")
    lme4 = importr('lme4')
    
try:
    lmerTest = importr('lmerTest')
except:
    utils = importr('utils')
    utils.install_packages('lmerTest', repo="http://cran.rstudio.com/") 
    lmerTest = importr('lmerTest')
    
try:
    emmeans = importr('emmeans')
except:
    utils = importr('utils')
    utils.install_packages('emmeans', repo="http://cran.rstudio.com/") 
    emmeans = importr('emmeans')
    
try:
    rstatix = importr('rstatix')
except:
    utils = importr('utils')
    utils.install_packages('rstatix', repo="http://cran.rstudio.com/") 
    rstatix = importr('rstatix')

In [3]:
OUTLIER_THRESH  = 3
SSD_THRESH = 200

dual_dict = {
    'stop_signal_with_cued_task_switching': {
        'dual_col': 'cue_task_switch',
        'DC': 'cue_stay_task_stay',
        'DE': 'cue_switch_task_switch',
        'OTHER': ['cue_switch_task_stay']
    },
    'stop_signal_with_directed_forgetting': {
        'dual_col': 'directed_forgetting_condition',
        'DC': 'con',
        'DE': 'neg',
        'OTHER': ['pos']
    },
    'stop_signal_with_flanker': {
        'dual_col': 'flanker_condition',
        'DC': 'congruent',
        'DE': 'incongruent'
    },
    'stop_signal_with_go_no_go': {
        'dual_col': 'go_nogo_condition',
        'DC': 'go',
        'DE': 'nogo'
    },
    'stop_signal_with_n_back': {
        'dual_col': 'delay_condition',
        'DC': 1.,
        'DE': 2.,
        'OTHER': [3.]
    },
    'stop_signal_with_predictable_task_switching': {
        'dual_col': 'predictable_condition',
        'DC': 'stay',
        'DE': 'switch'
    },
    'stop_signal_with_shape_matching': {
        'dual_col': 'shape_matching_condition',
        'DC': 'CONTROL',
        'DE': 'DISTRACTOR',
        'OTHER': ['DSD', 'SSS']
    },
}

stopmetrics_var_dict = {
    'columns': {
        'ID': 'worker_id',
        'block': 'current_block',
        'condition': 'SS_trial_type',
        'SSD': 'SS_delay',
        'goRT': 'rt',
        'stopRT': 'rt',
        'response': 'key_press',
        'correct_response': 'choice_correct_response',
        'choice_accuracy': 'choice_accuracy'},
    'key_codes': {
        'go': 'go',
        'stop': 'stop',
        'correct': 1,
        'incorrect': 0,
        'noResponse': -1}}

In [4]:
def get_query_str(col, condition):
    if type(condition)==str:
        return "%s=='%s'" % (col, condition)
    elif type(condition) in [float, int]:
        return "%s==%s" % (col, condition)
    elif type(condition)==list:
        return "%s in %s" % (col, condition)

def mean_pss(data_df, stop_type='all', query_suffix=None):
    return PostStopSlow().fit_transform(
        data_df, stop_type=stop_type, level='group', query_suffix=query_suffix,
    ).groupby('pre_ID').apply(
        lambda x: (x[f'post_goRT'] - x[f'pre_goRT']
                  ).mean()
    )

def stop_summarize(curr_data, thresh=SSD_THRESH, query_str=None):
    data_df = curr_data.query(query_str) if query_str is not None else curr_data
    
    sum_df = SSRTmodel().fit_transform(data_df, level='group')
    sum_df_threshSSDs = SSRTmodel().fit_transform(data_df.query(f'SSD>={thresh} or condition=="go"'), level='group').add_suffix('_wThresh_SSDs')
    sum_df_threshSubs = sum_df.query(f"mean_SSD >= {thresh}").add_suffix('_wThresh_subs')
    sum_df = pd.concat([sum_df, sum_df_threshSSDs, sum_df_threshSubs],
                       1,
                       sort=True
                      )
    query_suffix = "& "+query_str if query_str is not None else None
    for stop_type in ['all', 'fail', 'success']:
        sum_df = pd.concat(
            [sum_df, mean_pss(curr_data, stop_type=stop_type, query_suffix=query_suffix).rename('PSS_%s' % stop_type),
            mean_pss(data_df.query(f'SSD>{thresh} or condition=="go"'), stop_type=stop_type, query_suffix=query_suffix).rename('PSS_%s_wThresh_SSDs' % stop_type),
            mean_pss(data_df.query(f"ID in {sum_df_threshSubs.index.tolist()}"), stop_type=stop_type, query_suffix=query_suffix).rename('PSS_%s_wThresh_subs' % stop_type)],
            1,
            sort=True
        )
    return sum_df

def get_inhib_func(curr_data, cond):
    inhib_df = pd.DataFrame(curr_data.groupby(['ID', 'SSD']).apply(lambda x: x.stopRT.notnull().sum()/len(x))).reset_index()
    inhib_df['condition'] = cond
    return inhib_df

In [5]:
stop_data_dict = make_clean_concat_data(stop_subset=True)

In [7]:
stop_data_dict['stop_signal_single_task_network']

Unnamed: 0,SS_delay,SS_duration,SS_stimulus,SS_trial_type,att_check_percent,block_duration,choice_accuracy,choice_correct_response,correct,correct_response,...,possible_responses,rt,stim,stim_duration,stop_acc,stop_signal_condition,time_elapsed,timing_post_trial,trial_id,worker_id
0,500.0,500.0,<img class = center src='/static/experiments/s...,stop,,2000.0,1,77.0,,-1.0,...,"[90, 77]",791,square,1000.0,0.0,stop,110670,0.0,test_trial,s341
1,450.0,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,0,77.0,,77.0,...,"[90, 77]",-1,square,1000.0,,go,113196,0.0,test_trial,s341
2,450.0,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,1,77.0,,77.0,...,"[90, 77]",717,square,1000.0,,go,115706,0.0,test_trial,s341
3,450.0,500.0,<img class = center src='/static/experiments/s...,stop,,2000.0,0,90.0,,-1.0,...,"[90, 77]",-1,circle,1000.0,1.0,stop,118224,0.0,test_trial,s341
4,500.0,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,1,77.0,,77.0,...,"[90, 77]",686,square,1000.0,,go,120743,0.0,test_trial,s341
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
139,350.0,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,1,90.0,,90.0,...,"[90, 77]",533,circle,1000.0,,go,477034,0.0,test_trial,s360
140,350.0,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,1,90.0,,90.0,...,"[90, 77]",605,circle,1000.0,,go,479544,0.0,test_trial,s360
141,350.0,500.0,<img class = center src='/static/experiments/s...,stop,,2000.0,1,90.0,,-1.0,...,"[90, 77]",406,circle,1000.0,0.0,stop,482053,0.0,test_trial,s360
142,300.0,500.0,<img class = center src='/static/experiments/s...,go,,2000.0,1,77.0,,77.0,...,"[90, 77]",677,square,1000.0,,go,484563,0.0,test_trial,s360


In [8]:
prepro_data = {task: StopData(stopmetrics_var_dict).fit_transform(stop_data_dict[task].drop(columns=['correct_response'])) for task in stop_data_dict}

In [23]:
se_data = prepro_data['stop_signal_single_task_network']
shared_ssd_map = {}
for sid in se_data.ID.unique():
    sub_ssds = se_data.query("ID=='%s' and condition=='stop'" % sid).SSD.values
    sampled_ssds = set(sub_ssds)
    for dual_task in [t for t in prepro_data if t!='stop_signal_single_task_network']:
        sampled_ssds = sampled_ssds.intersection(set(prepro_data[dual_task].query("ID=='%s' and condition=='stop'" % sid).SSD.unique()))
    shared_ssd_map[sid] = list(sampled_ssds)

In [24]:
shared_ssd_map

{'s341': [600.0, 450.0, 500.0, 550.0],
 's142': [400.0],
 's126': [400.0, 250.0, 300.0, 350.0],
 's376': [400.0, 450.0, 500.0],
 's490': [600.0, 550.0],
 's295': [450.0, 500.0],
 's205': [400.0, 450.0, 300.0, 350.0],
 's454': [350.0],
 's090': [],
 's441': [400.0, 450.0, 350.0],
 's025': [300.0, 350.0],
 's396': [200.0, 250.0, 150.0],
 's365': [600.0, 650.0, 550.0],
 's135': [],
 's397': [600.0, 450.0, 500.0, 550.0],
 's380': [400.0, 450.0, 500.0, 350.0],
 's141': [200.0, 250.0, 300.0, 350.0],
 's320': [400.0, 450.0, 500.0, 350.0],
 's066': [350.0],
 's207': [250.0, 300.0, 350.0],
 's264': [400.0, 350.0],
 's069': [400.0, 450.0, 500.0],
 's005': [400.0],
 's248': [250.0, 300.0, 350.0],
 's419': [450.0, 550.0, 650.0, 400.0, 500.0, 600.0],
 's010': [400.0, 450.0, 500.0, 550.0],
 's214': [],
 's429': [600.0, 550.0],
 's044': [350.0],
 's539': [300.0, 350.0],
 's350': [400.0, 450.0, 500.0],
 's369': [450.0, 500.0],
 's360': [400.0, 450.0, 300.0, 350.0]}

In [32]:
sid='s214'

sub_ssds = se_data.query("ID=='%s' and condition=='stop'" % sid).SSD.values
sampled_ssds = set(sub_ssds)
print(sampled_ssds)
for dual_task in [t for t in prepro_data if t!='stop_signal_single_task_network']:
    curr_ssds = set(prepro_data[dual_task].query("ID=='%s' and condition=='stop'" % sid).SSD.unique())
    print(dual_task, curr_ssds)
    sampled_ssds = sampled_ssds.intersection(curr_ssds)
    print(sampled_ssds)
shared_ssd_map[sid] = list(sampled_ssds)

{450.0, 300.0, 400.0, 500.0, 350.0}
stop_signal_with_predictable_task_switching {450.0, 550.0, 650.0, 400.0, 500.0, 600.0, 700.0}
{400.0, 450.0, 500.0}
stop_signal_with_n_back {450.0, 550.0, 200.0, 300.0, 400.0, 500.0, 150.0, 600.0, 250.0, 350.0}
{400.0, 450.0, 500.0}
stop_signal_with_flanker {450.0, 200.0, 300.0, 400.0, 150.0, 250.0, 350.0}
{400.0, 450.0}
stop_signal_with_shape_matching {450.0, 550.0, 300.0, 400.0, 500.0, 250.0, 350.0}
{400.0, 450.0}
stop_signal_with_cued_task_switching {450.0, 200.0, 300.0, 400.0, 250.0, 350.0}
{400.0, 450.0}
stop_signal_with_directed_forgetting {100.0, 200.0, 300.0, 150.0, 250.0, 350.0}
set()
