## Likelihood-free estimation of stop model parameters

This notebook shows how to use approximate Bayesian computation via sequential Monte Carlo (ABC-SMC) using the pyABC package.

In [1]:
# imports

import pyabc
import json
from pyabc import (ABCSMC,
                   RV, Distribution)
import numpy as np
import scipy.stats as st
import tempfile
import os
import pandas as pd
import matplotlib.pyplot as plt
import logging

from ssd import fixedSSD
from stoptaskstudy import StopTaskStudy
%matplotlib inline



In [2]:
# set to True for debugging outputs

debug = True

if debug:
    df_logger = logging.getLogger('Distance')
    df_logger.setLevel(logging.DEBUG)



In [6]:
# create a single-layer metrics dict for output
# since the pickler can't handle multilevel dicts
def cleanup_metrics(metrics):
    for k in metrics['SSRT']:
        metrics['SSRT_' + k] = metrics['SSRT'][k]
    del metrics['SSRT']
    return(metrics)


def stopsignal_model(parameters):
    paramfile = 'params.json'
    with open(paramfile) as f:
            params = json.load(f)
    # install the parameters from the simulation
    parameters['nondecision'] = int(parameters['nondecision'])
    params['mu']['go'] = parameters['mu_go']
    params['mu']['stop'] = parameters['mu_go'] + parameters['mu_stop_delta']
    params['mu_delta_incorrect']  = parameters['mu_delta_incorrect']
    params['noise_sd'] = {'go': parameters['noise_sd'],
                          'stop': parameters['noise_sd']}
    params['nondecision'] = {'go': parameters['nondecision'],
                             'stop': parameters['nondecision']}
    #print(params)
    # TBD
    #    if args.p_guess_file is not None:
    #        p_guess = pd.read_csv(args.p_guess_file, index_col=0)
    #        assert 'SSD' in p_guess.columns and 'p_guess' in p_guess.columns

    #    if args.random_seed is not None:
    #        np.random.seed(args.random_seed)

    min_ssd, max_ssd, ssd_step = 0, 550, 50
    ssd = fixedSSD(np.arange(min_ssd, max_ssd + ssd_step, ssd_step))

    study = StopTaskStudy(ssd, None, params=params)

    trialdata = study.run()
    metrics = study.get_stopsignal_metrics()
    # summarize data - go trials are labeled with SSD of -inf so that
    # they get included in the summary
    stop_data = trialdata.groupby('SSD').mean().query('SSD >= 0').resp.values
    results = {}
    
    metrics = cleanup_metrics(metrics)
    for k in [ 'mean_go_RT', 'mean_stopfail_RT', 'go_acc']:
        results.update({k: metrics[k]})
    for i, value in enumerate(stop_data):
        results[f'presp_{i}'] = value

    return(results)

parameter_prior = Distribution(mu_go=RV("uniform", 0, 1),
                               mu_stop_delta=RV("uniform", 0, 1),
                              mu_delta_incorrect=RV("uniform", 0, 1),
                              noise_sd=RV("uniform", 1, 4),
                              nondecision=RV("uniform", 40, 100))
parameter_prior.get_parameter_names()


['mu_delta_incorrect', 'mu_go', 'mu_stop_delta', 'noise_sd', 'nondecision']

In [7]:
params={'mu_delta_incorrect': 0.10386248711279868,
 'mu_go': 0.11422675271126799,
 'mu_stop_delta': 0.7850423871897488,
 'noise_sd': 3.1238287051634597,
       'nondecision':50}
simulation = stopsignal_model(params)
simulation

{'mean_go_RT': 449.40858169520186,
 'mean_stopfail_RT': 336.0724946695096,
 'go_acc': 0.6871697315578101,
 'presp_0': 0.0,
 'presp_1': 0.05952380952380952,
 'presp_2': 0.11904761904761904,
 'presp_3': 0.24096385542168675,
 'presp_4': 0.4878048780487805,
 'presp_5': 0.5421686746987951,
 'presp_6': 0.5833333333333334,
 'presp_7': 0.6265060240963856,
 'presp_8': 0.6867469879518072,
 'presp_9': 0.7142857142857143,
 'presp_10': 0.7590361445783133,
 'presp_11': 0.8095238095238095}

In [9]:

def rmse(a, b):
    return(np.sqrt(np.sum((a - b)**2)))

# sum errors for presp, gort, and stopfailrt
# scaling factors were determined by hand to roughly equate the rmse
# for the different result variables
def distance(simulation, data):
    presp_rmse = rmse(simulation['presp_0'], data['presp_0'])*3
    gort_rmse = rmse(simulation['mean_go_RT'], data['mean_go_RT'])/10
    goacc_rmse = rmse(simulation['go_acc'], data['go_acc']) * 10
    stopfailrt_rmse = rmse(simulation['mean_stopfail_RT'], data['stopfail_rt'])/10
    return(presp_rmse + gort_rmse + stopfailrt_rmse + goacc_rmse)

def distance2(simulation, data):
    sum_rmse = 0
    for k in simulation:
        sum_rmse += rmse(simulation[k], data[k])
    return(sum_rmse)

distance_adaptive = pyabc.AdaptivePNormDistance(p=2)
distance_fixed = pyabc.PNormDistance(p=2)
abc = ABCSMC(stopsignal_model, parameter_prior, distance_adaptive)


INFO:Sampler:Parallelizing the sampling on 4 cores.


In [12]:
db_path = ("sqlite:///" +
           os.path.join(tempfile.gettempdir(), "test.db"))
observed_presp = pd.read_csv('presp_by_ssd_inperson.txt',  delimiter=r"\s+", index_col=0)

observed_data = {'mean_go_RT': 455.367, 'mean_stopfail_RT': 219.364, 'go_acc': .935}
for i, value in enumerate(observed_presp.presp.values):
    observed_data[f'presp_{i}'] = value

# "presp": observed_presp.presp.values,
abc.new(db_path, observed_data)

INFO:History:Start <ABCSMC(id=5, start_time=2020-12-06 11:23:29.384133, end_time=None)>


<pyabc.storage.history.History at 0x7fdd0c749af0>

In [13]:
print(distance_adaptive(simulation, observed_data))
print(distance2(simulation, observed_data))

116.86203473053882
124.64008614833743


In [None]:
history = abc.run(minimum_epsilon=.1, max_nr_populations=12,)

INFO:ABC:Calibration sample before t=0.
DEBUG:Distance:updated weights[0] = {'mean_go_RT': 0.001784751425403656, 'mean_stopfail_RT': 0.0024140436683980247, 'go_acc': 1.4063451505849232, 'presp_0': 1.0972415819411678, 'presp_1': 0.8229296093829475, 'presp_2': 0.7702541585305369, 'presp_3': 0.8640236160929868, 'presp_4': 0.9578275282057516, 'presp_5': 1.0256230971573053, 'presp_6': 1.152266777027476, 'presp_7': 1.1929478113672578, 'presp_8': 1.2895688557172515, 'presp_9': 1.4259502370661732, 'presp_10': 1.4564301367862342, 'presp_11': 1.5343926450461862}
INFO:Epsilon:initial epsilon is 1.2135460118448589
INFO:ABC:t: 0, eps: 1.2135460118448589.
INFO:ABC:Acceptance rate: 100 / 200 = 5.0000e-01, ESS=1.0000e+02.
DEBUG:Distance:updated weights[1] = {'mean_go_RT': 0.0016479132673832236, 'mean_stopfail_RT': 0.0021345288169192216, 'go_acc': 1.3123808677497324, 'presp_0': 1.3443697703981174, 'presp_1': 0.8005459776087006, 'presp_2': 0.7186858381600166, 'presp_3': 0.7825691651886789, 'presp_4': 0.

In [None]:
plot_kde = False
if plot_kde:
    fig, ax = plt.subplots()
    for t in range(history.max_t+1):
        df, w = history.get_distribution(m=0, t=t)
        pyabc.visualization.plot_kde_1d(
            df, w,
            xmin=0, xmax=1,
            x="mu_go", ax=ax,
            label="PDF t={}".format(t))
    #ax.axvline(observed_presp.presp.values, color="k", linestyle="dashed");
    ax.legend();


In [None]:
plot_ci = False
if plot_ci:
    ci_ax = pyabc.visualization.plot_credible_intervals(history)
    def get_map_estimates(ci_ax):
        map_estimates = {}
        for ax in ci_ax:
            map_estimates[ax.get_ylabel()] = ax.get_lines()[0].get_ydata()[-1]
        return(map_estimates)
    map_estimates = get_map_estimates(ci_ax)
    map_estimates

In [None]:

simulation = stopsignal_model(map_estimates)
print(simulation)

plot_presp = False
if plot_presp:
    plt.plot(simulation['presp'])
    plt.plot(observed_presp.presp.values, 'k')


In [None]:
params