In [3]:
import sys
sys.path.append("../")
sys.path.append("../../../assets")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import bayesflow as beef
import pandas as pd
import pickle

from experiments import NonStationaryDDMExperiment
from models import MixtureRandomWalkDDM, LevyFlightDDM, RegimeSwitchingDDM

In [None]:
# gpu setting and checking
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))

In [2]:
FIT_MODELS = True

NUM_OBS = 768
NUM_SAMPLES = 1000

LOCAL_PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
LOCAL_PARAM_NAMES  = [r'v', r'a', r'\tau']
HYPER_PARAM_LABELS = ['Transition std. deviation', 'Switch Probability']
HYPER_PARAM_NAMES  = [r'\sigma', r'q']

COMPARISON_COLOR = '#133a76'

FONT_SIZE_1 = 22
FONT_SIZE_2 = 20
FONT_SIZE_3 = 18

In [None]:
mixture_random_walk_model = MixtureRandomWalkDDM()
mixture_random_walk_experiment = NonStationaryDDMExperiment(
    mixture_random_walk_model,
    checkpoint_path="../checkpoints/smoothing_mixture_random_walk_ddm"
    )

levy_flight_model = LevyFlightDDM()
levy_flight_experiment = NonStationaryDDMExperiment(
    levy_flight_model,
    checkpoint_path="../checkpoints/smoothing_levy_flight_ddm"
    )

regime_switching_model = RegimeSwitchingDDM()
regime_switching_experiment = NonStationaryDDMExperiment(
    regime_switching_model,
    checkpoint_path="../checkpoints/smoothing_regime_switching_ddm"
    )

# Inference

In [None]:
data = pd.read_csv('../data/data_color_discrimination.csv')
data['rt'] = np.where(data['correct'] == 0, -data['rt'], data['rt'])

In [None]:
NUM_SUBJECTS = len(np.unique(data['id']))
emp_data = np.zeros((NUM_SUBJECTS, NUM_OBS, 1), dtype=np.float32)
for i in range(NUM_SUBJECTS):
    tmp = data[data['id'] == i+1]
    emp_data[i] = tmp['rt'].to_numpy()[:, np.newaxis]

emp_data.shape

In [None]:
if FIT_MODELS:
    mrw_local_post_samples_z = np.zeros((NUM_SUBJECTS, NUM_OBS, NUM_SAMPLES, 3))
    mrw_hyper_post_samples_z = np.zeros((NUM_SUBJECTS, NUM_SAMPLES, 5))
    lf_local_post_samples_z = np.zeros((NUM_SUBJECTS, NUM_OBS, NUM_SAMPLES, 3))
    lf_hyper_post_samples_z = np.zeros((NUM_SUBJECTS, NUM_SAMPLES, 5))
    rs_local_post_samples_z = np.zeros((NUM_SUBJECTS, NUM_OBS, NUM_SAMPLES, 3))
    rs_hyper_post_samples_z = np.zeros((NUM_SUBJECTS, NUM_SAMPLES, 5))

    with tf.device('/cpu:0'):
        for i in range(NUM_SUBJECTS):
            tmp_data = {'summary_conditions': emp_data[i:i+1]}
            samples = mixture_random_walk_experiment.amortizer.sample(tmp_data, NUM_SAMPLES)
            mrw_local_post_samples_z[i] = samples['local_samples']
            mrw_hyper_post_samples_z[i] = samples['global_samples']
            samples = levy_flight_experiment.amortizer.sample(tmp_data, NUM_SAMPLES)
            lf_local_post_samples_z[i] = samples['local_samples']
            lf_hyper_post_samples_z[i] = samples['global_samples']
            samples = regime_switching_experiment.amortizer.sample(tmp_data, NUM_SAMPLES)
            rs_local_post_samples_z[i] = samples['local_samples']
            rs_hyper_post_samples_z[i] = samples['global_samples']

    mrw_local_post_samples = mrw_local_post_samples_z * mixture_random_walk_model.local_prior_stds + mixture_random_walk_model.local_prior_means
    mrw_hyper_post_samples = mrw_hyper_post_samples_z * mixture_random_walk_model.hyper_prior_std + mixture_random_walk_model.hyper_prior_mean
    lf_local_post_samples = lf_local_post_samples_z * levy_flight_model.local_prior_stds + levy_flight_model.local_prior_means
    lf_hyper_post_samples = lf_hyper_post_samples_z * levy_flight_model.hyper_prior_std + levy_flight_model.hyper_prior_mean
    rs_local_post_samples = rs_local_post_samples_z * regime_switching_model.local_prior_stds + regime_switching_model.local_prior_means
    rs_hyper_post_samples = rs_hyper_post_samples_z * regime_switching_model.hyper_prior_std + regime_switching_model.hyper_prior_mean

    with open('../data/posteriors/mrw_local_post_samples.pkl', 'wb') as f:
        pickle.dump(mrw_local_post_samples, f)
    with open('../data/posteriors/mrw_hyper_post_samples.pkl', 'wb') as f:
        pickle.dump(mrw_hyper_post_samples, f)
    with open('../data/posteriors/lf_local_post_samples.pkl', 'wb') as f:
        pickle.dump(lf_local_post_samples, f)
    with open('../data/posteriors/lf_hyper_post_samples.pkl', 'wb') as f:
        pickle.dump(lf_hyper_post_samples, f)
    with open('../data/posteriors/rs_local_post_samples.pkl', 'wb') as f:
        pickle.dump(rs_local_post_samples, f)
    with open('../data/posteriors/rs_hyper_post_samples.pkl', 'wb') as f:
        pickle.dump(rs_hyper_post_samples, f)
else:
    with open('../data/posteriors/mrw_local_post_samples.pkl', 'rb') as f:
        mrw_local_post_samples = pickle.load(f)
    with open('../data/posteriors/mrw_hyper_post_samples.pkl', 'rb') as f:
        mrw_hyper_post_samples = pickle.load(f)
    with open('../data/posteriors/lf_local_post_samples.pkl', 'rb') as f:
        lf_local_post_samples = pickle.load(f)
    with open('../data/posteriors/lf_hyper_post_samples.pkl', 'rb') as f:
        lf_hyper_post_samples = pickle.load(f)
    with open('../data/posteriors/rs_local_post_samples.pkl', 'rb') as f:
        rs_local_post_samples = pickle.load(f)
    with open('../data/posteriors/rs_hyper_post_samples.pkl', 'rb') as f:
        rs_hyper_post_samples = pickle.load(f)