## Monitoring and visualising model activation traces

In this notebook, we will use RL to train a network on a simple task, and we will visualise the activation traces of the neurons in the network.

### Task

We will use the [AnnubesEnv](https://github.com/neurogym/neurogym/blob/dev/neurogym/envs/annubes.py) environment for this demo. We set the duration of each trial period (fixation, stimulus, decision) and wrap the environment with a Monitor wrapper, which can be used to keep track of activation traces. After training, we can evaluate the agent and visualise the recorded traces. 

In [1]:
# Ignore warnings.
import warnings
warnings.filterwarnings("ignore")

import numpy as np

from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3 import A2C  # ACER, PPO2

import neurogym as ngym
from neurogym.envs.annubes import AnnubesEnv
from neurogym.wrappers.monitor import Monitor

# Time step duration in ms
dt = 10

env = AnnubesEnv(dt=dt)

# check the custom environment and output additional warnings (if any)
check_env(env)

# check the environment with a random agent
obs, info = env.reset()
n_steps = 10
for _ in range(n_steps):
    # random action
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated:
        obs, info = env.reset()

print(env.timing)
print("----------------")
print(env.observation_space)
print(env.observation_space.name)
print("----------------")
print(env.action_space)
print(env.action_space.name)

[35mNeurogym[0m | [36m2025-02-07@14:43:15[0m | [1mLogger configured.[0m
{'fixation': 500, 'stimulus': 1000, 'iti': 0}
----------------
Box(0.0, 1.0, (4,), float32)
{'fixation': 0, 'start': 1, 'v': 2, 'a': 3}
----------------
Discrete(2)
{'fixation': 0, 'choice': [1]}


Make sure that the time step is correctly registered

In [2]:
env.dt

10

### Train the agent

The agent can be trained on different environments. calls both the actor and the critic networks, so if the activations of the value net are needed, now is the opportunity to record them.

In [3]:
# Create an environment for evaluation
annubes_env = AnnubesEnv(dt=dt)

# Create a monitor
env = Monitor(env, name = f"NeuroGym Monitor | {annubes_env.__class__.__qualname__}")

# Create the agent that will be trained in this environment
model = A2C(ActorCriticPolicy, env)

# Set the monitoring phases
phases = {ngym.MonitorPhase.Evaluation, ngym.MonitorPhase.Training}

Register the action net and its layers with the evaluation monitor. We are not going to register the value net because the A2C model does not call it during inference.

In [4]:
# Register networks with the monitor
act_net = model.policy.action_net
net_monitor = env.add_network(act_net, phases, "Action net")
for layer in act_net.modules():
    net_monitor.add_layer(layer, [ngym.NetParam.Activation], "Linear layer")

In [12]:
# Set the monitoring phase.
env.set_phase(ngym.MonitorPhase.Training)

# Set the number timesteps for training
total_timesteps = 3000

# Train the agent.
model.learn(total_timesteps=total_timesteps)

[35mNeurogym[0m | [36m2025-02-07@14:45:38[0m | [1mNeuroGym Monitor | AnnubesEnv | Trial:    30 | Time:     0 / (max     0, total 45330) | Avg. reward: 951.600[0m
[35mNeurogym[0m | [36m2025-02-07@14:45:39[0m | [1mNeuroGym Monitor | AnnubesEnv | Trial:    40 | Time:     0 / (max     0, total 60230) | Avg. reward: 991.000[0m


<stable_baselines3.a2c.a2c.A2C at 0x7beaaef3bad0>

Evaluate the agent

In [7]:
# Set the phase to evaluation
env.set_phase(ngym.MonitorPhase.Evaluation)

# Evaluate the policy
evaluate_policy(model, env, n_eval_episodes=10)

[35mNeurogym[0m | [36m2025-02-07@14:43:19[0m | [1mNeuroGym Monitor | AnnubesEnv | Trial:    20 | Time:     0 / (max     0, total 30430) | Avg. reward: 952.600[0m


(100.0, 0.0)

In [17]:
tr_data = env.networks['Action net'].layers['Linear layer'].monitors[ngym.NetParam.Activation].activations[ngym.MonitorPhase.Training]
ev_data = env.networks['Action net'].layers['Linear layer'].monitors[ngym.NetParam.Activation].activations[ngym.MonitorPhase.Evaluation]
# env.networks['Action net'].layers['Linear layer'].layer

In [18]:
[len(t) for t in tr_data], [len(t) for t in ev_data]

([0,
  178,
  179,
  179,
  179,
  178,
  179,
  179,
  179,
  179,
  178,
  179,
  179,
  179,
  76,
  178,
  179,
  179,
  179,
  178,
  179,
  179,
  179,
  179,
  178,
  179,
  179,
  179,
  179,
  178,
  179,
  179,
  179,
  179,
  178,
  25],
 [0, 149, 149, 149, 149, 149, 149, 149, 149, 149, 149, 0])

Plot everything that has been recorded.

In [8]:
env.plot()

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (14,) + inhomogeneous part.

Visualize the results

In [None]:
data = ngym.utils.plot_env(
    env, num_trials=10, ob_traces=["fixation", "start", "v", "a"], model=model
)

In [16]:
env.close()

We can now train a more complex agent (a recurrent policy) on the same environment.

In [None]:
model = RecurrentPPO("MlpPolicy", "CartPole-v0")