### Performance

In [64]:
import numpy as np
import pandas as pd
import mab_subjects

exps = (
    mab_subjects.rnn_u_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
)

perf_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    perf = task.get_optimal_choice_probability()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(perf.size) + 1,
            perf=perf,
            name=exp.sub_name,
            # grp="struc" if task.is_structured else "unstruc",
            grp=exp.tag,
        )
    )
    perf_df.append(df)

perf_df = pd.concat(perf_df, ignore_index=True)
mab_subjects.GroupData().save(perf_df, "rnn_perf")

rnn_perf saved


In [77]:
import matplotlib.pyplot as plt
import seaborn as sns
from mab_colors import colors_2arm_swap
from neuropy import plotting
from statplotannot.plots import SeabornPlotter, fix_legend
from mab_subjects import GroupData

df = GroupData().rnn_perf

fig = plotting.Fig(5, 3, num=1, fontsize=10)
ax = fig.subplot(fig.gs[1])

SeabornPlotter(data=df, x="trial_id", y="perf", hue="grp", ax=ax).lineplot(
    palette=colors_2arm_swap(), errorbar="se", err_kws=dict(ec=None)
)
fix_legend(ax=ax, loc="lower right", ncols=2, only_labels=True, fw="bold")
ax.set_title("LSTM performance over trials")
ax.set_ylabel("Pr (High)")

Text(0, 0.5, 'Pr (High)')

### Switching probability

In [66]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

exps = (
    mab_subjects.rnn_u_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
)

sp_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    sp = SwitchProb2Arm(task).by_trial()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(sp.size) + 1,
            switch_prob=sp,
            name=exp.sub_name,
            grp=exp.tag,
        )
    )
    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)
mab_subjects.GroupData().save(sp_df, "rnn_switch_prob")

rnn_switch_prob saved


In [78]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw
from mab_colors import colors_2arm_swap
from statplotannot.plots import SeabornPlotter, fix_legend

fig = plotting.Fig(4, 2, size=(8.5, 11), num=1, fontsize=10)

sp_df = mab_subjects.GroupData().rnn_switch_prob

ax = fig.subplot(fig.gs[0])
# ax.axhline(0, color="gray", lw=0.8, zorder=0)

plot_kw = dict(data=sp_df, x="trial_id", y="switch_prob", hue="grp", ax=ax)
sns.lineplot(
    palette=colors_2arm_swap(),
    # palette=["#E89317", "#3980ea"],
    errorbar="se",
    err_kws={"edgecolor": None},
    **plot_kw,
)
fix_legend(ax=ax, loc="upper right", ncols=2, only_labels=True, fw="bold")

ax.set_title("LSTM switch probability over trials")
ax.set_ylim(0.1)

(0.1, 0.3307073039413049)

### Conditional switching probability

In [68]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

from banditpy.analyses import SwitchProb2Arm
import pandas as pd

exps = (
    mab_subjects.rnn_u_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
)

sp_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    sp, seq = SwitchProb2Arm(task).by_history(n_past=3, history_as_str=True)

    df = pd.DataFrame(
        dict(
            seq=seq,
            switch_prob=sp,
            name=exp.sub_name,
            grp=exp.tag,
        )
    )
    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)
mab_subjects.GroupData().save(sp_df, "rnn_cond_switch_prob")

rnn_cond_switch_prob saved


In [83]:
from neuropy import plotting
import mab_subjects
from statplotannot.plots import SeabornPlotter
from mab_colors import colors_2arm_swap, colors_2arm

df = mab_subjects.GroupData().rnn_cond_switch_prob
df = df[df["grp"].isin(["u_on_u", "s_on_s"])]

fig = plotting.Fig(7, 3, size=(8.5, 11), num=1, fontsize=10)

ax = fig.subplot(fig.gs[0, :2])
SeabornPlotter(
    data=df,
    x="seq",
    y="switch_prob",
    hue="grp",
    hue_order=["u_on_u", "s_on_s"],
    ax=ax,
).barplot(
    dodge=False, palette=colors_2arm(1), alpha=0.6, errorbar="se"
).bootstrap_test()
ax.tick_params(axis="x", rotation=90)