In [None]:
# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import pickle
import numpy as np
np.set_printoptions(suppress=True)

import bayesflow as beef
import tensorflow as tf

import sys
sys.path.append("../")
from experiments import SmoothingExperiment, FilteringExperiment
sys.path.append("../../../assets/")
from models import MixtureRandomWalkDDM

In [None]:
SIMULATE_VALIATION_DATA = False
FIT_MODEL = True

NUM_OBS = 800
NUM_SAMPLES = 4000
NUM_VALIDATION_SIMULATIONS = 1000

MICRO_PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
MICRO_PARAM_NAMES  = [r'v', r'a', r'\tau']
MACRO_PARAM_LABELS = ['Transition std. deviation', 'Switch Probability']
MACRO_PARAM_NAMES  = [r'\sigma', r'q']
COMPARISON_COLOR = '#133a76'

FONT_SIZE_1 = 22
FONT_SIZE_2 = 20
FONT_SIZE_3 = 18

FIG_SIZE = (18, 8)

In [None]:
model = MixtureRandomWalkDDM()
smoothing_experiment = SmoothingExperiment(model, "checkpoints/smoothing_summary_network")
filtering_experiment = FilteringExperiment(model, "checkpoints/filtering_summary_network")

In [None]:
if SIMULATE_VALIATION_DATA:
    validation_data = model.generate(NUM_VALIDATION_SIMULATIONS)
    with open('../data/validation_data.pkl', 'wb') as f:
        pickle.dump(validation_data, f)
else:
    with open('../data/validation_data.pkl', 'rb') as f:
        validation_data = pickle.load(f)

In [None]:
if FIT_MODEL:
    post_hyper_params = np.zeros((NUM_VALIDATION_SIMULATIONS, NUM_SAMPLES, 5))
    post_local_params = np.zeros((NUM_VALIDATION_SIMULATIONS, NUM_OBS, NUM_SAMPLES, 3))
    configured_val_data = model.configure(validation_data)['summary_conditions']
    for i in range(NUM_VALIDATION_SIMULATIONS):
        post_samples = smoothing_experiment.amortizer.sample(
            {'summary_conditions': configured_val_data[i:i+1]},
            NUM_SAMPLES
            )
        post_hyper_params[i] = post_samples['global_samples']
        post_local_params[i] = post_samples['local_samples']
    with open('../data/post_hyper_params.pkl', 'wb') as f:
        pickle.dump(post_hyper_params, f)
    with open('../data/post_local_params.pkl', 'wb') as f:
        pickle.dump(post_local_params, f)
else:
    with open('../data/post_hyper_params.pkl', 'rb') as f:
        post_hyper_params = pickle.load(f)
    with open('../data/post_local_params.pkl', 'rb') as f:
        post_local_params = pickle.load(f)