In [1]:
import sys
sys.path.append("../")
sys.path.append("../../../assets/diffusion")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef
import pandas as pd

from experiments import RandomWalkDiffusionExperiment
from models import RandomWalkDiffusion

  from tqdm.autonotebook import tqdm


In [2]:
# If set to False, existing results will be loaded
# Set to True if you want to re-run the experiments
TRAIN_NETWORKS = False

## Neural Experiment

In [None]:
model = RandomWalkDiffusion()

In [None]:
neural_experiment = RandomWalkDiffusionExperiment(model)

### Training

In [None]:
if TRAIN_NETWORKS:
    history = neural_experiment.run(
        epochs=50, 
        iterations_per_epoch=1000, 
        batch_size=8
    )
else:
    history = neural_experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(history)

## Evaluation

In [None]:
PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
PARAM_NAMES  = [r'$v$', r'$a$', r'$\tau$']

FONT_SIZE_1 = 20
FONT_SIZE_2 = 18
FONT_SIZE_1 = 16

### Prior predictive checks

In [None]:
sim_data = model.generate(1)

In [None]:
post_samples = neural_experiment.amortizer.sample(model.configure(sim_data), 1000)

### Fit to empiric data

In [None]:
# prepare data
data = pd.read_csv('../data/optimal_policy_data.csv')

In [None]:
which = 0
person_data = {"summary_conditions": data['rt'].loc[data['id'] == which + 1].to_numpy()[None, ..., None]}

In [None]:
posterior_samples = neural_experiment.amortizer.sample(person_data, 2000)

In [None]:
post_means = posterior_samples['local_samples'].mean(axis=1)
post_stds = posterior_samples['local_samples'].std(axis=1)

In [None]:
f, axarr = plt.subplots(3, 1, figsize=(20, 16))
for i, ax in enumerate(axarr.flat):
    ax.plot(
        np.arange(post_means.shape[0]), post_means[:, i],
        color = "maroon"
    )
    ax.fill_between(
        np.arange(post_means.shape[0]),
        post_means[:, i] - post_stds[:, i],
        post_means[:, i] + post_stds[:, i],
        alpha = 0.5,
        color = "maroon"
    )
    ax.set_xlim([0, post_means.shape[0]])
    ax.set_title(PARAM_LABELS[i] + ' ({})'.format(PARAM_NAMES[i]), fontsize=20)
    if i == 0:
        ax.set_xlabel('Time', fontsize=18)
        ax.set_ylabel("Parameter value", fontsize=18)

    ax.set_title(PARAM_LABELS[i] + ' ({})'.format(PARAM_NAMES[i]), fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=16)
    
sns.despine()
f.tight_layout()