In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from notorious.agents import AgentsPIT
sns.set_style('white')
sns.set_context('notebook', font_scale=1.25)
%matplotlib inline

## Section 1: Simulate Data

In [None]:
np.random.seed(47404)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Agent parameters.
n_agents = 20
beta = np.random.uniform( 4.0, 8.0, n_agents)
eta  = np.random.uniform( 0.1, 0.4, n_agents)
tau  = np.random.uniform(-0.1, 0.3, n_agents)
nu   = np.random.uniform( 0.2, 0.6, n_agents)

beta = 8.00
eta  = 0.25
tau  = 0.20
nu   = 0.00

## Task parameters
n_trials = 20
n_blocks = 4
params = dict(GW  = (0.2,0.8, 0), NGW  = (0.8,0.2, 0),
              GAL = (0.2,0.8,-1), NGAL = (0.8,0.2,-1))

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Simulate behavior.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Initialize agents.
agents = AgentsPIT(beta, eta, tau, nu, w=0, n_agents=n_agents)

## Preallocate space.
Y = np.zeros((n_blocks,len(params),n_agents,n_trials),dtype=int)
R = np.zeros((n_blocks,len(params),n_agents,n_trials))

for i in range(n_blocks):

    for j, (p1, p2, s) in enumerate(params.values()):

        ## Simulate outcomes.
        r = np.random.binomial(1, (p1,p2), (n_agents,n_trials,2)) + s
        
        ## Simulate behavior.
        y = agents.train(r)
        
        ## Store behavior.
        m,n = y.shape; I,J = np.ogrid[:m,:n]
        R[i,j] = r[I,J,y]; Y[i,j] = y
        
## Reshape behavior.
Y = np.moveaxis(Y, (3,2), (1,2)).reshape(n_blocks, n_trials, -1)
R = np.moveaxis(R, (3,2), (1,2)).reshape(n_blocks, n_trials, -1)

## Section 2: Fit Models w/ Stan

In [None]:
import os, pystan
from stantools.io import load_model, save_fit

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Define parameters.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## I/O parameters.
stan_model = f'pit_m1'

## Sampling parameters.
samples = 1500
warmup = 1000
chains = 4
thin = 1
n_jobs = 4

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Prepare data.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#

## Define metadata.
N = n_agents
K = n_blocks
S = len(params)
T = n_trials

## Define mapping.
ix = np.repeat(np.arange(n_agents)+1, S)

## Assemble data.
dd = dict(N=N, K=K, S=S, T=T, Y=Y, R=R, ix=ix)

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
### Fit Stan Model.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
    
## Load StanModel
StanModel = load_model(os.path.join('stan_models',stan_model))

## Fit model.
StanFit = StanModel.sampling(data=dd, iter=samples, warmup=warmup, chains=chains, 
                             thin=thin, n_jobs=n_jobs, seed=0)

## Save.
# save_fit(f, StanFit, data=dd)