# Title

# Wald Distribution Intuition

In [4]:
import numpy as np
from ipywidgets import interact
import matplotlib.pyplot as plt

def wald(x, alpha, gamma, theta):
    first = alpha / (np.sqrt(2 * np.pi * np.power(x - theta, 3)) + 1e-6)
    second = np.exp(-(np.power(alpha - gamma * (x - theta), 2)) / (((2 * (x - theta))) + 1e-6))
    return first * second

def plot_wald(db, dr, ndt):
    x = np.arange(ndt, 1.75, .001, dtype=np.float64)
    plt.plot(x, wald(x, db, dr, ndt))
    plt.xlim((0, 1.75))
    plt.show()
    
interact(plot_wald, db=(0, 5, .01), dr=(0, 8, .01), ndt=(0, 2, .01));

# Model Structure

# Prior Determination

# Model Fitting

In [1]:
import pandas as pd
import sys
sys.path.append('../src')
from utils import select_subjects

# load in behavioral data and remove bad trials
data = pd.read_csv('../data/derivatives/behavior/group_data.tsv', sep='\t', 
                   na_values='n/a')
exclusions = ['no_response', 'error', 'post_error', 'fast_rt']
data = data[data[exclusions].sum(axis=1) == 0]

# exclude subjects
subjects = select_subjects('both')
data = data[data.participant_id.isin(subjects)]

# extract data for model
data_in = {}
data_in['Nt'] = data.shape[0]
data_in['rt'] = data.response_time
data_in['Ns']  = len(data.participant_id.unique())
data_in['min_rt'] = data.groupby(['participant_id']).response_time.min().as_matrix() - .001
data_in['sub_ix']= data.participant_id.astype('category').cat.codes + 1
data_in['tt']= data.trial_type.astype('category').cat.codes
data_in['mod']= data.modality.astype('category').cat.codes

In [2]:
import pickle
import os
import pystan

model_name = 'wald_hierarchical'

model_f = '../models/%s.stan' % model_name
cache_f = '../models/%s.pkl' % model_name

if not os.path.exists(cache_f):
    model = pystan.StanModel(file=model_f, model_name=model_name)
    pickle.dump(model, open(cache_f, 'w+'))
else:
    model = pickle.load(open(cache_f, 'r'))

In [7]:
niter = 3000
nwarmup = 2000
nchains = 3

fit = model.sampling(data=data_in, iter=niter, warmup=nwarmup,
                     chains=nchains, init=0, seed=13, thin=5)

# samples = extract_samples(fit, sample_exclude)
# samples.to_csv('%s/%s_samples.csv' % (model_dir, mod), index=False)

# summary = extract_summary(fit, summary_exclude)
# summary.to_csv('%s/%s_summary.csv' % (model_dir, mod), index=True)

In [4]:
import sys
sys.path.append('../src')
from stan_utils import extract_summary, extract_samples

sample_exclude = ['mu', 'lambda', 'db', 'dr', 'ndt', 'ndt_tmp', 'ndt_sub_tmp']
summary_exclude = sample_exclude + ['rt_pred', 'log_lik']

samples = extract_samples(fit, sample_exclude)
samples.to_csv('../models/wald_hierarchical_samples.csv', index=False)

summary = extract_summary(fit, summary_exclude)
summary.to_csv('../models/wald_hierarchical_summary.csv', index=True)

# Model Diagnostics

In [8]:
import sys
sys.path.append('../src')
from stan_utils import param_plot
from ipywidgets import interact
import pandas as pd


params = [p for p in samples.columns.values 
          if 'dr' in p or 'db' in p or 'ndt' in p]

def view_param(param): 
    
    fig = param_plot(param, samples, summary)
    fig.show()

interact(view_param, param=params)

<function __main__.view_param>