In [17]:
import nest_asyncio
nest_asyncio.apply()

from hddCRP.simulations import simulate_sessions
from hddCRP.modelBuilder import cdCRP
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb

import arviz as az

import os
import sys

In [18]:



args = [2, 10, 10]
# if __name__ == "__main__":
#     for ii, arg in enumerate(sys.argv):
#         if(ii < len(args)):
#             args[ii] = arg

simulation_id = args[0];
min_blocks = args[1];
max_blocks = args[2];

run_range = range(0, 10)
results_directory = "Results/Simulations/"
OVERWRITE = False;

if(simulation_id == 0):
    alpha = 2
    different_context_weights = [0.3, 0.3];
    within_session_timescales  = {"A" : 20}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.75

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 * n_blocks];
    num_subjects = lambda n_blocks : 1;
    num_responses = 3
elif(simulation_id == 1):
    alpha = 8
    different_context_weights = [0.6, 0.6];
    within_session_timescales  = {"A" : 40}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.25

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 * n_blocks];
    num_subjects = lambda n_blocks : 1;
    num_responses = 3
elif(simulation_id == 2):
    alpha = 2
    different_context_weights = [0.3, 0.3];
    within_session_timescales  = {"A" : 20}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.75

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 ];
    num_subjects = lambda n_blocks : n_blocks;
    num_responses = 3
elif(simulation_id == 3):
    alpha = 8
    different_context_weights = [0.6, 0.6];
    within_session_timescales  = {"A" : 40}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.25

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 ];
    num_subjects = lambda n_blocks :  n_blocks;
    num_responses = 3
elif(simulation_id == 4):
    alpha = 2
    different_context_weights = [0.3, 0.3];
    within_session_timescales  = {"A" : 40}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.75

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 ];
    num_subjects = lambda n_blocks : n_blocks;
    num_responses = 3
else:
    raise NotImplementedError("No sim found")

repeat_bias_1_back = 1;
block_range = range(min_blocks, max_blocks+1)

if(not os.path.exists(results_directory)):
    os.makedirs(results_directory)



In [19]:



for block_idx in block_range:
    print(f"BLOCK {block_idx}")
    fit_file = f"{results_directory}/sim_{simulation_id}_block_{block_idx}.pkl"
    fit_summary_file = f"{results_directory}/sim_summary_{simulation_id}_block_{block_idx}.pkl"
    if(((not os.path.isfile(fit_file)) or (not os.path.isfile(fit_summary_file))) or OVERWRITE):
        simulation_fits = pd.DataFrame()
        simulation_fit_metrics = pd.DataFrame()
        for run_idx in run_range:
            print(f"BLOCK {block_idx} - RUN {run_idx}")
            sim_seed  = (simulation_id+1) * 10000 + run_idx*100
            stan_seed = (simulation_id+1) * 10000 + run_idx*100 + block_idx
            sim_rng = np.random.Generator(np.random.MT19937(sim_seed))
            
            seqs = [];
            subject_labels = [];
            session_labels_all = [];
            for jj in range(num_subjects(block_idx)):
                seqs_c = simulate_sessions(session_lengths=session_lengths(block_idx), session_labels=session_labels(block_idx), num_responses=num_responses, 
                                        alpha=alpha,
                                        different_context_weights=different_context_weights,
                                        within_session_timescales=within_session_timescales, between_session_timescales=between_session_timescales,
                                        repeat_bias_1_back=repeat_bias_1_back, rng=sim_rng)
                subject_labels += [jj] * len(seqs_c)
                session_labels_all += session_labels(block_idx)
                seqs += seqs_c;

            model = cdCRP(seqs, subject_labels=subject_labels, session_labels=session_labels_all);
            model.same_nback_depth = 0;

            model.build(random_seed=stan_seed);
            model.fit_model()

            map_fit = model.get_map()
            fit_df = model.fit.to_frame()
            fit_df["block"] = block_idx
            fit_df["run"]   = run_idx
            fit_df["simulation"]   = simulation_id
            summary_df = model.fit_summary()
            summary_df["block"] = block_idx
            summary_df["run"]   = run_idx
            summary_df["MAP"] = pd.Series(map_fit)
            summary_df["simulation"]   = simulation_id

            simulation_fit_metrics = pd.concat([simulation_fit_metrics,summary_df], copy=False)
            simulation_fits = pd.concat([simulation_fits,fit_df], copy=False)

        simulation_fits.to_pickle(fit_file)
        simulation_fit_metrics.to_pickle(fit_summary_file)
    else:
        print("Fit files found: not overriding")

BLOCK 10
BLOCK 10 - RUN 0
Building...

In file included from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/boost/multi_array/multi_array_ref.hpp:32,
                 from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/boost/multi_array.hpp:34,
                 from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/boost/numeric/odeint/algebra/multi_array_algebra.hpp:22,
                 from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/boost/numeric/odeint.hpp:63,
                 from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/stan/math/prim/functor/ode_rk45.hpp:9,
                 from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/stan/math/prim/functor/integrate_ode_rk45.hpp:6,
                 from /home/latimerk/miniconda3/envs/JaiYuLab/lib/python3.10/site-packages/httpstan/include/st





Building: 90.2s, done.Messages from stanc:
    means either no prior is provided, or the prior(s) depend on data
    variables. In the later case, this may be a false positive.
    either no prior is provided, or the prior(s) depend on data variables. In
    the later case, this may be a false positive.
    either no prior is provided, or the prior(s) depend on data variables. In
    the later case, this may be a false positive.
    provided, or the prior(s) depend on data variables. In the later case,
    this may be a false positive.
Sampling:   0%
Sampling:   0% (1/8000)
Sampling:   0% (2/8000)
Sampling:   0% (3/8000)
Sampling:   0% (4/8000)
Sampling:   1% (103/8000)
Sampling:   3% (202/8000)
Sampling:   4% (301/8000)
Sampling:   5% (400/8000)
Sampling:   6% (500/8000)
Sampling:   8% (600/8000)
Sampling:   9% (700/8000)
Sampling:  10% (800/8000)
Sampling:  11% (900/8000)
Sampling:  12% (1000/8000)
Sampling:  14% (1100/8000)
Sampling:  15% (1200/8000)
Sampling:  16% (1300/8000)
Samp

In [None]:
simulation_fits
# observations are independent given a latent variable: a random measure  DP(alpha G)

parameters,lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,alpha,timeconstant_within_session_A,repeat_bias_1_back,context_similarity_depth_1,context_similarity_depth_2,block,run
draws,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
0,-276.669262,0.460032,0.360409,4.0,15.0,0.0,279.101993,7.581567,37.469237,0.354694,0.013035,0.224014,5,0
1,-278.196453,0.960354,0.427796,3.0,7.0,0.0,281.533132,5.027058,118.357600,0.535352,0.318274,0.018293,5,0
2,-276.122066,1.000000,0.305839,4.0,15.0,0.0,277.976747,1.927722,37.109624,0.420531,0.032847,0.070761,5,0
3,-275.456136,1.000000,0.377108,3.0,7.0,0.0,277.929641,3.107184,31.258102,0.372779,0.304212,0.031636,5,0
4,-275.723382,0.959229,0.360409,3.0,7.0,0.0,277.265186,8.867392,39.796667,0.368387,0.062805,0.135762,5,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3995,-276.408433,0.701595,0.347483,4.0,15.0,0.0,278.472980,2.190141,39.430017,0.472903,0.374472,0.041620,5,4
3996,-276.639459,0.992297,0.395161,3.0,7.0,0.0,279.478432,1.821074,24.766905,0.342678,0.025211,0.268210,5,4
3997,-276.594190,0.992716,0.390217,3.0,7.0,0.0,280.729019,8.435622,20.119449,0.546593,0.081900,0.383203,5,4
3998,-274.964737,0.959844,0.408716,3.0,7.0,0.0,277.302278,3.540950,32.220975,0.410393,0.087387,0.059269,5,4


In [None]:
model.to_dict()["subject_start_idx"]

array([  1,  51, 101, 151, 201, 251])