In [1]:
import sys
sys.path.append("../")
sys.path.append("../../../assets/diffusion")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
np.set_printoptions(suppress=True)
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef
import pandas as pd
import pickle

from experiments import RandomWalkDiffusionExperiment
from models import RandomWalkDiffusion

  from tqdm.autonotebook import tqdm


In [2]:
# If set to False, existing results will be loaded
# Set to True if you want to re-run the experiments
SIMULATE_DATA = True
TRAIN_NETWORKS = False

In [3]:
PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
PARAM_NAMES  = [r'$v$', r'$a$', r'$\tau$']

FONT_SIZE_1 = 20
FONT_SIZE_2 = 18
FONT_SIZE_1 = 16

In [4]:
random_walk_ddm = RandomWalkDiffusion()
experiment = RandomWalkDiffusionExperiment(random_walk_ddm)

INFO:root:Performing 2 pilot runs with the random_walk_diffusion_model model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 1320, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 1320)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 1320, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Loaded loss history from ../../optimal_policy/checkpoints/optimal_policy/history_50.pkl.
INFO:root:Networks loaded from ../../optimal_policy/checkpoints/optimal_policy/ckpt-50
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.


## Simulation

In [None]:
if SIMULATE_DATA:
    from models import StaticDiffusion, StationaryDiffusion, RegimeSwitchingDiffusion
    static_ddm = StaticDiffusion()
    stationary_ddm = StationaryDiffusion()
    regime_switching_ddm = RegimeSwitchingDiffusion()

    static_ddm_sim = static_ddm.generate(200)
    stationary_ddm_sim = stationary_ddm.generate(200)
    random_walk_ddm_sim = random_walk_ddm.generate(200)
    regime_switching_ddm_sim = regime_switching_ddm.generate(200)
    with open('../data/static_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(static_ddm_sim, f)
    with open('../data/stationary_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(stationary_ddm_sim, f)
    with open('../data/random_walk_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(random_walk_ddm_sim, f)
    with open('../data/regime_switching_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(regime_switching_ddm_sim, f)
else:
    with open('../data/static_ddm_sim_200.pkl', 'rb') as f:
        static_ddm_sim = pickle.load(f)
    with open('../data/stationary_ddm_sim_200.pkl', 'rb') as f:
        stationary_ddm_sim = pickle.load(f)
    with open('../data/random_walk_ddm_sim_200.pkl', 'rb') as f:
        random_walk_ddm_sim = pickle.load(f)
    with open('../data/regime_switching_ddm_sim_200.pkl', 'rb') as f:
        regime_switching_ddm_sim = pickle.load(f)

## Training

In [None]:
if TRAIN_NETWORKS:
    history = experiment.run(
        epochs=50, 
        iterations_per_epoch=1000, 
        batch_size=8
    )
else:
    history = experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(history)

## Evaluation