# Collective Operations

The `brainstate.nn._collective_ops` module provides helpers for managing *all* modules inside a model. These functions make it easy to initialise, reset, batch, and restore stateful objects without manually traversing the module graph. This notebook introduces the core APIs with practical examples.


## Prerequisites

- Familiarity with `brainstate.nn` modules and states
- `brainunit` installed (required by the BrainState package)
- Basic understanding of JAX and `vmap`


In [62]:
import brainstate
import jax.numpy as jnp

## Overview of the API

`brainstate.nn._collective_ops` exposes several utilities:

- `call_order` — decorator that fixes the execution order of methods
- `call_all_fns` / `vmap_call_all_fns` — call the same method on each node in a model
- `init_all_states` / `vmap_init_all_states` — initialise state variables everywhere
- `reset_all_states` / `vmap_reset_all_states` — reset existing states
- `assign_state_values` — restore state values from dictionaries keyed by absolute paths

We'll examine each group below.


## Ordering Calls with `call_order`

By default `call_all_fns` respects the order that nodes appear in the graph, but complex modules may need explicit ordering. The `call_order` decorator attaches a `call_order` attribute to any method; lower levels run first.


In [63]:
class EncoderDecoder(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = brainstate.nn.Linear((16,), (32,))
        self.decoder = brainstate.nn.Linear((32,), (16,))

    @brainstate.nn.call_order(0)
    def init_state(self):
        self.encoder.init_state()
        self.decoder.init_state()

    @brainstate.nn.call_order(1)
    def reset_state(self):
        self.encoder.reset_state()
        self.decoder.reset_state()


Even though `EncoderDecoder` simply forwards the calls, the decorator ensures that collective utilities honour the order when visiting child modules.


## Initialising Every Module

The simplest helper is `init_all_states`. It walks the module graph and calls `init_state` on each node. You can pass keyword arguments and exclude specific nodes when necessary.


In [64]:
model = brainstate.nn.Sequential(
    brainstate.nn.Linear((10,), (32,)),
    brainstate.nn.GELU(),
    brainstate.nn.Dropout(prob=0.1)
)

# Initialise the entire stack at once.
brainstate.nn.init_all_states(model, batch_size=4)

# Exclude stateless nodes via a filter (here: Dropout layer).
brainstate.nn.init_all_states(model, node_to_exclude=brainstate.nn.Dropout)

# Because the function returns the target, you can chain it during construction.
model = brainstate.nn.init_all_states(model)


## Resetting State Between Sequences

For recurrent models you often initialise once and then reset after processing a sequence. `reset_all_states` automates the reset pass across the entire module.


In [65]:
rnn = brainstate.nn.ValinaRNNCell(num_in=8, num_out=16)
brainstate.nn.init_all_states(rnn, batch_size=2)

# ... run some inference / training ...

# Reset hidden states before the next sequence.
brainstate.nn.reset_all_states(rnn)


ValinaRNNCell(
  in_size=(8,),
  out_size=(16,),
  num_out=16,
  num_in=8,
  state_initializer=ZeroInit(
    unit=Unit(10.0^0)
  ),
  activation=<function relu at 0x000001863944C360>,
  W=Linear(
    in_size=(24,),
    out_size=(16,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[16]),
        'weight': ShapedArray(float32[24,16])
      }
    )
  ),
  h=HiddenState(
    value=ShapedArray(float32[16])
  )
)

You can exclude nodes or pass additional arguments just like `init_all_states`. The decorator-driven order still applies, so you can reset buffers before hidden states if needed.


## Batched Initialisation with `vmap_*`

To create multiple independent instances of a model (ensembles or Monte-Carlo batches), use the vectorised variants. They insert a leading axis and manage separate random keys for each copy.


In [66]:
policy = brainstate.nn.Sequential(
    brainstate.nn.Linear((4,), (64,)),
    brainstate.nn.GELU(),
    brainstate.nn.Linear((64,), (2,))
)

# Create 8 independent versions of the policy.
brainstate.nn.vmap_init_all_states(policy, axis_size=8)

# Parameters gain an extra axis on the leading dimension.
weights = policy.layers[0].weight.value
print('Weight shape with batching:', weights['weight'].shape)


Weight shape with batching: (4, 64)


In [67]:
# When finished with a rollout, reset all batched states at once.
brainstate.nn.vmap_reset_all_states(policy, axis_size=8)


Sequential(
  in_size=(4,),
  out_size=(2,),
  layers=[
    Linear(
      in_size=(4,),
      out_size=(64,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[64]),
          'weight': ShapedArray(float32[4,64])
        }
      )
    ),
    GELU(approximate=False),
    Linear(
      in_size=(64,),
      out_size=(2,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[2]),
          'weight': ShapedArray(float32[64,2])
        }
      )
    )
  ]
)

If certain states should stay shared (for example statistics buffers), pass a `state_to_exclude` filter to `vmap_init_all_states`. Excluded states retain their original shape across the batch.


## Calling Arbitrary Methods Collectively

`call_all_fns` is the primitive behind the init/reset helpers. You can dispatch *any* method, provided that each child module implements it.


In [68]:
class LoggingLayer(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linear = brainstate.nn.Linear((size,), (size,))
        self.logged = []

    def init_state(self):
        self.linear.init_state()

    def log_stats(self):
        weight = self.linear.weight.value['weight']
        self.logged.append(jnp.mean(weight))

net = brainstate.nn.Sequential(
    LoggingLayer(size=8),
    LoggingLayer(size=8)
)

brainstate.nn.init_all_states(net)
for layer in net.layers:
    layer.log_stats()

stats = [layer.logged for layer in net.layers]
print('Logged means per layer:', stats)


Logged means per layer: [[Array(0.0521806, dtype=float32)], [Array(0.03177379, dtype=float32)]]


Use `vmap_call_all_fns` to repeat the same method across `axis_size` independent instances. It shares the interface and filter options.


## Restoring States with `assign_state_values`

Serialisation often involves mapping absolute state names back to objects. The `assign_state_values` helper performs the updates and returns any mismatched keys.


In [69]:
autoencoder = brainstate.nn.Sequential(
    brainstate.nn.Linear((16,), (8,)),
    brainstate.nn.ReLU(),
    brainstate.nn.Linear((8,), (16,))
)
brainstate.nn.init_all_states(autoencoder)

# Save values in a dict keyed by absolute state paths.
state_snapshot = {}
for path, state in autoencoder.states().items():
    if isinstance(state.value, dict):
        for key, value in state.value.items():
            new_path = path + (key,)
            state_snapshot[new_path] = value
    else:
        state_snapshot[path] = state.value

# ... modify weights or states ...

unexpected, missing = brainstate.nn.assign_state_values(autoencoder, state_snapshot)
print('Unexpected keys:', unexpected)
print('Missing keys:', missing)


Unexpected keys: [('layers', 0, 'weight', 'bias'), ('layers', 0, 'weight', 'weight'), ('layers', 2, 'weight', 'bias'), ('layers', 2, 'weight', 'weight')]
Missing keys: [('layers', 0, 'weight'), ('layers', 2, 'weight')]


## Putting It All Together

The snippet below demonstrates a typical lifecycle for a batched recurrent network: initialise, perform computation, reset, and restore weights.


In [70]:
rnn = brainstate.nn.ValinaRNNCell(num_in=4, num_out=8)
brainstate.nn.vmap_init_all_states(rnn,axis_size=4)

# Save a snapshot of initial states.
snapshot = {}
for path, state in rnn.states().items():
    if isinstance(state.value, dict):
        for key, value in state.value.items():
            new_path = path + (key,)
            snapshot[new_path] = value
    else:
        snapshot[path] = state.value

# Simulate a rollout.
inputs = brainstate.random.randn(12, 4, 4)
for t in range(inputs.shape[0]):
    output = rnn(inputs[t])

print("重置状态...")
brainstate.nn.vmap_reset_all_states(rnn, axis_size=4)
# Reset before the next episode.
unexpected, missing = brainstate.nn.assign_state_values(rnn, snapshot)
# brainstate.nn.vmap_reset_all_states(rnn)

# Restore parameters and hidden states.
brainstate.nn.assign_state_values(rnn, snapshot)


重置状态...


([('W', 'weight', 'bias'), ('W', 'weight', 'weight')], [('W', 'weight')])

## Best Practices

- Always call `init_all_states` once after constructing a module.
- Decorate stateful methods with `call_order` when their interaction matters.
- Use filters (`node_to_exclude`, `state_to_exclude`) to fine-tune traversal.
- Inspect the return values from `assign_state_values` to catch mismatched checkpoints.
- Employ the vmapped helpers for ensembles but remember the added leading axis.


## Further Reading

- [Module Basics](01_module_basics.ipynb)
- [Recurrent Networks](04_recurrent_networks.ipynb)
- API reference: `brainstate.nn._collective_ops`
