### Performance

In [None]:
import numpy as np
import pandas as pd
import mab_subjects

exps = mab_subjects.rnn_exps1

perf_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    perf = task.get_optimal_choice_probability()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(perf.size) + 1,
            perf=perf,
            name=exp.sub_name,
            # grp="struc" if task.is_structured else "unstruc",
            grp=exp.tag,
        )
    )
    perf_df.append(df)

perf_df = pd.concat(perf_df, ignore_index=True)
mab_subjects.GroupData().save(perf_df, "rnn_perf")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from mab_colors import colors_2arm_swap
from neuropy import plotting
from statplotannot.plots import SeabornPlotter, fix_legend
from mab_subjects import GroupData

df = GroupData().rnn_perf
# df = perf_df

fig = plotting.Fig(5, 3, num=1, fontsize=10)
ax = fig.subplot(fig.gs[1])

SeabornPlotter(data=df, x="trial_id", y="perf", hue="grp", ax=ax).lineplot(
    palette=colors_2arm_swap(), errorbar="se", err_kws=dict(ec=None)
)
fix_legend(ax=ax, loc="lower right", ncols=2, only_labels=True, fw="bold")
ax.set_title("LSTM performance over trials")
ax.set_ylabel("Pr (High)")

### Switching probability

In [None]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

exps = mab_subjects.rnn_exps1
sp_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    sp = SwitchProb2Arm(task).by_trial()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(sp.size) + 1,
            switch_prob=sp,
            name=exp.sub_name,
            grp=exp.tag,
        )
    )
    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)
mab_subjects.GroupData().save(sp_df, "rnn_switch_prob")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw
from mab_colors import colors_2arm_swap
from statplotannot.plots import SeabornPlotter, fix_legend

fig = plotting.Fig(4, 2, size=(8.5, 11), num=1, fontsize=10)

sp_df = mab_subjects.GroupData().rnn_switch_prob

ax = fig.subplot(fig.gs[0])
# ax.axhline(0, color="gray", lw=0.8, zorder=0)

plot_kw = dict(data=sp_df, x="trial_id", y="switch_prob", hue="grp", ax=ax)
sns.lineplot(
    palette=colors_2arm_swap(),
    # palette=["#E89317", "#3980ea"],
    errorbar="se",
    err_kws={"edgecolor": None},
    **plot_kw,
)
fix_legend(ax=ax, loc="upper right", ncols=2, only_labels=True, fw="bold", fs=10)

ax.set_title("LSTM switch probability over trials")
ax.set_ylim(0.1)

### Conditional switching probability

In [None]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

from banditpy.analyses import SwitchProb2Arm
import pandas as pd

exps = mab_subjects.rnn_exps1

sp_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    sp, seq = SwitchProb2Arm(task).by_history(n_past=3, history_as_str=True)

    df = pd.DataFrame(
        dict(
            seq=seq,
            switch_prob=sp,
            name=exp.sub_name,
            grp=exp.tag,
        )
    )
    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)
mab_subjects.GroupData().save(sp_df, "rnn_cond_switch_prob")

In [None]:
from neuropy import plotting
import mab_subjects
from statplotannot.plots import SeabornPlotter
from mab_colors import colors_2arm_swap, colors_2arm

df = mab_subjects.GroupData().rnn_cond_switch_prob
df = df[df["grp"].isin(["u_on_u", "s_on_s"])]

fig = plotting.Fig(7, 3, size=(8.5, 11), num=1, fontsize=10)

ax = fig.subplot(fig.gs[0, :2])
SeabornPlotter(
    data=df,
    x="seq",
    y="switch_prob",
    hue="grp",
    hue_order=["u_on_u", "s_on_s"],
    ax=ax,
).barplot(
    dodge=False, palette=colors_2arm(1), alpha=0.6, errorbar="se"
).bootstrap_test()
ax.tick_params(axis="x", rotation=90)

### Performance matrix
- Bins are represented by probability combination

In [None]:
import numpy as np
import pandas as pd
import mab_subjects
from scipy.ndimage import gaussian_filter1d

exps = mab_subjects.rnn_exps1

prob_perf_df = []

for i, exp in enumerate(exps):
    print(exp.sub_name)
    task = exp.b2a.filter_by_trials(min_trials=100, clip_max=200)

    probs = np.unique(task.probs[task.is_session_start.astype(bool)], axis=0)

    arr = []
    for prob in probs:
        prob = np.round(prob, 1)
        probs_new = np.array([prob, prob[::-1]])
        prob_perf = task.filter_by_probs(probs_new).get_optimal_choice_probability()
        delta_improvement = np.mean(prob_perf[-5:]) - np.mean(prob_perf[:5])
        final_perf = np.mean(prob_perf[-5:])
        middle_perf = np.mean(prob_perf[50:60])
        # prob_perf = gaussian_filter1d(prob_perf, sigma=2)
        arr.append([prob[0], prob[1], middle_perf, final_perf, delta_improvement])
        arr.append([prob[1], prob[0], middle_perf, final_perf, delta_improvement])

    df = pd.DataFrame(
        np.array(arr),
        columns=["prob1", "prob2", "middle_perf", "final_perf", "delta_improvement"],
    )
    df["name"] = exp.sub_name
    df["grp"] = exp.tag

    prob_perf_df.append(df)

prob_perf_df = pd.concat(prob_perf_df, ignore_index=True)
mab_subjects.GroupData().save(prob_perf_df, "rnn_perf_probability_matrix")

In [None]:
prob_perf_df

In [None]:
import matplotlib.pyplot as plt
from neuropy.plotting import Fig
import seaborn as sns
from scipy.stats import binned_statistic_2d
from mab_colors import colors_2arm_swap, colors_2arm
from statplotannot.plots import fix_legend
import mab_subjects
from statplotannot.plots import xtick_format
import numpy as np

fig = Fig(6, 6, size=(12, 11), num=1, fontsize=10)

mat_df = mab_subjects.GroupData().rnn_perf_probability_matrix
mat_df = prob_perf_df.copy()
equal_prob_bool = mat_df["prob1"] == mat_df["prob2"]
mat_df = mat_df[~equal_prob_bool]
mat_df = mat_df[mat_df["grp"].isin(["u_on_u", "s_on_s"])]

vmin = [0.4, 0.4, 0.1]
vmax = [0.9, 0.9, 0.4]
titles = [
    "\nMiddle performance",
    "\nfinal_perf",
    "\nDelta performance\n (last 5 - first 5 trials)",
]

for c, col_name in enumerate(["middle_perf", "final_perf", "delta_improvement"]):

    for g, grp in enumerate(["u_on_u", "s_on_s"]):
        df = mat_df[mat_df["grp"] == grp]
        x, y, values = df["prob1"], df["prob2"], df[col_name]
        bins = np.linspace(0, 0.9, 10) + 0.05
        centers = (bins[:-1] + bins[1:]) / 2
        centers = np.round(centers, 2)

        H, xedges, yedges, _ = binned_statistic_2d(
            x, y, values, statistic=np.nanmean, bins=bins
        )

        ax = fig.subplot(fig.gs[c, g])
        cplot = ax.pcolormesh(
            xedges,
            yedges,
            H.T,
            cmap="plasma",
            vmin=vmin[c],
            vmax=vmax[c],
            shading="auto",
        )
        ax.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])
        ax.set_yticks([0.1, 0.3, 0.5, 0.7, 0.9])
        ax.set_xlabel("Arm 1 prob.")
        ax.set_title(f"{grp} {titles[c]}")
        if g == 0:
            ax.set_ylabel("Arm 2 prob.")
        if g == 1:
            cb = plt.colorbar(
                cplot, ax=ax, label="Performance", shrink=0.7, anchor=(0, 0.9)
            )
            cb.outline.set_visible(False)

    df_correlated = mat_df[(mat_df["prob1"] + mat_df["prob2"]) == 1]

    ax = fig.subplot(fig.gs[c, 2])
    plot_kw = dict(
        data=df_correlated,
        x="prob1",
        y=col_name,
        hue="grp",
        ax=ax,
        palette=colors_2arm(),
    )
    # sns.stripplot(**plot_kw, dodge=True, size=3, alpha=0.4)
    sns.pointplot(
        **plot_kw,
        dodge=True,
        # markers=["o", "s"],
        # linestyles=["-", "--"],
        lw=2,
        errorbar="se",
        alpha=0.6,
    )

    ax.set_xlabel("Arm 1 - Arm 2 prob.")
    xticks = np.delete(np.arange(0.1, 1, 0.1).round(1), 4)

    ax.set_xticks(np.arange(8), [f"{p1}-{p2}" for p1, p2 in zip(xticks, xticks[::-1])])
    xtick_format(ax, rotation=45)
    # ax.set_xticks(ha="right")
    # ax.axis["bottom"].major_ticklabels.set_ha("right")
    ax.set_title("Performance for \ncorrelated probabilities")
    ax.set_ylabel(titles[c][1:])
    fix_legend(ax, frameon=True)
    ax.legend_.remove()

### 2 $\alpha$ + H model 
- When we estimated alpha parameters for chosen and unchosen choices, we found that unstructured env had higher alpha values for chosen arms compared to structured env. So we asked if 'persevrance' for arms/choices is making alpha_chosen higher for unstructured environment.

In [None]:
import numpy as np
import pandas as pd
from banditpy.models import Qlearn2Arm
import mab_subjects

exps = mab_subjects.rnn_exps1

params_df = []

for i, exp in enumerate(exps):
    mab = exp.b2a.filter_by_trials(min_trials=100, clip_max=100)
    print(exp.sub_name)
    qlearn = Qlearn2Arm(mab, model="persev", n_cpu=4)
    qlearn.fit(
        bounds=np.array([(-1, 1), (-1, 1), (0, 1), (1, 10), (0.005, 20)]), n_optimize=5
    )

    qlearn.print_params()
    df = pd.DataFrame(
        {
            "name": exp.sub_name,
            "param": ["alpha_chosen", "alpha_unchosen", "persev", "scaler", "beta"],
            "param_values": qlearn.estimated_params,
            # "grp": "struc" if mab.is_structured else "unstruc",
            "grp": exp.tag,
        }
    )
    params_df.append(df)

params_df = pd.concat(params_df, ignore_index=True)
mab_subjects.GroupData().save(params_df, "rnn_qlearn_2alphaH")

In [None]:
params_df = pd.concat(params_df, ignore_index=True)
mab_subjects.GroupData().save(params_df, "rnn_qlearn_2alphaH")

In [None]:
import mab_subjects
from neuropy import plotting
import pandas as pd
import seaborn as sns
from statannotations.Annotator import Annotator
from statannotations.stats.StatTest import StatTest
from statplot_utils import stat_kw


params_df = mab_subjects.GroupData().rnn_qlearn_2alphaH

fig = plotting.Fig(4, 5, num=1)
ax1 = fig.subplot(fig.gs[0])
ax2 = fig.subplot(fig.gs[1])

plot_kw = dict(
    x="param",
    y="param_values",
    hue="grp",
    hue_order=["u_on_u", "s_on_s"],
    # ax=ax1,
)


bar_kw = dict(
    errorbar="se",
    palette="dark:black",
    linestyle="none",
    alpha=0.5,
    dodge=0.4,
    zorder=1,
    marker=".",
    markersize=10,
    markeredgewidth=0,
    err_kws={"linewidth": 1},
)


strip_kw = dict(palette="husl", size=4, alpha=0.7, dodge=True, zorder=2)
indx_bool = params_df["param"].isin(["scaler", "beta"])
alpha_df = params_df[~indx_bool]
sns.pointplot(alpha_df, ax=ax1, **plot_kw, **bar_kw)
sns.stripplot(alpha_df, ax=ax1, **plot_kw, **strip_kw)

orders = ["alpha_chosen", "alpha_unchosen", "persev"]
pairs = [((_, "u_on_u"), (_, "s_on_s")) for _ in orders]

annotator = Annotator(pairs=pairs, data=alpha_df, ax=ax1, **plot_kw, order=orders)
annotator.configure(test="t-test_ind", **stat_kw, color="k", verbose=True)
annotator.apply_and_annotate()
annotator.reset_configuration()


beta_df = params_df[indx_bool]
sns.pointplot(beta_df, ax=ax2, **plot_kw, **bar_kw)
sns.stripplot(beta_df, ax=ax2, **plot_kw, **strip_kw)
orders = ["scaler", "beta"]
pairs = [((_, "u_on_u"), (_, "s_on_s")) for _ in orders]
annotator = Annotator(pairs=pairs, data=beta_df, ax=ax2, **plot_kw, order=orders)
annotator.configure(test="t-test_ind", **stat_kw, color="k", verbose=True)
annotator.apply_and_annotate()
annotator.reset_configuration()


ax1.axhline(0, ls="--", color="gray", zorder=0, lw=0.8)
ax2.set_xlim(-1, 2)
ax2.set_ylim(1, 15)
ax1.tick_params(axis="x", rotation=30)
ax2.tick_params(axis="x", rotation=30)
ax1.set_xlabel("")
ax2.set_xlabel("")
ax1.legend_.remove()
ax2.legend_.remove()
# ax2.legend_.remove()
ax1.set_ylabel("Estimated alpha values")
ax2.set_ylabel("Estimated beta values")
fig.fig.suptitle("Q-learning in two-armed bandit task")