In [1]:
import sys
sys.path.append("../")
sys.path.append("../../../assets/diffusion")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
np.set_printoptions(suppress=True)
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef
import pandas as pd
import pickle

from experiments import RandomWalkDiffusionExperiment
from models import RandomWalkDiffusion
from configuration import default_num_steps

  from tqdm.autonotebook import tqdm


In [2]:
# If set to False, existing results will be loaded
# Set to True if you want to re-run the experiments
SIMULATE_DATA = False
TRAIN_NETWORKS = False
FIT_MODEL = False

NUM_SIM = 200
NUM_STEPS = 400
NUM_SAMPLES = 2000

In [3]:
PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
PARAM_NAMES  = [r'$v$', r'$a$', r'$\tau$']

FONT_SIZE_1 = 20
FONT_SIZE_2 = 18
FONT_SIZE_1 = 16

In [4]:
random_walk_ddm = RandomWalkDiffusion(**default_num_steps)
experiment = RandomWalkDiffusionExperiment(random_walk_ddm)

INFO:root:Performing 2 pilot runs with the random_walk_diffusion_model model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 400, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 400)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 400, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Loaded loss history from ../../optimal_policy/checkpoints/optimal_policy/history_50.pkl.
INFO:root:Networks loaded from ../../optimal_policy/checkpoints/optimal_policy/ckpt-50
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.


## Simulation

In [None]:
if SIMULATE_DATA:
    from models import StaticDiffusion, StationaryDiffusion, RegimeSwitchingDiffusion
    static_ddm = StaticDiffusion()
    stationary_ddm = StationaryDiffusion()
    regime_switching_ddm = RegimeSwitchingDiffusion()
    static_ddm_sim = static_ddm.generate(NUM_SIM)
    stationary_ddm_sim = stationary_ddm.generate(NUM_SIM)
    random_walk_ddm_sim = random_walk_ddm.generate(NUM_SIM)
    regime_switching_ddm_sim = regime_switching_ddm.generate(NUM_SIM)
    with open('../data/static_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(static_ddm_sim, f)
    with open('../data/stationary_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(stationary_ddm_sim, f)
    with open('../data/random_walk_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(random_walk_ddm_sim, f)
    with open('../data/regime_switching_ddm_sim_200.pkl', 'wb') as f:
        pickle.dump(regime_switching_ddm_sim, f)
else:
    with open('../data/static_ddm_sim_200.pkl', 'rb') as f:
        static_ddm_sim = pickle.load(f)
    with open('../data/stationary_ddm_sim_200.pkl', 'rb') as f:
        stationary_ddm_sim = pickle.load(f)
    with open('../data/random_walk_ddm_sim_200.pkl', 'rb') as f:
        random_walk_ddm_sim = pickle.load(f)
    with open('../data/regime_switching_ddm_sim_200.pkl', 'rb') as f:
        regime_switching_ddm_sim = pickle.load(f)

## Training

In [None]:
if TRAIN_NETWORKS:
    history = experiment.run(
        epochs=50, 
        iterations_per_epoch=1000, 
        batch_size=8
    )
else:
    history = experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(history)

## Evaluation

In [None]:
with tf.device('/cpu:0'):
    if FIT_MODEL:
        local_post_static = np.zeros((NUM_SIM, NUM_STEPS, NUM_SAMPLES, 3))
        global_post_static = np.zeros((NUM_SIM, NUM_SAMPLES, 3))
        local_post_stationary = np.zeros((NUM_SIM, NUM_STEPS, NUM_SAMPLES, 3))
        global_post_stationary = np.zeros((NUM_SIM, NUM_SAMPLES, 3))
        local_post_random_walk = np.zeros((NUM_SIM, NUM_STEPS, NUM_SAMPLES, 3))
        global_post_random_walk = np.zeros((NUM_SIM, NUM_SAMPLES, 3))
        local_post_regime_switching = np.zeros((NUM_SIM, NUM_STEPS, NUM_SAMPLES, 3))
        global_post_regime_switching = np.zeros((NUM_SIM, NUM_SAMPLES, 3))


        for sim in range(NUM_SIM):
            # static ddm
            post_samples = experiment.amortizer.sample(
                {'summary_conditions': static_ddm_sim['sim_data'][sim:sim+1, :, None]},
                NUM_SAMPLES
            )
            local_post_static[sim] = post_samples['local_samples']
            global_post_static[sim] = post_samples['global_samples']
            # stationary ddm
            post_samples = experiment.amortizer.sample(
                {'summary_conditions': stationary_ddm_sim['sim_data'][sim:sim+1, :, None]},
                NUM_SAMPLES
            )
            local_post_stationary[sim] = post_samples['local_samples']
            global_post_stationary[sim] = post_samples['global_samples']
            # random walk ddm
            post_samples = experiment.amortizer.sample(
                {'summary_conditions': random_walk_ddm_sim['sim_data'][sim:sim+1, :, None]},
                NUM_SAMPLES
            )
            local_post_random_walk[sim] = post_samples['local_samples']
            global_post_random_walk[sim] = post_samples['global_samples']
            # regime switching ddm
            post_samples = experiment.amortizer.sample(
                {'summary_conditions': regime_switching_ddm_sim['sim_data'][sim:sim+1, :, None]},
                NUM_SAMPLES
            )
            local_post_regime_switching[sim] = post_samples['local_samples']
            global_post_regime_switching[sim] = post_samples['global_samples']

            print(sim)
            
        np.save('../data/local_post_static.npy', local_post_static)
        np.save('../data/global_post_static.npy', global_post_static)
        np.save('../data/local_post_stationary.npy', local_post_stationary)
        np.save('../data/global_post_stationary.npy', global_post_stationary)
        np.save('../data/local_post_random_walk.npy', local_post_random_walk)
        np.save('../data/global_post_random_walk.npy', global_post_random_walk)
        np.save('../data/local_post_regime_switching.npy', local_post_regime_switching)
        np.save('../data/global_post_regime_switching.npy', global_post_regime_switching)

    else:
        local_post_static = np.load('../data/local_post_static.npy')
        global_post_static = np.load('../data/global_post_static.npy')
        local_post_stationary = np.load('../data/local_post_stationary.npy')
        global_post_stationary = np.load('../data/global_post_stationary.npy')
        local_post_random_walk = np.load('../data/local_post_random_walk.npy')
        global_post_random_walk = np.load('../data/global_post_random_walk.npy')
        local_post_regime_switching = np.load('../data/local_post_regime_switching.npy')
        global_post_regime_switching = np.load('../data/global_post_regime_switching.npy')
        

In [None]:
post_samples['local_samples'].shape

In [None]:
post_samples_not_z = post_samples['local_samples'] * random_walk_ddm.local_prior_stds + random_walk_ddm.local_prior_means
post_means = post_samples_not_z.mean(axis=1)
post_stds = post_samples_not_z.std(axis=1)

In [None]:
true_params = regime_switching_ddm_sim['prior_draws'][0]

In [None]:
f, axarr = plt.subplots(3, 1, figsize=(20, 16))
for i, ax in enumerate(axarr.flat):
    ax.plot(
        np.arange(post_means.shape[0]), post_means[:, i],
        color = "maroon"
    )
    ax.fill_between(
        np.arange(post_means.shape[0]),
        post_means[:, i] - post_stds[:, i],
        post_means[:, i] + post_stds[:, i],
        alpha = 0.5,
        color = "maroon"
    )

    ax.plot(
        np.arange(true_params.shape[0]), true_params[:, i],
        color = "black"
    )

    ax.set_xlim([0, post_means.shape[0]])
    ax.set_title(PARAM_LABELS[i] + ' ({})'.format(PARAM_NAMES[i]), fontsize=20)
    if i == 0:
        ax.set_xlabel('Time', fontsize=18)
        ax.set_ylabel("Parameter value", fontsize=18)

    ax.set_title(PARAM_LABELS[i] + ' ({})'.format(PARAM_NAMES[i]), fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=16)
    
sns.despine()
f.tight_layout()

In [None]:
# post_samples = np.zeros((N_SUBS, N_OBS, N_SAMPLES, 3))
# with tf.device('/cpu:0'):
#     for i in range(N_SUBS):
#         tmp_data = {'summary_conditions': emp_data[i:i+1]}
#         samples = amortizer.sample(tmp_data, N_SAMPLES)
#         post_samples[i] = samples['local_samples']

# post_samples.shape