### Load Network

In [31]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from banditpy.models import BanditTrainer2Arm
from pathlib import Path

basepath = Path("D:/Data/mab/rnn_models/")

# ------ Structured network ----------
b2a_s = BanditTrainer2Arm(model_path=basepath / "structured_2arm_model.pt")
b2a_s.load_model()

# ------ Untructured network ----------
b2a_u = BanditTrainer2Arm(model_path=basepath / "unstructured_2arm_model.pt")
b2a_u.load_model()

Model and training history loaded from D:\Data\mab\rnn_models\structured_2arm_model.pt
Loaded 30000 training loss values
Model and training history loaded from D:\Data\mab\rnn_models\unstructured_2arm_model.pt
Loaded 40000 training loss values


### Test Network

In [54]:
from banditpy.core import Bandit2Arm
import matplotlib.pyplot as plt
from neuropy import plotting
from banditpy.core import Bandit2Arm


def df_to_bandit2arm(df):
    task = Bandit2Arm(
        probs=df.loc[:, ["arm1_reward_prob", "arm2_reward_prob"]].to_numpy(),
        choices=df["chosen_action"].to_numpy(),
        rewards=df["reward"].to_numpy(),
        session_ids=df["session_id"].to_numpy(),
    )
    return task


s_on_s = df_to_bandit2arm(b2a_s.evaluate(mode="S", n_sessions=500))
s_on_u = df_to_bandit2arm(b2a_s.evaluate(mode="U", n_sessions=500))

u_on_s = df_to_bandit2arm(b2a_u.evaluate(mode="S", n_sessions=500))
u_on_u = df_to_bandit2arm(b2a_u.evaluate(mode="U", n_sessions=500))

Starting evaluation with fixed weights...
Model and training history loaded from D:\Data\mab\rnn_models\structured_2arm_model.pt
Loaded 30000 training loss values


100%|██████████| 500/500 [00:24<00:00, 20.33it/s]


Evaluation complete.
Starting evaluation with fixed weights...
Model and training history loaded from D:\Data\mab\rnn_models\structured_2arm_model.pt
Loaded 30000 training loss values


100%|██████████| 500/500 [00:24<00:00, 20.63it/s]


Evaluation complete.
Starting evaluation with fixed weights...
Model and training history loaded from D:\Data\mab\rnn_models\unstructured_2arm_model.pt
Loaded 40000 training loss values


100%|██████████| 500/500 [00:24<00:00, 20.41it/s]


Evaluation complete.
Starting evaluation with fixed weights...
Model and training history loaded from D:\Data\mab\rnn_models\unstructured_2arm_model.pt
Loaded 40000 training loss values


100%|██████████| 500/500 [00:24<00:00, 20.74it/s]

Evaluation complete.





### Plotting

In [62]:
from neuropy import plotting
from statplotannot.plots.colormaps import colors_mab
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

fig = plotting.Fig(6, 5)

for i, loss in enumerate([b2a_u.training_loss_history, b2a_s.training_loss_history]):
    ax = fig.subplot(fig.gs[i, 0])
    ax.plot(loss, colors_mab()[i])
    ax.set_title(["Unstructured", "Structured"][i])
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Log loss")

ax = fig.subplot(fig.gs[0, 2])
for i, task in enumerate([u_on_u, u_on_s]):
    ax.plot(task.get_performance(), colors_mab([1, 1.2][i])[0])
    ax.set_xlabel("Trials")
    ax.set_ylim(0.4, 1)
    ax.set_ylabel("Pr (High)")
ax.legend(["congruent", "incongruent"], loc="upper right")
ax.set_title("Unstructured network")

ax = fig.subplot(fig.gs[1, 2])
for i, task in enumerate([s_on_s, s_on_u]):
    ax.plot(task.get_performance(), colors_mab([1, 1.3][i])[1])
    ax.set_xlabel("Trials")
    ax.set_ylim(0.4, 1)
    ax.set_ylabel("Pr (High)")
ax.legend(["congruent", "incongruent"], loc="lower right")
ax.set_title("Structured network")


for i, rnn in enumerate([b2a_u, b2a_s]):

    ax = fig.subplot(fig.gs[i, 1])
    session_data = rnn.analyze_hidden_states(reward_probs=[0.8, 0.2], n_trials=200)
    hidden_states = np.array(
        session_data["hidden_states"]
    )  # Shape: (n_trials, hidden_size)

    # Apply PCA to reduce to 2D
    pca = PCA(n_components=2)
    hidden_states_2d = pca.fit_transform(hidden_states)

    # 1. PCA trajectory colored by time
    scatter1 = ax.scatter(
        hidden_states_2d[:, 0],
        hidden_states_2d[:, 1],
        c=range(len(hidden_states_2d)),
        cmap="copper",
        s=60,
    )
    ax.plot(
        hidden_states_2d[:, 0], hidden_states_2d[:, 1], "k-", alpha=0.3, linewidth=1
    )
    ax.set_title(
        f"Hidden State Trajectory (PCA)\nVar explained: {pca.explained_variance_ratio_.sum():.3f}"
    )
    ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.3f})")
    ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.3f})")
    plt.colorbar(scatter1, ax=ax, label="Trial")