# Training an Agent on the Perceptual Decision Making Task


## Import Libraries


In [None]:
import warnings
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import gymnasium as gym
from neurogym.wrappers.monitor import Monitor
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env import DummyVecEnv
from pathlib import Path
from neurogym.utils.ngym_random import TruncExp
from neurogym.wrappers.reaction_time import ReactionTime
from neurogym.wrappers.pass_action import PassAction
from neurogym.wrappers.pass_reward import PassReward
from neurogym.wrappers.side_bias import SideBias

root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root_dir not in sys.path:
    sys.path.append(root_dir)
from plotting import plot_env
from utilities import compute_folder_metrics, collect_aggregate_pdm, plot_aggregate_pdm

# Suppress warnings
warnings.filterwarnings("ignore")

## RL Configuration and Training


### Environment Configuration


In [None]:
# Environment parameters
EVAL_TRIALS = 1000
dt = 100
dim_ring = 2  # Number of choices in the ring representation
abort = False  # Whether to allow aborting the trial if the agent does not fixate # TODO: add this as attribute to the environment
rewards = {
    "abort": -0.1,
    "correct": +1.0,
    "fail": 0.0
}
timing = {
    "fixation": TruncExp(600, 400, 700),
    "stimulus": 2000,
    "delay": 0,
    "decision": 100,
}
sigma = 1.0 # Standard deviation of the Gaussian noise in the ring representation

kwargs = {
    "dt": dt,
    "dim_ring": dim_ring,
    "rewards": rewards,
    "timing": timing,
    "sigma": sigma,
    "abort": abort,
}
block_dur = (20, 100)
probs = [[0.2, 0.8], [0.8, 0.2]]

# Create and wrap the environment
task = "PerceptualDecisionMaking-v0"

env = gym.make(task, **kwargs)
env = ReactionTime(env, end_on_stimulus=True)
env = PassReward(env)
env = PassAction(env)
env = SideBias(env, probs=probs, block_dur=block_dur)

### Untrained Environment Visualization


In [None]:
# Visualize example trials
fig = plot_env(
    env,
    name='IBL',
    ob_traces=[
        'Fixation',
        'Stim 1',
        'Stim 2',
        'PassReward',
        'PassAction'
    ],
    num_trials=15,
    plot_performance=True,
    fig_kwargs={'figsize': (9, 5)},
)
for text_obj in fig.findobj(match=plt.Text):
    text_obj.set_fontsize(11)
fig.savefig("untrained_env.pdf", dpi=300, bbox_inches='tight')

# Evaluate performance of the environment before training
eval_monitor = Monitor(
    env
)
print("\nEvaluating random policy performance...")
metrics = eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS)
print(f"\nRandom policy metrics ({EVAL_TRIALS:,} trials):")
print(f"Mean performance: {metrics['mean_performance']:.4f}")
print(f"Mean reward: {metrics['mean_reward']:.4f}")

### RL Training


In [None]:
# Constants
TRAIN_TRIALS = 10000
EVAL_TRIALS = 1000
NUM_MODELS = 10

# Get average timesteps per trial once
avg_timesteps = 7 # observed
total_timesteps = TRAIN_TRIALS * avg_timesteps
trials_per_figure = 10
steps_per_figure = trials_per_figure * avg_timesteps
trials_per_batch = 64
n_steps = avg_timesteps * trials_per_batch
batch_size = 32

# Policy configuration
policy_kwargs = {
    "lstm_hidden_size": 64,
    "n_lstm_layers": 1,
    "shared_lstm": True,
    "enable_critic_lstm": False,
}

for i in range(NUM_MODELS):
    print(f"\n=== Training model {i + 1}/{NUM_MODELS} ===")

    # Set seed
    seed = i
    env.seed(seed)
    np.random.seed(seed)


    # Create new Monitor instance with separate directory
    train_monitor = Monitor(
        env,
        trigger="trial",
        interval=1000,
        plot_create=True,
        plot_steps=steps_per_figure,
        verbose=True,
    )

    # Vectorize the monitored environment
    env_vec = DummyVecEnv([lambda: train_monitor])

    # Create model
    model = RecurrentPPO(
        "MlpLstmPolicy",
        env_vec,
        learning_rate=5e-4,
        n_steps=n_steps,
        batch_size=batch_size,
        ent_coef=0.01,
        policy_kwargs=policy_kwargs,
        seed=seed,
        verbose=1,
    )

    # Train model
    model.learn(total_timesteps=total_timesteps, log_interval=total_timesteps // 10)

    # Save model in monitor directory with custom filename
    model_path = f"rl_model_{i+1}.zip"
    model.save(train_monitor.save_dir / model_path)

    # Clean up environment
    env_vec.close()

## RL Evaluation


### Skip Training and Load the Latest Saved RL Model


In [None]:
# ======================================================================
# ONLY TO AVOID RERUNNING THE TRAINING

# Constants
TRAIN_TRIALS = 10000
EVAL_TRIALS = 1000
NUM_MODELS = 10

# Get average timesteps per trial once
avg_timesteps = 7 # observed
total_timesteps = TRAIN_TRIALS * avg_timesteps
trials_per_figure = 10
steps_per_figure = trials_per_figure * avg_timesteps
trials_per_batch = 64
n_steps = avg_timesteps * trials_per_batch
batch_size = 32

# Policy configuration
policy_kwargs = {
    "lstm_hidden_size": 64,
    "n_lstm_layers": 1,
    "shared_lstm": True,
    "enable_critic_lstm": False,
}

train_monitor = Monitor(
    env,
    trigger="trial",
    interval=1000,
    plot_create=True,
    plot_steps=steps_per_figure,
    verbose=True,
)
env_vec = DummyVecEnv([lambda: train_monitor])
train_monitor.save_dir = Path("/Users/giuliacrocioni/Desktop/docs/eScience/projects/ANNUBeS/paper/figure_2/runs/PerceptualDecisionMaking/2025-06-05_14-31-08")
model_path = "rl_model_1.zip"
# ======================================================================

### Visualize Trained RL Agent


In [None]:
loaded_model = RecurrentPPO.load(train_monitor.save_dir / model_path)

# Plot example trials with trained agent
fig = plot_env(
    env_vec,
    name='IBL (trained)',
    ob_traces=[
        'Fixation',
        'Stim 1',
        'Stim 2',
        'PassReward',
        'PassAction'
    ],
    num_trials=5,
    model=loaded_model,
    plot_performance=True,
    fig_kwargs={'figsize': (9, 5)}
)
for text_obj in fig.findobj(match=plt.Text):
    text_obj.set_fontsize(11)
fig.savefig("trained_env.pdf", dpi=300, bbox_inches='tight')

### Evaluate Trained RL Model


In [None]:
# Evaluate performance of the last trained model
print("\nEvaluating trained model performance...")
rl_trained_metrics = eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS, model=loaded_model)
print(f"\nTrained model metrics ({EVAL_TRIALS:,} trials):")
print(f"Mean performance: {rl_trained_metrics['mean_performance']:.4f}")
print(f"Mean reward: {rl_trained_metrics['mean_reward']:.4f}")

fig = train_monitor.plot_training_history(figsize=(6, 4), plot_performance=False)

ax = fig.axes[0]
ax.set_title("")
fig._suptitle.remove()
for text_obj in fig.findobj(match=plt.Text):
    text_obj.set_fontsize(11)

fig.savefig("training_history.pdf", dpi=300, bbox_inches='tight')

### Training History for All RL Models


In [None]:
# Plot training history for ALL models
data_base_dir = train_monitor.config.local_dir / train_monitor.config.env.name
fig, ax = plt.subplots(figsize=(6, 4))

all_rewards = []
all_performances = []
all_indices = []

for data_dir in sorted(data_base_dir.glob("*")):
    if not data_dir.is_dir():
        continue
    data_files = list(data_dir.glob("*.npz"))
    if not data_files:
        continue
    file_indices, cum_rewards, performances = compute_folder_metrics(data_dir)
    ax.plot(file_indices, performances, color="lightblue", linewidth=1, alpha=0.5)
    all_performances.append(performances)
    ax.plot(file_indices, cum_rewards, color="lightcoral", linewidth=1, alpha=0.5)
    all_rewards.append(cum_rewards)
    all_indices.append(file_indices)

# Plot average curves
avg_rewards = np.mean(np.array(all_rewards), axis=0)
avg_performances = np.mean(np.array(all_performances), axis=0)
common_indices = all_indices[0]

ax.plot(common_indices, avg_performances, "o-", color="blue", label="Average Accuracy", linewidth=2)
ax.plot(common_indices, avg_rewards, "o-", color="red", label="Average Reward", linewidth=2)

ax.set_xlabel("Trials")
ax.set_ylabel("Accuracy / Reward")
ax.set_ylim(-0.05, 1)
ax.legend(loc="center right", bbox_to_anchor=(1, 0.2))
ax.grid(True, alpha=0.3)

fig.tight_layout()
fig.subplots_adjust(top=0.8)

for ax in fig.axes:
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
for text_obj in fig.findobj(match=plt.Text):
    text_obj.set_fontsize(11)

fig.savefig("training_history_all_models.pdf", dpi=300, bbox_inches='tight')

### Evaluate All RL Models and Store Accuracy


In [None]:
# Evaluate all trained RL models and store performance
models_base_dir = train_monitor.config.local_dir / train_monitor.config.env.name

rl_mean_performance = []

for model_dir in sorted(models_base_dir.glob("*")):
    if not model_dir.is_dir():
        continue
    model_files = list(model_dir.glob("*.zip"))
    if not model_files:
        continue
    model_path = model_files[0]
    print(f"Loading model from: {model_path}")

    model = RecurrentPPO.load(model_path)
    metrics = eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS, model=model)
    rl_mean_performance.append(metrics['mean_performance'])
    print(f"{model_path} accuracy: {metrics['mean_performance']:.4f}")

# Print average performance
mean_performance = np.mean(rl_mean_performance)
std_performance = np.std(rl_mean_performance)
print(f"\nAverage RL model performance: {mean_performance:.4f} ± {std_performance:.4f}")

### RL Models Psychometric Curves


In [None]:
data_list = []
models_base_dir = train_monitor.config.local_dir / train_monitor.config.env.name
for model_dir in sorted(models_base_dir.glob("*")):
    if model_dir.is_dir():
        model_files = list(model_dir.glob("*.zip"))
        if model_files:
            model = RecurrentPPO.load(model_files[0])
            eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS, model=model)
            data_list.append(eval_monitor.data_eval)

In [None]:
aggregate = collect_aggregate_pdm(data_list, probs)

# Plotting
fig, ax = plt.subplots(figsize=(3, 3))

for i, blk in enumerate(range(len(probs))):
    plot_aggregate_pdm(ax, blk, aggregate)

ax.set_ylabel("P(Right)", fontsize=12)

# Styling
plt.tight_layout()
for text_obj in fig.findobj(match=plt.Text):
    text_obj.set_fontsize(11)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.legend(loc="center right", bbox_to_anchor=(1, 0.2))

plt.savefig("psychometric_avg.pdf", dpi=300, bbox_inches='tight')
plt.show()