In [1]:
import numpy as np
import pickle
import sys
sys.path.append("../../../assets")
from models import RandomWalkDDM, MixtureRandomWalkDDM, LevyFlightDDM, RegimeSwitchingDDM

2023-08-04 13:17:05.175964: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from tqdm.autonotebook import tqdm


# Constants

In [2]:
NUM_TRAINING_SIMULATIONS = 5000
NUM_VALIDATION_SIMULATIONS = 200

# Data Simulation

In [3]:
# initialize generative models
random_walk_ddm = RandomWalkDDM()
mixture_random_walk_ddm = MixtureRandomWalkDDM()
levy_flight_ddm = LevyFlightDDM()
regime_switching_ddm = RegimeSwitchingDDM()

INFO:root:Performing 2 pilot runs with the random_walk_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Performing 2 pilot runs with the mixture_random_walk_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch 

In [6]:
random_walk_ddm.generate(1)
mixture_random_walk_ddm.generate(1)
levy_flight_ddm.generate(1)
regime_switching_ddm.generate(1)

{'sim_data': array([[ 1.16211157,  1.17614336,  1.12792421,  1.17235578,  1.19918685,
          1.20085339,  1.22842443,  1.23751711,  1.1958213 ,  1.14589916,
          1.16666637,  1.22980245,  1.22605084,  1.17730779,  1.34112833,
          1.17961469,  1.25196646,  1.24905137,  1.17965818,  1.20679438,
          1.20351982,  1.27705842,  1.24587622,  1.56079626,  1.5186478 ,
          1.60228885,  1.40570484,  1.45681397,  1.75840947,  1.61401672,
          1.62659517,  1.63518482,  1.63964582,  1.39670603,  1.19297274,
          1.17336196, -1.18680744,  1.22132334,  1.28351116,  1.24928119,
          1.2616619 ,  1.24110041,  1.27251385, -1.22455598,  1.19893844,
          1.24626378,  1.42900165,  1.20129671,  1.22714474,  1.24570309,
          1.22272016,  1.23247817,  1.20884088, -1.21054191, -1.21422394,
          1.23711015,  1.19857451,  1.17764361,  1.26530936,  1.1730383 ,
          1.20211691,  1.18628848,  1.23221478,  1.24861053,  1.23149488,
         -1.20526163,  1.2

## Training Data

In [7]:
%%time
training_data_m1 = random_walk_ddm.generate(NUM_TRAINING_SIMULATIONS)
training_data_m2 = mixture_random_walk_ddm.generate(NUM_TRAINING_SIMULATIONS)
training_data_m3 = levy_flight_ddm.generate(NUM_TRAINING_SIMULATIONS)
training_data_m4 = regime_switching_ddm.generate(NUM_TRAINING_SIMULATIONS)

CPU times: user 9min 21s, sys: 283 ms, total: 9min 21s
Wall time: 9min 21s


In [8]:
training_data = {
    'model_outputs': [
        {'sim_data': training_data_m1['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': training_data_m2['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': training_data_m3['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': training_data_m4['sim_data'][:, :, None].astype(np.float32)}
        ],
    'model_indices': np.arange(4)
}

In [9]:
with open('../data/training_data.pkl', 'wb') as f:
    pickle.dump(training_data, f)

## Validation Data

In [10]:
%%time
validation_data_m1 = random_walk_ddm.generate(NUM_VALIDATION_SIMULATIONS)
validation_data_m2 = mixture_random_walk_ddm.generate(NUM_VALIDATION_SIMULATIONS)
validation_data_m3 = levy_flight_ddm.generate(NUM_VALIDATION_SIMULATIONS)
validation_data_m4 = regime_switching_ddm.generate(NUM_VALIDATION_SIMULATIONS)

CPU times: user 22.9 s, sys: 4.02 ms, total: 22.9 s
Wall time: 22.9 s


In [11]:
validation_data = {
    'model_outputs': [
        {'sim_data': validation_data_m1['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': validation_data_m2['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': validation_data_m3['sim_data'][:, :, None].astype(np.float32)},
        {'sim_data': validation_data_m4['sim_data'][:, :, None].astype(np.float32)}
        ],
    'model_indices': np.arange(4)
}

In [12]:
with open('../data/validation_data.pkl', 'wb') as f:
    pickle.dump(validation_data, f)