# Collective Operations for Neural Network Modules

When building complex neural networks, especially recurrent neural networks (RNNs) and deep hierarchical models, you often need to manage states across multiple modules. BrainState provides a set of powerful **collective operations** that allow you to efficiently initialize, reset, and manage states across an entire network hierarchy.

This tutorial covers:

1. **`call_order`**: A decorator to control execution order of methods
2. **`call_all_fns`**: Call any method across all modules in a network
3. **`init_all_states`**: Initialize states for all modules
4. **`reset_all_states`**: Reset states for all modules
5. **Vectorized operations**: `vmap_*` variants for creating batched ensembles

These operations are essential for:
- Managing recurrent neural network states
- Building ensemble models
- Creating complex hierarchical networks
- Implementing stateful computations

 Content Coverage:

1. Introduction - Overview of collective operations and their use cases
2. call_order Decorator (Section 1)
    - Basic usage and execution order control
    - Practical examples with ordered methods
3. Basic State Management (Section 2)
    - init_all_states - Initializing module states
    - reset_all_states - Resetting states between sequences
    - Practical examples with GRU cells
4. Nested Modules (Section 3)
    - Working with hierarchical networks
    - Stacked RNN example
    - Automatic traversal demonstration
5. Custom Methods with call_all_fns (Section 4)
    - Calling custom methods across modules
    - Passing arguments and keyword arguments
6. Vectorized Operations (Section 5)
    - vmap_init_all_states - Creating ensembles
    - vmap_reset_all_states - Resetting ensemble states
    - Practical ensemble learning examples
7. Practical Example (Section 6)
    - Complete sequence classification task
    - Training loop with proper state management
    - Synthetic data generation
8. Advanced Features (Section 7)
    - Selective operations with node_to_exclude
    - Using fn_if_not_exist parameter
    - Filter functions
9. Performance Tips (Section 8)
    - JIT compilation with state management
    - Batch processing multiple sequences
    - Performance benchmarking
10. Summary - Best practices and key takeaways




In [1]:
import jax
import jax.numpy as jnp
import brainstate

In [2]:
print(f"BrainState version: {brainstate.__version__}")
print(f"JAX version: {jax.__version__}")

BrainState version: 0.2.0
JAX version: 0.7.1


## 1. The `call_order` Decorator

The `call_order` decorator allows you to specify the execution order of methods when they are called collectively. This is particularly useful when you need certain initialization or reset operations to happen in a specific sequence.

### Basic Usage

Methods decorated with `@call_order(level)` are executed in ascending order of their level values. Methods without the decorator are executed first.

In [3]:
class OrderedModule(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.execution_log = []
    
    # This executes first (no decorator)
    def init_state(self):
        self.execution_log.append('init_state (no order)')
        self.state = brainstate.State(jnp.zeros(3))
    
    # This executes second (order 0)
    @brainstate.nn.call_order(0)
    def setup_connections(self):
        self.execution_log.append('setup_connections (order 0)')
    
    # This executes third (order 1)
    @brainstate.nn.call_order(1)
    def finalize(self):
        self.execution_log.append('finalize (order 1)')

module = OrderedModule()

# Call init_state on all modules
brainstate.nn.call_all_fns(module, 'init_state')
print("After calling init_state:")
print(module.execution_log)

# Call setup_connections
brainstate.nn.call_all_fns(module, 'setup_connections')
print("\nAfter calling setup_connections:")
print(module.execution_log)

# Call finalize
brainstate.nn.call_all_fns(module, 'finalize')
print("\nAfter calling finalize:")
print(module.execution_log)

After calling init_state:
['init_state (no order)']

After calling setup_connections:
['init_state (no order)', 'setup_connections (order 0)']

After calling finalize:
['init_state (no order)', 'setup_connections (order 0)', 'finalize (order 1)']


## 2. Basic State Management: `init_all_states` and `reset_all_states`

### Initializing States

The `init_all_states` function calls the `init_state` method on all modules in a network hierarchy.

In [None]:
# Create a simple RNN cell
gru = brainstate.nn.GRUCell(input_size=10, hidden_size=20)

# Initialize states for batch_size=5
brainstate.nn.init_all_states(gru, batch_size=5)

# Check the hidden state
print("Hidden state shape:", gru.h.value.shape)
print("Hidden state (initialized to zeros):")
print(gru.h.value)

### Resetting States

The `reset_all_states` function is crucial for RNNs. It resets the hidden states between sequences.

In [None]:
# Process a sequence
sequence_length = 10
input_data = jax.random.normal(jax.random.PRNGKey(0), (sequence_length, 5, 10))

print("Processing first sequence...")
for t in range(sequence_length):
    output = gru(input_data[t])
    if t == 0:
        print(f"  Step {t}, hidden state mean: {gru.h.value.mean():.6f}")

print(f"\nAfter sequence, hidden state mean: {gru.h.value.mean():.6f}")

# Reset states before processing next sequence
brainstate.nn.reset_all_states(gru, batch_size=5)
print(f"After reset, hidden state mean: {gru.h.value.mean():.6f}")

# Process second sequence with reset states
print("\nProcessing second sequence...")
for t in range(3):
    output = gru(input_data[t])
    if t == 0:
        print(f"  Step {t}, hidden state mean: {gru.h.value.mean():.6f}")

## 3. Working with Nested Modules

Collective operations automatically traverse the entire module hierarchy, making it easy to manage complex networks.

In [None]:
class StackedRNN(brainstate.nn.Module):
    """A stacked RNN with multiple GRU layers"""
    def __init__(self, input_size, hidden_sizes):
        super().__init__()
        self.layers = []
        
        # Create multiple GRU layers
        prev_size = input_size
        for i, hidden_size in enumerate(hidden_sizes):
            layer = brainstate.nn.GRUCell(prev_size, hidden_size)
            setattr(self, f'gru_{i}', layer)
            self.layers.append(layer)
            prev_size = hidden_size
    
    def __call__(self, x):
        # Forward pass through all layers
        for layer in self.layers:
            x = layer(x)
        return x

# Create a 3-layer stacked RNN
stacked_rnn = StackedRNN(input_size=10, hidden_sizes=[20, 15, 10])

# Initialize all states at once
brainstate.nn.init_all_states(stacked_rnn, batch_size=8)

# Check states in each layer
print("States after initialization:")
for i, layer in enumerate(stacked_rnn.layers):
    print(f"  Layer {i}: hidden state shape = {layer.h.value.shape}")

# Process some data
x = jax.random.normal(jax.random.PRNGKey(1), (8, 10))
output = stacked_rnn(x)
print(f"\nAfter forward pass:")
for i, layer in enumerate(stacked_rnn.layers):
    print(f"  Layer {i}: hidden state mean = {layer.h.value.mean():.6f}")

# Reset all states at once
brainstate.nn.reset_all_states(stacked_rnn, batch_size=8)
print(f"\nAfter reset:")
for i, layer in enumerate(stacked_rnn.layers):
    print(f"  Layer {i}: hidden state mean = {layer.h.value.mean():.6f}")

## 4. Custom Methods with `call_all_fns`

You can use `call_all_fns` to call any custom method across all modules.

In [None]:
class CustomModule(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.size = size
        self.counter = 0
    
    def init_state(self, batch_size=1):
        self.state = brainstate.State(jnp.zeros((batch_size, self.size)))
    
    def custom_operation(self, scale=1.0):
        """A custom method that scales the state"""
        self.counter += 1
        if hasattr(self, 'state'):
            self.state.value = self.state.value * scale
        print(f"  Module with size {self.size}: called {self.counter} times")

class Network(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.module1 = CustomModule(size=5)
        self.module2 = CustomModule(size=10)
        self.module3 = CustomModule(size=3)

network = Network()

# Initialize all states
brainstate.nn.init_all_states(network, batch_size=4)

# Call custom operation on all modules with a scaling factor
print("Calling custom_operation with scale=2.0:")
brainstate.nn.call_all_fns(network, 'custom_operation', kwargs={'scale': 2.0}, fn_if_not_exist='pass')

print("\nCalling again with scale=0.5:")
brainstate.nn.call_all_fns(network, 'custom_operation', kwargs={'scale': 0.5}, fn_if_not_exist='pass')

## 5. Vectorized Operations: Creating Ensembles

The `vmap_*` variants use JAX's vectorization to create multiple independent instances of a module, useful for ensemble learning or population-based methods.

### `vmap_init_all_states`: Creating Ensembles

In [None]:
# Create a single GRU cell
gru_ensemble = brainstate.nn.GRUCell(input_size=5, hidden_size=10)

# Create 20 independent instances (ensemble members)
# Each with batch_size=3
ensemble_size = 20
brainstate.nn.vmap_init_all_states(gru_ensemble, batch_size=3, axis_size=ensemble_size)

# Check the shape: (ensemble_size, batch_size, hidden_size)
print(f"Hidden state shape: {gru_ensemble.h.value.shape}")
print(f"Expected: ({ensemble_size}, 3, 10)")

# Each ensemble member has different random initialization
print(f"\nHidden state statistics across ensemble:")
print(f"  Mean: {gru_ensemble.h.value.mean():.6f}")
print(f"  Std:  {gru_ensemble.h.value.std():.6f}")

### Processing Data with Ensembles

In [None]:
# Create input data: (batch_size, input_size)
x = jax.random.normal(jax.random.PRNGKey(42), (3, 5))

# We need to vmap the forward pass to handle the ensemble dimension
@jax.vmap
def ensemble_forward(ensemble_member):
    # This function is called for each ensemble member
    return ensemble_member(x)

# Note: For simplicity, we'll just show the shape transformation
# In practice, you'd integrate this with the actual GRU forward pass
print(f"Input shape: {x.shape}")
print(f"After ensemble forward: each member processes the same input")
print(f"Output would have shape: ({ensemble_size}, 3, 10)")

### Resetting Ensemble States

In [None]:
# Reset all ensemble members at once
brainstate.nn.vmap_reset_all_states(gru_ensemble, batch_size=3, axis_size=ensemble_size)

print("After reset:")
print(f"  Hidden state shape: {gru_ensemble.h.value.shape}")
print(f"  All zeros? {jnp.allclose(gru_ensemble.h.value, 0.0)}")

## 6. Practical Example: Sequence Classification with State Management

Let's build a complete example that uses collective operations for a sequence classification task.

In [None]:
class SequenceClassifier(brainstate.nn.Module):
    """A simple RNN-based sequence classifier"""
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.rnn = brainstate.nn.GRUCell(input_size, hidden_size)
        self.output_layer = brainstate.nn.Linear(hidden_size, num_classes)
    
    def __call__(self, sequence):
        """
        Process a sequence and return classification logits.
        
        Args:
            sequence: (seq_len, batch_size, input_size)
        
        Returns:
            logits: (batch_size, num_classes)
        """
        # Process sequence step by step
        for t in range(sequence.shape[0]):
            hidden = self.rnn(sequence[t])
        
        # Use final hidden state for classification
        logits = self.output_layer(hidden)
        return logits

# Create model
model = SequenceClassifier(input_size=8, hidden_size=16, num_classes=3)

# Initialize states
batch_size = 4
brainstate.nn.init_all_states(model, batch_size=batch_size)

print("Model initialized!")
print(f"RNN hidden state shape: {model.rnn.h.value.shape}")

### Generate Synthetic Data

In [None]:
# Generate random sequences
num_sequences = 20
seq_length = 15
input_size = 8

key = jax.random.PRNGKey(123)
sequences = jax.random.normal(key, (num_sequences, seq_length, input_size))
labels = jax.random.randint(jax.random.PRNGKey(456), (num_sequences,), 0, 3)

print(f"Generated {num_sequences} sequences")
print(f"Sequence shape: {sequences.shape}")
print(f"Labels: {labels[:10]}...")

### Training Loop with State Reset

In [None]:
# Setup optimizer
optimizer = brainstate.optim.Adam(lr=1e-3)
optimizer.register_trainable_weights(model.states(brainstate.ParamState))

def loss_fn(predictions, targets):
    """Cross-entropy loss"""
    return -jnp.mean(jnp.sum(jax.nn.one_hot(targets, 3) * jax.nn.log_softmax(predictions), axis=-1))

@brainstate.compile.jit
def train_step(sequence, label):
    """Single training step"""
    def forward():
        logits = model(sequence)
        return loss_fn(logits, label)
    
    # Compute gradients
    grads = brainstate.augment.grad(forward, model.states(brainstate.ParamState))()
    
    # Update parameters
    optimizer.update(grads)
    
    return forward()

# Training loop
print("Training...")
for epoch in range(3):
    epoch_loss = 0.0
    
    for i in range(0, num_sequences, batch_size):
        # Get batch
        batch_seq = sequences[i:i+batch_size]
        batch_labels = labels[i:i+batch_size]
        
        # Transpose to (seq_len, batch_size, input_size)
        batch_seq = jnp.transpose(batch_seq, (1, 0, 2))
        
        # Reset states before each sequence
        brainstate.nn.reset_all_states(model, batch_size=batch_seq.shape[1])
        
        # Train step
        loss = train_step(batch_seq, batch_labels)
        epoch_loss += loss
    
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / (num_sequences // batch_size):.4f}")

## 7. Advanced: Selective Operations

### Excluding Specific Modules

Sometimes you may want to exclude certain modules from collective operations.

In [None]:
class MixedNetwork(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn1 = brainstate.nn.GRUCell(10, 20)
        self.rnn2 = brainstate.nn.GRUCell(20, 15)
        self.static_layer = brainstate.nn.Linear(15, 5)  # No state to initialize

network = MixedNetwork()

# Initialize only RNN cells (Linear doesn't have init_state)
# Use fn_if_not_exist='pass' to skip modules without the method
brainstate.nn.init_all_states(network, batch_size=8, fn_if_not_exist='pass')

print("States initialized:")
print(f"  RNN1 hidden state: {network.rnn1.h.value.shape}")
print(f"  RNN2 hidden state: {network.rnn2.h.value.shape}")

### Using Filters to Exclude Nodes

In [None]:
# Reset only the first RNN, exclude others
brainstate.nn.reset_all_states(
    network,
    batch_size=8,
    node_to_exclude=lambda node: node is network.rnn2
)

print("After selective reset:")
print(f"  RNN1 hidden state mean: {network.rnn1.h.value.mean():.6f}")
print(f"  RNN2 hidden state mean: {network.rnn2.h.value.mean():.6f}")
print("  (RNN1 is reset to zeros, RNN2 keeps its values)")

## 8. Performance Tips

### 1. Use JIT Compilation with State Management

In [None]:
# Create a model
rnn = brainstate.nn.GRUCell(10, 20)
brainstate.nn.init_all_states(rnn, batch_size=32)

# Define a JIT-compiled function that processes sequences
@brainstate.compile.jit
def process_sequence(sequence):
    outputs = []
    for t in range(sequence.shape[0]):
        output = rnn(sequence[t])
        outputs.append(output)
    return jnp.stack(outputs)

# Generate test data
test_seq = jax.random.normal(jax.random.PRNGKey(0), (50, 32, 10))

# First call: compilation
import time
start = time.time()
result = process_sequence(test_seq)
compile_time = time.time() - start

# Second call: using compiled version
brainstate.nn.reset_all_states(rnn, batch_size=32)
start = time.time()
result = process_sequence(test_seq)
run_time = time.time() - start

print(f"First call (with compilation): {compile_time*1000:.2f}ms")
print(f"Second call (compiled): {run_time*1000:.2f}ms")
print(f"Speedup: {compile_time/run_time:.1f}x")

### 2. Batch Processing Multiple Sequences

In [None]:
def process_multiple_sequences(sequences_list):
    """
    Process multiple independent sequences efficiently.
    
    Args:
        sequences_list: List of sequences, each (seq_len, batch_size, input_size)
    """
    results = []
    
    for sequence in sequences_list:
        # Reset states before each sequence
        brainstate.nn.reset_all_states(rnn, batch_size=sequence.shape[1])
        
        # Process sequence
        result = process_sequence(sequence)
        results.append(result)
    
    return results

# Create multiple sequences
sequences = [jax.random.normal(jax.random.PRNGKey(i), (30, 32, 10)) for i in range(5)]

# Process all sequences
outputs = process_multiple_sequences(sequences)
print(f"Processed {len(outputs)} sequences")
print(f"Each output shape: {outputs[0].shape}")

## Summary

BrainState's collective operations provide powerful tools for managing complex neural networks:

1. **`call_order(level)`**: Control method execution order
2. **`call_all_fns(target, fn_name, ...)`**: Call any method across all modules
3. **`init_all_states(target, ...)`**: Initialize all module states
4. **`reset_all_states(target, ...)`**: Reset all module states (critical for RNNs)
5. **`vmap_*` variants**: Create vectorized ensembles

### Best Practices:

- Always reset RNN states between independent sequences
- Use `fn_if_not_exist='pass'` when calling methods that may not exist on all modules
- Leverage `vmap_init_all_states` for ensemble methods
- Combine with JIT compilation for best performance
- Use `node_to_exclude` for fine-grained control

These operations make it easy to build and manage complex hierarchical networks while maintaining clean, readable code.