In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

from tqdm import tqdm
from scipy.stats import median_abs_deviation
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

import matplotlib
matplotlib.rcParams['font.sans-serif'] = "Palatino"
matplotlib.rcParams['font.family'] = "sans-serif"

from helpers import get_setup
from configurations import model_names

NUM_OBS = 768
NUM_SAMPLES = 1000
NUM_RESIM = 100

LOCAL_PARAM_LABELS = ['Drift rate', 'Threshold', 'Non-decision time']
LOCAL_PARAM_NAMES  = [r'v', r'a', r'\tau']

FONT_SIZE_1 = 24
FONT_SIZE_2 = 20
FONT_SIZE_3 = 16

  from tqdm.autonotebook import tqdm


In [None]:
with open('data/posteriors/samples_per_model.pkl', 'rb') as file:
    samples_per_model = pickle.load(file)

In [None]:
setup = [get_setup(names, "smoothing") for names in model_names]
models = [model[0] for model in setup]
trainers = [trainer[1] for trainer in setup]

In [None]:
setup = [get_setup(names, "smoothing") for names in model_names]
models = [model[0] for model in setup]
trainers = [trainer[1] for trainer in setup]

In [None]:
NUM_OBS = 768
NUM_SAMPLES = 1000
NUM_RESIM = 100

FONT_SIZE_1 = 24
FONT_SIZE_2 = 20
FONT_SIZE_3 = 16

import matplotlib
matplotlib.rcParams['font.sans-serif'] = "Palatino"
matplotlib.rcParams['font.family'] = "sans-serif"

# Fit models

In [None]:
which = 6
data = pd.read_csv('../data/data_color_discrimination.csv')
person_data = data.loc[data.id == which]
person_rt = np.where(person_data['correct'] == 0, -person_data['rt'], person_data['rt'])[None, :, None]
person_rt.shape

In [None]:
samples_z = mrw_ddm_exp.amortizer.sample({'summary_conditions': person_rt}, NUM_SAMPLES)

In [None]:
local_post = samples_z['local_samples'] * mrw_ddm.local_prior_stds + mrw_ddm.local_prior_means
local_post_t = np.transpose(local_post, (1, 0, 2))

In [None]:
idx = np.random.choice(np.arange(NUM_SAMPLES), NUM_RESIM, replace=False)
pred_data = np.abs(mrw_ddm.likelihood(local_post_t[idx, :, :])['sim_data'])

In [None]:
pred_df = pd.DataFrame({
    'speed_condition': np.tile(person_data['speed_condition'], NUM_RESIM),
    'difficulty': np.tile(person_data['difficulty'], NUM_RESIM),
    'rt': pred_data.flatten(),
    })

In [None]:
grouped = pred_df.groupby(['difficulty', 'speed_condition'])
pred_summary = grouped.agg({
    'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
})
pred_summary = pred_summary.reset_index(drop=False)
pred_summary.columns = ['difficulty', 'speed_condition', 'median', 'mad']
pred_summary

In [None]:
grouped = data.groupby(['speed_condition', 'difficulty'])
true_summary = grouped.agg({
    'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
})
true_summary = true_summary.reset_index(drop=False)
true_summary.columns = ['speed_condition', 'difficulty', 'median', 'mad']

In [None]:
[true_summary, pred_summary]

In [None]:
bar_width = 0.1

# Create a figure with subplots
fig, ax = plt.subplots(1, 2, figsize=(16, 6))

ax[0].scatter(
    pred_summary.loc[pred_summary.speed_condition == 0, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 0, 'median'],
    color='maroon', alpha=0.8, label="Mixture random Walk DDM"
)
ax[0].errorbar(
    pred_summary.loc[pred_summary.speed_condition == 0, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 0, 'median'],
    yerr=pred_summary.loc[pred_summary.speed_condition == 0, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='maroon', alpha=0.8
    )
ax[0].scatter(
    true_summary.loc[true_summary.speed_condition == 0, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 0, 'median'],
    color='black', alpha=0.8, label="Empiric"
)
ax[0].errorbar(
    true_summary.loc[true_summary.speed_condition == 0, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 0, 'median'],
    yerr=true_summary.loc[true_summary.speed_condition == 0, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='black', alpha=0.8
    )
ax[1].scatter(
    pred_summary.loc[pred_summary.speed_condition == 1, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 1, 'median'],
    color='maroon', alpha=0.8, label="Mixture random Walk DDM"
)
ax[1].errorbar(
    pred_summary.loc[pred_summary.speed_condition == 1, 'difficulty'] * 2 - bar_width,
    pred_summary.loc[pred_summary.speed_condition == 1, 'median'],
    yerr=pred_summary.loc[pred_summary.speed_condition == 1, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='maroon', alpha=0.8
    )
ax[1].scatter(
    true_summary.loc[true_summary.speed_condition == 1, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 1, 'median'],
    color='black', alpha=0.8, label="Empiric"
)
ax[1].errorbar(
    true_summary.loc[true_summary.speed_condition == 1, 'difficulty'] * 2 + bar_width,
    true_summary.loc[true_summary.speed_condition == 1, 'median'],
    yerr=true_summary.loc[true_summary.speed_condition == 1, 'mad'],
    fmt='none', capsize=5, elinewidth=1,
    color='black', alpha=0.8
    )

ax[0].set_title("Accuracy Condition", fontsize=FONT_SIZE_1)
ax[1].set_title("Speed Condition", fontsize=FONT_SIZE_1)

x_labels = ['1', '2', '3', '4']
x_positions = [0, 2, 4, 6]

ax[0].set_xticks(x_positions, x_labels)
ax[1].set_xticks(x_positions, x_labels)

ax[0].set_ylim([0.3, 1.6])
ax[1].set_ylim([0.3, 1.6])


ax[0].tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
ax[1].tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
ax[0].set_ylabel("Response Time", fontsize=FONT_SIZE_2)
ax[0].set_xlabel("Difficulty", labelpad=10, fontsize=FONT_SIZE_2)
ax[1].set_xlabel("Difficulty", labelpad=10, fontsize=FONT_SIZE_2)


fig.subplots_adjust(hspace=0.5)

# legend
handles = [
    Line2D(
        xdata=[], ydata=[], marker='o', markersize=5,
        color='maroon', alpha=0.8,  label="Mixture random Walk DDM"
        ),
    Line2D(
        xdata=[], ydata=[], marker='o', markersize=5,
        color='black', alpha=0.8,  label="Empiric"
        )
    ]

fig.legend(
    handles,
    ["Mixture random walk DDM", "Empiric"],
    fontsize=FONT_SIZE_2, bbox_to_anchor=(0.5, -0.1),
    loc="center", ncol=2
    )
sns.despine()
fig.tight_layout()

In [None]:
with open('data/posteriors/samples_per_model.pkl', 'rb') as file:
    samples_per_model = pickle.load(file)

In [None]:
def plot_parameter_trajectory(person_data, local_samples, lw=2):
    # get conditions
    condition = person_data['speed_condition'].to_numpy()
    idx_speed = []
    if condition[0] == 1:
        idx_speed.append([0])
        idx_speed.append(np.where(condition[:-1] != condition[1:])[0])
        idx_speed = np.concatenate(idx_speed)
    else:
        idx_speed.append(np.where(condition[:-1] != condition[1:])[0])
        idx_speed.append([NUM_OBS])
        idx_speed = np.concatenate(idx_speed)
    # calculate posterior median and mad
    post_median = np.median(local_samples, axis=1)
    post_mad = median_abs_deviation(local_samples, axis=1)
    # plot
    fig, axarr = plt.subplots(3, 1, figsize=(18, 14))
    for i, ax in enumerate(axarr.flat):
        # parameter trajectory
        ax.plot(
            range(NUM_OBS),
            post_median[:, i], 
            color='maroon', alpha=1.0, lw=lw, label="Posterior median"
            )
        ax.fill_between(
            range(NUM_OBS),
            post_median[:, i] - post_mad[:, i],
            post_median[:, i] + post_mad[:, i],
            color='maroon', alpha=0.5, label="Posterior MAD"
            )

        # yellow shades
        x = 0
        while x < idx_speed.shape[0]:
            ax.axvspan(idx_speed[x] + 1, idx_speed[x + 1] + 1, alpha=0.2, color='#f0c654', label="Speed condition")
            x = x + 2
        # difficulty manipulation
        if i == 0:
            ax.plot(
                range(NUM_OBS),
                (person_data['difficulty'] - 3) * -2,
                color='black', alpha=0.5, lw=lw, label="Difficulty manipulation"
                )
        # aestehtics
        ax.set_title(f'{LOCAL_PARAM_LABELS[i]} (${LOCAL_PARAM_NAMES[i]}$)', fontsize=FONT_SIZE_1)
        ax.grid(alpha=0.3)
        time = np.arange(0, 768+1, 48)
        time[0] = 1
        ax.set_xticks(time)
        ax.margins(x=0.01)
        ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
        ax.set_ylabel("Parameter\nValue", rotation=0, labelpad=70, fontsize=FONT_SIZE_2)
        if i == 2:
            ax.set_xlabel("Trial", labelpad=20, fontsize=FONT_SIZE_2)   

    sns.despine()
    # fig.tight_layout()
    fig.subplots_adjust(hspace=0.5)
    # legend
    handles = [
        Line2D(xdata=[], ydata=[], color='maroon', alpha=0.8, lw=3, label="Posterior median"),
        Patch(facecolor='maroon', alpha=0.5, edgecolor=None, label="Posterior MAD"),
        Patch(facecolor='#f0c654', alpha=0.2, edgecolor=None, label="Speed condition"),
        Line2D(xdata=[], ydata=[], color='black', alpha=0.5, lw=3, label="Difficulty manipulation")
        ]
    fig.legend(
        handles,
        ["Posterior median", "Posterior MAD", "Speed condition", "Difficulty manipulation"],
        fontsize=FONT_SIZE_2, bbox_to_anchor=(0.5, -0.001),
        loc="center", ncol=4
        )
    
    return fig

In [None]:
rw_local_post = samples_per_model[0]['local_samples']
rw_local_post.shape

In [None]:
data = pd.read_csv('data/data_color_discrimination.csv')
local_samples = rw_local_post[0]
person_data = data.loc[data.id == 1]

In [None]:
f = plot_parameter_trajectory(person_data, local_samples)

In [None]:
with open('data/winning_model_per_person.pkl', 'rb') as file:
    winning_model_per_person = pickle.load(file)

In [None]:
winning_model_per_person



In [None]:
with open('data/posteriors/samples_per_model.pkl', 'rb') as file:
    samples_per_model = pickle.load(file)
with open('data/winning_model_per_person.pkl', 'rb') as file:
    winning_model_per_person = pickle.load(file)

In [None]:
setup = [get_setup(names, "smoothing") for names in model_names]
models = [model[0] for model in setup]

In [None]:
NUM_OBS = 768
NUM_SAMPLES = 1000
NUM_RESIM = 100

In [None]:
samples_per_model[0]['local_samples'][0].shape

In [None]:
sub = 0
person_data = data.loc[data.id == sub+1]
grouped = data.groupby(['speed_condition', 'difficulty'])
person_summary = grouped.agg({
    'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
})
person_summary = person_summary.reset_index(drop=False)
person_summary.columns = ['speed_condition', 'difficulty', 'median', 'mad']


# posterior re-simulation for all models
idx = np.random.choice(np.arange(NUM_SAMPLES), NUM_RESIM, replace=False)
summaries = []
summaries.append(person_summary)
for i, model in enumerate(models):
    pred_data = np.abs(
        model.likelihood(samples_per_model[i]['local_samples'][sub, :, idx, :])['sim_data']
        )
    pred_df = pd.DataFrame({
        'speed_condition': np.tile(person_data['speed_condition'], NUM_RESIM),
        'difficulty': np.tile(person_data['difficulty'], NUM_RESIM),
        'rt': pred_data.flatten(),
        })
    grouped = pred_df.groupby(['difficulty', 'speed_condition'])
    pred_summary = grouped.agg({
        'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
    })
    pred_summary = pred_summary.reset_index(drop=False)
    pred_summary.columns = ['difficulty', 'speed_condition', 'median', 'mad']
    summaries.append(pred_summary)

In [None]:
np.arange(-0.6, 0.7, 0.2)

In [None]:
BAR_WIDTH = np.arange(-0.6, 0.7, 0.2)
MODEL_NAMES = [
    'Random walk DDM', 'Mixture random walk DDM',
    'Levy flight DDM', 'Regime switching DDM'
    ]
X_AXIS_VALUES = np.arange(4) * 2
LABELS = [
    'Empiric', 'Random walk', 'Mixture random walk',
    'Levy flight', 'Regime switching'
    ]
COLORS = [
    "black", "orange", "maroon", "#133a76", "green"
]
CONDITIONS = ["Accuracy Condition", "Speed Condition"]

In [None]:
summaries[0]["median"].min() - 0.1
summaries[0]["median"].max() + 0.3

In [None]:
COLORS = ['black', '#FFD700', '#DC143C', '#008080', '#008000']
COLORS = ['black', '#F68F2F', '#902F2F', '#31546A', '#4F632E']

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(16, 6))
handles = []
for i, ax in enumerate(axarr.flat):
    for t in range(len(summaries)):
        ax.scatter(
            X_AXIS_VALUES + BAR_WIDTH[t],
            summaries[t].loc[summaries[t].speed_condition == i, 'median'],
            s=75, color=COLORS[t], alpha=0.8, label=LABELS[t]
        )

        ax.errorbar(
            X_AXIS_VALUES + BAR_WIDTH[t],
            summaries[t].loc[summaries[t].speed_condition == i, 'median'],
            yerr=summaries[t].loc[summaries[t].speed_condition == i, 'mad'],
            fmt='none', capsize=5, elinewidth=1,
            color=COLORS[t], alpha=0.8
            )

        handles.append(
            Line2D(
                xdata=[], ydata=[], marker='o', markersize=5,
                color=COLORS[t], alpha=ALPHA,  label=LABELS[t]
            )
        )

    ax.set_title(CONDITIONS[i], fontsize=FONT_SIZE_1)

    x_labels = ['1', '2', '3', '4']
    x_positions = [0, 2, 4, 6]
    ax.set_xticks(x_positions, x_labels)

    ax.set_ylim([
        summaries[0]["median"].min() - 0.1,
        summaries[0]["median"].max() + 0.4])

    ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
    if i == 0:
        ax.set_ylabel("Response Time", fontsize=FONT_SIZE_2)

    ax.set_xlabel("Difficulty", labelpad=10, fontsize=FONT_SIZE_2)

fig.subplots_adjust(hspace=0.5)

fig.legend(
    handles,
    LABELS,
    fontsize=FONT_SIZE_2, bbox_to_anchor=(0.5, -0.1),
    loc="center", ncol=5
)
sns.despine()
fig.tight_layout()

In [None]:
handles

In [None]:
with open('data/posteriors/samples_per_model.pkl', 'rb') as file:
    samples_per_model = pickle.load(file)

In [None]:
data = pd.read_csv('data/data_color_discrimination.csv')

In [None]:
setup = [get_setup(names, "smoothing") for names in model_names]
models = [model[0] for model in setup]

In [None]:
grouped_data = data.groupby(['speed_condition', 'difficulty'])
overall_summary = grouped_data.agg({
            'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
        })
overall_summary = overall_summary.reset_index(drop=False)
overall_summary.columns = ['speed_condition', 'difficulty', 'median', 'mad']

In [None]:
summary_per_model = []
summary_per_model.append(overall_summary)
for i, model in enumerate(models):
    model_resim_data = np.zeros((NUM_SUBS, NUM_RESIM, NUM_OBS, 3))
    for sub in range(NUM_SUBS):
        person_data = data.loc[data.id == sub+1]
        idx = np.random.choice(np.arange(NUM_SAMPLES), NUM_RESIM, replace=False)
        pred_data = np.abs(
            model.likelihood(samples_per_model[i]['local_samples'][sub, :, idx, :])['sim_data']
        )
        pred_rt = pred_data[:, :, None]
        condition = np.tile(person_data['speed_condition'], (NUM_RESIM, 1))[:, :, None]
        difficulty = np.tile(person_data['difficulty'], (NUM_RESIM, 1))[:, :, None]
        model_resim_data[sub] = np.c_[pred_rt, condition, difficulty]

    # Reshape the array to (14*100*700, 3)
    reshaped_data = model_resim_data.reshape(-1, 3)
    df = pd.DataFrame(reshaped_data, columns=['rt', 'speed_condition', 'difficulty'])
    grouped_data = df.groupby(['speed_condition', 'difficulty'])
    summary = grouped_data.agg({
        'rt': ['median', lambda x: np.median(np.abs(x - np.median(x)))]
    })
    summary = summary.reset_index(drop=False)
    summary.columns = ['speed_condition', 'difficulty', 'median', 'mad']
    summary_per_model.append(summary)


In [None]:
summaries = summary_per_model

In [None]:
ALPHA = 0.95
COLORS = ['']

In [None]:
handles = []
fig, axarr = plt.subplots(1, 2, figsize=(16, 6))
for i, ax in enumerate(axarr.flat):
    for t in range(len(summaries)):
        ax.scatter(
            X_AXIS_VALUES + BAR_WIDTH[t],
            summaries[t].loc[summaries[t].speed_condition == i, 'median'],
            s=75, color=COLORS[t], label=LABELS[t]
        )

        ax.errorbar(
            X_AXIS_VALUES + BAR_WIDTH[t],
            summaries[t].loc[summaries[t].speed_condition == i, 'median'],
            yerr=summaries[t].loc[summaries[t].speed_condition == i, 'mad'],
            # fmt='none', capsize=5, elinewidth=2,
            # color=COLORS[t]
            fmt='o', color=COLORS[t], markersize=8, elinewidth=2, capsize=0
            )

        handles.append(
            Line2D(
                xdata=[], ydata=[], marker='o', markersize=10, lw=3,
                color=COLORS[t], label=LABELS[t]
            )
        )

    ax.set_title(CONDITIONS[i], fontsize=FONT_SIZE_1)

    x_labels = ['1', '2', '3', '4']
    x_positions = [0, 2, 4, 6]
    ax.set_xticks(x_positions, x_labels)

    ax.set_ylim([
        summaries[0]["median"].min() - 0.1,
        summaries[0]["median"].max() + 0.4])

    ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
    if i == 0:
        ax.set_ylabel("Response Time", labelpad=10, fontsize=FONT_SIZE_2)

    ax.set_xlabel("Difficulty", labelpad=10, fontsize=FONT_SIZE_2)

# legend
fig.subplots_adjust(hspace=0.5)
fig.legend(
    handles,
    LABELS,
    fontsize=FONT_SIZE_2, bbox_to_anchor=(0.5, -0.05),
    loc="center", ncol=5
    )
sns.despine()
fig.tight_layout()
plt.savefig(f"plots/post_resimulation_overall.pdf", dpi=300, bbox_inches="tight")