In [None]:
import nest_asyncio
nest_asyncio.apply()

from hddCRP.modelBuilder import cdCRP
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

import itertools

from pandas.api.types import CategoricalDtype
import os

import arviz as az


In [None]:
overwrite_existing_results = False
results_directory = "Results/population/modelSelection/"

if(not os.path.exists(results_directory)):
    os.makedirs(results_directory)

data_filename = 'data/Data_turns_all_by_session.pkl';
with open(data_filename, 'rb') as data_file:
    data = pickle.load(data_file)


context_depth = 2;
nback_depth   = 1;
session_numbers = None#[1]; # index by 1
number_of_trials    = 50;

action_labels = [0,1,2]


In [None]:


if(session_numbers is None):
    fit_file = f"{results_directory}/fits_trials_{number_of_trials}"
    fit_summary_file = f"{results_directory}/fit_summary_trials_{number_of_trials}"

    waics_file_prefix = f"{results_directory}/waics_trials_{number_of_trials}"
    seed_offset = number_of_trials
else:
    start_session = np.min(session_numbers)
    end_session = np.max(session_numbers)
    fit_file = f"{results_directory}/fits_session_{start_session}"
    fit_summary_file = f"{results_directory}/fit_summary_session_{start_session}"
    waics_file_prefix = f"{results_directory}/waics_session_{start_session}"
    if(end_session != start_session):
        fit_file += f"_to_{start_session}"
        fit_summary_file  += f"_to_{start_session}"
    seed_offset = start_session

fit_file += f"_cd{context_depth}_nb{nback_depth}"
fit_summary_file  += f"_cd{context_depth}_nb{nback_depth}"
waics_file_prefix += f"_cd{context_depth}_nb{nback_depth}"

fit_file += f".pkl"
fit_summary_file += f".pkl"

WAICs = []
if(not os.path.isfile(fit_file) or overwrite_existing_results):
    data_fits = pd.DataFrame()
    data_fit_metrics = pd.DataFrame()


    
    sequences = []
    session_types = []
    subject_labels = []
    population_labels = []
    for group in ["uniform", "diverse"]:
        for subject_p in data["group_definition"][group]:
            sequences_0 = data["data"][subject_p]["data"]; # turns in each session
            session_types_0 = data["data"][subject_p]["task"] # which maze

            if(session_numbers is None):
                ii = list(np.where(np.array(session_types_0)=='C')[0])
                seqs_c = [sequences_0[xx] for xx in ii]
                seqs_c = list(itertools.chain.from_iterable(seqs_c))
                sequences += [seqs_c[:number_of_trials]]
                session_types += ['C']
                subject_labels += [subject_p]
                population_labels += [group]
            else:
                ii = list(np.where(np.array(session_types_0)=='C')[0][np.array(session_numbers)-1])
                sequences     += [sequences_0[xx] for xx in ii]
                session_types += [session_types_0[xx] for xx in ii]
                subject_labels += [subject_p] * len(ii)
                population_labels += [group] * len(ii)

    model = cdCRP(sequences,
                  session_labels=session_types,
                  subject_labels=subject_labels,
                  population_labels=population_labels,
                  possible_observations=action_labels);
    model.same_nback_depth = nback_depth
    model.context_depth = context_depth
    
    include_interaction_timescales = True  if (len(model.get_all_interaction_types()) > 0) else False;
    nvars = model.same_nback_depth + model.context_depth + 2 + include_interaction_timescales

    for model_index in range(2**nvars - 1,-1,-1):
        var_settings = [bool(int(i)) for i in bin(model_index)[2:].zfill(5)]
        print(f"model {model_index} / {(2**nvars)}: {var_settings} ")



        ctr = 0;

        model.population_shared_alpha = var_settings[ctr];
        ctr += 1
        model.population_shared_within_session_timescale = var_settings[ctr];
        ctr += 1
        if(include_interaction_timescales):
            model.population_shared_between_session_timescale = var_settings[ctr];
            ctr += 1
        for ii in range(model.context_depth):
            model.population_shared_context[ii] = var_settings[ctr];
            ctr += 1

        for ii in range(model.same_nback_depth):
            model.population_shared_same_nback[ii] = var_settings[ctr];
            ctr += 1

        stan_seed = 20000 + 100 * model_index + seed_offset
        model.build(random_seed=stan_seed);
        model.fit_model()

        #map_fit = model.get_map()
        fit_df  = model.fit.to_frame()

        fit_df["model_index"] = model_index
        summary_df = model.fit_summary()
        summary_df["model_index"] = model_index
        #summary_df["MAP"] = pd.Series(map_fit)
        if(session_numbers is None):
            summary_df["number_of_trials"] = number_of_trials
            summary_df["start_session_C"]  = pd.NA
            summary_df["end_session_C"]    = pd.NA
            fit_df["number_of_trials"] = number_of_trials
            fit_df["start_session_C"]  = pd.NA
            fit_df["end_session_C"]    = pd.NA
        else:
            summary_df["number_of_trials"] = pd.NA
            summary_df["start_session_C"]  = start_session
            summary_df["end_session_C"]    = end_session
            fit_df["number_of_trials"] = pd.NA
            fit_df["start_session_C"]  = start_session
            fit_df["end_session_C"]    = end_session

        data_fit_metrics = pd.concat([data_fit_metrics,summary_df], copy=False)
        data_fits = pd.concat([data_fits,fit_df], copy=False)

        model.waic(pointwise=True).to_pickle(f"{waics_file_prefix}_model_{model_index}.pkl")

    data_fits.to_pickle(fit_file)
    data_fit_metrics.to_pickle(fit_summary_file)
else:
    print("fit file found")



In [None]:
waics = np.zeros((32,19))

ms = {}
vs = []
for model_index in range(32):
    df = pd.read_pickle(f"{waics_file_prefix}_model_{model_index}.pkl")
    waics[model_index,:] = df.waic_i.data

    model_str = bin(model_index)[2:].zfill(5);
    xx = [bool(int(i)) for i in model_str]
    vs += [xx]
    ms[model_str] = df;
ts = np.sum(waics,axis=1)

az.compare(ms)

In [None]:
str(xx)