In [None]:
import nest_asyncio
nest_asyncio.apply()

from hddCRP.simulations import simulate_sessions
from hddCRP.modelBuilder import cdCRP
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import os

In [None]:
N_blocks   = 20
N_sims_per_block = 50
simulation_range = [0,1]

block_range = range(1, N_sims_per_block+1)
run_range = range(0, N_blocks)

results_directory = "Results/Simulations/"
OVERWRITE = False;

if(not os.path.exists(results_directory)):
    os.makedirs(results_directory)

include_repeat_bias   = True;
context_depth = 1;

if(include_repeat_bias):
    nback_depth   = 1;
else:
    nback_depth   = 0;

In [None]:
for simulation_id in simulation_range:

    if(simulation_id == 0):
        alpha = 5
        different_context_weights = [0.2, 0.2];
        within_session_timescales  = {"A" : 20}
        between_session_timescales = None; #{("A","A") : 2}
        repeat_bias_1_back = 0.5

        session_labels  = lambda n_blocks : ["A"] ;
        session_lengths = lambda n_blocks : [25 * n_blocks ];
        num_subjects = lambda n_blocks : 1;
        num_responses = 3
        simulation_name = "sim. 1"
    elif(simulation_id == 1):
        alpha = 3
        different_context_weights = [0.8, 0.8];
        within_session_timescales  = {"A" : 50}
        between_session_timescales = None; #{("A","A") : 2}
        repeat_bias_1_back = 1.0

        session_labels  = lambda n_blocks : ["A"] ;
        session_lengths = lambda n_blocks : [25 * n_blocks ];
        num_subjects = lambda n_blocks :  1;
        num_responses = 3
        simulation_name = "sim. 2"
    else:
        raise NotImplementedError("No sim found")
    
    if(not nback_depth):
        repeat_bias_1_back = None;
    different_context_weights = different_context_weights[:context_depth]
    
    fit_file = f"{results_directory}/simulation_{simulation_id}"
    fit_summary_file = f"{results_directory}/simulation_summary_{simulation_id}"
    fit_file += f"_cd{context_depth}_nb{nback_depth}"
    fit_summary_file  += f"_cd{context_depth}_nb{nback_depth}"
    fit_file += f".pkl"
    fit_summary_file  += f".pkl"

        
    if(((not os.path.isfile(fit_file)) or (not os.path.isfile(fit_summary_file))) or OVERWRITE):
        simulation_fits = pd.DataFrame()
        simulation_fit_metrics = pd.DataFrame()
    else:
        simulation_fits = pd.read_pickle(fit_file)
        simulation_fit_metrics = pd.read_pickle(fit_summary_file)

    for block_idx in block_range:
        print(f"BLOCK {block_idx}")

        
        for run_idx in run_range:
            if(not ("simulation_id" in simulation_fit_metrics) or not ("block" in simulation_fit_metrics) or not ("run" in simulation_fit_metrics) or
                simulation_fit_metrics.query("simulation_id == @simulation_id and block == @block_idx and run == @run_idx").size == 0):
                print(f"BLOCK {block_idx} - RUN {run_idx}")
                sim_seed  = (simulation_id+8) * 10000 + nback_depth * 1001 + context_depth * 1000 + run_idx*100 
                stan_seed = (simulation_id+8) * 10000 + nback_depth * 1001 + context_depth * 1000 + run_idx*100 + block_idx
                
                seqs = [];
                subject_labels = [];
                session_labels_all = [];
                for jj in range(num_subjects(block_idx)):
                    sim_rng = np.random.Generator(np.random.MT19937(sim_seed + jj))
                    seqs_c = simulate_sessions(session_lengths=session_lengths(block_idx), session_labels=session_labels(block_idx), num_responses=num_responses, 
                                            alpha=alpha,
                                            different_context_weights=different_context_weights,
                                            within_session_timescales=within_session_timescales, between_session_timescales=between_session_timescales,
                                            repeat_bias_1_back=repeat_bias_1_back, rng=sim_rng)
                    subject_labels += [jj] * len(seqs_c)
                    session_labels_all += session_labels(block_idx)
                    seqs += seqs_c;

                model = cdCRP(seqs, subject_labels=subject_labels, session_labels=session_labels_all);
                model.same_nback_depth = nback_depth;
                model.context_depth = context_depth;
                
                model.context_depth = len(different_context_weights)
                model.build(random_seed=stan_seed);
                model.fit_model()

                fit_df = model.fit.to_frame()
                fit_df["block"] = block_idx
                fit_df["run"]   = run_idx
                fit_df["simulation_id"]   = simulation_id
                summary_df = model.fit_summary()
                summary_df["block"] = block_idx
                summary_df["run"]   = run_idx
                summary_df["simulation_id"]   = simulation_id
                summary_df["trials"]   = model.session_lengths.sum()
                summary_df["sessions"]   = model.num_sessions
                summary_df["n_subjects"]   = model.num_subjects
                summary_df["simulation"]   = simulation_name
                #map_fit = model.get_map()
                #summary_df["MAP"] = pd.Series(map_fit)

                true_param = {"alpha" : alpha,
                "timeconstant_within_session_A" : within_session_timescales["A"]}
                if(nback_depth >= 1):
                    true_param["repeat_bias_1_back"] = repeat_bias_1_back
                if(nback_depth >= 1):
                    true_param["context_similarity_depth_1"] = different_context_weights[0]
                if(nback_depth >= 21):
                    true_param["context_similarity_depth_2"] = different_context_weights[1]

                summary_df["true"] = pd.Series(true_param)
                

                simulation_fit_metrics = pd.concat([simulation_fit_metrics,summary_df], copy=False)
                simulation_fits = pd.concat([simulation_fits,fit_df], copy=False)

                simulation_fits.to_pickle(fit_file)
                simulation_fit_metrics.to_pickle(fit_summary_file)
            else:
                print("Fit files found: not overriding")

In [None]:
simulation_fit_metrics = pd.DataFrame()
for simulation_id in [0,1]:
    fit_summary_file = f"{results_directory}/simulation_summary_{simulation_id}"
    fit_summary_file  += f"_cd{context_depth}_nb{nback_depth}"
    fit_summary_file  += f".pkl"
    simulation_fit_metrics = pd.concat([simulation_fit_metrics, pd.read_pickle(fit_summary_file)])
simulation_fit_metrics["simulation"] = simulation_fit_metrics["simulation"].astype("category")
simulation_fit_metrics.reset_index(names="parameter", inplace=True)
simulation_fit_metrics["parameter"] = simulation_fit_metrics["parameter"].map({"alpha" : "alpha",
                                          "context_similarity_depth_1" : "C",
                                          "context_similarity_depth_2" : "C2",
                                          "repeat_bias_1_back" : "B",
                                          "timeconstant_within_session_A" : "tau"})

var_to_plot = "median"
palette="colorblind"   
g = sns.FacetGrid(simulation_fit_metrics, row="parameter", height=1.5, aspect=10/(1.5), sharey=False);
g.map_dataframe(sns.pointplot, x="trials", y=var_to_plot, errorbar=("pi",90), dodge=0.1, hue="simulation", palette=palette);
g.map_dataframe(sns.pointplot, x="trials", y="true",  linestyles="--", markers="", hue="simulation", palette=palette); #
g.add_legend()
for ax in g.axes[:,0]:
    ax.set_ylabel(None);

g.axes[-1,0].set_ylabel('estimate')