# Event-driven Computations

Event-driven is an important characteristic that distinguishes brain-simulation computing models from other traditional computing models. Neurons in the brain are only activated and compute after receiving information. The set of neurons that need to be computed at each seconds can be regarded as a sparse vector. This means that the traditional matrix multiplication method is no longer suitable for computing the firing process of neurons.

![](../_static/dense-mv-vs-event-spmv.png)

`BrainState` provides specialized operator optimizations for the characteristic of event-driven processing, enabling the model to reduce the use of computing resources and increase speed when dealing with sparse data. `BrainState` offers several connection methods for synapses between neurons, and here are some examples.



In [3]:
import jax
import jax.numpy as jnp

import brainstate

### 1.FixedProb

Sometimes we do not know the definite connections between neurons, but we know that for each presynaptic neuron, there is a fixed probability that it connects to a population of postsynaptic neurons.

`BrainState` provides the class `brainstate.event.FixedProb` to define this connection case.

The class `brainstate.event.FixedProb` accepts the following parameters:

- `n_pre`: The number of presynaptic neurons.

- `n_post`: The number of postsynaptic neurons.

- `prob`: The fixed probability of each presynaptic neuron connecting to a postsynaptic neuron.

- `weight`: The maximum synaptic conductance, which can accept a float, an array, or a function, supporting scalars and physical quantities.

- `allow_multi_conn`: Whether a presynaptic neuron can have multiple connections with the same postsynaptic neuron. The default is `True`.

- `seed`: The random generation seed. The default is `None`.

- `name`: The module name. The default is `None`.

- `grad_mode`: The automatic differentiation method. The default is `vjp`.


In [4]:
# pre-synaptic spikes, 10000 neurons, 1% sparsity
pre_spikes = brainstate.random.rand(10000) < 0.01

In [5]:
# dense weight matrix, 10000x1000, 1% sparsity
dense_w = (brainstate.random.rand(10000, 1000) < 0.01).astype(float)

# event-driven weight matrix, 10000x1000, 1% sparsity
fp = jax.jit(brainstate.event.FixedProb(10000, 1000, 0.01, 1.))

In [6]:
# compile and run at first, then measure the average running time
jnp.dot(pre_spikes, dense_w)
%timeit -n 100 -r 10 jnp.dot(pre_spikes, dense_w)

fp(pre_spikes)
%timeit -n 100 -r 10 fp(pre_spikes)

1.24 ms ± 178 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)
421 μs ± 6.09 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)


### 2.Linear


The synaptic connections between neurons form a dense matrix.

`BrainState` provides the class `brainstate.event.Linear` to define this connection case.

The class `brainstate.event.Linear` accepts the following parameters:

- `n_pre`: The number of presynaptic neurons.

- `n_post`: The number of postsynaptic neurons.

- `weight`: The maximum synaptic conductance, which can accept a float, an array, or a function, supporting scalars and physical quantities.

- `name`: The module name.The default is `None`.

- `grad_mode`: The automatic differentiation method. The default is `vjp`.


In [7]:
# dense weight matrix, 10000x1000 
dense_w = brainstate.random.rand(10000, 1000).astype(float)

# event-driven weight matrix, 10000x1000
fp = jax.jit(brainstate.event.Linear(10000, 1000, dense_w))

In [8]:
# 先编译运行一次，随后测平均运行时间
jnp.dot(pre_spikes, dense_w)
%timeit -n 100 -r 10 jnp.dot(pre_spikes, dense_w)

fp(pre_spikes)
%timeit -n 100 -r 10 fp(pre_spikes)

1.26 ms ± 161 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)
400 μs ± 37.5 μs per loop (mean ± std. dev. of 10 runs, 100 loops each)


### 3.CSRLinear

In most cases, the synaptic connections between neurons form a sparse matrix. The CSR format is a common way to store sparse matrices, and its representation is shown in the figure below:

![](../_static/csr_matrix.png)

The CSR format consists of three NumPy arrays: `indices`, `indptr`, `values`:

- `indices`: Records the column indices of each non-zero element in the matrix, sorted first by row and then by column.

- `indptr`: Has a length of `row+1`, where `indptr[i]` indicates the starting position of the non-zero elements of the `i`th row in the `indices`.

- `values`: Records the values corresponding to the non-zero elements in the `indices`.

`BrainState` provides the class `brainstate.event.CSRLinear` to define this type of connection process.

The class `brainstate.event.CSRLinear` accepts the following parameters:

`n_pre`: The number of presynaptic neurons.

`n_post`: The number of postsynaptic neurons.

`indptr`: The `indptr` in the CSR format of the sparse matrix.

`indices`: The `indices` in the CSR format of the sparse matrix.

`weight`: The maximum synaptic conductance, which can accept a float, an array, or a function, supporting scalars and physical quantities.

`name`: The module name, with a default value of `None`.

`grad_mode`: The automatic differentiation method, with a default value of `vjp`.


In [9]:
# dense weight matrix, 10000x10000, 0.01% sparsity
dense_w = (brainstate.random.rand(10000, 10000) < 0.0001).astype(float)

# event-driven weight matrix, 10000x1000, 0.01% sparsity, CSR format
from scipy.sparse import csr_matrix

csr = csr_matrix(dense_w)
fp = jax.jit(brainstate.event.CSRLinear(10000, 10000, csr.indptr, csr.indices, csr.data))

In [10]:
# 先编译运行一次，随后测平均运行时间
jnp.dot(pre_spikes, dense_w)
%timeit -n 100 -r 10 jnp.dot(pre_spikes, dense_w)

fp(pre_spikes)
%timeit -n 100 -r 10 fp(pre_spikes)

13.5 ms ± 1.18 ms per loop (mean ± std. dev. of 10 runs, 100 loops each)
3.12 ms ± 1.07 ms per loop (mean ± std. dev. of 10 runs, 100 loops each)


### 4.Weights of Synaptic


When modeling synaptic connections, the `weight` accepts a float, an array, or a function, supporting both scalars and physical quantities.

If the `weight` accepts a float, it indicates that all synaptic weights are the same, and storing floats can greatly save memory space.

If the `weight` accepts an array, it must be ensured that the array can be aligned with the connection matrix, meaning the array is the same size as the connection matrix, or the connection matrix can be divided into multiple arrays.

If the `weight` accepts a function, a weight matrix of the size of the connection matrix will be generated during the init phase.

If importing the `BrainUnit`, `BrainState` can support the input and calculation of physical quantities.