In [1]:
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt

np.set_printoptions(suppress=True)
from scipy.stats import levy_stable
from tqdm import tqdm

import sys
sys.path.append("../../assets")

from helpers import get_setup
from configurations import model_names, default_bounds
from likelihoods import _sample_diffusion_trial, sample_non_stationary_diffusion_process

  from tqdm.autonotebook import tqdm


In [2]:
import matplotlib
matplotlib.rcParams['font.sans-serif'] = "Palatino"
matplotlib.rcParams['font.family'] = "sans-serif"

NUM_OBS = 768
NUM_SUBS = 14
NUM_SAMPLES = 1000
HORIZON_SIZE = 68
NUM_RESIMS = 250

RNG = np.random.default_rng()

In [3]:
def calc_sma(data, period=5):
    j = next(i for i, x in enumerate(data) if x is not None)
    our_range = range(len(data))[j + period - 1:]
    empty_list = [None] * (j + period - 1)
    sub_result = [np.mean(data[i - period + 1: i + 1]) for i in our_range]
    return np.array(empty_list + sub_result)

def post_resim(samples, num_resims):
    num_obs = samples.shape[0]
    num_samples = samples.shape[1]
    idx = np.arange(0, num_samples-1, num_samples/num_resims, dtype=np.int32)
    theta = samples[:, idx]
    pred_data = np.zeros((num_resims, num_obs))
    for sim in range(num_resims):
        pred_data[sim] = np.abs(sample_non_stationary_diffusion_process(theta[:, sim]))
    return pred_data

def get_next_theta(theta_t, eta, model_name,
                   lower_bounds=default_bounds["lower"],
                   upper_bounds=default_bounds["upper"]):
    theta_next = np.zeros(3)
    if model_name == "mrw":
        z = RNG.normal(size=(3))
        switch_samples = RNG.random(size=(2))
        stay = switch_samples > eta[3:]
        # transition model
        # update v
        if stay[0]:
            theta_next[0] = np.clip(
                theta_t[0] + eta[0] * z[0],
                a_min=lower_bounds[0], a_max=upper_bounds[0]
            )
        else:
            theta_next[0] = RNG.uniform(lower_bounds[0], upper_bounds[0])
        # update a
        if stay[1]:
            theta_next[1] = np.clip(
                theta_t[1] + eta[1] * z[1],
                a_min=lower_bounds[1], a_max=upper_bounds[1]
            )
        else:
            theta_next[1] = RNG.uniform(lower_bounds[1], upper_bounds[1])
            # update tau
        theta_next[2] = np.clip(
            theta_t[2] + eta[2] * z[2],
            a_min=lower_bounds[2], a_max=upper_bounds[2]
        )
    if model_name == "lf":
        levy_scale = np.clip(eta[:2] / np.sqrt(2), a_min=0.0000001, a_max=np.inf)
        z_norm = RNG.normal(size=1)
        z_levy = levy_stable.rvs(np.clip(eta[3:], a_min=0.01, a_max=2.0), 0, scale=levy_scale, size=2)
        # update v and a
        theta_next[:2] = np.clip(
            theta_t[:2] + z_levy,
            a_min=lower_bounds[:2], a_max=upper_bounds[:2]
        )
        # update tau
        theta_next[2] = np.clip(
            theta_t[2] + eta[2] * z_norm,
            a_min=lower_bounds[2], a_max=upper_bounds[2]
        )
    return theta_next

In [4]:
data = pd.read_csv('data/data_color_discrimination.csv')
with open('data/posteriors/samples_per_model.pkl', 'rb') as file:
    samples_per_model = pickle.load(file)
with open('data/winning_model_per_person.pkl', 'rb') as file:
    winning_model_per_person = pickle.load(file)

In [5]:
pred_data = np.zeros((NUM_SUBS, NUM_RESIMS, NUM_OBS))
for sub in range(NUM_SUBS):
    winning_model = winning_model_per_person[sub]
    model, trainer = get_setup(model_names[winning_model], "smoothing")
    person_data = data.loc[data.id == sub + 1]
    person_rt = np.where(person_data['correct'] == 0, -person_data['rt'], person_data['rt'])[None, :, None]
    tmp_post = trainer.amortizer.sample({"summary_conditions": person_rt[:, :NUM_OBS-HORIZON_SIZE, :]}, 500)
    tmp_local = tmp_post['local_samples'] * model.local_prior_stds + model.local_prior_means
    tmp_hyper = tmp_post['global_samples'] * model.hyper_prior_stds + model.hyper_prior_means
    for i in tqdm(range(HORIZON_SIZE+1)):
        tmp_post = trainer.amortizer.sample(
            {"summary_conditions": person_rt[:, :NUM_OBS-(HORIZON_SIZE-i), :]}, NUM_SAMPLES
        )
        tmp_local = tmp_post['local_samples'] * model.local_prior_stds + model.local_prior_means
        tmp_hyper = tmp_post['global_samples'] * model.hyper_prior_stds + model.hyper_prior_means
        if i == 0:
            pred_data[sub, :, :NUM_OBS-HORIZON_SIZE] = post_resim(tmp_local, NUM_RESIMS)
        else:
            idx = np.arange(0, NUM_SAMPLES-1, NUM_SAMPLES/NUM_RESIMS, dtype=np.int32)
            last_theta = tmp_local[-1, idx, :]
            last_eta = tmp_hyper[idx, :]
            for j in range(NUM_RESIMS):
                next_theta = get_next_theta(last_theta[j], last_eta[j], model_names[winning_model])
                pred_data[sub, j, NUM_OBS-HORIZON_SIZE+i-1] = _sample_diffusion_trial(
                    next_theta[0], next_theta[1], next_theta[2]
                )

np.save('data/rt_time_series_resim.npy', pred_data)

INFO:root:Performing 2 pilot runs with the levy_flight_ddm model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 800)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 5)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 800, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.
INFO:root:Loaded loss history from checkpoints/smoothing_levy_flight_ddm/history_75.pkl.
INFO:root:Networks loaded from checkpoints/smoothing_levy_flight_ddm/ckpt-75
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.
  1%|▏         | 1/69 [00:17<19:59, 17.6

In [None]:
pred_rt = np.abs(pred_data[0])
for resim in range(NUM_RESIMS):
    pred_rt[resim] = calc_sma(pred_rt[resim])
pred_rt_mean = pred_rt.mean(axis=0)
pred_rt_quantiles = np.quantile(pred_rt, [0.025, 0.975], axis=0)

In [None]:
plt.plot(range(NUM_OBS), calc_sma(np.abs(person_rt[0, :, 0])), color='black', alpha=0.8, lw=1)
plt.plot(range(NUM_OBS), pred_rt_mean, color='maroon', alpha=0.8, lw=1)
plt.fill_between(range(NUM_OBS), pred_rt_quantiles[0], pred_rt_quantiles[1], color='maroon', alpha=0.4)