# Key Concepts

Welcome to ``brainstate``!

This section will provide a brief introduction to key concepts of the ``brainstate`` framework.

``BrainState`` is a high-performance computing framework specifically designed for modeling brain dynamics, built on top of [JAX](https://github.com/jax-ml/jax). It provides a comprehensive toolchain for neuroscientists, computational modeler, and brain-inspired computing researchers to build, optimize, and deploy various neural network models. It integrates advanced features such as modern hardware acceleration, automatic differentiation, and event-driven computing, specifically tailored for neural networks, particularly Spiking Neural Networks (SNNs). The following tutorial will detail its core functionalities and use cases, helping you quickly get started and understand how to utilize ``BrainState`` for constructing and optimizing your brain dynamics models.

In [1]:
import jax.numpy as jnp

import brainstate as bst

## Overview of Core Functions

The core functionalities of ``BrainState`` include the following components:

- **Program Compilation**: Supports [program compilation](../apis/compile.rst) using the [`State`-based syntax](../apis/brainstate.rst), enabling the deployment of computational models on various hardware platforms such as CPU, GPU, and TPU.
- **Program Functionality Augmentation**: Provides [program functionality enhancement](../apis/augment.rst) features using the [``PyGraph``-based syntax](../apis/graph.rst), simplifying the process of constructing complex computational models through mechanisms such as automatic differentiation, batching, and parallelization.
- **Event-Driven Computation**: Supports operator optimization based on [event-driven computation](../apis/event.rst), significantly improving the efficiency and scalability of Spiking Neural Networks.
- **Additional Features**: Includes convenient auxiliary tools such as random number generation, surrogate gradient functions, and model parameter management, facilitating diverse model construction for users.

In the following sections, we will examine in detail the implementation methodologies and optimization strategies for each of these functionalities.

## 1. ``State`` Syntax

While JAX typically favors functional programming paradigms, this approach may not be sufficiently intuitive for complex computational tasks such as brain dynamics modeling. ``BrainState`` introduces the ``State`` syntax, a highly abstracted interface that enables users to define and manage computational states much more easily and intuitively. The core characteristics of ``State`` syntax include:

- All variables that need to be modified should be encapsulated within the ``State`` object, allowing users to track and debug the model's state.
- Any variables not encapsulated by the ``State`` are immutable and cannot be modified after program compilation. The compilation functions provided in ``BrainState`` can be referenced in the [``brainstate.compile`` module](../apis/compile.rst).

This means that in `BrainState`, all variables requiring mutation must be encapsulated within ``State`` objects to ensure program correctness and maintainability.

The ``State`` class can have various subclasses. For example, in ``BrainState``, ``ParamState`` is a subclass of ``State`` used to encapsulate model parameters, while ``RandomState`` is another subclass designed to encapsulate the state of random number generators. Users can easily extend their own ``State`` subclasses to meet diverse needs. For instance:

In [2]:
class Counter(bst.State):
    pass

In the example above, by inheriting from the ``State`` class, we define a ``Counter`` class that can be used to encapsulate the state of a counter. This approach allows users to more flexibly define and manage the model's state, enhancing the readability and maintainability of the code.

``State`` can wrap any Python data types, such as integers, floats, arrays, ``jax.Array``, and any of these Python data types encapsulated within dictionaries, lists or tuples. Users can access and modify this data through the ``State.value`` attribute. For example:

In [3]:
example = bst.State(jnp.ones(3))

example

State(
  value=Array([1., 1., 1.], dtype=float32)
)

In [4]:
example.value = bst.random.random(3)

example

State(
  value=Array([0.7541833, 0.5430715, 0.0052768], dtype=float32)
)

``State`` supports any [PyTree](https://jax.readthedocs.io/en/latest/working-with-pytrees.html), meaning that users can encapsulate any pytreee data structure within a ``State`` object, facilitating convenient state management and computation.

In [5]:
example2 = bst.State({'a': jnp.ones(3), 'b': jnp.zeros(4)})

example2

State(
  value={'a': Array([1., 1., 1.], dtype=float32), 'b': Array([0., 0., 0., 0.], dtype=float32)}
)

## 2. ``PyGraph`` Syntax

In JAX, a pytree (Python tree) is a versatile data structure used to flexibly represent nested, tree-like Python containers. It can encompass various containers such as lists, tuples, and dictionaries, while also allowing for the nesting of different types of data structures, such as NumPy arrays, JAX arrays, or custom objects. This flexibility makes pytree highly useful for data processing and model construction; however, its expressive capability may be limited in complex scientific computing scenarios.

In many scientific computations, we often need to define intricate computation graphs that may include cyclic references, nested structures, and dynamically generated computational processes -- situations that are difficult to represent with the pytree structure. To address this challenge, ``brainstate`` provides the ``PyGraph`` data structure, offering users a more intuitive and flexible means to define and manipulate complex computational models that intertwine various modular objects in Python.

The design of ``PyGraph`` is derived from Flax's [nnx module](https://flax.readthedocs.io/) and extends and optimises it for ``State`` retrieval, management and manipulation in ``brainstate``. ``PyGraph`` is constructed from basic sub-nodes represented by ``brainstate.graph.Node``, which can form directed acyclic graphs (DAGs) and support cyclic references between nodes, thereby facilitating the construction of complex computational workflows more naturally.

Below is a simple code example.

In [6]:
class Linear(bst.graph.Node):
    def __init__(self, din: int, dout: int):
        self.din, self.dout = din, dout
        self.w = bst.ParamState(bst.random.rand(din, dout))
        self.b = bst.ParamState(jnp.zeros((dout,)))

    def __call__(self, x):
        return x @ self.w.value + self.b.value

In [7]:
model = Linear(2, 5)

model

Linear(
  din=2,
  dout=5,
  w=ParamState(
    value=Array([[0.94007385, 0.0341239 , 0.57106984, 0.8860879 , 0.1215055 ],
           [0.6579735 , 0.9333912 , 0.2877561 , 0.09438729, 0.44648933]],      dtype=float32)
  ),
  b=ParamState(
    value=Array([0., 0., 0., 0., 0.], dtype=float32)
  )
)

We can incorporate a self-reference within the model to create a cyclic graph. Even in such cases, PyGraph maintains its capability to correctly process this self-referential structure.

In [8]:
model.self = model

model

Linear(
  din=2,
  dout=5,
  w=ParamState(
    value=Array([[0.94007385, 0.0341239 , 0.57106984, 0.8860879 , 0.1215055 ],
           [0.6579735 , 0.9333912 , 0.2877561 , 0.09438729, 0.44648933]],      dtype=float32)
  ),
  b=ParamState(
    value=Array([0., 0., 0., 0., 0.], dtype=float32)
  ),
  self=Linear(...)
)

``brainstate.graph.Node`` can be freely composed within nested structures, accommodating any (nested) pytree types, including lists, dictionaries, tuples, and others. Below is an example illustrating a Multi-Layer Perceptron (MLP).

In [9]:
class MLP(bst.graph.Node):
    def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
        self.input = bst.nn.Linear(din, dmid)
        self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
        self.output = bst.nn.Linear(dmid, dout)

    def __call__(self, x):
        x = bst.functional.relu(self.input(x))
        for layer in self.layers:
            x = bst.functional.relu(layer(x))
        return self.output(x)


model = MLP(2, 1, 3)

model

MLP(
  input=Linear(
    in_size=(2,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[-1.2849224 ],
             [ 0.79070586]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ),
  layers=[Linear(
    in_size=(1,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[0.0800172]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ), Linear(
    in_size=(1,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[-0.6626468]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ), Linear(
    in_size=(1,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[0.20347145]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  )],
  output=Linear(
    in_size=(1,),
    out_size=(3,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[-1.8724355 ,  1.0703413 , -0.04334985]], dtype=floa

The `brainstate.graph` module offers a robust set of tools for creating and managing `PyGraph`, including features for node creation, connections, computations, and updates. These tools allow users to build computational graphs modularly, streamlining the management of workflows while improving model readability and maintainability. Users can leverage [simple APIs](../apis/graph.rst) to add nodes, define dependencies, and dynamically update node states. Additionally, `PyGraph` provides structured representations of computation graphs, helping users intuitively grasp the structure and operation of their workflows.

For instance, `brainstate.graph.states` can efficiently retrieve all `State` instances within the model:

In [10]:
states = bst.graph.states(model)

states

FlattedDict({
  ('input', 'weight'): ParamState(
    value={'weight': Array([[-1.2849224 ],
           [ 0.79070586]], dtype=float32), 'bias': Array([0.], dtype=float32)}
  ),
  ('layers', 0, 'weight'): ParamState(
    value={'weight': Array([[0.0800172]], dtype=float32), 'bias': Array([0.], dtype=float32)}
  ),
  ('layers', 1, 'weight'): ParamState(
    value={'weight': Array([[-0.6626468]], dtype=float32), 'bias': Array([0.], dtype=float32)}
  ),
  ('layers', 2, 'weight'): ParamState(
    value={'weight': Array([[0.20347145]], dtype=float32), 'bias': Array([0.], dtype=float32)}
  ),
  ('output', 'weight'): ParamState(
    value={'weight': Array([[-1.8724355 ,  1.0703413 , -0.04334985]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
  )
})

In [11]:
states.to_nest()

NestedDict({
  'input': {
    'weight': ParamState(
      value={'weight': Array([[-1.2849224 ],
             [ 0.79070586]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  },
  'layers': {
    0: {
      'weight': ParamState(
        value={'weight': Array([[0.0800172]], dtype=float32), 'bias': Array([0.], dtype=float32)}
      )
    },
    1: {
      'weight': ParamState(
        value={'weight': Array([[-0.6626468]], dtype=float32), 'bias': Array([0.], dtype=float32)}
      )
    },
    2: {
      'weight': ParamState(
        value={'weight': Array([[0.20347145]], dtype=float32), 'bias': Array([0.], dtype=float32)}
      )
    }
  },
  'output': {
    'weight': ParamState(
      value={'weight': Array([[-1.8724355 ,  1.0703413 , -0.04334985]], dtype=float32), 'bias': Array([0., 0., 0.], dtype=float32)}
    )
  }
})

For example, ``brainstate.graph.nodes`` can efficiently retrieve all ``Node`` instances encompassed within the model:

In [12]:
nodes = bst.graph.nodes(model)

nodes

FlattedDict({
  ('input',): Linear(
    in_size=(2,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[-1.2849224 ],
             [ 0.79070586]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ),
  ('layers', 0): Linear(
    in_size=(1,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[0.0800172]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ),
  ('layers', 1): Linear(
    in_size=(1,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[-0.6626468]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ),
  ('layers', 2): Linear(
    in_size=(1,),
    out_size=(1,),
    w_mask=None,
    weight=ParamState(
      value={'weight': Array([[0.20347145]], dtype=float32), 'bias': Array([0.], dtype=float32)}
    )
  ),
  ('output',): Linear(
    in_size=(1,),
    out_size=(3,),
    w_mask=None,
    weight=ParamState(
      value={'weight'

In summary, the `PyGraph` syntax offers strong support for complex models in scientific computing, allowing users to efficiently construct, manage, and optimize computation graphs. This capability enhances research and applications in brain dynamics modeling.

## 3. Program Compilation

In high-performance computing, hardware acceleration is essential for improving computational efficiency. The `BrainState` framework achieves hardware compilation and deployment through a `State`-based syntax. This provides a highly abstracted interface that allows users to write code once and generate an intermediate representation (IR) that can be compiled and optimized across various hardware platforms.

The compilation capabilities of `brainstate` are primarily housed within the [``brainstate.compile`` module](../apis/compile.rst). These compilation APIs offer a comprehensive range of functionalities, including:

- **Conditional Statements**: Supports `if-else` logic, enabling users to execute different computational workflows based on varying conditions.
- **Loop Statements**: Supports `for` loops for repeated execution of identical computational operations.
- **While Statements**: Supports `while` loops for condition-based repetitive execution of computational tasks.
- **Just-In-Time (JIT) Compilation**: Enhances computational efficiency and performance through JIT compilation.

A notable feature of `brainstate` compilation is its focus solely on `State`. During program execution, whenever a `State` instance is encountered, it is compiled into the computation graph and executed on various hardware. This approach allows users to define complex programs freely, while the compiler optimizes based on the actual execution paths of the program, significantly enhancing computational efficiency. Furthermore, this `State`-sensitive compilation mode enables users to express program logic with greater flexibility, free from constraints imposed by concepts like `PyGraph` or `PyTree`, thereby maximizing programming versatility.

Below is a simple example of the compilation process:

In [13]:
a = bst.State(1.)


def add(i):
    a.value += 1.


bst.compile.for_loop(add, jnp.arange(10))

print(a.value)

11.0


In this example, we define a simple for loop that increments the value of `a` by 1 in each iteration. By calling the `bst.compile.for_loop` function, we compile this loop into a computation graph and execute it on JAX.

Another distinctive feature of `brainstate` compilation is its capability to recursively invoke both JAX's functional compilation functions and `brainstate`'s built-in `State`-aware compilation functions. `State` variables created or used during intermediate steps remain local and are optimized out throughout the entire program. This characteristic leads to reduced memory consumption and improved execution speed.

Here is a simple example of the compilation process:

In [14]:
b = bst.State(0.)


def add(i):
    c = bst.State(0.)

    def cond(j):
        return j <= i

    def body(j):
        c.value += 1.
        return j + 1

    bst.compile.while_loop(cond, body, 0.)

    b.value += c.value


bst.compile.for_loop(add, jnp.arange(10))

print(b.value)

55.0


It is worth noting that ``brainstate`` compilation also supports debugging using JAX's debugging tools. For instance, users can print the values of intermediate states in the program by calling the ``jax.debug.print`` function, which aids in debugging and optimizing the program. The following example provides a debugging output for the program mentioned above. For more information on JAX debugging capabilities, please refer to the [JAX Debugging Documentation](https://jax.readthedocs.io/en/latest/debugging/index.html).

In [15]:
import jax

b = bst.State(0.)


def add(i):
    c = bst.State(0.)

    def cond(j):
        return j <= i

    def body(j):
        c.value += 1.
        return j + 1

    bst.compile.while_loop(cond, body, 0.)

    b.value += c.value
    jax.debug.print('b = {b}, c = {c}', b=b.value, c=c.value)


bst.compile.for_loop(add, jnp.arange(10))

b = 1.0, c = 1.0
b = 3.0, c = 2.0
b = 6.0, c = 3.0
b = 10.0, c = 4.0
b = 15.0, c = 5.0
b = 21.0, c = 6.0
b = 28.0, c = 7.0
b = 36.0, c = 8.0
b = 45.0, c = 9.0
b = 55.0, c = 10.0


``brainstate`` also supports compilation for different hardware platforms. Users can deploy their models on various hardware, including CPU, GPU, and TPU, simply by changing a parameter. Users can call the following at the beginning of their program:

```python
brainstate.environ.set(platform='cpu')  # CPU backend

brainstate.environ.set(platform='gpu')  # GPU backend

brainstate.environ.set(platform='tpu')  # TPU backend
```

Alternatively, users can use JAX's syntax:

```python
jax.config.update('jax_platform_name', 'cpu')  # CPU backend

jax.config.update('jax_platform_name', 'gpu')  # GPU backend

jax.config.update('jax_platform_name', 'tpu')  # TPU backend
```

This flexible compilation approach enables users to better leverage the advantages of different hardware, enhancing computational efficiency and performance.

## 4. Program Augmentation

`Brainstate` also provides a suite of functionality enhancement transformations. For example, while a program may be defined solely for forward inference, users can easily obtain additional gradient information through automatic differentiation transformations such as `grad`. These functional enhancement transformations enable users to construct and optimize complex computational models more conveniently.

However, enhancing program functionality requires a solid understanding of the program's structure and clarity regarding enhancement objectives. Users must be familiar with the program's architecture before compilation. To assist with this, the `PyGraph` syntax can be utilized, making it easier for users to define and manage computational models.

The various operations and management functions provided by `PyGraph` related to `State` and graph representations significantly simplify the construction of complex functional enhancement transformations. Key enhancement transformations available in `brainstate` include:

- **Automatic Differentiation**: Essential for model optimization, particularly in backpropagation and gradient descent algorithms.
- **Batch Processing**: Support for processing large-scale data in batches, which significantly enhances training speed and inference efficiency.
- **Multi-Device Parallelism**: Facilitates parallel computation across multiple devices, improving computational efficiency and overall model performance.

Below is a simple example of automatic differentiation:

In [16]:
# <input, output> pair
x = jnp.ones((1, 2))
y = jnp.ones((1, 3))

# model
model = bst.nn.Linear(2, 3)


# loss function
def loss_fn(x, y):
    return jnp.mean((y - model(x)) ** 2)


prev_loss = loss_fn(x, y)

# gradients
weights = model.states()
grads = bst.augment.grad(loss_fn, weights)(x, y)

# SGD update
for key, grad in grads.items():
    updates = jax.tree.map(lambda p, g: p - 0.1 * g, weights[key].value, grad)
    weights[key].value = updates

# loss evaluation
assert loss_fn(x, y) < prev_loss

In the example above, we define a simple linear model and then compute the model's loss function. By calling the ``bst.augment.grad`` function, we can easily obtain the gradient information of the model and update the model parameters using gradient descent algorithms. However, this automatic differentiation enhancement transformation requires us to know in advance which parameters need gradients. Therefore, we use the ``brainstate.graph.states`` function to retrieve all ``State`` instances within the model.

Overall, functional enhancement transformations offer users a more convenient and efficient method for constructing and optimizing computational models. By fully leveraging these features, users can achieve faster model training and inference, ultimately enhancing model performance and efficiency. For more information on functional enhancement transformations, please refer to the tutorial of [Program Augmentation](../tutorials/program_augmentation-en.ipynb).

## 5. Event-Driven Computation

Event-driven computation represents a novel paradigm distinct from traditional computing models. Modern computers and compilers are primarily optimized for dense matrices, particularly for operations like matrix multiplication, which underpin contemporary artificial neural networks, as illustrated in Figure A below. However, this computation model does not align with the spiking dynamics of the real brain, where neurons are event-driven and activate only upon receiving spike events. Furthermore, the connections between neurons are sparse, as depicted in Figure B below.

![](../_static/dense-mv-vs-event-spmv.png)

`BrainState` offers operator optimizations that cater specifically to the characteristics of event-driven computing. This paradigm effectively minimizes unnecessary computations, significantly reducing resource consumption and enhancing speed, especially when dealing with sparse data. In Spiking Neural Network (SNN) models, neurons activate only upon receiving spike events, thereby avoiding the computational waste inherent in traditional neural networks, which continue calculations even in the absence of input. This characteristic renders SNNs particularly well-suited for applications that require low power consumption and high efficiency, such as edge computing and embedded systems.

**Example**: Defining an event-driven interaction between neurons:

In [17]:
# pre-synaptic spikes, 10000 neurons, 1% sparsity
pre_spikes = bst.random.rand(10000) < 0.01

# dense weight matrix, 10000x1000, 1% sparsity
dense_w = (bst.random.rand(10000, 1000) < 0.01).astype(float)

# event-driven weight matrix, 10000x1000, 1% sparsity
fp = jax.jit(bst.event.FixedProb(10000, 1000, 0.01, 1.))

In [18]:
jnp.dot(pre_spikes, dense_w)
%timeit -n 100 -r 10 jnp.dot(pre_spikes, dense_w)

1.19 ms ± 104 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [19]:
fp(pre_spikes)
%timeit -n 100 -r 10 fp(pre_spikes)

597 µs ± 14.7 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


## 6. Other Auxiliary Functions

In addition to the core functionalities mentioned above, `BrainState` offers a range of auxiliary features that simplify model construction and optimization for users. These features include, but are not limited to:

- **Random Number Generation**: Quickly generate random numbers with various distributions, useful for simulating randomness or managing random variables.
- **Parameter Management**: Provide a straightforward interface for initializing, storing, and updating model parameters, accommodating complex model architectures and multi-layer networks.
- **Debugging Tools**: Assist users in monitoring the status and computational results of various layers during the model development process, making it easier to identify potential issues.

## Conclusion

`BrainState` is a powerful framework for brain dynamics modeling, providing capabilities for cross-hardware compilation, computation model enhancement, event-driven computing, and a comprehensive suite of auxiliary tools. For users involved in neuroscience, cognitive modeling, and SNN development, `BrainState` offers extensive modular functionalities that facilitate the rapid construction, optimization, and deployment of efficient brain dynamics models.

By thoroughly understanding and leveraging the features outlined above, you can easily create and optimize computational models that are well-suited for a variety of research tasks and hardware platforms.