In [1]:
# import sys
# sys.path.append("../")
# sys.path.append("../../../assets")

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.backend import clear_session
import bayesflow as beef
import pandas as pd
import pickle

# from experiments import NonStationaryDDMExperiment
# from models import RandomWalkDDM, MixtureRandomWalkDDM, LevyFlightDDM, RegimeSwitchingDDM
from helpers import get_setup
from configurations import model_names

  from tqdm.autonotebook import tqdm


In [14]:
with open('data/posteriors/samples_per_model.pkl', 'rb') as file:
    samples_per_model = pickle.load(file)

In [8]:
get_setup(model_names[0], skip_checks=True)

INFO:root:Performing 2 pilot runs with the random_walk_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Loaded loss history from checkpoints/smoothing_random_walk_ddm/history_75.pkl.
INFO:root:Networks loaded from checkpoints/smoothing_random_walk_ddm/ckpt-75


[<models.RandomWalkDDM at 0x2974bb010>,
 <experiments.NonStationaryDDMExperiment at 0x2979ce290>]

In [9]:
setup = [get_setup(names, "smoothing") for names in model_names]
models = [model[0] for model in setup]
trainers = [trainer[1] for trainer in setup]

INFO:root:Performing 2 pilot runs with the random_walk_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Loaded loss history from checkpoints/smoothing_random_walk_ddm/history_75.pkl.
INFO:root:Networks loaded from checkpoints/smoothing_random_walk_ddm/ckpt-75
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.
INFO:root:Performing 2 pilot runs with t

In [3]:
models

[<models.RandomWalkDDM at 0x28a34c950>,
 <models.MixtureRandomWalkDDM at 0x28de20f10>,
 <models.LevyFlightDDM at 0x28de91290>,
 <models.RegimeSwitchingDDM at 0x28f2496d0>]

In [None]:
NUM_OBS = 768
NUM_SAMPLES = 1000
NUM_RESIMULATIONS = 100

FONT_SIZE_1 = 24
FONT_SIZE_2 = 20
FONT_SIZE_3 = 16

import matplotlib
matplotlib.rcParams['font.sans-serif'] = "Palatino"
matplotlib.rcParams['font.family'] = "sans-serif"

# Fit models

In [None]:
which = 6
data = pd.read_csv('../data/data_color_discrimination.csv')
person_data = data.loc[data.id == which]
person_rt = np.where(person_data['correct'] == 0, -person_data['rt'], person_data['rt'])[None, :, None]
person_rt.shape

In [None]:
samples_z = mrw_ddm_exp.amortizer.sample({'summary_conditions': person_rt}, NUM_SAMPLES)

In [None]:
local_post = samples_z['local_samples'] * mrw_ddm.local_prior_stds + mrw_ddm.local_prior_means
local_post_t = np.transpose(local_post, (1, 0, 2))

In [None]:
idx = np.random.choice(np.arange(NUM_SAMPLES), NUM_RESIMULATIONS, replace=False)
pred_data = np.abs(mrw_ddm.likelihood(local_post_t[idx, :, :])['sim_data'])

In [None]:
pred_df = pd.DataFrame({
    'speed_condition': np.tile(person_data['speed_condition'], NUM_RESIMULATIONS),
    'difficulty': np.tile(person_data['difficulty'], NUM_RESIMULATIONS),
    'rt': pred_data.flatten(),
    })

In [None]:
grouped = pred_df.groupby(['difficulty', 'speed_condition'])
pred_summary = grouped.agg({
    'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
})
pred_summary = pred_summary.reset_index(drop=False)
pred_summary.columns = ['difficulty', 'speed_condition', 'median', 'mad']
pred_summary

In [None]:
grouped = data.groupby(['speed_condition', 'difficulty'])
true_summary = grouped.agg({
    'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
})
true_summary = true_summary.reset_index(drop=False)
true_summary.columns = ['speed_condition', 'difficulty', 'median', 'mad']

In [None]:
[true_summary, pred_summary]

In [None]:
bar_width = 0.1

# Create a figure with subplots
fig, ax = plt.subplots(1, 2, figsize=(16, 6))

ax[0].scatter(
    pred_summary.loc[pred_summary.speed_condition == 0, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 0, 'median'],
    color='maroon', alpha=0.8, label="Mixture random Walk DDM"
)
ax[0].errorbar(
    pred_summary.loc[pred_summary.speed_condition == 0, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 0, 'median'],
    yerr=pred_summary.loc[pred_summary.speed_condition == 0, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='maroon', alpha=0.8
    )
ax[0].scatter(
    true_summary.loc[true_summary.speed_condition == 0, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 0, 'median'],
    color='black', alpha=0.8, label="Empiric"
)
ax[0].errorbar(
    true_summary.loc[true_summary.speed_condition == 0, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 0, 'median'],
    yerr=true_summary.loc[true_summary.speed_condition == 0, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='black', alpha=0.8
    )
ax[1].scatter(
    pred_summary.loc[pred_summary.speed_condition == 1, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 1, 'median'],
    color='maroon', alpha=0.8, label="Mixture random Walk DDM"
)
ax[1].errorbar(
    pred_summary.loc[pred_summary.speed_condition == 1, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 1, 'median'],
    yerr=pred_summary.loc[pred_summary.speed_condition == 1, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='maroon', alpha=0.8
    )
ax[1].scatter(
    true_summary.loc[true_summary.speed_condition == 1, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 1, 'median'],
    color='black', alpha=0.8, label="Empiric"
)
ax[1].errorbar(
    true_summary.loc[true_summary.speed_condition == 1, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 1, 'median'],
    yerr=true_summary.loc[true_summary.speed_condition == 1, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='black', alpha=0.8
    )

ax[0].set_title("Accuracy Condition", fontsize=FONT_SIZE_1)
ax[1].set_title("Speed Condition", fontsize=FONT_SIZE_1)

x_labels = ['1', '2', '3', '4']
x_positions = [0, 2, 4, 6]

ax[0].set_xticks(x_positions, x_labels)
ax[1].set_xticks(x_positions, x_labels)

ax[0].set_ylim([0.3, 1.6])
ax[1].set_ylim([0.3, 1.6])


ax[0].tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
ax[1].tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
ax[0].set_ylabel("Response Time", fontsize=FONT_SIZE_2)
ax[0].set_xlabel("Difficulty", labelpad=10, fontsize=FONT_SIZE_2)
ax[1].set_xlabel("Difficulty", labelpad=10, fontsize=FONT_SIZE_2)


fig.subplots_adjust(hspace=0.5)

# legend
handles = [
    Line2D(
        xdata=[], ydata=[], marker='o', markersize=5,
        color='maroon', alpha=0.8,  label="Mixture random Walk DDM"
        ),
    Line2D(
        xdata=[], ydata=[], marker='o', markersize=5,
        color='black', alpha=0.8,  label="Empiric"
        )
    ]

fig.legend(
    handles,
    ["Mixture random walk DDM", "Empiric"],
    fontsize=FONT_SIZE_2, bbox_to_anchor=(0.5, -0.1),
    loc="center", ncol=2
    )
sns.despine()
fig.tight_layout()