In [None]:
import sys
sys.path.append("../")
sys.path.append("../../../assets")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef
import pandas as pd

from experiments import NonStationaryDDMExperiment
from models import MixtureRandomWalkDDM

In [None]:
# gpu setting and checking
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))

# Constants

In [None]:
TRAIN_NETWORKS = True

NUM_OBS = 768
NUM_SAMPLES = 4000

PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
PARAM_NAMES  = [r'$v$', r'$a$', r'$\tau$']
FONT_SIZE_1 = 22
FONT_SIZE_2 = 20
FONT_SIZE_3 = 18

In [None]:
model = MixtureRandomWalkDDM()
experiment = NonStationaryDDMExperiment(model, checkpoint_path="../checkpoints/mixture_random_walk_ddm")

# Training

In [None]:
if TRAIN_NETWORKS:
    history = experiment.run(
        epochs=75, 
        iterations_per_epoch=1000, 
        batch_size=16
    )
else:
    history = experiment.trainer.loss_history.get_plottable()

In [None]:
f = beef.diagnostics.plot_losses(history)

# Evaluation

## Calibration Error

## Parameter Recovery

# Parameter Estimation

In [None]:
data = pd.read_csv('../data/data_color_discrimination.csv')
data['rt'] = np.where(data['correct'] == 0, -data['rt'], data['rt'])

In [None]:
NUM_SUBJECTS = len(np.unique(data['id']))
emp_data = np.zeros((NUM_SUBJECTS, NUM_OBS, 1), dtype=np.float32)
for i in range(NUM_SUBJECTS):
    tmp = data[data['id'] == i+1]
    emp_data[i] = tmp['rt'].to_numpy()[:, np.newaxis]

emp_data.shape

In [None]:
%%time
local_post_samples = np.zeros((NUM_SUBJECTS, NUM_OBS, NUM_SAMPLES, 3))
hyper_post_samples = np.zeros((NUM_SUBJECTS, NUM_SAMPLES, 5))
with tf.device('/cpu:0'):
    for i in range(NUM_SUBJECTS):
        tmp_data = {'summary_conditions': emp_data[i:i+1]}
        samples = experiment.amortizer.sample(tmp_data, NUM_SAMPLES)
        local_post_samples[i] = samples['local_samples']
        hyper_post_samples[i] = samples['global_samples']