# `ETraceOp`: Eligibility Trace Operator

In `brainscale`, the Eligibility Trace Operator (`ETraceOp`) plays a central role in connecting neural populations and defining their synaptic interactions. Its primary responsibility is to compute the post-synaptic current based on the model's inputs (pre-synaptic activity) and its parameters (e.g., synaptic weights). Critically, `ETraceOp` natively supports learning mechanisms based on Eligibility Traces, which is a key process for simulating temporal credit assignment in biological neural systems. This allows the model to update connection weights based on delayed reward or error signals.

The design philosophy of `ETraceOp` is to decouple the computational logic (the operator itself) from the trainable parameters (`ETraceParam`), providing significant flexibility and extensibility.

In [1]:
import brainevent
import brainstate
import jax
import jax.numpy as jnp

import brainscale

## Built-in Eligibility Trace Operators

`brainscale` provides a suite of powerful, pre-configured eligibility trace operators that cater to the most common neural network modeling needs. These operators are used in conjunction with the parameter container `brainscale.ETraceParam` to form the building blocks of a neural network.

The main built-in operators include:

  * [`brainscale.MatMulOp`](../apis/generated/brainscale.MatMulOp.rst): Implements standard matrix multiplication, serving as the foundation for fully-connected (Dense) layers.
  * [`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst): Implements convolution, supporting 1D, 2D, and 3D operations, which is core to building Convolutional Neural Networks (CNNs).
  * [`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst): Designed for sparse connectivity, this operator implements sparse matrix multiplication. It is particularly crucial in Graph Neural Networks (GNNs) and large-scale biophysical models that require efficient representation of sparse connections.
  * [`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst): Performs element-wise mathematical operations, often used to implement activation functions, scaling, or other custom element-by-element transformations.
  * [`brainscale.LoraOp`](../apis/generated/brainscale.LoraOp.rst): Implements Low-Rank Adaptation, an efficient technique for fine-tuning large pre-trained models.

### `brainscale.MatMulOp`: The Matrix Multiplication Operator

The [`brainscale.MatMulOp`](../apis/generated/brainscale.MatMulOp.rst) is one of the most fundamental operators, supporting matrix multiplication for scenarios like fully-connected layers.

**Core Operation**:

  * **Input**: A matrix $x \in \mathbb{R}^{B \times D_{in}}$
  * **Parameters**: A dictionary $w$ containing a weight matrix `weight` $\in \mathbb{R}^{D_{in} \times D_{out}}$ and a bias vector `bias` $\in \mathbb{R}^{D\_{out}}$
  * **Output**: A matrix $y \in \mathbb{R}^{B \times D_{out}}$

**Supported Operation Types**:

**1.  Standard matrix multiplication**:

$$y = x \cdot \text{param['weight']} + \text{param['bias']}$$


In [2]:
# Standard matrix multiplication

brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(4, 5),
        'bias': brainstate.random.rand(5)
    },
    brainscale.MatMulOp()
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[5]),
    'weight': ShapedArray(float32[4,5])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=MatMulOp(
    is_diagonal=False,
    weight_mask=None,
    weight_fn=<function MatMulOp.<lambda> at 0x000002A178B83B00>,
    apply_weight_fn_before_mask=False
  ),
  is_etrace=True
)


**2.  Masked operation**: The `weight_mask` parameter can be used to implement sparse connections, where only weights corresponding to `True` in the mask are active.

$$y = x \cdot (\text{param['weight']} \odot \text{mask}) + \text{param['bias']}$$


In [3]:
# Matrix multiplication with a mask to implement sparse connectivity

brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(4, 5),
        'bias': brainstate.random.rand(5)
    },
    brainscale.MatMulOp(
        weight_mask=brainstate.random.rand(4, 5) > 0.5
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[5]),
    'weight': ShapedArray(float32[4,5])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=MatMulOp(
    is_diagonal=False,
    weight_mask=Array([[ True, False, False, False,  True],
           [False,  True,  True,  True, False],
           [ True, False,  True,  True, False],
           [ True, False,  True, False, False]], dtype=bool),
    weight_fn=<function MatMulOp.<lambda> at 0x000002A178B83B00>,
    apply_weight_fn_before_mask=False
  ),
  is_etrace=True
)

**3.  Weight function**: The `weight_fn` parameter applies a function to the weight matrix before multiplication. For instance, using `jnp.abs` can enforce Dale's Law by ensuring all synaptic weights are positive (excitatory).

$$y = x \cdot f(\text{param['weight']}) + \text{param['bias']}$$


In [4]:
# Apply a function to the weights

brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(4, 5),
        'bias': brainstate.random.rand(5)
    },
    brainscale.MatMulOp(
        weight_fn=jnp.abs   # Ensures weights are positive
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[5]),
    'weight': ShapedArray(float32[4,5])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=MatMulOp(
    is_diagonal=False,
    weight_mask=None,
    weight_fn=<PjitFunction of <function abs at 0x000002A14F860400>>,
    apply_weight_fn_before_mask=False
  ),
  is_etrace=True
)

4. **Composition of masking and weight function**：

$$y = x \cdot f(\text{param['weight']} \odot \text{mask}) + \text{param['bias']}$$

In [5]:
# 同时使用掩码和权重函数
brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(4, 5),
        'bias': brainstate.random.rand(5)
    },
    brainscale.MatMulOp(
        weight_fn=jnp.abs,
        weight_mask=brainstate.random.rand(4, 5) > 0.5
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[5]),
    'weight': ShapedArray(float32[4,5])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=MatMulOp(
    is_diagonal=False,
    weight_mask=Array([[False, False, False, False,  True],
           [ True,  True,  True,  True,  True],
           [False, False, False, False,  True],
           [ True,  True,  True,  True,  True]], dtype=bool),
    weight_fn=<PjitFunction of <function abs at 0x000002A14F860400>>,
    apply_weight_fn_before_mask=False
  ),
  is_etrace=True
)

### `brainscale.ConvOp`: The Convolution Operator

The [`brainscale.ConvOp`](../apis/generated/brainscale.ConvOp.rst) provides general-purpose convolution operations suitable for models like CNNs.

**Dimensionality Support**:
A key feature of `ConvOp` is its ability to adapt to different dimensions. By specifying the `xinfo` parameter (a `jax.ShapeDtypeStruct` object), it can automatically infer and execute 1D, 2D, or 3D convolutions:

  * **1D Convolution**: For input shape `(length, channels)`, e.g., `xinfo=jax.ShapeDtypeStruct((32, 3), ...)`.
  * **2D Convolution**: For input shape `(height, width, channels)`, e.g., `xinfo=jax.ShapeDtypeStruct((32, 32, 3), ...)`.
  * **3D Convolution**: For input shape `(depth, height, width, channels)`, e.g., `xinfo=jax.ShapeDtypeStruct((32, 32, 32, 3), ...)`.

**Supported Operation Types** (where $\star$ denotes convolution):

**1.  Standard convolution**:

$$y = x \star \text{param['weight']} + \text{param['bias']}$$

where $\star$ is the convolution operation.

In [6]:
# Example of a 2D convolution
brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(3, 3),
        'bias': jnp.zeros(16)
    },
    brainscale.ConvOp(
        xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),  # (height, width, channels)
        window_strides=[1, 1],
        padding='SAME',
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[16]),
    'weight': ShapedArray(float32[3,3])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ConvOp(
    is_diagonal=False,
    window_strides=[
      1,
      1
    ],
    padding=SAME,
    lhs_dilation=None,
    rhs_dilation=None,
    feature_group_count=1,
    batch_group_count=1,
    dimension_numbers=None,
    weight_mask=None,
    weight_fn=<function ConvOp.<lambda> at 0x000002A178B83EC0>,
    xinfo=ShapeDtypeStruct(shape=(32, 3, 3), dtype=float32)
  ),
  is_etrace=True
)

Like `MatMulOp`, `ConvOp` also supports `weight_mask` and `weight_fn` for flexible and complex convolution definitions.

**2. Masking operation**：

$$y = x \star  (\mathrm{param['weight']} * \mathrm{mask}) + \mathrm{param['bias']}$$


In [7]:
brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(3, 3),
        'bias': jnp.zeros(16)
    },
    brainscale.ConvOp(
        xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),
        window_strides=[1, 1],
        padding='SAME',
        weight_mask=brainstate.random.rand(3, 3) > 0.5
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[16]),
    'weight': ShapedArray(float32[3,3])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ConvOp(
    is_diagonal=False,
    window_strides=[
      1,
      1
    ],
    padding=SAME,
    lhs_dilation=None,
    rhs_dilation=None,
    feature_group_count=1,
    batch_group_count=1,
    dimension_numbers=None,
    weight_mask=Array([[ True, False,  True],
           [False,  True,  True],
           [False, False,  True]], dtype=bool),
    weight_fn=<function ConvOp.<lambda> at 0x000002A178B83EC0>,
    xinfo=ShapeDtypeStruct(shape=(32, 3, 3), dtype=float32)
  ),
  is_etrace=True
)

**3. Weight function**：

$$y = x \star  f(\mathrm{param['weight']}) + \mathrm{param['bias']}$$

In [8]:
# 以2D卷积为例
brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(3, 3),
        'bias': jnp.zeros(16)
    },
    brainscale.ConvOp(
        xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),
        window_strides=[1, 1],
        padding='SAME',
        weight_fn=jnp.abs
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[16]),
    'weight': ShapedArray(float32[3,3])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ConvOp(
    is_diagonal=False,
    window_strides=[
      1,
      1
    ],
    padding=SAME,
    lhs_dilation=None,
    rhs_dilation=None,
    feature_group_count=1,
    batch_group_count=1,
    dimension_numbers=None,
    weight_mask=None,
    weight_fn=<PjitFunction of <function abs at 0x000002A14F860400>>,
    xinfo=ShapeDtypeStruct(shape=(32, 3, 3), dtype=float32)
  ),
  is_etrace=True
)

**4. Weight function + masking**：

$$y = x \star  f(\mathrm{param['weight']} * \mathrm{mask}) + \mathrm{param['bias']}$$

In [9]:
# 以2D卷积为例
brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(3, 3),
        'bias': jnp.zeros(16)
    },
    brainscale.ConvOp(
        xinfo=jax.ShapeDtypeStruct((32, 3, 3), jnp.float32),
        window_strides=(1, 1),
        padding='SAME',
        weight_mask=brainstate.random.rand(3, 3) > 0.5,
        weight_fn=jnp.abs,
    )
)

ETraceParam(
  value={
    'bias': ShapedArray(float32[16]),
    'weight': ShapedArray(float32[3,3])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ConvOp(
    is_diagonal=False,
    window_strides=(1, 1),
    padding=SAME,
    lhs_dilation=None,
    rhs_dilation=None,
    feature_group_count=1,
    batch_group_count=1,
    dimension_numbers=None,
    weight_mask=Array([[ True, False, False],
           [ True, False,  True],
           [False, False, False]], dtype=bool),
    weight_fn=<PjitFunction of <function abs at 0x000002A14F860400>>,
    xinfo=ShapeDtypeStruct(shape=(32, 3, 3), dtype=float32)
  ),
  is_etrace=True
)

### `brainscale.SpMatMulOp`: The Sparse Matrix Multiplication Operator

The [`brainscale.SpMatMulOp`](../apis/generated/brainscale.SpMatMulOp.rst) operator supports sparse matrix multiplication operations, suitable for scenarios like Graph Neural Networks (GNNs). It takes feature maps $x$ and parameters $w$ as input, and outputs the sparse matrix multiplication result $y$.

`brainscale.SpMatMulOp` performs similar operations to `brainscale.MatMulOp` for matrix multiplication:

$$y = x @ \text{param['weight']} + \text{param['bias']}$$

However, in this case, `param['weight']` is a sparse matrix, typically implemented using sparse matrices from [``brainevent``](https://brainevent.readthedocs.io/), which maintains computational efficiency while significantly reducing memory consumption. These include:

- ``brainevent.CSR``: Compressed Sparse Row matrix.
- ``brainevent.CSC``: Compressed Sparse Column matrix.
- ``brainevent.COO``: Coordinate Format sparse matrix.

`brainscale.SpMatMulOp` supports the following operations:

**1. Standard matrix multiplication**:

$$y = x @ \text{param['weight']} + \text{param['bias']}$$




In [10]:
data = jnp.where(
    brainstate.random.rand(100, 100) < 0.2,
    brainstate.random.rand(100, 100),
    0.
)
csr = brainevent.CSR.fromdense(data)

In [11]:
brainscale.ETraceParam(
    {'weight': brainstate.random.rand(100)},
    brainscale.SpMatMulOp(csr)
)

ETraceParam(
  value={
    'weight': ShapedArray(float32[100])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=SpMatMulOp(
    is_diagonal=False,
    sparse_mat=CSR(float32[100, 100], nse=1924),
    weight_fn=<function SpMatMulOp.<lambda> at 0x000002A178BA42C0>
  ),
  is_etrace=True
)

**2. Weight function**:

$$y = x @ f(\text{param['weight']}) + \text{param['bias']}$$


In [12]:
brainscale.ETraceParam(
    {'weight': brainstate.random.rand(100)},
    brainscale.SpMatMulOp(csr, weight_fn=jnp.abs)
)

ETraceParam(
  value={
    'weight': ShapedArray(float32[100])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=SpMatMulOp(
    is_diagonal=False,
    sparse_mat=CSR(float32[100, 100], nse=1924),
    weight_fn=<PjitFunction of <function abs at 0x000002A14F860400>>
  ),
  is_etrace=True
)

### `brainscale.ElemWiseOp`: Element-wise Operation Operator

[`brainscale.ElemWiseOp`](../apis/generated/brainscale.ElemWiseOp.rst) provides a concise way to apply element-wise function transformations to parameters. It doesn't directly process pre-synaptic input $x$, but operates directly on its own parameters $w$.

**Core Operation**:

$$y = f(w)$$

This can be used to create learnable activation function parameters, neuron thresholds, time constants, etc.

Here are some typical examples of element-wise operation operators:


In [13]:
brainscale.ETraceParam(
    brainstate.random.rand(4),
    brainscale.ElemWiseOp(jnp.abs)  # Absolute value operation
)

ETraceParam(
  value=ShapedArray(float32[4]),
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ElemWiseOp(
    fn=<PjitFunction of <function abs at 0x000002A14F860400>>,
    is_diagonal=True
  ),
  is_etrace=True
)

In [14]:
brainscale.ETraceParam(
    brainstate.random.rand(4),
    brainscale.ElemWiseOp(jnp.exp)  # Exponential operation
)

ETraceParam(
  value=ShapedArray(float32[4]),
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ElemWiseOp(
    fn=<PjitFunction of <function exp at 0x000002A14F8054E0>>,
    is_diagonal=True
  ),
  is_etrace=True
)

In [15]:
# Using custom lambda function

brainscale.ETraceParam(
    brainstate.random.rand(4),
    brainscale.ElemWiseOp(lambda x: x ** 2 + 1.)  # Custom function
)

ETraceParam(
  value=ShapedArray(float32[4]),
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=ElemWiseOp(
    fn=<function <lambda> at 0x000002A17672A340>,
    is_diagonal=True
  ),
  is_etrace=True
)

## Custom Eligibility Trace Operators

Although `brainscale` provides a comprehensive suite of built-in operators, research and applications often require exploration of novel neural network layers or synaptic plasticity rules. For this purpose, `brainscale` allows users to easily create custom operators by inheriting from the `brainscale.ETraceOp` base class.

Customizing an operator involves understanding and implementing two core methods: `xw_to_y` and `yw_to_w`.

1. **`xw_to_y(self, x, w)`: Define Forward Propagation**

   * **Purpose**: This method defines the core computational logic of the operator, i.e., how to compute the post-synaptic output `y` based on pre-synaptic input `x` and operator parameters `w`. It is functionally equivalent to the `forward` method of layers in standard deep learning frameworks.
   * **Mathematical representation**: $y = f(x, w)$.
   * **Parameters**:
     * `x`: Pre-synaptic neural activity (e.g., spikes, firing rates, or feature vectors).
     * `w`: A dictionary containing all learnable parameters of this operator (e.g., `{'weight': ..., 'bias': ...}`).

![](../_static/etraceop-xw2y.png)

2. **`yw_to_w(self, y, w)`: Define Gradient/Trace Propagation**

   * **Purpose**: This method is the core of the eligibility trace learning mechanism. It defines how "learning signals" from post-synaptic neurons (typically error gradients) "flow back" and influence the eligibility trace of each parameter. It answers the question: "How much does a change in output `y` affect parameter `w`?"
   * **Mathematical representation**: $w_{new} = g(y_{grad}, w)$. Here, $y_{grad}$ is an abstract "learning signal," and $w_{new}$ represents the update direction of parameters or their eligibility trace.
   * **Application scenario**: The computation result of this method will be directly used to update the eligibility trace $\boldsymbol{\epsilon}^{t}$. The final weight update $\Delta w$ will be the product of the learning signal (such as reward prediction error) and the eligibility trace, i.e., $\Delta w \propto \text{LearningSignal} \cdot \boldsymbol{\epsilon}^{t}$.
   * **Parameters**:
     * `y`: A vector representing learning signal or gradient, with dimensions matching the output dimensions of `xw_to_y`.
     * `w`: Current parameter dictionary.

![](../_static/etraceop-yw2w.png)


### Example: Building `CustomizedMatMul` from Scratch

Let's demonstrate how to customize an operator that has the same functionality as `MatMulOp` through a concrete example.


In [16]:
class CustomizedMatMul(brainscale.ETraceOp):
    """
    A custom matrix multiplication eligibility trace operator.
    It implements the computation y = x @ w['weight'] + w['bias'].
    """
    def xw_to_y(self, x, w: dict):
        """
        Forward propagation: compute y = x @ weight + bias
        """
        return jnp.dot(x, w['weight']) + w['bias']

    def yw_to_w(self, y, w: dict):
        """
        Parameter update (for eligibility trace): compute how learning signals affect weights and biases.
        This is similar to computing gradients dL/dw = dL/dy * dy/dw.
        """

        # For weight w['weight'], its gradient is the outer product of input x and output gradient y.
        # Here, we expand y's dimensions from (B, D_out) to (B, 1, D_out),
        # and x's dimensions from (B, D_in) to (B, D_in, 1),
        # thus implementing batch outer product computation through broadcast multiplication (x[..., None] * y[:, None, :]).
        # For simplification, we represent this dependency relationship in a more abstract way.
        y_expanded = jnp.expand_dims(y, axis=-2) # Shape becomes (..., 1, D_out)
        return {
            'weight': y_expanded * w['weight'], # Example update rule
            'bias': y * w['bias'] # Example update rule
        }

**Code Analysis**:

* In `xw_to_y`, we implement the logic of standard matrix multiplication plus bias, which is very intuitive.
* In `yw_to_w`, we define the update rules. The keys and value shapes of the returned dictionary must exactly match the original parameters `w`. `jnp.expand_dims` is used here to adjust the dimensions of `y` to ensure the broadcast mechanism can correctly apply the influence of `y` to each element of `w['weight']`.

### Using Custom Operators

Once defined, `CustomizedMatMul` can be used like any built-in operator, combined with `ETraceParam`, and seamlessly integrated into `brainscale`'s computational graph.


In [17]:
# 1. Instantiate custom operator
my_op = CustomizedMatMul()

# 2. Use ETraceParam to associate operator with specific parameters
param = brainscale.ETraceParam(
    {
        'weight': brainstate.random.rand(4, 5),  # D_in=4, D_out=5
        'bias': brainstate.random.rand(5)
    },
    my_op # Pass custom operator instance
)

# 3. Use in model (simulation)
# Create some mock input data
dummy_input = brainstate.random.rand(1, 4) # Batch=1, D_in=4

# brainscale's runner will automatically call op.xw_to_y(dummy_input, param.value)
# We can manually call to verify
output = my_op.xw_to_y(dummy_input, param.value)

print("Custom operator instantiation:")
print(param)
print("\nForward computation output:")
print(output)
print("Output shape:", output.shape)

Custom operator instantiation:
ETraceParam(
  value={
    'bias': ShapedArray(float32[5]),
    'weight': ShapedArray(float32[4,5])
  },
  gradient=<ETraceGrad.adaptive: 'adaptive'>,
  op=CustomizedMatMul(
    is_diagonal=False
  ),
  is_etrace=True
)

Forward computation output:
[[1.8759773 1.6249228 1.8983953 1.4766655 1.8770288]]
Output shape: (1, 5)


Through this approach, you can build highly customized components for novel neuron models, complex synaptic plasticity rules, or any scenario requiring differentiable parameters and custom computational logic.