# `ETraceState`: Online Learning State Management

In the `brainscale` framework, the `ETraceState` class family provides powerful state management functionality specifically designed for implementing **eligibility trace-based online learning mechanisms**. Eligibility traces are important concepts in reinforcement learning and neural network training, allowing systems to track and update the historical activity of neurons and synapses, thereby enabling more efficient learning algorithms.

## Core Features

- **State Tracking**: Real-time recording of dynamic states of neurons and synapses
- **Online Learning**: Support for parameter online updates based on eligibility traces
- **Flexible Architecture**: Applicable to single neurons, multi-compartment models, and complex tree structures
- **High-Performance Computing**: Based on JAX for high-performance numerical computation


In [1]:
import brainscale
import brainstate
import brainunit as u
import jax.numpy as jnp

## `brainscale.ETraceState` Class: Single State Management

The [`brainscale.ETraceState`](../apis/generated/brainscale.ETraceState.rst) class is actually a subclass of `brainstate.HiddenState`, specifically designed for managing state variables of a single neuron or synapse. Each instance can only represent one state dimension, ensuring clarity and controllability in state management.


### Practical Application: GIF Neuron Model

**Generalized Integrate-and-Fire (GIF) Neuron** is an advanced neuron model with adaptive mechanisms and dynamic threshold. Its mathematical description is as follows:

$$
\begin{aligned}
\frac{\mathrm{d} I_1}{\mathrm{d} t} &= - k_1 I_1 \quad \text{(Adaptation current 1)} \\
\frac{\mathrm{d} I_2}{\mathrm{d} t} &= - k_2 I_2 \quad \text{(Adaptation current 2)} \\
\tau \frac{\mathrm{d} V}{\mathrm{d} t} &= - (V - V_{\mathrm{rest}}) + R\sum_{j}I_j + RI \quad \text{(Membrane potential)} \\
\frac{\mathrm{d} V_{\mathrm{th}}}{\mathrm{d} t} &= a(V - V_{\mathrm{rest}}) - b(V_{\mathrm{th}} - V_{\mathrm{th}\infty}) \quad \text{(Dynamic threshold)}
\end{aligned}
$$

When $V > V_{\mathrm{th}}$, the neuron fires and executes reset:

$$
\begin{aligned}
I_1 &\leftarrow R_1 I_1 + A_1 \\
I_2 &\leftarrow R_2 I_2 + A_2 \\
V &\leftarrow V_{\mathrm{reset}} \\
V_{\mathrm{th}} &\leftarrow \max(V_{\mathrm{th,reset}}, V_{\mathrm{th}})
\end{aligned}
$$

Where $V$ is the membrane potential, $V_{rest}$ is the resting potential, $R$ is the membrane resistance, $I$ is the input current, $V_{th}$ is the threshold potential, $V_{th\infty}$ is the resting threshold potential, $a$ and $b$ are threshold dynamics parameters, $I_j$ are adaptation currents representing arbitrary numbers of internal currents, $R_j$ are decay coefficients for adaptation currents, $A_j$ are increments for adaptation currents, $V_\mathrm{reset}$ is the reset potential, and $V_{th, \mathrm{reset}}$ is the threshold reset potential.

For the GIF neuron model, we need to use four `brainscale.ETraceState` classes to define its four states. Here is an example code:



In [2]:
class GIF(brainstate.nn.Neuron):
    """
    Generalized Integrate-and-Fire Neuron Model

    Implements a neuron model with dual adaptive currents and dynamic threshold
    """

    def __init__(self, size, **kwargs):
        super().__init__(size, **kwargs)

        # Model parameters
        self.tau = 20.0 * u.ms  # Membrane time constant
        self.R = 100.0 * u.ohm  # Membrane resistance
        self.V_rest = -70.0 * u.mV  # Resting potential
        self.V_reset = -80.0 * u.mV  # Reset potential
        self.V_th_inf = -50.0 * u.mV  # Resting threshold

        # Adaptation parameters
        self.k1 = 0.1 / u.ms  # Adaptation current 1 decay rate
        self.k2 = 0.05 / u.ms  # Adaptation current 2 decay rate
        self.R1, self.R2 = 0.9, 0.8  # Decay coefficients during firing
        self.A1 = 10.0 * u.nA  # Adaptation current 1 increment
        self.A2 = 5.0 * u.nA  # Adaptation current 2 increment

        # Threshold dynamics parameters
        self.a = 0.1 / u.ms  # Threshold adaptation strength
        self.b = 0.02 / u.ms  # Threshold recovery strength

    def init_state(self, *args, **kwargs):
        # Adaptation currents (initialized to zero)
        self.I1 = brainscale.ETraceState(jnp.zeros(self.varshape) * u.nA)
        self.I2 = brainscale.ETraceState(jnp.zeros(self.varshape) * u.nA)

        # Membrane potential (randomly initialized near resting potential)
        self.V = brainscale.ETraceState(
            brainstate.random.normal(self.varshape) * 2.0 * u.mV + self.V_rest
        )

        # Dynamic threshold (randomly initialized)
        self.V_th = brainscale.ETraceState(
            brainstate.random.uniform(-52.0, -48.0, self.varshape) * u.mV
        )

In [3]:
# Create and initialize neuron population
gif_neurons = GIF(size=100)  # 100 neurons
gif_neurons.init_state()

Each `brainscale.ETraceState` class instance represents a state variable, corresponding to $I_1$, $I_2$, $V$, and $V_{th}$ in the GIF neuron model respectively. Each state variable can only define the one state of neuron or synapse.

In [4]:
print(f"Dimensions managed by each state variable: {gif_neurons.I1.num_state}")
print(f"Neuron population shape: {gif_neurons.I1.varshape}")

Dimensions managed by each state variable: 1
Neuron population shape: (100,)


### Key Feature Description

- **Independence**: Each `ETraceState` instance manages an independent state variable
- **Type Safety**: Supports physical units, ensuring computational correctness
- **Tracking Capability**: Automatically records state change history, supporting eligibility trace computation

## `brainscale.ETraceGroupState` Class: Group State Management

The [`brainscale.ETraceGroupState`](../apis/generated/brainscale.ETraceGroupState.rst) class is specifically designed for defining multiple states of neuron or synapse populations. It is a subclass of the `brainscale.ETraceState` class, inheriting all its attributes and methods.

In multi-compartment neuron models, each variable represents the state of multiple compartments, such as membrane potential. If each compartment's membrane potential were defined using a separate `brainscale.ETraceState` class, then multiple state variables would need to be defined in multi-compartment neuron models, leading to verbose and difficult-to-maintain code. However, using the `brainscale.ETraceGroupState` class allows multiple state variables to be combined together, simplifying code structure.

In the following example, we will use the `brainscale.ETraceGroupState` class to define state variables for a three-compartment neuron model.

### Multi-Compartment Neuron Modeling

In [5]:
import braincell

# Instantiate a Morphology object
morphology = braincell.Morphology()

# Create individual sections using the creation methods
morphology.add_cylinder_section('soma', length=20 * u.um, diam=10 * u.um, nseg=1)  # Soma section
morphology.add_cylinder_section('axon', length=100 * u.um, diam=1 * u.um, nseg=2)  # Axon section
morphology.add_point_section(
    'dendrite',
    positions=[[0, 0, 0], [100, 0, 0], [200, 0, 0]] * u.um,
    diams=[2, 3, 2] * u.um,
    nseg=3
)  # Dendrite section with explicit points and diameters

# Connect the sections: axon and dendrite connected to soma
morphology.connect('axon', 'soma', parent_loc=1.0)  # Axon connects to soma at the end
morphology.connect('dendrite', 'soma', parent_loc=1.0)  # Dendrite connects to soma at the end

# Print a summary of the morphology
morphology

Morphology(
  sections={
    'soma': Section<name='soma', nseg=1, points=2, Ra=100.0 * ohm * cmetre, cm=1.0 * ufarad / cmeter2, parent=None, parent_loc=None>,
    'axon': Section<name='axon', nseg=2, points=2, Ra=100.0 * ohm * cmetre, cm=1.0 * ufarad / cmeter2, parent='soma', parent_loc=1.0>,
    'dendrite': Section<name='dendrite', nseg=3, points=3, Ra=100.0 * ohm * cmetre, cm=1.0 * ufarad / cmeter2, parent='soma', parent_loc=1.0>
  }
)

In [6]:
class ThreeCompartmentNeuron(braincell.MultiCompartment):
    def __init__(self, pop_size, morphology):
        super().__init__(pop_size, morphology=morphology)

    def init_state(self, *args, **kwargs):
        self.V = brainscale.ETraceGroupState(jnp.zeros(self.varshape) * u.mV)

In [7]:
multi_neuron = ThreeCompartmentNeuron(10, morphology)
multi_neuron.init_state()

Each `brainscale.ETraceGroupState` instance represents a state variable containing state information for multiple compartments. In this example, we only defined the membrane potential $V$ state variable, but actually more state variables can be defined, such as adaptation currents $I_j$, etc. Each state variable can contain state information for multiple compartments, making it usable in multi-compartment neuron models.

In [8]:
print(f"Total number of compartments: {multi_neuron.V.num_state}")
print(f"State shape: {multi_neuron.V.varshape}")
print(f"Neuron population size: {multi_neuron.pop_size}")

Total number of compartments: 6
State shape: (10,)
Neuron population size: (10,)


### Advantage Analysis

- **Unified Management**: Single state object manages multiple compartments, simplifying code structure
- **Spatial Consistency**: Maintains spatial relationships and computational consistency between compartments
- **Efficient Computation**: Vectorized operations improve computational efficiency

## `brainscale.ETraceTreeState` Class: Tree Structure State

[`brainscale.ETraceTreeState`](../apis/generated/brainscale.ETraceTreeState.rst) provides the most flexible state management solution, supporting **PyTree tree structures**, suitable for neural network models with complex hierarchical relationships. It is a subclass of the `brainscale.ETraceState` class, inheriting all its attributes and methods.

The following uses the GIF model as an example to demonstrate how to use the `brainscale.ETraceTreeState` class to define tree-structured state variables.

### Advanced Application Example

In [9]:
class GIF_tree(brainstate.nn.Neuron):
    def init_state(self, *args, **kwargs):
        self.state = brainscale.ETraceTreeState(
            {
                'I1': jnp.zeros(self.varshape) * u.mA,
                'I2': jnp.zeros(self.varshape) * u.mA,
                'V': brainstate.random.random(self.varshape) * u.mV,
                'Vth': brainstate.random.uniform(1, 2, self.varshape) * u.mV
            }
        )

In [10]:
gif_tree = GIF_tree(5)
gif_tree.init_state()

Each `brainscale.ETraceTreeState` instance represents a tree-structured state variable containing multiple sub-state variables. In this example, we defined four state variables: $I_1$, $I_2$, $V$, and $V_{th}$, which are organized into a tree structure.

In [11]:
print(f"Number of independent states in state tree: {gif_tree.state.num_state}")

Number of independent states in state tree: 4


### Advantages of Tree Structure

- **Hierarchical Organization**: Logically clear organization of complex state variables
- **Flexible Access**: Support for nested access and partial updates
- **Strong Extensibility**: Easy to add new state dimensions and functional modules
- **Type Diversity**: Support for different types and shapes of state variables


## Summary

The `ETraceState` class family in `brainscale` provides powerful and flexible state management solutions for neural network modeling:

| Type | Applicable Scenarios | Advantages | Typical Applications |
|------|---------------------|------------|---------------------|
| `ETraceState` | Single state variable | Simple and intuitive, type-safe | Basic neuron models |
| `ETraceGroupState` | Homogeneous multi-state | Unified management, efficient computation | Multi-compartment neurons |
| `ETraceTreeState` | Complex hierarchical structures | Flexible organization, easy to extend | Advanced neural network models |

Choosing the appropriate state management type can significantly improve code readability, maintainability, and computational efficiency, providing a solid foundation for building complex neural network models.