## Monitoring and visualising model activation traces

In this notebook, we will use RL to train a network on a simple task, and we will visualise the activation traces of the neurons in the network.

### Task

We will use the [AnnubesEnv](https://github.com/neurogym/neurogym/blob/dev/neurogym/envs/annubes.py) environment for this demo. We set the duration of each trial period (fixation, stimulus, decision) and wrap the environment with a Monitor wrapper, which can be used to keep track of activation traces. After training, we can evaluate the agent and visualise the recorded traces. 

In [2]:
# Ignore warnings.
import warnings

warnings.filterwarnings("ignore")

import numpy as np

from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy

import awkward as ak

from sb3_contrib.common.recurrent.policies import RecurrentActorCriticPolicy
from sb3_contrib import RecurrentPPO

import neurogym as ngym
from neurogym.envs.annubes import AnnubesEnv
from neurogym.wrappers.monitor import Monitor

[35mNeurogym[0m | [36m2025-03-18@13:11:23[0m | [1mLogger configured.[0m


In [3]:
# Time step duration in ms
dt = 10

env = AnnubesEnv(dt=dt)

# check the custom environment and output additional warnings (if any)
check_env(env)

# check the environment with a random agent
obs, info = env.reset()
n_steps = 10
for _ in range(n_steps):
    # random action
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated:
        obs, info = env.reset()

print(env.timing)
print("----------------")
print(env.observation_space)
print(env.observation_space.name)
print("----------------")
print(env.action_space)
print(env.action_space.name)

{'fixation': 500, 'stimulus': 1000, 'iti': 0}
----------------
Box(0.0, 1.0, (4,), float32)
{'fixation': 0, 'start': 1, 'v': 2, 'a': 3}
----------------
Discrete(2)
{'fixation': 0, 'choice': [1]}


Make sure that the time step is correctly registered

In [4]:
env.dt

10

### Train the agent

The agent can be trained on different environments. calls both the actor and the critic networks, so if the activations of the value net are needed, now is the opportunity to record them.

In [5]:
# Create an environment for evaluation
annubes_env = AnnubesEnv(dt=dt)

# Create a monitor
env = Monitor(env, name=f"NeuroGym Monitor | {annubes_env.__class__.__qualname__}")

# Create the agent that will be trained in this environment
model = RecurrentPPO(
    RecurrentActorCriticPolicy,
    env,
    policy_kwargs={'lstm_hidden_size': 64}
)

# Set the monitoring phases
phases = {ngym.MonitorPhase.Evaluation}

In [6]:
model.policy

RecurrentActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (pi_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (vf_features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (policy_net): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=2, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
  (lstm_actor): LSTM(4, 64)
  (lstm_critic): LSTM(4, 64)
)

Register the action net and its layers with the evaluation monitor. We are not going to register the value net because the A2C model does not call it during inference.

In [7]:
# Register the action network with the monitor
act_net = model.policy.action_net
act_net_monitor = env.add_network(act_net, phases, "Action net")
for layer in act_net.modules():
    act_net_monitor.add_layer(layer, [ngym.NetParam.Activation], "Linear layer")

# Register the LSTM actor with the monitor
lstm_actor = model.policy.lstm_actor
lstm_actor_monitor = env.add_network(lstm_actor, phases, "LSTM actor")
for layer in lstm_actor.modules():
    lstm_actor_monitor.add_layer(layer, [ngym.NetParam.Activation], "LSTM layer")

In [8]:
# # Set the monitoring phase.
# env.set_phase(ngym.MonitorPhase.Training)

# # Set the number timesteps for training
# total_timesteps = 10000

# # Train the agent.
# model.learn(total_timesteps=total_timesteps)

In [9]:
lstm_layer_activations = env.networks['LSTM actor'].layer_monitors['LSTM layer'].param_monitors[ngym.NetParam.Activation]

Evaluate the agent

In [10]:
lstm_states = None
episode_starts = np.ones((1,), dtype=bool)

# Set the phase to evaluation
env.set_phase(ngym.MonitorPhase.Evaluation)

# Evaluate the policy
evaluate_policy(model, env, n_eval_episodes=10)

[35mNeurogym[0m | [36m2025-03-18@13:11:27[0m | [1mNeuroGym Monitor | AnnubesEnv | Trial:    10 | Time:     0 / (max     0, total 14900) | Avg. reward: 900.000[0m


(90.0, 30.0)

In [11]:
[len(tr) for tr in lstm_layer_activations.history[ngym.MonitorPhase.Evaluation]]

[149, 149, 149, 149, 149, 149, 149, 149, 149, 149, 0]

In [12]:
len(lstm_layer_activations.history[ngym.MonitorPhase.Evaluation][1][1])

64

In [13]:
traces= env.get_traces("Action net")

In [14]:
eval_hist  = traces['Action net']['Linear layer'][ngym.NetParam.Activation][ngym.MonitorPhase.Evaluation]

In [15]:
eval_arr = ak.Array([item for item in eval_hist if len(item) > 0])

In [16]:
eval_arr = np.array([e for e in eval_arr])

In [17]:
eval_arr.mean(axis = 0).shape

(149, 2)

In [18]:
env.plot_notebook()

BokehModel(combine_events=True, render_bundle={'docs_json': {'6eaaf3c1-f2eb-458b-8b70-25253d2f4c39': {'version…

In [19]:
env.plot_browser()

Launching server at http://localhost:46109


Close the environment

In [20]:
env.close()

In [21]:
apar = ngym.wrappers.monitors.parameters.activation.ActivationParams(bg_colour="#ff00ff")

AttributeError: module 'neurogym.wrappers.monitors.parameters.activation' has no attribute 'ActivationParams'

In [None]:
import bokeh.model
import bokeh.models
import bokeh.plotting
import panel as pn
import bokeh

fig = bokeh.plotting.figure()

apar.param.bg_colour

pn.panel(
    pn.Row(
        pn.Column(pn.widgets.Button(name="Mute", width=200), apar.param),
        pn.pane.Bokeh(fig),
    )
)

In [None]:
pn.serve(pn.Column(pn.widgets.Button(name="Mute", width=200), apar.param))