### Load Network

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from banditpy.models import BanditTrainer2Arm
from pathlib import Path

basepath = Path("D:/Data/mab/rnn_models/")
struc_models, unstruc_models = [], []

for i in range(10):
    # ------ Structured network ----------
    s = BanditTrainer2Arm(model_path=basepath / f"structured_2arm_model{i}.pt")
    s.load_model()
    struc_models.append(s)

    # ------ Untructured network ----------
    u = BanditTrainer2Arm(model_path=basepath / f"unstructured_2arm_model{i}.pt")
    u.load_model()
    unstruc_models.append(u)

### Test Network

In [None]:
struc_models[0].model_path.stem

In [None]:
from banditpy.core import Bandit2Arm
import matplotlib.pyplot as plt
from neuropy import plotting
from banditpy.core import Bandit2Arm


def df_to_b2a(df):
    task = Bandit2Arm(
        probs=df.loc[:, ["arm1_reward_prob", "arm2_reward_prob"]].to_numpy(),
        choices=df["chosen_action"].to_numpy(),
        rewards=df["reward"].to_numpy(),
        session_ids=df["session_id"].to_numpy(),
    )
    return task


basepath = Path("D:/Data/mab/rnn_data/")
for models in [struc_models, unstruc_models]:
    for i, model in enumerate(models):
        model_name = model.model_path.stem
        model_folder = basepath / model_name
        model_folder.mkdir(exist_ok=True)

        struc_exp_folder = model_folder / f"{model_name}_structured"
        struc_exp_folder.mkdir(exist_ok=True)
        dfs = model.evaluate(mode="S", n_sessions=300)
        dfs.to_csv(struc_exp_folder / f"{model_name}_structured.csv", index=False)

        unstruc_exp_folder = model_folder / f"{model_name}_unstructured"
        unstruc_exp_folder.mkdir(exist_ok=True)
        dfu = model.evaluate(mode="U", n_sessions=300)
        dfu.to_csv(unstruc_exp_folder / f"{model_name}_unstructured.csv", index=False)

    # print(f"Structured Model {i}: {model.model_path}")

# s_on_s = [df_to_b2a(_.evaluate(mode="S", n_sessions=300)) for _ in struc_models]
# s_on_u = [df_to_b2a(_.evaluate(mode="U", n_sessions=300)) for _ in struc_models]

# u_on_s = [df_to_b2a(_.evaluate(mode="S", n_sessions=300)) for _ in unstruc_models]
# u_on_u = [df_to_b2a(_.evaluate(mode="U", n_sessions=300)) for _ in unstruc_models]

In [None]:
from pathlib import Path

basepath = Path("D:/Data/mab/rnn_data/")
folders = sorted(basepath.glob("unstructured**"))

In [None]:
folders[0].stem

### Plotting

In [None]:
from neuropy import plotting
from mab_colors import colors_2arm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

fig = plotting.Fig(5, 7, num=1, size=(14, 11), fontsize=10)

for i, loss in enumerate([b2a_u.training_loss_history, b2a_s.training_loss_history]):
    ax = fig.subplot(fig.gs[i, 0])
    ax.plot(loss, colors_2arm()[i])
    ax.set_title(["Unstructured", "Structured"][i])
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Log loss")

ax = fig.subplot(fig.gs[0, 1])
for i, task in enumerate([u_on_u, u_on_s]):
    ax.plot(task.get_optimal_choice_probability(), colors_2arm([1, 1.3][i])[0])
    ax.set_xlabel("Trials")
    ax.set_ylim(0.4, 1)
    ax.set_ylabel("Pr (High)")
ax.legend(["self", "struc"], loc="lower right")
ax.set_title("Unstructured network")

ax = fig.subplot(fig.gs[1, 1])
for i, task in enumerate([s_on_s, s_on_u]):
    ax.plot(task.get_optimal_choice_probability(), colors_2arm([1, 1.4][i])[1])
    ax.set_xlabel("Trials")
    ax.set_ylim(0.4, 1)
    ax.set_ylabel("Pr (High)")
ax.legend(["self", "unstruc"], loc="lower right")
ax.set_title("Structured network")

cmap = sns.color_palette("hot", as_cmap=True)
# cmap = sns.cubehelix_palette(as_cmap=True)

reward_probs = np.array([[0.8, 0.2], [0.56, 0.44], [0.25, 0.92], [0.38, 0.46]])
for r, rw in enumerate(reward_probs):
    for i, rnn in enumerate([b2a_u, b2a_s]):
        ax = fig.subplot(fig.gs[i, r + 2])

        session_data = rnn.analyze_hidden_states(reward_probs=rw, n_trials=100)
        hidden_states = np.array(
            session_data["hidden_states"]
        )  # Shape: (n_trials, hidden_size)
        actions = np.array(session_data["actions"])  # Shape: (n_trials,)
        marker_size = np.ones_like(actions)

        # marker_size[actions == 1] = 5  # Size for arm 1
        # marker_size[actions == 2] = 15  # Size for arm 2

        marker_size[actions == np.argmax(rw) + 1] = 10  # Size for high arm
        marker_size[actions == np.argmin(rw) + 1] = 30  # Size for low arm

        # Apply PCA to reduce to 2D
        pca = PCA(n_components=2)
        hidden_states_2d = pca.fit_transform(hidden_states)

        # 1. PCA trajectory colored by time
        scatter1 = ax.scatter(
            hidden_states_2d[:, 0],
            hidden_states_2d[:, 1],
            c=range(len(hidden_states_2d)),
            cmap=cmap,
            s=marker_size,
        )
        ax.plot(
            hidden_states_2d[:, 0], hidden_states_2d[:, 1], "k-", alpha=0.3, linewidth=1
        )
        # ax.set_title(
        #     f"Trajectory (PCA)\nVar explained: {pca.explained_variance_ratio_.sum():.3f}"
        # )

        ax.set_title(f"Trajectory (PCA)\nreward: {rw}")

        ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.3f})")
        ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.3f})")

        # if r == 0:
        #     plt.colorbar(scatter1, ax=ax, label="Trial")

### Performance

In [1]:
import numpy as np
import pandas as pd
import mab_subjects

exps = (
    mab_subjects.rnn_s_on_s
    + mab_subjects.rnn_s_on_u
    + mab_subjects.rnn_u_on_s
    + mab_subjects.rnn_u_on_u
)

perf_df = []

for i, exp in enumerate(exps):

    task = exp.b2a
    perf = task.get_optimal_choice_probability()

    df = pd.DataFrame(
        dict(
            trial_id=np.arange(perf.size) + 1,
            perf=perf,
            name=exp.sub_name,
            # grp="struc" if task.is_structured else "unstruc",
            grp=exp.tag,
        )
    )
    perf_df.append(df)

perf_df = pd.concat(perf_df, ignore_index=True)

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.lineplot(
    data=perf_df,
    x="trial_id",
    y="perf",
    hue="grp",
    # markers=True,
    # dashes=False,
    palette="tab10",
    errorbar="se",
)

<Axes: xlabel='trial_id', ylabel='perf'>

### Switching probability

In [None]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

sp_df = []
for i, task in enumerate([u_on_u, s_on_s]):
    sp = SwitchProb2Arm(task=task).by_trial()
    df = pd.DataFrame(
        {
            "trial_id": np.arange(len(sp)) + 1,
            "switch_prob": sp,
            # "switch_prob_smooth": switch_prob_smooth,
            "name": ["Unstructured", "Structured"][i],
            "grp": "struc" if task.is_structured else "unstruc",
        }
    )

    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw
from statplotannot.plots.colormaps import colors_mab

fig = plotting.Fig(4, 2, size=(8.5, 11), num=1, fontsize=10)

grpdata = mab_subjects.GroupData()

ax = fig.subplot(fig.gs[0])
# ax.axhline(0, color="gray", lw=0.8, zorder=0)

plot_kw = dict(data=sp_df, x="trial_id", y="switch_prob", hue="grp", ax=ax)
sns.lineplot(
    palette=colors_mab(),
    # palette=["#E89317", "#3980ea"],
    errorbar="se",
    err_kws={"edgecolor": None},
    **plot_kw,
)

### Conditional switching probability

In [None]:
from banditpy.analyses import SwitchProb2Arm
import pandas as pd

sp_df = []
for i, task in enumerate([u_on_u, s_on_s]):
    sp, seq = SwitchProb2Arm(task=task).by_history(n_past=3, history_as_str=True)
    df = pd.DataFrame(
        {
            "trial_id": np.arange(len(sp)) + 1,
            "seq": seq,
            "switch_prob": sp,
            # "switch_prob_smooth": switch_prob_smooth,
            "name": ["Unstructured", "Structured"][i],
            "grp": "struc" if task.is_structured else "unstruc",
        }
    )

    sp_df.append(df)

sp_df = pd.concat(sp_df, ignore_index=True)

In [None]:
sp_df["grp"].unique()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from neuropy import plotting
import mab_subjects
import numpy as np
from statannotations.Annotator import Annotator
from statplot_utils import stat_kw
from statplotannot.plots.colormaps import colors_mab
from statplotannot.plots import SeabornPlotter

fig = plotting.Fig(4, 2, size=(8.5, 11), num=1, fontsize=10)


ax = fig.subplot(fig.gs[i, :2])
SeabornPlotter(
    data=sp_df,
    x="seq",
    y="switch_prob",
    hue="grp",
    hue_order=["unstruc", "struc"],
    ax=ax,
).barplot(dodge=False, palette=colors_mab(1), alpha=0.6, errorbar="se")
ax.tick_params(axis="x", rotation=90)