# Simulate stochastic choices

author: steeve laquitaine

This tutorial simulates the stochastic choices made by a standard Bayesian model.

## Setup

In [4]:
# go to the project's root path
import os
os.chdir("..")

In [5]:
# import dependencies
from bsfit.nodes.models.bayes import StandardBayes
from bsfit.nodes.dataEng import (
    simulate_task_conditions,
)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Set the parameters

In [19]:
# set the parameters
SUBJECT = "sub01"
PRIOR_SHAPE = "vonMisesPrior"
PRIOR_MODE = 225
OBJ_FUN = "maxLLH"
READOUT = "map"
PRIOR_NOISE = [80, 40]      # e.g., prior's std
STIM_NOISE = [0.33, 0.66, 1.0]
SIM_P = {
    "k_llh": [2.7, 10.7, 33],
    "k_prior": [2.7, 33],
    "k_card": [1],
    "prior_tail": [0],
    "p_rand": [0],
    "k_m": [2000],
}
GRANULARITY = "trial"
CENTERING = True

## Simulate task conditions (design matrix)

In [20]:
# simulate task conditions
conditions = simulate_task_conditions(
    stim_noise=STIM_NOISE,
    prior_mode=PRIOR_MODE,
    prior_noise=PRIOR_NOISE,
    prior_shape=PRIOR_SHAPE,
)

The task conditions are shown below.

In [8]:
conditions

Unnamed: 0,stim_mean,stim_std,prior_mode,prior_std,prior_shape
0,5,0.33,225,80,vonMisesPrior
1,10,0.33,225,80,vonMisesPrior
2,15,0.33,225,80,vonMisesPrior
3,20,0.33,225,80,vonMisesPrior
4,25,0.33,225,80,vonMisesPrior
...,...,...,...,...,...
67,340,1.00,225,40,vonMisesPrior
68,345,1.00,225,40,vonMisesPrior
69,350,1.00,225,40,vonMisesPrior
70,355,1.00,225,40,vonMisesPrior


## Simulate standard Bayes's trial predictions

In [21]:
# instantiate the model
model = StandardBayes(
    prior_shape=PRIOR_SHAPE, 
    prior_mode=PRIOR_MODE, 
    readout=READOUT
    )

# simulate trial predictions
output = model.simulate(
    dataset=conditions,
    sim_p=SIM_P,
    granularity=GRANULARITY,
    centering=CENTERING,
)

Running simulation ...

-logl:nan, aic:nan, kl:[ 2.7 10.7 33. ], kp:[ 2.7 33. ], kc:[1.], pt:0.00, pr:0.00, km:2000.00


The simulated dataset is shown below.

In [22]:
output["dataset"]

Unnamed: 0,stim_mean,stim_std,prior_mode,prior_std,prior_shape,estimate
0,5.0,0.33,225,80.0,vonMisesPrior,0
1,5.0,0.33,225,80.0,vonMisesPrior,9
2,5.0,0.33,225,80.0,vonMisesPrior,27
3,5.0,0.33,225,80.0,vonMisesPrior,15
4,5.0,0.33,225,80.0,vonMisesPrior,351
...,...,...,...,...,...,...
2155,360.0,1.00,225,40.0,vonMisesPrior,228
2156,360.0,1.00,225,40.0,vonMisesPrior,229
2157,360.0,1.00,225,40.0,vonMisesPrior,225
2158,360.0,1.00,225,40.0,vonMisesPrior,226


### Calculate prediction statistics

In [23]:
# [TODO]: fix 
# simulate predictions
from matplotlib import pyplot as plt
plt.figure(figsize=(15,5))

model = model.simulate(
    dataset=output["dataset"],
    sim_p=SIM_P,
    granularity="",
    centering=CENTERING,
)

Running simulation ...

Calculating predictions ...



IndexError: arrays used as indices must be of integer (or boolean) type

<Figure size 1080x360 with 0 Axes>

In [24]:
GRANULARITY = "mean"

# instantiate the model
model = StandardBayes(
    prior_shape=PRIOR_SHAPE, 
    prior_mode=PRIOR_MODE, 
    readout=READOUT
    )

# simulate trial predictions
output = model.simulate(
    dataset=conditions,
    sim_p=SIM_P,
    granularity=GRANULARITY,
    centering=CENTERING,
)

Running simulation ...

-logl:nan, aic:nan, kl:[ 2.7 10.7 33. ], kp:[ 2.7 33. ], kc:[1.], pt:0.00, pr:0.00, km:2000.00


In [25]:
output

{'PestimateGivenModel': array([[3.96452177e-002, 3.32554590e-002, 2.17299923e-002, ...,
         9.99988867e-320, 4.24043516e-019, 4.62592927e-019],
        [3.98449785e-002, 3.51428585e-002, 2.41347386e-002, ...,
         9.99988867e-320, 1.85037171e-018, 7.32438801e-019],
        [3.96452177e-002, 3.67673435e-002, 2.65415552e-002, ...,
         9.99988867e-320, 9.99988867e-320, 3.18032637e-018],
        ...,
        [3.67673435e-002, 2.65415552e-002, 1.49506586e-002, ...,
         6.50808271e-018, 5.07385130e-018, 9.99988867e-320],
        [3.80831160e-002, 2.89000307e-002, 1.71022242e-002, ...,
         1.08272976e-018, 9.99988867e-320, 9.99988867e-320],
        [3.90519839e-002, 3.11563910e-002, 1.93727456e-002, ...,
         9.99988867e-320, 9.99988867e-320, 9.99988867e-320]]),
 'map': array([  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
         12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,
         23.,  24.,  25.,  26.,  27.,  28.,  29.,  

Tutorial complete !