In [1]:
import nest_asyncio
nest_asyncio.apply()

from hddCRP.simulations import simulate_sessions
from hddCRP.modelBuilder import cdCRP
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb

import arviz as az

import os

In [2]:


simulation_id = 0;
max_blocks = 5;
run_range = range(0, 50)
results_directory = "Results/Simulations/"


if(simulation_id == 0):
    alpha = 2
    different_context_weights = [0.75, 0.75];
    within_session_timescales  = {"A" : 20}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.75
    min_blocks = 1;

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 * n_blocks];
    num_responses = 3
elif(simulation_id == 1):
    alpha = 8
    different_context_weights = [0.25, 0.25];
    within_session_timescales  = {"A" : 40}
    between_session_timescales = {("A","A") : 2}
    repeat_bias_1_back = 0.25
    min_blocks = 1;

    session_labels  = lambda n_blocks : ["A"] ;
    session_lengths = lambda n_blocks : [50 * n_blocks];
    num_responses = 3
else:
    raise NotImplementedError("No sim found")

block_range = range(min_blocks, max_blocks+1)

if(not os.path.exists(results_directory)):
    os.makedirs(results_directory)



In [3]:



for block_idx in block_range:
    print(f"BLOCK {block_idx}")
    fit_file = f"{results_directory}/sim_{simulation_id}_block_{block_idx}.pkl"
    fit_summary_file = f"{results_directory}/sim_summary_{simulation_id}_block_{block_idx}.pkl"
    if((not os.path.isfile(fit_file)) or (not os.path.isfile(fit_summary_file))):
        simulation_fits = pd.DataFrame()
        simulation_fit_metrics = pd.DataFrame()
        for run_idx in run_range:
            print(f"BLOCK {block_idx} - RUN {run_idx}")
            sim_seed  = (simulation_id+1) * 10000 + run_idx*100
            stan_seed = (simulation_id+1) * 10000 + run_idx*100 + block_idx
            sim_rng = np.random.Generator(np.random.MT19937(sim_seed))
            seqs = simulate_sessions(session_lengths=session_lengths(block_idx), session_labels=session_labels(block_idx), num_responses=num_responses, 
                                    alpha=alpha,
                                    different_context_weights=different_context_weights,
                                    within_session_timescales=within_session_timescales, between_session_timescales=between_session_timescales,
                                    repeat_bias_1_back=repeat_bias_1_back, rng=sim_rng)

            model = cdCRP(seqs);
            model.build(random_seed=stan_seed);
            model.fit_model()

            map_fit = model.get_map()
            fit_df = model.fit.to_frame()
            fit_df["block"] = block_idx
            fit_df["run"]   = run_idx
            summary_df = model.fit_summary()
            summary_df["block"] = block_idx
            summary_df["run"]   = run_idx
            summary_df["MAP"] = pd.Series(map_fit)

            simulation_fit_metrics = pd.concat([simulation_fit_metrics,summary_df], copy=False)
            simulation_fits = pd.concat([simulation_fits,fit_df], copy=False)

        simulation_fits.to_pickle(fit_file)
        simulation_fit_metrics.to_pickle(fit_summary_file)
    else:
        print("Fit files found: not overriding")

BLOCK 1
BLOCK 1 - RUN 0
Building...



Building: found in cache, done.Messages from stanc:
    means either no prior is provided, or the prior(s) depend on data
    variables. In the later case, this may be a false positive.
    prior is provided, or the prior(s) depend on data variables. In the later
    case, this may be a false positive.
    either no prior is provided, or the prior(s) depend on data variables. In
    the later case, this may be a false positive.
    either no prior is provided, or the prior(s) depend on data variables. In
    the later case, this may be a false positive.
    provided, or the prior(s) depend on data variables. In the later case,
    this may be a false positive.
Sampling:   0%
Sampling:   0% (1/8000)
Sampling:   0% (2/8000)
Sampling:   0% (3/8000)
Sampling:   0% (4/8000)
Sampling:   1% (103/8000)
Sampling:   3% (202/8000)
Sampling:   5% (401/8000)
Sampling:   6% (500/8000)
Sampling:   8% (600/8000)
Sampling:   9% (700/8000)
Sampling:  10% (800/8000)
Sampling:  11% (900/8000)
Sampling:  

RuntimeError: 

In [None]:
np.percentile(np.arange(0,100),97.5)