### Performance

In [1]:
import numpy as np
import pandas as pd
import mab_subjects

exps = (
    mab_subjects.rnn_u_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
)

perf_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    perf = task.get_optimal_choice_probability()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(perf.size) + 1,
            perf=perf,
            name=exp.sub_name,
            # grp="struc" if task.is_structured else "unstruc",
            grp=exp.tag,
        )
    )
    perf_df.append(df)

perf_df = pd.concat(perf_df, ignore_index=True)
# mab_subjects.GroupData().save(perf_df, "rnn_perf")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from mab_colors import colors_2arm_swap
from neuropy import plotting
from statplotannot.plots import SeabornPlotter, fix_legend
from mab_subjects import GroupData

df = GroupData().rnn_perf

fig = plotting.Fig(5, 3, num=1, fontsize=10)
ax = fig.subplot(fig.gs[1])

SeabornPlotter(data=df, x="trial_id", y="perf", hue="grp", ax=ax).lineplot(
    palette=colors_2arm_swap(), errorbar="se", err_kws=dict(ec=None)
)
fix_legend(ax=ax, loc="lower right", ncols=2, only_labels=True, fw="bold")
ax.set_title("LSTM performance over trials")
ax.set_ylabel("Pr (High)")

### Switching probability

In [None]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

exps = (
    mab_subjects.rnn_u_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
)

sp_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    sp = SwitchProb2Arm(task).by_trial()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(sp.size) + 1,
            switch_prob=sp,
            name=exp.sub_name,
            grp=exp.tag,
        )
    )
    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)
mab_subjects.GroupData().save(sp_df, "rnn_switch_prob")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw
from mab_colors import colors_2arm_swap
from statplotannot.plots import SeabornPlotter, fix_legend

fig = plotting.Fig(4, 2, size=(8.5, 11), num=1, fontsize=10)

sp_df = mab_subjects.GroupData().rnn_switch_prob

ax = fig.subplot(fig.gs[0])
# ax.axhline(0, color="gray", lw=0.8, zorder=0)

plot_kw = dict(data=sp_df, x="trial_id", y="switch_prob", hue="grp", ax=ax)
sns.lineplot(
    palette=colors_2arm_swap(),
    # palette=["#E89317", "#3980ea"],
    errorbar="se",
    err_kws={"edgecolor": None},
    **plot_kw,
)
fix_legend(ax=ax, loc="upper right", ncols=2, only_labels=True, fw="bold")

ax.set_title("LSTM switch probability over trials")
ax.set_ylim(0.1)

### Conditional switching probability

In [None]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

from banditpy.analyses import SwitchProb2Arm
import pandas as pd

exps = (
    mab_subjects.rnn_u_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
)

sp_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    sp, seq = SwitchProb2Arm(task).by_history(n_past=3, history_as_str=True)

    df = pd.DataFrame(
        dict(
            seq=seq,
            switch_prob=sp,
            name=exp.sub_name,
            grp=exp.tag,
        )
    )
    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)
mab_subjects.GroupData().save(sp_df, "rnn_cond_switch_prob")

In [None]:
from neuropy import plotting
import mab_subjects
from statplotannot.plots import SeabornPlotter
from mab_colors import colors_2arm_swap, colors_2arm

df = mab_subjects.GroupData().rnn_cond_switch_prob
df = df[df["grp"].isin(["u_on_u", "s_on_s"])]

fig = plotting.Fig(7, 3, size=(8.5, 11), num=1, fontsize=10)

ax = fig.subplot(fig.gs[0, :2])
SeabornPlotter(
    data=df,
    x="seq",
    y="switch_prob",
    hue="grp",
    hue_order=["u_on_u", "s_on_s"],
    ax=ax,
).barplot(
    dodge=False, palette=colors_2arm(1), alpha=0.6, errorbar="se"
).bootstrap_test()
ax.tick_params(axis="x", rotation=90)

### Performance matrix
- Bins are represented by probability combination

In [3]:
import numpy as np
import pandas as pd
import mab_subjects
from scipy.ndimage import gaussian_filter1d

exps = mab_subjects.rnn_exps

prob_perf_df = []

for i, exp in enumerate(exps):
    print(exp.sub_name)
    task = exp.b2a.filter_by_trials(min_trials=100, clip_max=200)

    probs = np.unique(task.probs[task.is_session_start.astype(bool)], axis=0)

    arr = []
    for prob in probs:
        prob_perf = task.filter_by_probs(prob).get_optimal_choice_probability()
        delta_improvement = np.mean(prob_perf[-3:]) - np.mean(prob_perf[:3])
        final_perf = np.mean(prob_perf[-3:])
        middle_perf = np.mean(prob_perf[50:60])
        # prob_perf = gaussian_filter1d(prob_perf, sigma=2)
        arr.append([prob[0], prob[1], middle_perf, final_perf, delta_improvement])

    df = pd.DataFrame(
        np.array(arr),
        columns=["prob1", "prob2", "middle_perf", "final_perf", "delta_improvement"],
    )
    df["name"] = exp.sub_name
    df["grp"] = exp.tag

    prob_perf_df.append(df)

prob_perf_df = pd.concat(prob_perf_df, ignore_index=True)
mab_subjects.GroupData().save(prob_perf_df, "rnn_perf_probability_matrix")

unstructured_2arm_model0_unstructured
unstructured_2arm_model1_unstructured
unstructured_2arm_model2_unstructured
unstructured_2arm_model3_unstructured
unstructured_2arm_model4_unstructured
unstructured_2arm_model5_unstructured
unstructured_2arm_model6_unstructured
unstructured_2arm_model7_unstructured
unstructured_2arm_model8_unstructured
unstructured_2arm_model9_unstructured
unstructured_2arm_model0_structured
unstructured_2arm_model1_structured
unstructured_2arm_model2_structured
unstructured_2arm_model3_structured
unstructured_2arm_model4_structured
unstructured_2arm_model5_structured
unstructured_2arm_model6_structured
unstructured_2arm_model7_structured
unstructured_2arm_model8_structured
unstructured_2arm_model9_structured
structured_2arm_model0_structured
structured_2arm_model1_structured
structured_2arm_model2_structured
structured_2arm_model3_structured
structured_2arm_model4_structured
structured_2arm_model5_structured
structured_2arm_model6_structured
structured_2arm_model7

In [35]:
import matplotlib.pyplot as plt
from neuropy.plotting import Fig
import seaborn as sns
from scipy.stats import binned_statistic_2d
from mab_colors import colors_2arm_swap
from statplotannot.plots import fix_legend
import mab_subjects
from statplotannot.plots import xtick_format
import numpy as np

fig = Fig(6, 6, size=(12, 11), num=1, fontsize=10)

mat_df = mab_subjects.GroupData().rnn_perf_probability_matrix
mat_df = prob_perf_df.copy()
half_prob_indx = mat_df["prob1"] == mat_df["prob2"]
mat_df = mat_df[~half_prob_indx]

vmin = [0.4, 0.4, 0.1]
vmax = [0.9, 0.9, 0.4]
titles = [
    "\nMiddle performance",
    "\nfinal_perf",
    "\nDelta performance\n (last 3 - first 3 trials)",
]

for c, col_name in enumerate(["middle_perf", "final_perf", "delta_improvement"]):

    for g, grp in enumerate(["u_on_u", "u_on_s", "s_on_s", "s_on_u"]):
        df = mat_df[mat_df["grp"] == grp]
        x, y, values = df["prob1"], df["prob2"], df[col_name]
        bins = np.linspace(0, 0.9, 10) + 0.05
        centers = (bins[:-1] + bins[1:]) / 2
        centers = np.round(centers, 2)

        H, xedges, yedges, _ = binned_statistic_2d(
            x, y, values, statistic=np.mean, bins=bins
        )

        ax = fig.subplot(fig.gs[c, g])
        cplot = ax.pcolormesh(
            xedges,
            yedges,
            H.T,
            cmap="plasma",
            vmin=vmin[c],
            vmax=vmax[c],
            shading="auto",
        )
        ax.set_xlabel("Arm 1 prob.")
        ax.set_title(f"{grp} {titles[c]}")
        if g == 0:
            ax.set_ylabel("Arm 2 prob.")
        if g == 3:
            cb = plt.colorbar(
                cplot, ax=ax, label="Performance", shrink=0.7, anchor=(0, 0.9)
            )
            cb.outline.set_visible(False)

    df_correlated = mat_df[(mat_df["prob1"] + mat_df["prob2"]) == 1]

    ax = fig.subplot(fig.gs[c, 4])
    plot_kw = dict(
        data=df_correlated,
        x="prob1",
        y=col_name,
        hue="grp",
        ax=ax,
        palette=colors_2arm_swap(),
    )
    # sns.stripplot(**plot_kw, dodge=True, size=3, alpha=0.4)
    sns.pointplot(
        **plot_kw,
        dodge=True,
        # markers=["o", "s"],
        # linestyles=["-", "--"],
        lw=2,
        errorbar="se",
        alpha=0.6,
    )

    ax.set_xlabel("Arm 1 - Arm 2 prob.")
    xticks = np.delete(np.arange(0.1, 1, 0.1).round(1), 4)

    ax.set_xticks(np.arange(8), [f"{p1}-{p2}" for p1, p2 in zip(xticks, xticks[::-1])])
    xtick_format(ax, rotation=45)
    # ax.set_xticks(ha="right")
    # ax.axis["bottom"].major_ticklabels.set_ha("right")
    ax.set_title("Performance for \ncorrelated probabilities")
    ax.set_ylabel(titles[c][1:])
    fix_legend(ax, frameon=True)
    ax.legend_.remove()