In [9]:
import numpy as np
import pickle
import sys
sys.path.append("../../../assets")
from models import RandomWalkDDM, MixtureRandomWalkDDM, LevyFlightDDM, RegimeSwitchingDDM

# Constants

In [10]:
NUM_TRAINING_SIMULATIONS = 25000
NUM_VALIDATION_SIMULATIONS = 250

# Data Simulation

In [11]:
# initialize generative models
random_walk_ddm = RandomWalkDDM()
mixture_random_walk_ddm = MixtureRandomWalkDDM()
levy_flight_ddm = LevyFlightDDM()
regime_switching_ddm = RegimeSwitchingDDM()

INFO:root:Performing 2 pilot runs with the random_walk_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.


INFO:root:No optional prior non-batchable context provided.
INFO:root:Performing 2 pilot runs with the mixture_random_walk_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 5)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Performing 2 pilot runs with the levy_flight_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batc

## Training Data

In [12]:
%%time
training_data_m1 = random_walk_ddm.generate(NUM_TRAINING_SIMULATIONS)
training_data_m2 = mixture_random_walk_ddm.generate(NUM_TRAINING_SIMULATIONS)
training_data_m3 = levy_flight_ddm.generate(NUM_TRAINING_SIMULATIONS)
training_data_m4 = regime_switching_ddm.generate(NUM_TRAINING_SIMULATIONS)

CPU times: user 24min 49s, sys: 8.36 s, total: 24min 57s
Wall time: 25min 5s


In [13]:
training_data = {
    'model_outputs': [
        {'sim_data': training_data_m1['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': training_data_m2['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': training_data_m3['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': training_data_m4['sim_data'][:, :, None].astype(np.float32)}
        ],
    'model_indices': np.arange(4)
}

In [15]:
with open('../data/training_data.pkl', 'wb') as f:
    pickle.dump(training_data, f)

## Validation Data

In [16]:
%%time
validation_data_m1 = random_walk_ddm.generate(NUM_VALIDATION_SIMULATIONS)
validation_data_m2 = mixture_random_walk_ddm.generate(NUM_VALIDATION_SIMULATIONS)
validation_data_m3 = levy_flight_ddm.generate(NUM_VALIDATION_SIMULATIONS)
validation_data_m4 = regime_switching_ddm.generate(NUM_VALIDATION_SIMULATIONS)

CPU times: user 1min, sys: 397 ms, total: 1min 1s
Wall time: 1min 1s


In [17]:
validation_data = {
    'model_outputs': [
        {'sim_data': validation_data_m1['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': validation_data_m2['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': validation_data_m3['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': validation_data_m4['sim_data'][:, :, None].astype(np.float32)}
        ],
    'model_indices': np.arange(4)
}

In [18]:
with open('../data/validation_data.pkl', 'wb') as f:
    pickle.dump(validation_data, f)