# Perceptual Decision Making Task

This environment implements a two-alternative forced choice perceptual decision-making task, where the agent must integrate noisy sensory evidence over time to make accurate decisions. The task is based on classic motion discrimination experiments ([Britten et al. 1992](https://www.jneurosci.org/content/12/12/4745)) and has been adapted for studying neural mechanisms of decision-making in computational models. The key features of the task are:

1. On each trial, a noisy stimulus appears on either the left or right side of the visual field with varying coherence levels (evidence strength).

2. Choices are represented as angles evenly spaced around a circle. With the default of 2 choices (`dim_ring=2`), this corresponds to:

   - Position 1: 0° (left choice)
   - Position 2: 180° (right choice)

3. The stimulus is presented as a cosine modulation with additive Gaussian noise, requiring the agent to integrate evidence over time to overcome noise and make accurate decisions.

4. The agent can respond at any time after stimulus onset.

5. The environment includes blocks where one side is more likely than the other, and augments observations with previous actions and rewards.

In this notebook, we will:

1. Train an agent on the task using reinforcement learning with [Stable-Baselines3](https://stable-baselines3.readthedocs.io/).
2. Analyze the agent's psychometric curves and compare performance across different coherence levels and block contexts.


# 0. Install Dependencies

To begin, install the `neurogym` package. This will automatically install all required dependencies, including Stable-Baselines3.

For detailed instructions on how to install `neurogym` within a conda environment or in editable mode, refer to the [installation instructions](https://github.com/neurogym/neurogym?tab=readme-ov-file#installation).


In [None]:
# Uncomment to install
# ! pip install neurogym[rl]

# 1. Training an Agent on the Perceptual Decision Making Task


## 1.1 Environment Setup and Initial Agent Behavior

Let's now create and explore the environment using the `PerceptualDecisionMaking` class from neurogym. We'll use the default configuration which includes:

- `dim_ring = 2`: Two possible choices (left/right) represented at 0° and 180°. Note that the ring architecture can support any number of choices, making it suitable for more complex decision-making scenarios.
- `timing = {'fixation': ~600, 'stimulus': 2000, 'delay': 0, 'decision': 100}` (in milliseconds).
- `rewards = {'abort': -0.1, 'correct': +1.0, 'fail': 0.0}`; abort is a penalty applied when the agent fails to fixate. The task allows the trial to be aborted if fixation does not occur, which is where the name of this penalty comes from.
- `sigma = 1.0`: Standard deviation of the noise added to the inputs.

In this notebook, several wrappers are used to modify the environment's behavior:

- `ReactionTime` wrapper allows the agent to respond at any time after stimulus onset.
- `SideBias` wrapper introduces blocks where one side is more likely than the other. It uses two key parameters:
  - `probs = [[0.2, 0.8], [0.8, 0.2]]`: Probability matrices defining the likelihood of each choice (only two choices in this example) in different blocks
  - `block_dur = (20, 100)`: Block duration randomly sampled between 20-100 trials, determining how long each bias condition persists
- `PassAction` and `PassReward` wrappers augment the observations with the previous step's action and reward, respectively, enabling the agent to use recent history in decision-making.


### 1.1.1 Import Libraries


In [None]:
import warnings
import matplotlib.pyplot as plt
import numpy as np
import os
import neurogym as ngym
import gymnasium as gym
from neurogym.wrappers.monitor import Monitor
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env import DummyVecEnv
from pathlib import Path
from neurogym.utils.ngym_random import TruncExp
from neurogym.wrappers.reaction_time import ReactionTime
from neurogym.wrappers.pass_action import PassAction
from neurogym.wrappers.pass_reward import PassReward
from neurogym.wrappers.side_bias import SideBias
from neurogym.utils.psychometric import plot_psychometric

# Suppress warnings
warnings.filterwarnings("ignore")

### 1.1.2 Environment Setup


In [None]:
# Environment parameters
# These settings are low to speed up testing; we recommend setting EVAL_TRIALS to at least 1000
EVAL_TRIALS = 100
dt = 100
dim_ring = 2  # Number of choices in the ring representation
abort = False  # Whether to allow aborting the trial if the agent does not fixate
rewards = {
    "abort": -0.1,
    "correct": +1.0,
    "fail": 0.0
}
timing = {
    "fixation": TruncExp(600, 400, 700),
    "stimulus": 2000,
    "delay": 0,
    "decision": 100,
}
sigma = 1.0 # Standard deviation of the Gaussian noise in the ring representation

kwargs = {
    "dt": dt,
    "dim_ring": dim_ring,
    "rewards": rewards,
    "timing": timing,
    "sigma": sigma,
    "abort": abort,
}
block_dur = (20, 100) # Extremes of the block duration in milliseconds
probs = [[0.2, 0.8], [0.8, 0.2]] # Probabilities of choosing left or right in the two blocks

# Create and wrap the environment
task = "PerceptualDecisionMaking-v0"
env = gym.make(task, **kwargs)
env = ReactionTime(env, end_on_stimulus=True)
env = PassReward(env)
env = PassAction(env)
env = SideBias(env, probs=probs, block_dur=block_dur)

# Print environment specifications
print("Trial timing (in milliseconds):")
print(env.timing)

print("\nObservation space structure:")
print(env.observation_space)

print("\nAction space structure:")
print(env.action_space)
print("Action mapping:")
print(env.action_space.name)

### 1.1.3 Random Agent Behavior

Let's now plot the behavior of a random agent on the task. The agent will randomly choose between the two options (left/right), and we will visualize its behavior over 5 trials. We will also plot the reward received by the agent at each time step, as well as the performance on each trial. Note that performance is only defined at the end of a trial: it is 1 if the agent made the correct choice, and 0 otherwise.

To keep track of the agent's behavior, we will use the `Monitor` wrapper, which monitors training by:

- Tracking and saving behavioral data (rewards, actions, observations) every `sv_per` steps.
- Generating visualization figures during training if` sv_fig=True`.
- Providing progress information if `verbose=True`.

Here, we’ll use the wrapper solely to compute the agent’s performance, but later it will help us assess learning and save intermediate results.


In [None]:
# Visualize example trials
fig = ngym.utils.plot_env(
    env,
    name='Perceptual Decision Making',
    ob_traces=[
        'Fixation',
        'Stim 1',
        'Stim 2',
        'PassReward', # Reward for the previous action
        'PassAction' # Action taken in the previous step
    ],
    num_trials=5,
)

# Evaluate performance of the environment before training
eval_monitor = Monitor(
    env
)
print("\nEvaluating random policy performance...")
metrics = eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS)
print(f"\nRandom policy metrics ({EVAL_TRIALS:,} trials):")
print(f"Mean performance: {metrics['mean_performance']:.4f}")
print(f"Mean reward: {metrics['mean_reward']:.4f}")

As we can see, the agent's behavior is entirely random. Through training, we expect the agent to improve by learning to respect the fixation period, and map signal peaks to the correct choices in the ring representation. Let’s move on to training the agent to see whether it can learn these key aspects of the task.


## 1.2 Training and Evaluating the Agent

We will now train the agent using Stable-Baselines3’s implementation of [PPO (Proximal Policy Optimization)](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html), a widely used reinforcement learning algorithm known for its stability and efficiency.

To support recurrent policies, we will use [RecurrentPPO](https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html#recurrent-ppo), which extends PPO with recurrent neural networks, specifically LSTMs.


### 1.2.1 Training the Agent


In [None]:
# Set the number of trials to train on
# These settings are low to speed up testing; we recommend setting TRAIN_TRIALS to at least 10000 and `interval` in Monitor to 1000
avg_timesteps = 7 # Observed
TRAIN_TRIALS = 100  # Choose the desired number of trials
total_timesteps = TRAIN_TRIALS * avg_timesteps
print(f"Training for {TRAIN_TRIALS:,} trials ≈ {total_timesteps:,} timesteps")

# Configure monitoring with trial-appropriate parameters
trials_per_figure = 10  # Show 10 trials in each figure
steps_per_figure = int(trials_per_figure * avg_timesteps)

train_monitor = Monitor(
    env,
    trigger="trial",              # Save based on completed trials
    interval=100,                # Save data every 100 trials
    plot_create=True,             # Save visualization figures
    plot_steps=steps_per_figure,  # Number of steps to visualize on the figure
    verbose=True,                 # Print stats when data is saved
)

# DummyVecEnv is Stable-Baselines3 wrapper that converts the environment
# into a vectorized form (required by PPO), allowing for parallel training of multiple environments
env_vec = DummyVecEnv([lambda: train_monitor])

# Create and train Recurrent PPO agent
# Set n_steps to be a multiple of your average trial length
trials_per_batch = 64
n_steps = int(avg_timesteps * trials_per_batch)  # Collect approximately 64 trials per update
batch_size = 32  # Small batch size for short episodes
policy_kwargs = {
    "lstm_hidden_size": 128,      # Small LSTM for short sequences
    "n_lstm_layers": 2,           # Single layer is sufficient
    "shared_lstm": True,          # Share LSTM to reduce parameters
    "enable_critic_lstm": False,  # Disable separate LSTM for critic when sharing
}
rl_model = RecurrentPPO(
    "MlpLstmPolicy",
    env_vec,
    learning_rate=3e-4,       # Learning rate for the optimizer
    n_steps=n_steps,          # Align with multiple complete episodes
    batch_size=32,            # Smaller batch size
    ent_coef=0.0,             # Entropy coefficient for exploration
    policy_kwargs=policy_kwargs,
    verbose=1
)

rl_model.learn(total_timesteps=total_timesteps, log_interval=int(total_timesteps/10))
env_vec.close()

### 1.2.2 Plot the Behavior of the Trained Agent


In [None]:
# Plot example trials with trained agent
fig = ngym.utils.plot_env(
    env_vec,
    name='Perceptual Decision Making (trained)',
    ob_traces=[
        'Fixation',
        'Stim 1',
        'Stim 2',
        'PassReward',
        'PassAction'
    ],
    num_trials=5,
    model=rl_model,
)

After training, we visualize the agent's behavior on a few example trials. In contrast to the random agent, we should now see:

- Consistent fixation maintenance during the fixation period
- Choices that correlate with the evidence strength of the stimulus
- Performance significantly above chance level (0.5), reflecting successful context-dependent decisions

The plot shows the trained agent's behavior across 5 example trials, allowing us to visualize how well it has learned to make appropriate choices using the ring representation.


### 1.2.3 Evaluate the Agent's Performance


In [None]:
# Evaluate performance of the last trained model
print("\nEvaluating trained model performance...")
rl_trained_metrics = eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS, model=rl_model)
print(f"\nTrained model metrics ({EVAL_TRIALS:,} trials):")
print(f"Mean performance: {rl_trained_metrics['mean_performance']:.4f}")
print(f"Mean reward: {rl_trained_metrics['mean_reward']:.4f}")

fig = train_monitor.plot_training_history(figsize=(6, 4), plot_performance=False)

### 1.2.4 Plot the Agent's Psychometric Curves


In [None]:
if not os.getenv("GITHUB_ACTIONS"):
    # Evaluate policy and extract data
    eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS, model=rl_model)
    data = eval_monitor.data_eval

    # Extract trial-level fields
    trials = data['trial']
    coh = np.array([t['coh'] for t in trials])
    block = np.array([np.array_equal(t['probs'], np.array(probs[1])) for t in trials]).astype(int) # block 1 is 0, block 2 is 1

    # Filter out trials where action is 0 (no action taken)
    actions_only_mask = data['action'] != 0
    coh = coh[actions_only_mask]
    block = block[actions_only_mask]
    data_action = data['action'][actions_only_mask]
    data_gt = data['gt'][actions_only_mask]

    # Convert actions and ground truth to binary (0 = left, 1 = right)
    ch = (data_action == 2).astype(int)
    gt = (data_gt == 2).astype(int)

    # Plotting setup
    fig, ax = plt.subplots(figsize=(3, 3))

    # Plot psychometric curves for each block
    for i, blk in enumerate(range(len(probs))):
        # Filter trials matching current block
        mask = block == blk
        ev = coh[mask]
        ch_m = ch[mask]
        ref = gt[mask]

        # Signed evidence: negative if correct answer is left
        sig_ev = np.where(ref == 0, -ev, ev)

        plot_psychometric(sig_ev, ch_m, ax, legend=f'Block {blk+1}')
