In [1]:
import numpy as np
import pandas as pd
import argparse
from glob import glob
from os import path

from sympy.solvers import solve
from sympy import Symbol
import scipy.stats as sstats

from utils import SimulateData
from stopsignalmetrics import SSRTmodel

In [2]:
params = {
    'n_trials_stop': np.arange(0, 600, 50),
    'n_guess_stop': np.arange(0, 560, 45)
}

simulator = SimulateData(guesses=True)
tsts_data = simulator.simulate(params)

In [3]:
tsts_data

Unnamed: 0,condition,SSD,trial_idx,mu_go,mu_stop,accum_go,accum_stop,process_go,process_stop,block,goRT,stopRT
0,stop,50.0,0,0.2,0.6,32.995480,101.512016,"[1.1228058859341903, 0, 0, 0.5266842054572151,...","[0.7322789504604907, 3.7958496540778, 2.917113...",0,,
1,stop,50.0,1,0.2,0.6,74.681183,100.396750,"[0, 2.1833643050332023, 3.4249850214325503, 2....","[0, 0.6665193095796709, 1.45077921380828, 0.15...",0,,
2,stop,50.0,2,0.2,0.6,29.196587,101.332944,"[1.8606452469010342, 0.162229528895256, 0.9307...","[3.6203777962329564, 2.9030782165567004, 1.829...",0,,
3,stop,50.0,3,0.2,0.6,57.849907,100.268183,"[1.4305284370461606, 1.5847763258813474, 2.255...","[0, 0.42176275492078563, 0, 0, 1.2612256027382...",0,,
4,stop,50.0,4,0.2,0.6,57.648952,100.111675,"[1.0470023023938138, 0.8345209970179899, 0, 0....","[0.5064049141548703, 0, 2.8960931403509207, 4....",0,,
...,...,...,...,...,...,...,...,...,...,...,...,...
4295,go,,995,0.2,0.6,100.681059,0.000000,"[1.9916852223569887, 3.419591988091699, 2.3246...",[],0,680.0,
4296,go,,996,0.2,0.6,101.561984,0.000000,"[0.4801801977031876, 0, 0, 0.19593739666170887...",[],0,333.0,
4297,go,,997,0.2,0.6,102.152071,0.000000,"[0.7584042620262557, 0, 0.5218836821677137, 1....",[],0,750.0,
4298,go,,998,0.2,0.6,100.157670,0.000000,"[0, 2.6534121222326603, 3.3962379139223207, 2....",[],0,656.0,


In [4]:
np.arange(0, 560, 45)

array([  0,  45,  90, 135, 180, 225, 270, 315, 360, 405, 450, 495, 540])

In [5]:
simulator._n_trials_stop

{0: 0,
 50: 50,
 100: 100,
 150: 150,
 200: 200,
 250: 250,
 300: 300,
 350: 350,
 400: 400,
 450: 450,
 500: 500,
 550: 550}

In [6]:
simulator._n_guess_stop

{0: 0.0,
 50: 0.0,
 100: 0.0,
 150: 0.0,
 200: 0.0,
 250: 0.0,
 300: 0.0,
 350: 0.0,
 400: 0.0,
 450: 0.0,
 500: 0.0,
 550: 0.0}

# Figuring out graded mu go

In [7]:
import numpy as np
import pandas as pd
import argparse
from os import path
from glob import glob
import scipy.stats as sstats

from stopsignalmetrics import SSRTmodel
from utils import SimulateData
from simulate import generate_exgauss_sampler_from_fit


def get_args():
    parser = argparse.ArgumentParser(description='ABCD data simulations')
    parser.add_argument('--abcd_dir', default='./abcd_data',
                        help='location of ABCD data')
    parser.add_argument('--sim_dir', default='./simulated_data',
                        help='location of simulated data')
    parser.add_argument('--out_dir', default='./ssrt_metrics',
                        help='location to save ssrt metrics')
    parser.add_argument('--n_graded_go_trials', default=10000)
    args = parser.parse_args([])
    return(args)


def generate_out_df(data, SSD_guess_dict, graded_go_dict):
    info = []
    ssrtmodel = SSRTmodel(model='replacement')
    goRTs = data.loc[data.goRT.notnull(), 'goRT'].values
    SSDs = [i for i in data.SSD.unique() if i == i]
    SSDs.sort()

    for SSD in SSDs:
        curr_df = data.query(
            "condition=='go' | condition=='stop' and SSD == %s" % SSD
            ).copy()
        curr_metrics = ssrtmodel.fit_transform(curr_df)
        if (curr_metrics['p_respond'] == 0) | (curr_metrics['p_respond'] == 1):
            curr_info = [v for v in curr_metrics.values()] +\
                    [SSD, np.nan, np.nan]
        else:
            goRTs_w_guesses = add_guess_RTs_and_sort(goRTs,
                                                     SSD,
                                                     SSD_guess_dict)
            if SSD < 200:
                print('w guesses:')
            SSRT_w_guesses = SSRT_wReplacement(curr_metrics,
                                               goRTs_w_guesses,
                                               verbose=(SSD < 200))
            if SSD < 200:
                print('w graded:')
            SSRT_w_graded = SSRT_wReplacement(curr_metrics,
                                              graded_go_dict[SSD].copy(),
                                              verbose=(SSD < 200))

            curr_info = [v for v in curr_metrics.values()] +\
                        [SSD, SSRT_w_guesses, SSRT_w_graded]
        info.append(curr_info)
        cols = [k for k in curr_metrics.keys()] +\
               ['SSD', 'SSRT_w_guesses', 'SSRT_w_graded']

    return pd.DataFrame(
        info,
        columns=cols)


def add_guess_RTs_and_sort(goRTs, SSD, SSD_guess_dict):
    curr_n = len(goRTs)
    p_guess = SSD_guess_dict[SSD]
    if p_guess == 1.0:
        guess_RTs = sample_exgauss(curr_n)
        guess_RTs.sort()
        return guess_RTs
    elif p_guess <= 0:  # SSDs 550 and 650
        goRTs.sort()
        return goRTs
    else:
        # Equation logic:
        # p_guess = n_guess / (n_guess + curr_n) =>
        # n_guess = (p_guess * curr_n) / (1 - p_guess)
        n_guess = int(np.rint(float((p_guess*curr_n)/(1-p_guess))))
        guess_RTs = sample_exgauss(n_guess)
        all_RTs = np.concatenate([goRTs, guess_RTs])
        all_RTs.sort()
        return all_RTs


def simulate_graded_RTs_and_sort(n_trials, SSD, verbose=False):
    simulator = SimulateData()
    params = simulator._init_params({})
    params['n_trials_stop'] = n_trials
    params['n_trials_go'] = n_trials

    params['mu_go'] = simulator._log_mu_go(params['mu_go'], SSD)
    simulator._set_n_trials(params)
    simulator._set_n_guesses(params)  # no guessing is happening

    data_dict = simulator._simulate_go_trials(simulator._init_data_dict(),
                                              params)
    goRTs = data_dict['RT']
    goRTs.sort()
    if verbose:
        print(SSD)
        for p in np.arange(0, 100, 5):
            print(p, sstats.scoreatpercentile(goRTs, p))
    return goRTs


def get_nth_RT(P_respond, goRTs):
    """Get nth RT based P(response|signal) and sorted go RTs."""
    nth_index = int(np.rint(P_respond*len(goRTs))) - 1
    if nth_index < 0:
        nth_RT = goRTs[0]
    elif nth_index >= len(goRTs):
        nth_RT = goRTs[-1]
    else:
        nth_RT = goRTs[nth_index]
    return nth_RT


def SSRT_wReplacement(metrics, sorted_go_RTs, verbose=False):
    P_respond = metrics['p_respond']
    goRTs_w_replacements = np.concatenate((
        sorted_go_RTs,
        [metrics['max_RT']] * metrics['omission_count']))

    nrt = get_nth_RT(P_respond, goRTs_w_replacements)
    if verbose:
        print('SSD', metrics['mean_SSD'])
        print('p_respond', P_respond)
        print('nrt', nrt)
    return nrt - metrics['mean_SSD']

In [9]:
    args = get_args()

    # GET ABCD INFO
    abcd_data = pd.read_csv('%s/minimal_abcd_clean.csv' % args.abcd_dir)
    p_guess_df = pd.read_csv('%s/p_guess_per_ssd.csv' % args.abcd_dir)

    SSD_guess_dict = {float(col): float(p_guess_df[col].values[0]) for col
                      in p_guess_df.columns}
    print(SSD_guess_dict)

    SSD0_RTs = abcd_data.query(
        "SSDDur == 0.0 and correct_stop==0.0"
        ).stop_rt_adjusted.values
    sample_exgauss = generate_exgauss_sampler_from_fit(SSD0_RTs)

    # SET UP GRADED MU GO DISTS
    graded_go_dict = {}
    for SSD in [i for i in abcd_data.SSDDur.unique() if i == i]:
        graded_go_dict[SSD] = simulate_graded_RTs_and_sort(
            args.n_graded_go_trials,
            SSD,
            verbose=(SSD < 200))

#     # CALCULATE SSRT
#     for data_file in [glob(path.join(args.sim_dir, '*.csv'))[0]]:
#         sim_type = path.basename(
#             data_file
#             ).replace('.csv', '')
#         out_df = generate_out_df(pd.read_csv(data_file),
#                                  SSD_guess_dict,
#                                  graded_go_dict)
#         out_df.to_csv(path.join(args.out_dir, '%s.csv' % sim_type))

{50.0: 0.845301348829293, 100.0: 0.610622825783264, 0.0: 1.0, 150.0: 0.452841954188218, 200.0: 0.2941470114067, 250.0: 0.200997922993901, 300.0: 0.14331766365641901, 350.0: 0.0798946466477947, 400.0: 0.0564683334369033, 450.0: 0.0271712158267143, 500.0: 0.0191213813425612, 550.0: -0.0013543997871993901, 600.0: 0.00573360220001983, 650.0: -0.00349382177060603, 700.0: 0.0748603699110067, 750.0: 0.035672034056798, 800.0: 0.185669151023814, 850.0: 0.156473540451239, 900.0: 0.26911106708549204}
50.0
0 98.0
5 197.0
10 234.0
15 267.0
20 298.0
25 328.0
30 360.0
35 396.0
40 433.0
45 470.0
50 509.5
55 555.0
60 607.3999999999996
65 661.0
70 724.0
75 802.0
80 896.0
85 1024.1499999999996
90 1191.2000000000007
95 1490.0
100.0
0 93.0
5 187.0
10 219.0
15 247.0
20 274.0
25 302.0
30 329.0
35 358.0
40 387.0
45 418.0
50 452.0
55 490.0
60 528.0
65 575.0
70 633.2999999999993
75 697.0
80 776.0
85 877.0
90 1019.1000000000004
95 1266.0499999999993
0.0
0 106.0
5 225.0
10 278.0
15 328.0
20 374.0
25 424.0
30 478.

In [13]:
graded_go_dict[100.0]

[93,
 110,
 115,
 120,
 123,
 124,
 127,
 127,
 129,
 131,
 131,
 131,
 132,
 132,
 133,
 134,
 135,
 136,
 137,
 137,
 139,
 139,
 140,
 140,
 140,
 141,
 141,
 141,
 142,
 146,
 147,
 148,
 148,
 149,
 153,
 153,
 154,
 154,
 155,
 155,
 156,
 159,
 161,
 161,
 162,
 164,
 164,
 165,
 166,
 166,
 167,
 167,
 167,
 169,
 171,
 172,
 172,
 173,
 173,
 173,
 174,
 174,
 175,
 175,
 176,
 177,
 177,
 178,
 178,
 179,
 179,
 180,
 181,
 181,
 182,
 184,
 184,
 186,
 186,
 186,
 186,
 187,
 188,
 189,
 190,
 190,
 190,
 192,
 192,
 193,
 194,
 195,
 195,
 196,
 196,
 197,
 197,
 198,
 198,
 199,
 200,
 200,
 201,
 201,
 202,
 202,
 203,
 203,
 204,
 204,
 204,
 204,
 204,
 204,
 204,
 206,
 206,
 206,
 207,
 207,
 208,
 209,
 209,
 210,
 210,
 211,
 211,
 211,
 212,
 212,
 212,
 212,
 212,
 213,
 213,
 213,
 214,
 215,
 215,
 215,
 216,
 216,
 216,
 216,
 217,
 217,
 218,
 218,
 219,
 220,
 220,
 220,
 220,
 221,
 221,
 222,
 222,
 223,
 224,
 224,
 226,
 226,
 226,
 227,
 227,
 227,
 228,