In [None]:
import nest_asyncio
nest_asyncio.apply()

from hddCRP.modelBuilder import cdCRP
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from statannotations.Annotator import Annotator

import os

In [None]:
session_numbers = [1]; # index by 1

overwrite_existing_results = False
results_directory = "Results/individualFit/"

if(not os.path.exists(results_directory)):
    os.makedirs(results_directory)

context_depth = 1;
nback_depth = 1;

data_filename = 'data/Data_turns_all_by_session.pkl';
with open(data_filename, 'rb') as data_file:
    data = pickle.load(data_file)

subjects = list(data["data"].keys())
subjects.sort()
print("subjects = " + str(subjects))

action_labels = [0,1,2]

In [None]:
start_session = np.min(session_numbers)
end_session = np.max(session_numbers)
fit_file = f"{results_directory}/fits_session_{start_session}"
fit_summary_file = f"{results_directory}/fit_summary_session_{start_session}"
if(end_session != start_session):
    fit_file += f"_to_{start_session}"
    fit_summary_file  += f"_to_{start_session}"
    
if(nback_depth != 1 or context_depth != 2):
    fit_file += f"_cd{context_depth}_nb{nback_depth}"
    fit_summary_file  += f"_cd{context_depth}_nb{nback_depth}"

fit_file += f".pkl"
fit_summary_file += f".pkl"

if(not os.path.isfile(fit_file) or overwrite_existing_results):
    data_fits = pd.DataFrame()
    data_fit_metrics = pd.DataFrame()
    for subject_index, subject in enumerate(subjects):
        print(f"subject {subject} ")
        sequences_0 = data["data"][subject]["data"]; # turns in each session
        session_types_0 = data["data"][subject]["task"] # which maze

        ii = list(np.where(np.array(session_types_0)=='C')[0][np.array(session_numbers)-1]) # sessions in map C
        sequences     = [sequences_0[xx] for xx in ii]
        session_types = [session_types_0[xx] for xx in ii]

        stan_seed = (subject_index+1) * 1000 + start_session

        # build model with given sequences
        model = cdCRP(sequences, session_labels=session_types, subject_labels=subject, possible_observations=action_labels);

        # set model depth
        model.same_nback_depth = nback_depth
        model.context_depth = context_depth

        # fit with Stan
        model.build(random_seed=stan_seed);
        model.fit_model()

        map_fit = model.get_map()
        fit_df  = model.fit.to_frame()
        summary_df = model.fit_summary()

        fit_df["subject"] = subject
        fit_df["start_session_C"] = start_session
        fit_df["end_session_C"] = end_session
        summary_df["subject"] = subject
        summary_df["MAP"] = pd.Series(map_fit)
        summary_df["start_session_C"] = start_session
        summary_df["end_session_C"] = end_session

        data_fit_metrics = pd.concat([data_fit_metrics,summary_df], copy=False)
        data_fits = pd.concat([data_fits,fit_df], copy=False)

    data_fits.to_pickle(fit_file)
    data_fit_metrics.to_pickle(fit_summary_file)
else:
    print("fit file found")



In [None]:
fit_summary = pd.read_pickle(fit_summary_file)
group_names = np.array(["uniform", "diverse"])

mapping = {}
for jj in range(2):
    for ii in data["group_definition"][group_names[jj]]:
        mapping[ii] = group_names[jj]
fit_summary["group"] = fit_summary["subject"].map(mapping)

params = list(fit_summary.index.unique())

fit_summary.index.name = 'parameter'
fit_summary = fit_summary.reset_index()

estimator_to_plot = "median"

formatted_fit = fit_summary.pivot(index="subject", columns=["parameter"],values=[estimator_to_plot]).droplevel(0,axis=1)
formatted_fit["group"] = fit_summary[["group","subject"]].drop_duplicates("subject").set_index("subject")

#     print(fit_summary["median"][fit_summary["parameter"] == ii])
plt.figure(figsize=(10,6))
g = sns.PairGrid(formatted_fit, hue="group")
g.map_upper(sns.scatterplot)
g.map_diag(sns.histplot)
g.add_legend()
# fit_summary[["median","subject"]].set_index("subject")

In [None]:

grps = [["timeconstant_within_session_C"], ["alpha"], ["repeat_bias_1_back", "context_similarity_depth_1", "context_similarity_depth_2"]]
ii = 0
plt.figure(figsize=(len(grps[ii])*3,4))
fit_summary2 = fit_summary[["group", estimator_to_plot, "parameter"]]
ax = sns.boxplot(data=fit_summary2, x="parameter", y=estimator_to_plot, order=grps[ii], hue_order=["uniform", "diverse"],
            hue="group")

pairs = [((xx, "uniform"), (xx, "diverse")) for xx in grps[ii]]
annot = Annotator(ax, pairs=pairs, data=fit_summary2, x="parameter", y=estimator_to_plot, order=grps[ii], hue="group", hue_order=["uniform", "diverse"])
#Annotator.new_plot(ax=ax, pairs=pairs, data=fit_summary2, x="parameter", y=estimator_to_plot, order=grps[ii], hue="group", hue_order=["uniform", "diverse"])
annot.configure(test='Mann-Whitney', verbose=2, text_format='simple')
annot.apply_test()
annot.annotate()
sp = sns.stripplot(data=fit_summary2,  x="parameter", y=estimator_to_plot, hue="group", order=grps[ii], hue_order=["uniform", "diverse"],
              dodge=True, size=6, alpha=1,  edgecolor=[0,0,0], linewidth=0.5)
sns.despine(offset=5, trim=True);
plt.xlabel(None);


In [None]:
ii = 1
plt.figure(figsize=(len(grps[ii])*3,4))
ax = sns.boxplot(x="parameter", y=estimator_to_plot, order=grps[ii], hue_order=["uniform", "diverse"],
            hue="group",
            data=fit_summary2)

pairs = [((xx, "uniform"), (xx, "diverse")) for xx in grps[ii]]
annot = Annotator(ax, pairs=pairs, data=fit_summary2, x="parameter", y=estimator_to_plot, order=grps[ii], hue="group", hue_order=["uniform", "diverse"])
annot.configure(test='Mann-Whitney', verbose=2, text_format='simple')

annot.apply_test()
annot.annotate()

sns.stripplot(data=fit_summary2, x="parameter", y=estimator_to_plot, hue="group", size=6, order=grps[ii], hue_order=["uniform", "diverse"],
             dodge=True, alpha=1,  edgecolor=[0,0,0], linewidth=0.5)
sns.despine(offset=5, trim=True)
plt.xlabel(None)

In [None]:
ii = 2
plt.figure(figsize=(len(grps[ii])*3,4))
ax = sns.boxplot(x="parameter", y=estimator_to_plot, order=grps[ii], hue_order=["uniform", "diverse"],
            hue="group",
            data=fit_summary2)

pairs = [((xx, "uniform"), (xx, "diverse")) for xx in grps[ii]]
annot = Annotator(ax, pairs=pairs, data=fit_summary2, x="parameter", y=estimator_to_plot, order=grps[ii], hue="group", hue_order=["uniform", "diverse"])
#Annotator.new_plot(ax=ax, pairs=pairs, data=fit_summary2, x="parameter", y=estimator_to_plot, order=grps[ii], hue="group", hue_order=["uniform", "diverse"])
annot.configure(test='Mann-Whitney', verbose=2, text_format='simple')
annot.apply_test()
annot.annotate()
sns.stripplot(data=fit_summary2, x="parameter", y=estimator_to_plot, hue="group", size=6, order=grps[ii], hue_order=["uniform", "diverse"],
             dodge=True, alpha=1,  edgecolor=[0,0,0], linewidth=0.5)
sns.despine(offset=5, trim=True)
plt.xlabel(None)

In [None]:
fit_summary.head()

In [None]:
x = fit_summary.pivot(index="subject", columns=["parameter"],values=[estimator_to_plot]).droplevel(0,axis=1)
x["group"] = fit_summary[["group","subject"]].drop_duplicates("subject").set_index("subject")
# fit_summary.pivot(index="subject", columns=["parameter"], values=["group"])
x