### Performance across sessions binned
- Calculate performance across sessions and bin sessions into roughly 10 sessions per bin. This one also calculates "struc_in_unstruc" sessions

In [None]:
import numpy as np
import pandas as pd
import mab_subjects

exps = mab_subjects.unstruc.allsess + mab_subjects.struc.allsess

perf_df = []

for i, exp in enumerate(exps):
    print(exp.sub_name)
    mab = exp.mab.keep_by_trials(min_trials=100, clip_max=150)
    perf = mab.get_performance(bin_size=10)

    df = pd.DataFrame(data=perf, columns=np.arange(perf.shape[1]) + 1)
    df["session_bin"] = np.arange(perf.shape[0]) + 1
    df = df.melt(id_vars="session_bin", var_name="trial_id", value_name="perf")
    df.sort_values(by=["session_bin", "trial_id"], inplace=True)

    df["name"] = exp.sub_name
    df["grp"] = "struc" if mab.is_structured else "unstruc"

    perf_df.append(df)

    if ~mab.is_structured:
        mab2 = exp.mab.keep_by_trials(min_trials=100, clip_max=150)
        session_prob_sum = mab2.probs[mab2.is_session_start.astype(bool)].sum(axis=1)
        good_sessions = mab2.sessions[session_prob_sum == 1]
        assert np.all(good_sessions)
        mab2 = mab2.keep_sessions_by_id(good_sessions)

        perf = mab2.get_performance(bin_size=10)

        df2 = pd.DataFrame(data=perf, columns=np.arange(perf.shape[1]) + 1)
        df2["session_bin"] = np.arange(perf.shape[0]) + 1
        df2 = df2.melt(id_vars="session_bin", var_name="trial_id", value_name="perf")
        df2.sort_values(by=["session_bin", "trial_id"], inplace=True)

        df2["name"] = exp.sub_name
        df2["grp"] = "struc_in_unstruc"
        perf_df.append(df2)

perf_df = pd.concat(perf_df, ignore_index=True)
mab_subjects.GroupData().save(perf_df, "perf_100min150max_10bin")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw

fig = plotting.Fig(1, 3, size=(11, 3), num=1)

grpdata = mab_subjects.GroupData()
df = grpdata.perf_100min150max_10bin

ax = fig.subplot(fig.gs[0])
# ax.axhline(0, color="gray", lw=0.8, zorder=0)

plot_kw = dict(data=df, x="trial_id", y="perf", hue="grp", ax=ax)
sns.lineplot(
    # palette=["gray", "gray"],
    # edgecolor="white",
    # facecolor=(0, 0, 0, 0),
    # alpha=0.4,
    errorbar="se",
    **plot_kw,
)


# orders = ["unstruc", "struc"]
# pairs = [(("unstruc"), ("struc"))]
# annotator = Annotator(pairs=pairs, order=orders, **plot_kw)
# annotator.configure(test="Kruskal", **stat_kw, color="k", verbose=True)
# annotator.apply_and_annotate()
# annotator.reset_configuration()

ax.set_title("Performance")
ax.set_xticks([1, 50, 100, 150])
# ax.legend("")