# Batching Online Learning


The online learning module of the BrainScale framework provides two efficient batching strategies to optimize neural network training processes:

- **Manual Batching**: Explicitly manage batch dimensions with model state shape `(B, M)`, where `B` is the batch size and `M` is the number of model parameters
- **Automatic Batching**: Use `brainstate.transform.vmap` function for vectorized operations, maintaining single-sample model state shape `(M)` while processing batch data through automatic vectorization

**💡 Memory and Computational Advantages of Batching**

**1. Memory Layout Optimization:**
```python
# Inefficient memory layout (sample-by-sample)
for i in range(128):
    process_sample(i)  # 128 memory allocations

# Efficient memory layout (batching)
process_batch(all_128_samples)  # 1 memory allocation, contiguous storage
```

**2. Parallel Computing Advantages:**
```python
# CPU sample-by-sample: 128 × single sample time
# GPU batching: approximately equal to single sample time (ideal case)
```

The core concept of manual batching is: **explicitly control batch dimensions to maximize parallel computing efficiency**. While the code is slightly more complex, it provides significant performance improvements in large-scale training.

This tutorial will provide a detailed comparison of the implementation differences, applicable scenarios, and performance characteristics of these two approaches.

In [1]:
import brainstate
import braintools
import brainscale
import brainunit as u
import jax

brainstate.environ.set(dt=1.0 * u.ms)


## Preparation: Dataset + Model

First, we create a simulated classification task dataset.



In [2]:
# Dataset parameter configuration
n_time = 16  # Time steps
n_batch = 128      # Batch size
n_in = 100         # Input feature dimension
n_hidden = 200     # Hidden layer neuron count
n_out = 10         # Output class count

# Generate random training data
xs = brainstate.random.rand(n_time, n_batch, n_in)  # Input data shape: (16, 128, 100)
ys = brainstate.random.randint(0, n_out, n_batch)   # Label data shape: (128,)


Next, we construct a recurrent neural network based on Leaky Integrate-and-Fire (LIF) neurons to perform classification on this dataset.


In [3]:
class LIFNet(brainstate.nn.Module):
    """
    LIF Neural Network Model

    Architecture: Input layer -> LIF neuron layer (with recurrent connections) -> Output layer
    """

    def __init__(self, n_in, n_hidden, n_out):
        super().__init__()

        # LIF neuron layer: simulates biological neuron leaky integrate-and-fire behavior
        self.neu = brainscale.nn.LIF(n_hidden)

        # Weight initialization strategy
        rec_init = brainstate.init.KaimingNormal(unit=u.mV)    # Recurrent connection weights
        ff_init = brainstate.init.KaimingNormal(unit=u.mV)     # Feedforward connection weights

        # Synaptic connection layer: integrates feedforward input and recurrent feedback
        self.syn = brainstate.nn.DeltaProj(
            comm=brainscale.nn.Linear(
                n_in + n_hidden, n_hidden,
                # Connection weight matrix: [feedforward weights; recurrent weights]
                w_init=u.math.concatenate([
                    ff_init([n_in, n_hidden]),
                    rec_init([n_hidden, n_hidden])
                ], axis=0),
                b_init=brainstate.init.ZeroInit(unit=u.mV)
            ),
            post=self.neu
        )

        # Output layer: converts spike activity to classification output
        self.out = brainstate.nn.LeakyRateReadout(n_hidden, n_out)

    def update(self, x):
        """
        Model forward propagation

        Args:
            x: Input data

        Returns:
            Network output (classification logits)
        """
        # Integrate current input and recurrent spike feedback
        combined_input = u.math.concatenate([x, self.neu.get_spike()], axis=-1)
        self.syn(combined_input)

        # Return current timestep output
        return self.out(self.neu())


## Manual Batching

Manual batching requires the following conditions:

1. Initialize the model with batch model states, where model state shape is $\mathbb{R}^{B×M}$, with $B$ being the batch size and $M$ being the number of model parameters.
2. When calling the model's `.update` function, pass a batch of sample data with shape $\mathbb{R}^{B×D}$, where $D$ is the sample data dimension.
3. When initializing the online learning algorithm, set the `mode` parameter to `brainstate.mixin.Batching()` to enable manual batching mode. Alternatively, use `brainstate.environ.set(mode=brainstate.mixin.Batching())` to set global batching mode.

### Core Features

Manual batching mode requires developers to explicitly handle batch dimensions:

1. **State Initialization**: Model state shape must be `(B, M)`
2. **Data Format**: Input data shape is `(B, D)`
3. **Mode Setting**: Use `brainstate.mixin.Batching()` to enable batching mode

### Specific Example

Here's a simple example of manual batching.



In [4]:
class TrainerManualBatching:
    """Manual Batching Trainer"""

    def __init__(self, n_in, n_hidden, n_out):
        self.model = LIFNet(n_in, n_hidden, n_out)
        self.optimizer = brainstate.optim.Adam(lr=1e-3)
        # Register trainable parameters
        self.optimizer.register_trainable_weights(self.model.states(brainstate.ParamState))

    @brainstate.transform.jit(static_argnums=0)
    def train(self, inputs, targets):
        """
        Single training step

        Args:
            inputs: Input sequence shape: (T, B, D)
            targets: Target labels shape: (B,)
        """
        # Step 1: Initialize batch model states
        brainstate.nn.init_all_states(self.model, batch_size=inputs.shape[1])

        # Step 2: Create online learning algorithm instance
        model = brainscale.ES_D_RTRL(
            self.model,
            decay_or_rank=0.9,                           # Eligibility trace decay factor
            mode=brainstate.mixin.Batching()            # Enable manual batching mode
        )

        # Step 3: Compile computation graph (optimize execution efficiency)
        model.compile_graph(inputs[0])

        # Step 4: Get trainable parameters
        weights = self.model.states(brainstate.ParamState)

        def _etrace_grad(inp):
            """Calculate single-step loss and gradients"""
            out = model(inp)
            loss = braintools.metric.softmax_cross_entropy_with_integer_labels(
                out, targets
            ).mean()
            return loss, out

        def _etrace_step(prev_grads, x):
            """Eligibility trace gradient accumulation step"""
            # Calculate current step gradients
            f_grad = brainstate.augment.grad(
                _etrace_grad, weights,
                has_aux=True, return_value=True
            )
            cur_grads, local_loss, out = f_grad(x)

            # Accumulate gradients (eligibility trace mechanism)
            next_grads = jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads)
            return next_grads, (out, local_loss)

        # Step 5: Temporal forward propagation and gradient accumulation
        grads = jax.tree.map(u.math.zeros_like, weights.to_dict_values())
        grads, (outs, losses) = brainstate.compile.scan(_etrace_step, grads, inputs)

        # Step 6: Gradient clipping and parameter update
        grads = brainstate.functional.clip_grad_norm(grads, 1.0)
        self.optimizer.update(grads)

        return losses.mean()

    def f_train(self, n_epochs, inputs, targets):
        """Complete training process"""
        for epoch in range(n_epochs):
            loss = self.train(inputs, targets)
            print(f'Epoch {epoch + 1}/{n_epochs}, Loss: {loss:.4f}')

In [5]:
# Create trainer and start training
trainer_manual = TrainerManualBatching(n_in, n_hidden, n_out)
trainer_manual.f_train(10, xs, ys)

Epoch 1/10, Loss: 4.8891
Epoch 2/10, Loss: 3.3748
Epoch 3/10, Loss: 2.9881
Epoch 4/10, Loss: 2.7006
Epoch 5/10, Loss: 2.7921
Epoch 6/10, Loss: 2.7267
Epoch 7/10, Loss: 2.5467
Epoch 8/10, Loss: 2.5019
Epoch 9/10, Loss: 2.4054
Epoch 10/10, Loss: 2.3258


### Code Explanation

**Data Flow Overview**

Before diving into the analysis, let's understand the entire batching data flow:

```
Input data: inputs(T, B, D) → Model states(B, M) → Batch computation → Gradient accumulation → Parameter update
```

Where:
- `T`: Number of time steps
- `B`: Batch size
- `D`: Input feature dimension
- `M`: Model state dimension

**🔑 Three Core Elements of Batch Operations**

**1. Batch State Initialization - Why is this critical?**

```python
brainstate.nn.init_all_states(self.model, batch_size=inputs.shape[1])
```

What this line of code actually does:

```python
# Original state (single sample)
Neuron voltage: V → shape (200,)
Neuron spike: spike → shape (200,)

# After batching (128 samples)
Neuron voltage: V → shape (128, 200)
Neuron spike: spike → shape (128, 200)
```

Why is this necessary?

- RNN/SNN needs to maintain states between time steps
- Batching means simultaneously processing 128 independent sequences
- Each sequence needs its own state copy

**2. Batching Mode - How does the algorithm perceive batching?**

```python
mode=brainstate.mixin.Batching()
```

This parameter tells the ES_D_RTRL algorithm:

```python
# Without Batching mode expectation
Input: (100,) single sample
Output: (10,) single prediction

# With Batching mode expectation
Input: (128, 100) batch samples
Output: (128, 10) batch predictions
```

**3. Temporal Batch Processing - The most complex part**

```python
grads, (outs, losses) = brainstate.compile.scan(_etrace_step, grads, inputs)
```

What happens here:

```python
# inputs shape: (50, 128, 100) - 50 time steps, 128 samples per step

Time step 0: Process inputs[0] → (128, 100) → Update 128 states → Calculate gradients
Time step 1: Process inputs[1] → (128, 100) → Update 128 states → Accumulate gradients
...
Time step 49: Process inputs[49] → (128, 100) → Update 128 states → Final gradients
```


## Automatic Batching

Automatic batching primarily uses the `brainstate.transform.vmap` function for implementation. This function can vectorize the model's update function to achieve batch operations.

### Core Features

Automatic batching implements vectorized operations through the `vmap` function:

1. **State Management**: Model states maintain single-sample shape `(M)`
2. **Automatic Vectorization**: `vmap` automatically handles batch dimension mapping
3. **Code Simplicity**: Reduces complexity of manual batching

### Specific Example


In [6]:
class TrainerAutoBatching:
    """Automatic Batching Trainer"""

    def __init__(self, n_in, n_hidden, n_out):
        self.model = LIFNet(n_in, n_hidden, n_out)
        self.optimizer = brainstate.optim.Adam(lr=1e-3)
        self.optimizer.register_trainable_weights(self.model.states(brainstate.ParamState))

    @brainstate.transform.jit(static_argnums=0)
    def train(self, inputs, targets):
        """
        Single training step (automatic batching version)

        Args:
            inputs: Input sequence shape: (T, B, D)
            targets: Target labels shape: (B,)
        """
        # Step 1: Create online learning algorithm instance (no manual batching mode needed)
        model = brainscale.ES_D_RTRL(self.model, decay_or_rank=0.9)

        # Step 2: Use vmap to create batch state initialization function
        @brainstate.transform.vmap_new_states(
            axis_size=inputs.shape[1],                   # Batch size
            state_tag='new',                            # State tag (for distinguishing different state groups)
        )
        def init():
            """Initialize single sample model states"""
            brainstate.nn.init_all_states(self.model)
            model.compile_graph(inputs[0, 0])           # Compile graph with single sample

        # Execute batch initialization
        init()

        # Step 3: Create vectorized model wrapper
        vmap_model = brainstate.nn.Vmap(
            model,
            vmap_states='new'                           # Specify state group to vectorize
        )

        # Step 4: Get trainable parameters
        weights = self.model.states(brainstate.ParamState)

        def _etrace_grad(inp):
            """Calculate single-step loss and gradients (auto-vectorized version)"""
            out = vmap_model(inp)                       # Automatically handle batch dimensions
            loss = braintools.metric.softmax_cross_entropy_with_integer_labels(
                out, targets
            ).mean()
            return loss, out

        def _etrace_step(prev_grads, x):
            """Eligibility trace gradient accumulation step"""
            f_grad = brainstate.augment.grad(
                _etrace_grad, weights,
                has_aux=True, return_value=True
            )
            cur_grads, local_loss, out = f_grad(x)
            next_grads = jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads)
            return next_grads, (out, local_loss)

        # Step 5: Temporal forward propagation and gradient accumulation
        grads = jax.tree.map(u.math.zeros_like, weights.to_dict_values())
        grads, (outs, losses) = brainstate.compile.scan(_etrace_step, grads, inputs)

        # Step 6: Gradient clipping and parameter update
        grads = brainstate.functional.clip_grad_norm(grads, 1.0)
        self.optimizer.update(grads)

        return losses.mean()

    def f_train(self, n_epochs, inputs, targets):
        """Complete training process"""
        for epoch in range(n_epochs):
            loss = self.train(inputs, targets)
            print(f'Epoch {epoch + 1}/{n_epochs}, Loss: {loss:.4f}')


In [7]:
# Create trainer and start training
trainer_auto = TrainerAutoBatching(n_in, n_hidden, n_out)
trainer_auto.f_train(10, xs, ys)

Epoch 1/10, Loss: 4.0168
Epoch 2/10, Loss: 3.1031
Epoch 3/10, Loss: 2.8214
Epoch 4/10, Loss: 2.5791
Epoch 5/10, Loss: 2.6134
Epoch 6/10, Loss: 2.5754
Epoch 7/10, Loss: 2.4770
Epoch 8/10, Loss: 2.3772
Epoch 9/10, Loss: 2.3254
Epoch 10/10, Loss: 2.2870


### Code Explanation

**Core Concept Comparison**

Before diving into analysis, let's understand the fundamental differences between automatic and manual batching:

| Dimension | Manual Batching | Automatic Batching |
|-----------|-----------------|-------------------|
| **State Management** | Explicit batch states `(B, M)` | Single-sample states `(M)` + auto-vectorization |
| **Computation Method** | Direct batch computation | `vmap` functional mapping |
| **Code Complexity** | Need to handle batch dimensions | Abstract away batch details |

```python
# Conceptual comparison
Manual batching approach: Create 128 neuron state copies, compute simultaneously
Automatic batching approach: Define single neuron computation logic, auto-replicate 128 times
```

**Detailed Analysis of Key Steps**

**Step 1: Algorithm Instance Creation (Simplified Mode)**

```python
# Step 1: Create online learning algorithm instance (no manual batching mode needed)
model = brainscale.ES_D_RTRL(self.model, decay_or_rank=0.9)
```

**Automatic Batching Key Point:**

Notice here we do **NOT** use `mode=brainstate.mixin.Batching()`:

```python
# Manual batching version
model = brainscale.ES_D_RTRL(
    self.model,
    decay_or_rank=0.9,
    mode=brainstate.mixin.Batching()  # Explicitly enable batching
)

# Automatic batching version
model = brainscale.ES_D_RTRL(self.model, decay_or_rank=0.9)
# Algorithm thinks it's processing a single sample!
```

**Step 2: vmap State Initialization (Core Mechanism) - The Magic of State Vectorization**

```python
# Step 2: Use vmap to create batch state initialization function
@brainstate.transform.vmap_new_states(
    axis_size=inputs.shape[1],                   # Batch size
    state_tag='new',                            # State tag (for distinguishing different state groups)
)
def init():
    """Initialize single sample model states"""
    brainstate.nn.init_all_states(self.model) # Only initialize 1!
    model.compile_graph(inputs[0, 0])           # Compile graph with single sample

# Execute batch initialization
init()
```

Key understanding here:
- The function `init()` only knows how to initialize **1 sample's** states
- `vmap_new_states` automatically **replicates this function 128 times**
- Result: Get 128 independent state copies, but code only wrote logic for 1 sample
- State tag `'new'` helps distinguish different state groups, ensuring convenient extraction of these batch-initialized states later

**Step 3: Vmap Model Wrapper - Automatic Single Sample→Batch Conversion**

```python
vmap_model = brainstate.nn.Vmap(model, vmap_states='new')

# Internal flow when called:
Input: (128, 100)
  ↓ vmap auto-decomposes
128 parallel computations: each processes (100,) → (10,)
  ↓ vmap auto-combines
Output: (128, 10)
```

When you call `vmap_model(inp)`:

```python
# inp.shape = (128, 100)

# Step 1: vmap decomposes input
sample_0 = inp[0]   # (100,)
sample_1 = inp[1]   # (100,)
...
sample_127 = inp[127] # (100,)

# Step 2: Parallel execution (conceptually, may actually be vectorized)
result_0 = model_with_state_0(sample_0)   # (10,)
result_1 = model_with_state_1(sample_1)   # (10,)
...
result_127 = model_with_state_127(sample_127) # (10,)

# Step 3: vmap combines output
output = stack([result_0, result_1, ..., result_127])  # (128, 10)
```

**Key Insight:** The model function always thinks it's processing a single sample, completely unaware of batching!

The beauty of this design is that you can write neural network logic in the most simple and intuitive way (single sample), then automatically gain efficient batching capabilities.



## Summary

BrainScale's two batching strategies each have their advantages. Manual batching provides finer control and higher performance, suitable for production environments with large-scale training; automatic batching reduces implementation complexity with a cleaner API, making it more suitable for research and prototyping.

Choosing the appropriate batching strategy requires comprehensive consideration of specific application scenarios, performance requirements, and development efficiency. It's recommended to use automatic batching in early project stages for rapid idea validation, and consider migrating to manual batching during performance optimization phases.