# Key Concepts


Welcome to `brainscale`!

BrainScale is a Python library designed for implementing online learning in neural network models with dynamics. Online learning represents a learning paradigm that enables continuous parameter updates as models receive new data streams. This approach proves particularly valuable in real-world applications, including robotic control systems, agent decision-making processes, and large-scale data stream processing.


In this section, I will introduce some of the key concepts that are fundamental to understanding and using online learning methods defined in ``brainscale`` . These concepts include:

- Concepts related to build high-Level Neural Networks.
- Concepts related to customize neural network module: ``ETraceVar`` for hidden states, ``ETraceParam`` for weight parameters, and ``ETraceOp`` for input-to-hidden transition.
- Concepts for online learning algorithms ``ETraceAlgorithm``.

``brainscale`` is seamlessly integrated in the [brain dynamics programming ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/) centred on ``brainstate``. We strongly recommend that you first familiarise yourself with [basic usage of ``brainstate``](https://brainstate.readthedocs.io/), as this will help you better understand how ``brainscale`` works.


In [1]:
import brainstate as bst
import brainunit as u

import brainscale

## 1. Dynamical Models Supported in ``brainscale``

``BrainScale`` does not support online learning for arbitrary dynamical models. The dynamical models currently supported by ``BrainScale`` exhibit a specific architectural constraint, as illustrated in the figure below, wherein the "dynamics" and "interactions between dynamics" are strictly segregated. Models adhering to this architecture can be decomposed into two primary components:

- **Dynamics**: This component characterizes the intrinsic dynamics of neurons, encompassing models such as the Leaky Integrate-and-Fire (LIF) neuron model, the FitzHugh-Nagumo model, and Long Short-Term Memory (LSTM) networks. The update of dynamics (hidden states) is implemented through strictly element-wise operations, although the model may incorporate multiple hidden states.

- **Interaction between Dynamics**: This component defines the interactions between neurons, implemented through weight matrices or connectivity matrices. The interactions between model dynamics can be realized through standard matrix multiplication, convolutional operations, or sparse operations.

![](../_static/model-dynamics-supported.png)




To elucidate the class of dynamical models supported by BrainScale, let us examine a fundamental Leaky Integrate-and-Fire (LIF) neural network model. The dynamics of this network are governed by the following differential equations:

$$
\begin{aligned}
\tau \frac{dv_i}{dt} &= -v_i + I_{\text{ext}} + s_i \\
\tau_s \frac{ds_i}{dt} &= -s_i + \sum_{j} w_{ij} \delta(t - t_j)
\end{aligned}
$$

Here, $v_i$ represents the membrane potential of the neuron. When this potential exceeds a threshold value $v_{th}$, the neuron generates an action potential and its membrane potential is reset to $v_{\text{reset}}$, as described by:

$$
\begin{aligned}
z_i & = \mathcal{H}(v_i-v_{th}) \\
v_i & \leftarrow v_{\text{reset}}
\end{aligned}
$$

Additionally, $s_i$ denotes the postsynaptic current, $I_{\text{ext}}$ represents the external input current, $w_{ij}$ is the synaptic weight from neuron $i$ to neuron $j$, and $\delta(t - t_j)$ is the Dirac delta function indicating the reception of a synaptic event at time $t_j$. The time constants $\tau$ and $\tau_s$ characterize the temporal evolution of the membrane potential and postsynaptic current, respectively.

Through numerical integration, we discretize the above differential equations and express them in vector form, yielding the following update rules:

$$
\begin{aligned}
\mathbf{v}_i^{t+1} &= \mathbf{v}_i^{t} + \frac{\Delta t}{\tau} (-\mathbf{v}_i^{t} + \mathbf{I}_{\text{ext}} + \mathbf{s}^t) \\
\mathbf{s}_i^{t+1} &= \mathbf{s}_i^{t} + \frac{\Delta t}{\tau_s} (-\mathbf{s}_i^{t} + \underbrace{  W \mathbf{z}^t  } _ {\text{neuronal interaction}} )
\end{aligned}
$$

Notably, the dynamics of the LIF neurons are updated through element-wise operations, while the interaction component is implemented via matrix multiplication. All dynamical models supported by BrainScale can be decomposed into similar `dynamics` and `interaction` components. It is particularly worth noting that this architecture encompasses the majority of recurrent neural network models, thus enabling BrainScale to support online learning for a wide range of recurrent neural network architectures.

## 2. `brainscale.nn`: Constructing Neural Networks with Online Learning Support

The construction of neural network models supporting online learning in BrainScale follows identical conventions as those employed in `brainstate`. For comprehensive tutorials, please refer to the documentation on [Building Artificial Neural Networks](https://brainstate.readthedocs.io/en/latest/tutorials/artificial_neural_networks-en.html) and [Building Spiking Neural Networks](https://brainstate.readthedocs.io/en/latest/tutorials/spiking_neural_networks-en.html).

The sole distinction lies in the requirement to utilize components from the [`brainscale.nn` module](../apis/nn.rst) for model construction. These components represent extensions of `brainstate.nn` module, specifically engineered to provide modular units with online learning capabilities.

Below, we present a basic implementation demonstrating the construction of a Leaky Integrate-and-Fire (LIF) neural network using the `brainscale.nn` module.

In [2]:
class LIF_Delta_Net(bst.nn.Module):
    def __init__(
        self,
        n_in,
        n_rec,
        tau_mem=5. * u.ms,
        V_th=1. * u.mV,
        spk_fun=bst.surrogate.ReluGrad(),
        spk_reset: str = 'soft',
        rec_scale: float = 1.,
        ff_scale: float = 1.,
    ):
        super().__init__()

        # Using the LIF model in brainscale.nn
        self.neu = brainscale.nn.LIF(n_rec, tau=tau_mem, spk_fun=spk_fun, spk_reset=spk_reset, V_th=V_th)

        # Constructing input and recurrent connection weights
        rec_init = bst.init.KaimingNormal(rec_scale, unit=u.mV)
        ff_init = bst.init.KaimingNormal(ff_scale, unit=u.mV)
        w_init = u.math.concatenate([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0)

        # Using delta synaptic projections to construct input and recurrent connections
        self.syn = bst.nn.DeltaProj(
            # Using the Linear model in brainscale.nn
            comm=brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w_init, b_init=bst.init.ZeroInit(unit=u.mV)),
            post=self.neu
        )

    def update(self, spk):
        inp = u.math.concatenate([spk, self.neu.get_spike()], axis=-1)
        self.syn(inp)
        self.neu()
        return self.neu.get_spike()

In this exemplar implementation, we define a `LIF_Delta_Net` class that inherits from `bst.nn.Module`. The architecture incorporates two primary components: a Leaky Integrate-and-Fire (LIF) neuron model implemented as `self.neu`, and a `DeltaProj` module designated as `self.syn` which manages both input and recurrent connectivity.

Subsequently, we shall proceed to construct a three-layer Gated Recurrent Unit (GRU) neural network model:

In [3]:
class GRU_Net(bst.nn.Module):
    def __init__(
        self,
        n_in: int,
        n_rec: int,
        n_out: int,
        n_layer: int,
    ):
        super().__init__()

        # Building the GRU Layer
        self.layers = []
        for i in range(n_layer - 1):
            # Using the GRUCell model within brainscale.nn
            self.layers.append(brainscale.nn.GRUCell(n_in, n_rec))
            n_in = n_rec
        self.layers.append(brainscale.nn.GRUCell(n_in, n_out))

    def update(self, x):
        # Updating the GRU Layer
        for layer in self.layers:
            x = layer(x)
        return x

As demonstrated, the process of constructing neural network models using the [`brainscale.nn` module](../apis/nn.rst) maintains complete procedural equivalence with the construction methodology employed in the [`brainstate.nn` module](https://brainstate.readthedocs.io/en/latest/apis/nn.html). This architectural parallelism enables direct utilization of existing `brainstate` tutorials for developing neural network models with online learning capabilities.

## 3. `ETraceState`, `ETraceParam`, and `ETraceOp`: Customizing Network Modules

While the `brainscale.nn` module provides fundamental network components, it does not encompass all possible network dynamics. Consequently, BrainScale implements a mechanism for customizing module development through three primary classes: `ETraceState`, `ETraceParam`, and `ETraceOp`.

**Core Components**

- **`brainscale.ETraceState`**: Represents the hidden states $\mathbf{h}$ within modules, defining dynamical states such as membrane potentials in LIF neurons or postsynaptic conductances in exponential synaptic models.

- **`brainscale.ETraceParam`**: Corresponds to model parameters $\theta$ within modules, encompassing elements such as weight matrices for linear matrix multiplication and adaptive time constants in LIF neurons. All parameters requiring gradient updates during training must be defined within `ETraceParam`.

- **`brainscale.ETraceOp`**: Describes the operations that transform input data into postsynaptic currents based on model parameters. Supports various computational operations including linear matrix multiplication, sparse matrix operations, and convolution operations.

**Foundational Framework**

These three components—`ETraceState`, `ETraceParam`, and `ETraceOp`—constitute the fundamental conceptual framework underlying neural network models with online learning capabilities in BrainScale.

In the following sections, we will present a series of illustrative examples demonstrating the practical implementation of custom network modules using `ETraceState`, `ETraceParam`, and `ETraceOp`.

### 3.1 `ETraceState`: Model States

Let us first examine a fundamental Leaky Integrate-and-Fire (LIF) neuron model, whose dynamics are characterized by the following differential equations:

$$
\begin{aligned}
\tau \frac{dv_i}{dt} &= -v_i + I_{\text{ext}} + v_\text{rest} \\
z_i & = \mathcal{H}(v_i-v_{th}) \\
v_i & \leftarrow v_{\text{reset}} \quad \text{if} \quad z_i > 0
\end{aligned}
$$

Here, $v_i$ represents the membrane potential of the neuron. When this potential exceeds a threshold value $v_{th}$, the neuron generates an action potential, and its membrane potential is reset to $v_{\text{reset}}$. The Heaviside function $\mathcal{H}$ characterizes the action potential generation, while $I_{\text{ext}}$ denotes the external input current. The temporal evolution of the membrane potential is governed by the time constant $\tau$, and $v_\text{rest}$ represents the resting membrane potential.

In [4]:
import jax
from typing import Callable


class LIF(bst.nn.Neuron):
    """
    Leaky integrate-and-fire neuron model.
    """

    def __init__(
        self,
        size: bst.typing.Size,
        R: bst.typing.ArrayLike = 1. * u.ohm,
        tau: bst.typing.ArrayLike = 5. * u.ms,
        V_th: bst.typing.ArrayLike = 1. * u.mV,
        V_reset: bst.typing.ArrayLike = 0. * u.mV,
        V_rest: bst.typing.ArrayLike = 0. * u.mV,
        V_initializer: Callable = bst.init.Constant(0. * u.mV),
        spk_fun: Callable = bst.surrogate.ReluGrad(),
        spk_reset: str = 'soft',
        name: str = None,
    ):
        super().__init__(size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # parameters
        self.R = bst.init.param(R, self.varshape)
        self.tau = bst.init.param(tau, self.varshape)
        self.V_th = bst.init.param(V_th, self.varshape)
        self.V_rest = bst.init.param(V_rest, self.varshape)
        self.V_reset = bst.init.param(V_reset, self.varshape)
        self.V_initializer = V_initializer

    def init_state(self, batch_size: int = None, **kwargs):
        # Here is the most critical step, we define an ETraceState class 
        # that describes the kinetic state of membrane potentials
        self.V = brainscale.ETraceState(bst.init.param(self.V_initializer, self.varshape, batch_size))

    def reset_state(self, batch_size: int = None, **kwargs):
        self.V.value = bst.init.param(self.V_initializer, self.varshape, batch_size)

    def get_spike(self, V=None):
        V = self.V.value if V is None else V
        v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
        return self.spk_fun(v_scaled)

    def update(self, x=0. * u.mA):
        last_v = self.V.value
        lst_spk = self.get_spike(last_v)
        V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
        V = last_v - (V_th - self.V_reset) * lst_spk
        # membrane potential
        dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
        V = bst.nn.exp_euler_step(dv, V)
        V = self.sum_delta_inputs(V)
        self.V.value = V
        return self.get_spike(V)

In the code above, we implement the `LIF` model through inheritance from `brainstate.nn.Neuron`. The class incorporates an `ETraceState` class variable `self.V` that characterizes the dynamical state of the membrane potential. The `init_state` method establishes the initial conditions for the membrane potential dynamics, while the `update` method implements the temporal evolution of these dynamics.

This implementation maintains substantial structural similarity with the `LIF` class definition in `brainstate`, with one crucial distinction: whereas `brainstate` employs `brainstate.HiddenState` to represent the membrane potential dynamics, `brainscale` utilizes `brainstate.ETraceState` to explicitly designate this dynamical state for online learning applications.

Therefore, we say that `brainscale.ETraceState` can be conceptualized as the counterpart to `brainstate.HiddenState`, specifically designed for defining model states that require eligibility trace updates. Should model states be defined using `brainstate.HiddenState` rather than `brainscale.ETraceState`, the online learning compiler in `brainscale` will fail to recognize these states. This oversight results in the compiled online learning rules being ineffective for the affected states, potentially leading to erroneous or omitted gradient updates in the model.

But we should still be aware that there are obvious differences between `ETraceState` and `HiddenState`:
- `brainscale.ETraceState`: Explicitly marks states for eligibility trace computation
- `brainstate.HiddenState`: Standard state representation within ``branstate`` without online learning capabilities

### 3.2 `ETraceParam`: Model Parameters

Let us examine a fundamental matrix multiplication operator, defined by the following mathematical expression:

$$
y = W x + b
$$

where $W$ represents the weight matrix, $x$ denotes the input vector, and $b$ is the bias vector. These model parameters can be defined using the `ETraceParam` class.

In [5]:
def generate_weight(
    n_in, n_out, init: Callable = bst.init.KaimingNormal()
) -> brainscale.ETraceParam:
    weight = init([n_in, n_out])
    bias = bst.init.ZeroInit()([n_out])
    
    # Here is the most crucial step, we define an ETraceParam class to describe the weight matrix and bias vector
    return brainscale.ETraceParam({'weight': weight, 'bias': bias})

In the code above, we define a `generate_weight` function that produces weight matrices and bias vectors. This function returns an `ETraceParam` object that encapsulates these parameters.

`brainscale.ETraceParam` serves as the counterpart to `brainstate.ParamState`, specifically designed for model parameters requiring eligibility trace updates. When model parameters $\theta$ are defined using `brainscale.ETraceParam`, the online learning compiler in `brainscale` implements temporally-dependent gradient updates according to the following formula:

$$
\nabla_\theta \mathcal{L}=\sum_{t} \frac{\partial \mathcal{L}^{t}}{\partial \mathbf{h}^{t}} \sum_{k=1}^t \frac{\partial \mathbf{h}^t}{\partial \boldsymbol{\theta}^k},
$$

where $\boldsymbol{\theta}^k$ represents the weight $\boldsymbol{\theta}$ utilized at time step $k$.

Conversely, when model parameters $\theta$ are defined using `brainstate.ParamState`, the online learning compiler computes only the instantaneous gradient of the loss function with respect to the weights:

$$
\nabla_\theta \mathcal{L}=\sum_{t} \frac{\partial \mathcal{L}^{t}}{\partial \mathbf{h}^{t}} \frac{\partial \mathbf{h}^t}{\partial \boldsymbol{\theta}^t}.
$$

This implementation distinction signifies that in `brainscale`'s online learning framework, parameters defined as `brainstate.ParamState` are treated as those not requiring eligibility trace updates, thereby forfeiting the ability to compute gradients with temporal dependencies. This architectural design enhances the flexibility of parameter update patterns, thereby increasing the customizability of gradient computation mechanisms.

### 3.3 `ETraceOp`: Model Input-Output Functions

`ETraceOp` represents another fundamental concept in dynamics interaction. While `ETraceParam` characterizes the parameters utilized in dynamical interactions, `ETraceOp` defines the operational transformations of these interactions. It must be implemented as a function adhering to the following specification:

```python
def op(
    x: jax.Array, 
    param: brainscale.ETraceParam
) -> jax.Array:
    pass
```

`ETraceOp` describes the transformation of model inputs to outputs based on specified parameters. This framework supports various computational operations, including:
- Linear matrix multiplication
- Sparse matrix operations
- Convolution operations
- and more

For the matrix multiplication operator discussed above, we can define an appropriate `ETraceOp` as follows:

In [6]:
@brainscale.ETraceOp
def matmul(x, w):
    weight = w['weight']
    bias = w['bias']
    return x @ weight + bias

### 3.4 `ETraceParamOp` = `ETraceParam` + `ETraceOp`: Integration of Model Parameters and Operations

It becomes evident that `ETraceParam` and `ETraceOp` are inherently coupled concepts. `ETraceParam` cannot be parameterized independently of functional transformations, while `ETraceOp` cannot perform input-output transformations without parameterization. This intrinsic relationship motivates the introduction of a unified concept: `ETraceParamOp`.

`ETraceParamOp` encapsulates both model parameters and their associated operations, providing a comprehensive framework for describing parameter-operation combinations. During instantiation, it requires both a parameter object and an operation function as inputs.

Using the `ETraceParamOp` framework, we can reformulate the matrix multiplication operator discussed above as follows:

In [7]:
class Linear(bst.nn.Module):
    """
    Linear layer.
    """
    def __init__(
        self,
        in_size: bst.typing.Size,
        out_size: bst.typing.Size,
        w_init: Callable = bst.init.KaimingNormal(),
    ):
        super().__init__()

        # input and output shape
        self.in_size = in_size
        self.out_size = out_size

        # weights
        weight = bst.init.param(w_init, [self.in_size[-1], self.out_size[-1]], allow_none=False)
        
        # operation
        op = lambda x, w: u.math.matmul(x, w)
        
        # Here is the most crucial step, we define an ETraceParamOp class to describe the weight matrix and operations
        self.weight_op = brainscale.ETraceParamOp(weight, op)

    def update(self, x):
        # Operation of ETraceParamOp
        return self.weight_op.execute(x)

As demonstrated in the code above, `ETraceParamOp` represents an integrated construct that unifies model parameters and their associated operations. This unified framework requires both a parameter object and an operation function for instantiation. The `execute` method of `ETraceParamOp` implements the operational transformation, converting input data into output data according to the specified parameters and operations.

Through this implementation, we have successfully defined a fundamental `Linear` layer module, encompassing a weight matrix and its corresponding matrix multiplication operation. This module serves as a building block for constructing neural network models with online learning capabilities.

## 4. `ETraceAlgorithm`: Online Learning Algorithms

`ETraceAlgorithm` represents another fundamental concept in the BrainScale framework, defining both the eligibility trace update process during model state evolution and the gradient update rules for model parameters. Implemented as an abstract class, `ETraceAlgorithm` serves as a specialized framework for describing various forms of online learning algorithms within BrainScale.

The algorithmic support provided by `brainscale.ETraceAlgorithm` is founded upon the three fundamental concepts previously introduced: `ETraceState`, `ETraceParam`, and `ETraceOp`. 

`brainscale.ETraceAlgorithm` implements a flexible online learning compiler that enables online learning capabilities for any neural network model constructed using these three foundational concepts.

Specifically, BrainScale currently supports the following online learning algorithms:

1. `brainscale.DiagIODimAlgorithm`
    - Implements the ES-D-RTRL algorithm for online learning
    - Achieves $O(N)$ memory and computational complexity for online gradient computation
    - Optimized for large-scale spiking neural network models
    - Detailed algorithm specifications are available in [our paper](https://doi.org/10.1101/2024.09.24.614728)

2. `brainscale.DiagParamDimAlgorithm`
    - Utilizes the D-RTRL algorithm for online learning
    - Features $O(N^2)$ memory and computational complexity for online gradient computation
    - Applicable to both recurrent neural networks and spiking neural network models
    - Complete algorithmic details are documented in [our paper](https://doi.org/10.1101/2024.09.24.614728)

3. `brainscale.DiagHybridDimAlgorithm`
    - Implements selective application of ES-D-RTRL or D-RTRL algorithms for parameter updates
    - Preferentially employs D-RTRL for convolutional layers and highly sparse connections
    - Optimizes memory and computational complexity of parameter updates through adaptive algorithm selection

The framework is designed for extensibility, with ongoing development to support additional online learning algorithms for diverse application scenarios.

In the following demonstration, we will illustrate the process of constructing a neural network model with online learning capabilities using `brainscale.ETraceAlgorithm`. This example will serve to exemplify the practical implementation of the concepts discussed above.

In [8]:
with bst.environ.context(dt=0.1 * u.ms):

    # Define a simple recurrent neural network composed of LIF neurons
    model = LIF_Delta_Net(10, 10)
    bst.nn.init_all_states(model)
    
    # The model is fed into an online learning algorithm with a view to online learning
    model = brainscale.DiagIODimAlgorithm(model, decay_or_rank=0.99)
    
    # Compile the model's eligibility trace based on one input data. 
    # Thereafter, the model is called to update not only the state 
    # of the model, but also the model's eligibility trace
    example_input = bst.random.random(10) < 0.1
    model.compile_graph(example_input)

In essence, the user-defined neural network model solely specifies how the model states $\mathbf{h}$ evolve forward in time as a function of inputs. The compiled `ETraceAlgorithm`, in contrast, defines the update dynamics of the model's eligibility traces $\mathbf{\epsilon}$ in relation to state updates. Consequently, subsequent model invocations result in concurrent updates of both model states and their corresponding eligibility traces.

In [9]:
with bst.environ.context(dt=0.1 * u.ms):
    
    out = model(example_input)

# The eligibility trace of the model's pre-synaptic neural 
# activity trace can be obtained by calling "model.etrace_xs"
bst.util.PrettyMapping(model.etrace_xs)

{
  Var(id=2878074780864):float32[20]: ShortTermState(
    value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.], dtype=float32)
  )
}

In [10]:
# The eligibility trace of the model's post-synaptic neural 
# activity trace can be obtained by calling "model.etrace_dfs"
bst.util.PrettyMapping(model.etrace_dfs)

{
  (Var(id=2878074780928):float32[10], ('neu', 'V')): ShortTermState(
    value=Array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01],      dtype=float32)
  )
}

## 5. Conclusion

`BrainScale` provides a comprehensive and elegant framework for online learning, with its core concepts organized into the following hierarchical layers:

1. Infrastructure Layer
    - Supports specific dynamical model architectures with strict separation between "dynamics" and "interactions"
    - Built upon the `BrainState` ecosystem, maintaining full compatibility with its programming paradigm
    - Provides ready-to-use neural network components through the `brainscale.nn` module

2. Core Concepts Layer
    - `ETraceState`: Designates hidden states requiring eligibility trace updates
    - `ETraceParam`: Identifies model parameters requiring eligibility trace updates
    - `ETraceOp`: Defines specific operations for dynamical interactions
    - `ETraceParamOp`: Provides unified interface for parameter-operation combinations

3. Algorithm Implementation Layer
    - `DiagIODimAlgorithm`: Implements ES-D-RTRL algorithm with $O(N)$ complexity
    - `DiagParamDimAlgorithm`: Implements D-RTRL algorithm with $O(N^2)$ complexity
    - `DiagHybridDimAlgorithm`: Adaptive hybrid approach, selecting between $O(N)$ and $O(N^2)$ complexity algorithms based on network architecture

4. Operational Workflow
    - Construction of neural networks using foundational components
    - Selection and application of appropriate online learning algorithms
    - Model compilation generating eligibility trace computation graphs
    - Concurrent updates of model states and eligibility traces through forward propagation

BrainScale's distinctive architecture encapsulates complex online learning algorithms behind concise interfaces while providing flexible customization mechanisms. This design philosophy achieves a balance between:
- High performance and usability
- Seamless integration with the existing BrainState ecosystem
- Intuitive and efficient construction of online learning neural networks

Through this architectural approach, BrainScale transforms the development and training of online learning neural networks into an intuitive and efficient process, providing a powerful toolkit for neural computation research.