In [4]:
import argparse
import json
import os
import numpy as np
from trial import Trial
from collections import namedtuple
import pandas as pd
from stopsignalmetrics import StopData, SSRTmodel, PostStopSlow, Violations, StopSummary

from stoptaskstudy import StopTaskStudy, fixedSSD

In [5]:
def get_args():
    parser = argparse.ArgumentParser(description='ABCD data simulator')
    parser.add_argument('--paramfile', help='json file containing parameters')
    parser.add_argument('--min_ssd', help='minimum SSD value', default=0)
    parser.add_argument('--max_ssd', help='maximum SSD value', default=550)
    parser.add_argument('--ssd_step', help='SSD step size', default=50)
    parser.add_argument('--random_seed', help='random seed', type=int)
    parser.add_argument('--tracking', help='use tracking algorithm', action='store_true')
    parser.add_argument('--n_subjects', type=int,
                        help='number of subjects to simulate', default=1)
    parser.add_argument('--out_dir',
                        default='./simulated_data/pseudosubjects',
                        help='location to save simulated data')
    return parser.parse_args([])

In [6]:
args = get_args()
print(f'simulating stop task for {args.n_subjects} subjects')
if args.paramfile is not None:
    with open(args.paramfile) as f:
        params = json.load(f)
else:
    params = None
print(params)

if args.random_seed is not None:
    np.random.seed(args.random_seed)

if args.tracking:
    ssd = trackingSSD()
else:
    ssd = fixedSSD(np.arange(args.min_ssd, args.max_ssd + args.ssd_step, args.ssd_step))
study = StopTaskStudy(ssd, args.out_dir)

# save some extra params for output to json
study.params['args'] = args.__dict__
study.params['pwd'] = os.getcwd()

for i in range(args.n_subjects):
    print(f'running subject {i + 1}')
    trialdata = study.run()
    study.save_trialdata()

    # summarize data - go trials are labeled with SSD of -inf so that
    # they get included in the summary
    print(trialdata.groupby('SSD').mean())
    print('go_accuracy', trialdata.query('trialtype=="go"').correct.mean())
    print(study.get_stopsignal_metrics())

simulating stop task for 1 subjects
None
running subject 1
                rt      resp
SSD                         
-inf    357.063432  0.999500
 0.0    265.000000  0.005988
 50.0   283.000000  0.029940
 100.0  311.125000  0.095808
 150.0  342.909091  0.066667
 200.0  369.724138  0.175758
 250.0  394.851852  0.323353
 300.0  404.043956  0.544910
 350.0  397.130435  0.826347
 400.0  397.186667  0.898204
 450.0  393.716981  0.952096
 500.0  370.383234  1.000000
 550.0  348.024096  0.994012
go_accuracy 0.5574787393696848
{'SSRT': {'mean': 81.86343171585793, 'integration': 60.80000000000001, 'omission': 60.80000000000001, 'replacement': 60.80000000000001}, 'mean_SSD': 275.2, 'p_respond': 0.4935, 'max_RT': 985.0, 'mean_go_RT': 357.0634317158579, 'mean_stopfail_RT': 380.8014184397163, 'omission_count': 5, 'omission_rate': 0.0005, 'go_acc': 0.9341670835417709, 'stopfail_acc': 0.839918946301925}
