### Env example, performance matrix, and overall performance

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from neuropy import plotting
import seaborn as sns
import mab_subjects
from statplotannot.plots import fix_legend
import seaborn as sns
from mab_colors import colors_2arm
from statplotannot.plots import fix_legend
import mab_subjects
import numpy as np
from scipy.ndimage import gaussian_filter
from palettable.scientific.sequential import GrayC_7

exps = mab_subjects.struc.Sterling + mab_subjects.unstruc.Debruyne
mov_mean = lambda x, w: np.convolve(x, np.ones(w) / w, mode="valid")

# Indexes for good examples
# Sterling: 21022, 22522
# Debruyne: 54821, 56321

window_size = 30
Titles = ["Structured environment", "Unstructured environment"]

fig = plotting.Fig(11, 6, fontsize=11, constrained_layout=True)

subfig = fig.add_subfigure(fig.gs[:2, :])
axs = subfig.subplots(2, 2, height_ratios=[3, 1], sharex=True, sharey=True)


# --------- Example sessions ---------#
for i, exp in enumerate(exps):

    task = exp.b2a.filter_by_trials(min_trials=100, clip_max=100)
    probs = task.probs
    n_trials = probs.shape[0]

    if i == 0:
        n1 = 21022
        n2 = n1 + 1500
    if i == 1:
        n1 = 54821
        n2 = n1 + 1500

    # n1 = np.random.randint(int(0.8 * n_trials), probs.shape[0])
    # n2 = n1 + 1500
    # print(n1, n2)

    prob1 = probs[:, 0][n1:n2]
    prob2 = probs[:, 1][n1:n2]

    ax = axs[0, i]
    ax.plot(prob1, label="Left", color="blue")
    ax.plot(prob2, label="Right", color="orange")
    # ax.plot(prob1 - prob2, label="P(L) - P(R)", color="purple")
    # ax.spines["right"].set_visible(True)
    ax.set_yticks([0.1, 0.50, 1.00])
    if i == 0:
        ax.set_ylabel("P (reward)")
    ax.legend(loc="upper right")
    fix_legend(ax, only_labels=True)

    # ax2 = ax.twinx()
    ax2 = axs[1, i]

    rewards = task.rewards[n1:n2]
    choices = task.choices[n1:n2]
    choices_left = np.where(choices == 1, 1, 0)
    choices_right = np.where(choices == 2, 1, 0)

    reward_rate = mov_mean(rewards, window_size)
    choice_left_prob = mov_mean(choices_left, window_size)
    choice_right_prob = mov_mean(choices_right, window_size)

    # ax2.plot(reward_rate, label="Reward rate (20-trial MA)", color="red")
    ax2.plot(choice_left_prob, label="P(L)", color="#7a7a7a")
    # ax2.plot(choice_right_prob, label="P(R)", color="purple")
    # ax2.plot(rewards, label="Reward", color="green", alpha=0.5)
    if i == 0:
        ax2.set_ylabel("Choice\nP(Left)")
    ax2.set_xlabel("Trials")

    # subfig.suptitle(Titles[i])

# ------- Performance matrix ------------
df = mab_subjects.GroupData().perf_probability_matrix.latest
df = df[df["lesion"] == "pre_lesion"]
for g, grp in enumerate(["struc", "unstruc"]):
    df_grp = df[df["grp"] == grp]
    perf_mean = df_grp["perf_mat"].mean()
    # perf_mean = gaussian_filter(perf_mean, sigma=0.1)
    perf_mean = np.tril(perf_mean, k=-1)
    mask = np.triu(np.ones_like(perf_mean, dtype=bool), k=0)

    perf_mean[mask] = np.nan

    ax = fig.subplot(fig.gs[2:4, 2 * g : 2 * g + 2])

    im = ax.pcolormesh(
        perf_mean.T,
        # cmap=GrayC_7.mpl_colormap,
        cmap="Blues",
        shading="auto",
        vmin=0.5,
        vmax=1,
    )
    ticks = np.arange(0, 8) + 0.5
    ax.set_xticks(ticks, [10, 20, 30, 40, 60, 70, 80, 90])
    ax.set_yticks(ticks, [10, 20, 30, 40, 60, 70, 80, 90])
    ax.set_xlim(1, 8)
    ax.set_ylim(0, 7)
    ax.spines["right"].set_visible(True)
    ax.spines["top"].set_visible(True)
    ax.set_xlabel("Higher reward probability")
    ax.set_ylabel("Lower reward probability")

    if g == 0:
        cax = plt.axes((0.2, 0.5, 0.3, 0.01))
        cb = plt.colorbar(im, cax=cax, orientation="horizontal")
        cb.set_label("Performance")


# ------ Overall performance ------------
perf_df = mab_subjects.GroupData().perf_AAdataset_Block1.latest
df1 = perf_df[perf_df["lesion"].isin(["pre_lesion"])]

for d, df in enumerate([df1]):

    df["grp_new"] = df["grp"] + "_" + df["lesion"]

    ax = fig.subplot(fig.gs[2:4, 4:6])
    sns.lineplot(
        data=df,
        x="trial_id",
        y="performance",
        hue="grp",
        # palette=colors_2arm(),
        palette=colors_2arm() + ["#7a7a7a"],
        hue_order=["unstruc", "struc", "struc_in_unstruc"],
        errorbar="se",
    )
    fix_legend(ax)
    ax.set_ylim(0.2, 0.9)
    ax.set_ylabel("P(High)")
    ax.set_title("Overall performance")
    # ax.grid(axis="y", zorder=-1, alpha=0.5)

In [None]:
plt.imshow(perf_mean, cmap="Blues", vmin=0.5, vmax=1)

### Switch probability 

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw
from statplotannot.plots import fix_legend, SeabornPlotter
from mab_colors import colors_2arm

fig = plotting.Fig(4, 4, size=(8.5, 11), num=1, fontsize=12)

df1 = mab_subjects.GroupData().swp_AAdataset_Block1.latest
df1 = df1[df1["lesion"] == "pre_lesion"]

df2 = mab_subjects.GroupData().swp_trial_history.latest
df2 = df2[df2["lesion"] == "pre_lesion"]

hue_order = ["unstruc", "struc"]
palette = colors_2arm()

ax = fig.subplot(fig.gs[:2])
plot_kw = dict(
    data=df1,
    x="trial_id",
    y="switch_prob_block1_smth",
    hue="grp",
    hue_order=hue_order,
    ax=ax,
)
sns.lineplot(
    palette=palette,
    # palette=["#E89317", "#3980ea"],
    errorbar="se",
    err_kws={"edgecolor": None},
    **plot_kw,
)

ax.set_title("Switch probability")
ax.set_ylabel("Switch probability")
ax.set_xlabel("Trial number")
fix_legend(ax)
ax.set_ylim(0, 0.25)
ax.set_xlim(1, 100)
ax.set_xticks([1, 25, 50, 75, 100])

ax = fig.subplot(fig.gs[2])
sns.stripplot(
    data=df2,
    x="seq",
    y="switch_prob",
    hue="grp",
    hue_order=hue_order,
    palette=colors_2arm(1.2),
    dodge=True,
    alpha=0.7,
    edgecolor="white",
    linewidth=0.5,
    zorder=2,
    ax=ax,
)

SeabornPlotter(
    data=df2,
    x="seq",
    y="switch_prob",
    hue="grp",
    hue_order=["unstruc", "struc"],
    ax=ax,
).barplot(dodge=True, palette=colors_2arm(), alpha=0.8, errorbar="se").bootstrap_test()
ax.set_ylabel("Switch probability")
ax.set_xlabel("")
fix_legend(ax)
# ax.set_ylim(0, 0.25)
ax.set_yticks([0, 0.05, 0.1, 0.15, 0.2])


# fig.savefig(mab_subjects.figpath / "mab_switching_probability_block1", format="svg")

### Performance: Pre vs Naive

In [None]:
from statplotannot.plots import SeabornPlotter, fix_legend
from mab_colors import colors_2arm
import seaborn as sns
from neuropy import plotting
import mab_subjects

fig = plotting.Fig(5, 3, fontsize=11)

perf_df = mab_subjects.GroupData().perf_AAdataset_Block1.latest
# perf_df = perf_df[perf_df["dataset"] == "ACdataset"]
# df1 = perf_df[perf_df["lesion"].isin(["pre_lesion"])]
# df2 = perf_df[perf_df["lesion"].isin(["pre_lesion", "post_lesion_OFC"])]
df = perf_df[perf_df["lesion"].isin(["pre_lesion", "naive_lesion_OFC"])]


df_unstruc = df[df["grp"] == "unstruc"]
df_struc = df[df["grp"] == "struc"]
colors = colors_2arm()

for d2, df_temp in enumerate([df_unstruc, df_struc]):
    df_temp["grp_new"] = df_temp["grp"] + "_" + df_temp["lesion"]

    ax = fig.subplot(fig.gs[0, d2])
    sns.lineplot(
        data=df_temp,
        x="trial_id",
        y="performance",
        hue="grp_new",
        # palette=colors_2arm(),
        palette=[colors[d2], "gray"],
        # hue_order=["unstruc_pre_lesion", "struc_pre_lesion"],
        errorbar="se",
    )
    fix_legend(ax)
    ax.set_ylim(0.4, 0.9)
    ax.set_ylabel("P(High)")
    # ax.set_title(y)
    # ax.grid(axis="y", zorder=-1, alpha=0.5)

### Qlearning lesioned

In [None]:
from neuropy import plotting
import pandas as pd
from pathlib import Path
import seaborn as sns
from mab_colors import colors_2arm
from statplotannot.plots import fix_legend, SeabornPlotter
import mab_subjects

fp = Path(
    "C:/Users/asheshlab/OneDrive/academia/analyses/adlab/results/qlearn_2alphaH_fit_1stBlock_24112025_154819.csv"
)
data = pd.read_csv(fp)
data = data[data["lesion"].isin(["pre_lesion", "naive_lesion_OFC"])]

data1 = data[data["param"].isin(["alpha_chosen", "alpha_unchosen"])]
data1["grp_new"] = data1["grp"] + "_" + data1["lesion"]

data2 = data[data["param"].isin(["persev"])]
data2["grp_new"] = data2["grp"] + "_" + data2["lesion"]

data3 = data[data["param"].isin(["beta"])]
data3["grp_new"] = data3["grp"] + "_" + data3["lesion"]

fig = plotting.Fig(4, 6, num=2, fontsize=11)
fig.fig.suptitle("Q-learning parameters in two-armed bandit task")

hue_order = [
    "unstruc_pre_lesion",
    "unstruc_naive_lesion_OFC",
    "struc_pre_lesion",
    "struc_naive_lesion_OFC",
]

palette = [
    colors_2arm(1.3)[0],
    colors_2arm(0.7)[0],
    colors_2arm(1.3)[1],
    colors_2arm(0.7)[1],
]

axs = [fig.subplot(fig.gs[0, :2]), fig.subplot(fig.gs[2]), fig.subplot(fig.gs[3])]

for i, df in enumerate([data1, data2, data3]):
    ax = axs[i]

    SeabornPlotter(
        data=df,
        x="param",
        y="param_values",
        hue="grp_new",
        hue_order=hue_order,
        ax=ax,
    ).barplot(
        palette=palette,
        linestyle="none",
        alpha=0.9,
        dodge=0.4,
        zorder=1,
        # width=width,
        err_kws={"linewidth": 1},
        errorbar="se",
    ).bootstrap_test()

    ax.set_ylabel("Parameter value")
    ax.set_xlabel("")
    # ax.tick_params(axis="x", rotation=30)
    for label in ax.get_xticklabels():
        label.set_rotation(0)
        label.set_rotation_mode("anchor")
        label.set_horizontalalignment("right")

    if i == 0:
        fix_legend(ax, only_labels=False)
        ax.set_xlim(-0.6, 1.8)
        ax.axhline(0, ls="--", color="gray", zorder=0, lw=0.8)
        ax.set_ylim(-0.1, 0.4)
    else:
        ax.legend_.remove()
        # ax.set_xlim(-0.5, 0.5)
        ax.set_xlim(-0.75, 0.75)