# Multi-Task Contextual Decision-Making (Yang19 Cognitive Tasks)

This environment implements a suite of 20 context-dependent cognitive tasks inspired by the work of [Yang et al. (2019)](https://pubmed.ncbi.nlm.nih.gov/30643294/#:~:text=The%20brain%20has%20the%20ability,task%20representations%2C%20a%20critical%20feature). In these tasks, the agent must flexibly switch between different context rules on each trial, similar in spirit to the classic [context-dependent decision-making](https://neurogym.github.io/neurogym/latest/api/envs/#neurogym.envs.native.contextdecisionmaking.ContextDecisionMaking) experiment of [Mante et al. (2013)](https://www.nature.com/articles/nature12742). The key features of this multi-task environment are:

1. **Two-Choice Outputs** (`dim_ring=2`): All tasks share a common action space represented as positions on a ring. With `dim_ring = 2`, there are two choice outputs located at 0° and 180° on this ring (interpreted as "left" and "right" choices, respectively). The agent's goal on each trial is to select the correct one of these two outputs.
2. **Stimulus Inputs with Modality-Specific Evidence**: Each task provides sensory evidence (stimuli) that may come from one or two modalities. The stimuli are often encoded via a cosine-tuned input bump centered on one of the choice positions, with its amplitude representing the evidence strength. For example, a stimulus favoring the left choice might produce an activity peak at 0° on the input ring.
3. **Randomized Ground-Truth Choices**: The correct choice (ground truth) is randomized on each trial (for two choices: left or right). This means on some trials left is the correct answer, and on others right is correct, preventing the agent from biasing toward one action.
4. **Variable Difficulty Evidence**: The strength of the stimulus evidence (e.g., coherence level in a motion stimulus) is also randomly sampled each trial. Sometimes the evidence strongly favors one choice, and other times it is weak or ambiguous.

This collection of 20 tasks exercises various cognitive skills, including working memory (e.g., remembering stimuli across a delay in matching tasks), perceptual decision-making (integrating noisy sensory evidence over time), rule-based categorization, and inhibitory control (suppressing or initiating actions under certain rules). All tasks are implemented in a consistent format so that a single neural network agent can learn them together via supervised training.

In this notebook, we will:

1. Train an agent on the 20-task suite using supervised learning. We will generate trial data from the environment and train a recurrent neural network to predict the correct choices.
2. Evaluate the agent's performance and behavior across tasks. After training, we will examine how well the agent learned each of the 20 tasks and whether it can flexibly switch contexts from trial to trial.


# 0. Install Dependencies

To begin, install the `neurogym` package. This will automatically install all required dependencies, including Stable-Baselines3.

For detailed instructions on how to install `neurogym` within a conda environment or in editable mode, refer to the [installation instructions](https://github.com/neurogym/neurogym?tab=readme-ov-file#installation).


In [None]:
# Uncomment to install
# ! pip install neurogym[rl]

# 1. Training an Agent on Yang19 Cognitive Tasks


## 1.1 Environment Setup and Initial Agent Behavior

Let's begin by creating and exploring the environment using the `yang19` collection of tasks from NeuroGym. We’ll use the default configuration for all parameters except `dim_ring`, which we set to 2 in order to represent two alternative choices (left/right) arranged on a ring.

To get a sense of the environment dynamics, we’ll visualize a couple of representative tasks (`dm1` and `multidm`) to better understand their structure and how the agent is expected to interact with them.


### 1.1.1 Import Libraries


In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch

import neurogym as ngym
from neurogym.wrappers import ScheduleEnvs, Monitor
from neurogym.utils import plot_env, RandomSchedule
from neurogym.envs import get_collection

from IPython.display import clear_output
warnings.filterwarnings("ignore")
clear_output()

### 1.1.2 Environment Setup


In [None]:
# Environment parameters
# This setting is low to speed up testing, we recommend setting it to at least 1000
EVAL_TRIALS = 100
dt = 100
dim_ring = 2  # Number of choices in the ring representation

rewards = {
    "abort": -0.1,
    "correct": +1.0,
    "fail": 0.0
}
sigma = 1.0 # Standard deviation of the Gaussian noise in the ring representation

In [None]:
### dm1
timing = {
    "fixation": ("uniform", (200, 500)),
    "stimulus": ("choice", [200, 400, 600]),
    "decision": 200,
}
kwargs = {
    "dt": dt,
    "dim_ring": dim_ring,
    "rewards": rewards,
    "timing": timing,
    "sigma": sigma,
}
task = "dm1"
task = f"yang19.{task}-v0"
env_dm = ngym.make(task, **kwargs)

# Print environment specifications
print("Trial timing (in milliseconds):")
print(env_dm.timing)

print("\nObservation space structure:")
print(env_dm.observation_space)

print("\nAction space structure:")
print(env_dm.action_space)
print("Action mapping:")
print(env_dm.action_space.name)

obs, info = env_dm.reset()

stim1 = [f'Stim {i}, Mod 1' for i in range(1, kwargs['dim_ring'] + 1)]
stim2 = [f'Stim {i}, Mod 2' for i in range(1, kwargs['dim_ring'] + 1)]

# Visualize example trials
fig = plot_env(
    env_dm,
    name='DM1',
    ob_traces=['Fixation'] + stim1 + stim2,
    num_trials=5,
    plot_performance=True,
    fig_kwargs={'figsize': (9, 5)},
)

In [None]:
### multidm
timing = {
    "fixation": ("uniform", (200, 500)),
    "stimulus": 500,
    "decision": 200,
}
kwargs = {
    "dt": dt,
    "dim_ring": dim_ring,
    "rewards": rewards,
    "timing": timing,
    "sigma": sigma,
}
task = "multidm"
task = f"yang19.{task}-v0"
env_multidm = ngym.make(task, **kwargs)

# Print environment specifications
print("Trial timing (in milliseconds):")
print(env_multidm.timing)

print("\nObservation space structure:")
print(env_multidm.observation_space)

print("\nAction space structure:")
print(env_multidm.action_space)
print("Action mapping:")
print(env_multidm.action_space.name)

obs, info = env_multidm.reset()

stim1 = [f'Stim {i}, Mod 1' for i in range(1, kwargs['dim_ring'] + 1)]
stim2 = [f'Stim {i}, Mod 2' for i in range(1, kwargs['dim_ring'] + 1)]

# Visualize example trials
fig = plot_env(
    env_multidm,
    name='MultiDM',
    ob_traces=['Fixation'] + stim1 + stim2,
    num_trials=5,
    plot_performance=True,
    fig_kwargs={'figsize': (9, 5)},
)

We wrap the `yang19` environments using NeuroGym’s `ScheduleEnvs` wrapper, which allows us to interleave multiple tasks into a single training loop. Each cognitive task is instantiated as a separate environment, and a scheduling policy (`RandomSchedule`) determines which task is sampled on each trial. When `env_input=True`, a one-hot vector is appended to the observation to indicate the currently active task, allowing the agent to learn task-specific behavior within a unified architecture.


In [None]:
kwargs = {'dt': dt, 'dim_ring': dim_ring}
seq_len = 100

# Make supervised dataset
tasks = get_collection('yang19')
envs = [ngym.make(task, **kwargs) for task in tasks]
schedule = RandomSchedule(len(envs))
env = ScheduleEnvs(envs, schedule=schedule, env_input=True)

# Print environment specifications
print("Trial timing (in milliseconds):")
print(env.timing)

print("\nObservation space structure:")
print(env.observation_space)

print("\nAction space structure:")
print(env.action_space)
print("Action mapping:")
print(env.action_space.name)

### 1.1.3 Random Agent Behavior

Let's now plot the behavior of a random agent on the task. The agent will randomly choose between the two options (left/right, blue line), and we will visualize its behavior over 5 trials. We will also plot the reward received by the agent at each time step, as well as the performance on each trial. Note that performance is only defined at the end of a trial: it is 1 if the agent made the correct choice, and 0 otherwise.


In [None]:
obs, info = env.reset()

# Visualize example trials
fig = plot_env(
    env,
    name='Yang et al.',
    ob_traces=None,
    num_trials=5,
    plot_performance=True,
    fig_kwargs={'figsize': (9, 5)},
)

# Evaluate performance of random policy
eval_monitor = Monitor(env)
print("\nEvaluating random policy performance...")
metrics = eval_monitor.evaluate_policy(num_trials=EVAL_TRIALS)
print(f"\nRandom policy metrics ({EVAL_TRIALS:,} trials):")
print(f"Mean performance: {metrics['mean_performance']:.4f}")
print(f"Mean reward: {metrics['mean_reward']:.4f}")

# 2. Learning the Task as a Supervised Problem

We will now train the agent using supervised learning. NeuroGym provides functionality to generate a dataset directly from the environment, allowing us to sample batches of inputs and corresponding labels for training.


## 2.1 Converting the Environment to a Supervised Dataset


### 2.1.1 Dataset Setup


In [None]:
seq_len = 100
print(f"Using sequence length: {seq_len}")

# Make supervised dataset
batch_size = 32
print(f"Creating dataset with batch_size={batch_size}")
dataset = ngym.Dataset(env, batch_size=batch_size, seq_len=seq_len)

env = dataset.env

# Extract dimensions
ob_size = env.observation_space.shape[0]
act_size = env.action_space.n
# 20 observations for one-hot encoding of the 20 environments, 1 for fixation,
# 2 for each modality (2 modalities, 4 observations total), for a total of 25 observations
print(f"Observation size: {ob_size}")
print(f"Action size: {act_size}")

# Get a batch of data
inputs, target = dataset()
print(f"Input batch shape: {inputs.shape}")
print(f"Target batch shape: {target.shape}")

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

### 2.1.2 Model Setup


In [None]:
# Define the LSTM model
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Net, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=False)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden=None):
        lstm_out, hidden = self.lstm(x, hidden)
        output = self.fc(lstm_out)
        return output, hidden

    def init_hidden(self, batch_size, device):
        h0 = torch.zeros(1, batch_size, self.hidden_size).to(device)
        c0 = torch.zeros(1, batch_size, self.hidden_size).to(device)
        return (h0, c0)

# Create the model
hidden_size = 128
sl_model = Net(
    input_size=ob_size,
    hidden_size=hidden_size,
    output_size=act_size,
).to(device)

## 2.2 Training and Evaluating a Neural Network Model


### 2.2.1 Training the Model


In [None]:
# This setting is low to speed up testing, we recommend setting it to at least 1000
EPOCHS = 40

# This weighting deprioritizes class 0 while keeping the other 16 classes equally important,
# aligning with the reward distribution idea from the RL setting
class_weights = torch.tensor([0.05] + [1.0]*(act_size - 1)).to(device)
# Define the optimizer and loss function
optimizer = torch.optim.Adam(sl_model.parameters(), lr=0.01, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Training loop
loss_history = []

for i in range(EPOCHS):
    # Get a batch of data
    inputs, targets = dataset()

    # Convert to PyTorch tensors
    inputs = torch.from_numpy(inputs).float().to(device)
    targets = torch.from_numpy(targets).long().to(device)

    # Initialize hidden state
    hidden = sl_model.init_hidden(inputs.size(1), device)

    # Zero gradients
    optimizer.zero_grad()

    # Forward pass with hidden state tracking
    outputs, _ = sl_model(inputs, hidden)

    # Reshape for CrossEntropyLoss
    outputs_flat = outputs.reshape(-1, outputs.size(2))
    targets_flat = targets.reshape(-1)

    # Calculate loss
    # Weight the loss to account for class imbalance (very low weight to 0s, higher weights to 1s and 2s)
    loss = criterion(outputs_flat, targets_flat)

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    # print statistics
    loss_history.append(loss.item())
    if i % 100 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(i, EPOCHS, loss.item()))

print('Finished Training')

# Plot the loss curve
plt.figure(figsize=(8, 4))
plt.plot(loss_history)
plt.title('Training Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss (50-iteration moving average)')
plt.grid(True, alpha=0.3)
plt.show()

### 2.2.2 Evaluate the Model's Performance


In [None]:
verbose = False
# Evaluate performance of the trained model
sl_model.eval()
# Initialize storage for each environment
sl_mean_performance = {}
for e_i in env.envs:
    sl_mean_performance[e_i.spec.id] = []

# Evaluate for specified number of trials
print(f"Evaluating model performance across {EVAL_TRIALS} trials...\n")

for env_idx, e_i in enumerate(env.envs):
    data = {"action": [], "gt": [], "trial": []}
    total_correct = 0
    for trial_idx in range(EVAL_TRIALS):
        trial = e_i.new_trial()
        data["trial"].append(trial)
        ob, gt = e_i.ob, e_i.gt
        data["gt"].append(gt[-1])
        trial_length = ob.shape[0]

        # Add one-hot encoding for the environment
        env_one_hot = np.zeros((trial_length, len(env.envs)))
        env_one_hot[:, env_idx] = 1.0  # Set the current environment index to 1

        # Concatenate original observation with one-hot encoding
        ob_with_env = np.concatenate([ob, env_one_hot], axis=1)

        ob_with_env = ob_with_env[:, np.newaxis, :]  # Add batch dimension

        inputs = torch.from_numpy(ob_with_env).float().to(device)
        hidden = sl_model.init_hidden(1, device)

        with torch.no_grad():
            outputs, _ = sl_model(inputs, hidden)
            pred_actions = torch.argmax(outputs, dim=2)
            data["action"].append(pred_actions[-1, 0].cpu().numpy())

        decision_idx = trial_length - 1
        is_correct = (gt[decision_idx] == pred_actions[decision_idx, 0].cpu().numpy())
        total_correct += is_correct

    accuracy = total_correct / EVAL_TRIALS
    sl_mean_performance[e_i.spec.id].append(accuracy)
    for key in data:
        if key != "trial":
            data[key] = np.array(data[key])


# Print average performance
print("Average performance across all environments:")
for e_i in env.envs:
    mean_acc = np.mean(sl_mean_performance[e_i.spec.id])
    print(f"{e_i.spec.id}: {mean_acc:.4f}")