# Program Compilation

``BrainState`` enables cross-hardware compilation and deployment through ``brainstate.compile``, allowing Python code to be transformed into an intermediate representation (IR) and subsequently compiled and optimized for various hardware platforms (CPU, GPU and TPU).

The compilation provided in ``BrainState`` is primarily integrated into ``brainstate.compile``. These compilation APIs encompass a range of syntactic functionalities, including:
	1.	Just-in-Time Compilation (JIT): Supports JIT compilation to enhance computational efficiency and performance.
	2.	Conditional Statements: Enables if-else logic, allowing users to execute different computational processes based on varying conditions.
	3.	Loop Statements: Supports for and while loops, enabling users to execute repetitive computational operations conveniently.

The JAX framework itself offers highly practical transformation methods to optimize program performance ([Transformations](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations)). Typically, JAX transformations operate on ``jax.Arrays`` within ``pytrees`` and adhere to value semantics. A significant portion of the computations in BrainPy relies on JAX. JAX provides Python programs with excellent transformations, including differentiation, vectorization, parallelization, and JIT compilation. However, JAX’s transformations are function-based rather than state-based. This means JAX cannot optimize ``State`` objects, as they are dynamically generated at runtime.

To address this limitation, BrainPy introduces ``brainstate.compile``, which compiles ``State`` objects into computational graphs, enabling their optimization. These compilation extensions enhance JAX transformations, allowing ``State`` objects to be passed into and out of transformations.

One of the key features of ``BrainState``’s compilation approach is its state-awareness: during program execution, any encountered ``State`` instance is compiled into the computational graph and executed on various hardware platforms. This compilation strategy empowers users to define arbitrarily complex programs, with the compiler optimizing based on actual runtime branches. This significantly improves computational efficiency. Moreover, the state-aware compilation model allows users to express program logic more flexibly without being constrained by concepts like `PyGraph` or ``PyTree``, thereby unlocking greater programming versatility.

Below, we will delve into the specific usage of ``brainstate.compile``.


In [1]:
import jax.numpy as jnp
import jax
import brainstate

## Just-in-Time Compilation (JIT)

``brainstate.compile`` supports JIT compilation, enhancing computational efficiency and performance. Users can leverage ``brainstate.compile.jit`` to compile functions just-in-time for improved execution efficiency. Here is an example:

In [2]:
a = brainstate.State(brainstate.random.randn(10))

@brainstate.compile.jit
def fun1(inp):
    a.value += inp

    b = brainstate.State(brainstate.random.randn(1))

    def inner_fun(x):
        b.value += x

    brainstate.compile.for_loop(inner_fun, brainstate.random.randn(100))

    return a.value + b.value

x = brainstate.random.randn(10)
print(fun1(x))
key = fun1.stateful_fun.get_arg_cache_key(x)
fun1.stateful_fun.get_states(key)

[ 6.094308   6.739417  11.3118305 12.238316   9.976001   9.109862
  9.419138  10.82519   11.687967   7.8044806]


(State(
   value=Array([-3.4689972 , -2.8238876 ,  1.7485251 ,  2.6750102 ,  0.41269594,
          -0.45344275, -0.14416695,  1.2618849 ,  2.1246629 , -1.7588243 ],      dtype=float32)
 ),
 RandomState([1043340322 1926620362]))

In the code above, we define a function ``fun1``, which takes an input ``inp`` and performs operations on a ``State`` variable within the function. Inside the function, we define an inner function ``inner_fun`` and manipulate another ``State`` variable within it. At the end of ``fun1``, we return the sum of the two ``State`` variables.

Finally, we call the ``fun1`` function with a random input ``x``. After the function call, we use ``fun1.stateful_fun.get_arg_cache_key`` to retrieve the cache key for the function call and ``fun1.stateful_fun.get_states`` to get all the ``State`` variables used in the function call. Notably, the local state variable ``b`` inside the function is not cached; it is optimized out during the computational graph construction. As a result, only the external variable ``a`` and the random state variable ``RandomState`` are cached, which helps reduce memory usage and improve computational efficiency.

It is important to note that ``brainstate.compile.jit`` generates a cache. Therefore, when the function is called multiple times, the cache is only generated during the first call, and subsequent calls directly use the cached version. If recompilation is needed, the cache can be cleared using ``fun1.stateful_fun.clear_cache()``.

## Conditional Statements

``brainstate.compile`` supports conditional statements, enabling users to execute different computational flows based on varying conditions. It provides three functions for compiling conditional statements: ``brainstate.compile.cond``, ``brainstate.compile.switch``, and ``brainstate.compile.if_else``. Below, we explore the usage of these functions one by one.

### Using ``brainstate.compile.cond``

The usage of ``brainstate.compile.cond`` is as follows:

In [3]:
st1 = brainstate.State(brainstate.random.rand(10))
st2 = brainstate.State(brainstate.random.rand(2))
st3 = brainstate.State(brainstate.random.rand(5))
st4 = brainstate.State(brainstate.random.rand(2, 10))

def true_fun(x):
    st1.value = st2.value @ st4.value + x

def false_fun(x):
    st3.value = (st3.value + 1.) * x

brainstate.compile.cond(True, true_fun, false_fun, 2.)

If the parameter types are correct, the semantics of ``cond()`` are equivalent to Python’s implementation, where ``pred`` must be of scalar type.

### Using ``brainstate.compile.switch``

Next, here is the usage of ``brainstate.compile.switch``:

In [4]:
branches = [jax.lax.add, jax.lax.mul]

def cfun(x):
    return brainstate.compile.switch(x, branches, x, x)

print(cfun(2), cfun(-1))

4 -2


In the code above, we define a function ``cfun`` that takes an input ``x`` and selects different branch functions based on the value of ``x``. At the end of the function, we return the results of the computations from the two branch functions. After calling the function with inputs ``2`` and ``-1``, we can see that the function chooses different branch functions for computation based on the input value.

### Using ``brainstate.compile.if_else``

The usage of ``brainstate.compile.if_else`` is as follows:

In [5]:
def f(a):
    return brainstate.compile.ifelse(conditions=[a < 0,
                                                 a >= 0 and a < 2,
                                                 a >= 2 and a < 5,
                                                 a >= 5 and a < 10,
                                                 a >= 10],
                                     branches=[lambda: 1,
                                        lambda: 2,
                                        lambda: 3,
                                        lambda: 4,
                                        lambda: 5])

assert f(3) == 3
assert f(1) == 2
assert f(-1) == 1

In the code above, we define a function ``f`` that takes an input ``a`` and selects different branch functions based on the value of ``a``. At the end of the function, we return the computation result of the selected branch function. After calling the function with inputs ``3``, ``1``, and ``-1``, we observe that the function selects different branch functions for computation based on the input value.

It is important to note that conditional statements do not have caching functionality. Each time the function is re-executed, it needs to be recompiled. Therefore, if the function will be called multiple times, it is recommended to use ``brainstate.compile.jit`` for just-in-time compilation.

## Loop Statements

``brainstate.compile`` supports for loops, allowing users to repeatedly execute the same computational operations. It provides two compilation functions for loop statements: ``brainstate.compile.for_loop`` and ``brainstate.compile.while_loop``. Below, we will explore the usage of these two loop statements.

### ``Using brainstate.compile.for_loop``

The usage of ``brainstate.compile.for_loop`` is as follows:

In [6]:
a = brainstate.ShortTermState(0.)
b = brainstate.ShortTermState(0.)

def f(i):
    a.value += (1 + b.value)
    return a.value

n_iter = 10
ops = jnp.arange(n_iter)
r = brainstate.compile.for_loop(f, ops)

print(a)
print(b)

ShortTermState(
  value=Array(10., dtype=float32, weak_type=True)
)
ShortTermState(
  value=0.0
)


In the code above, we define a function ``f`` that takes an input ``i`` and performs operations on a ``State`` variable within the function. At the end of the function, we return the value of the ``State`` variable ``a``. After calling the function with an array ``jnp.arange(n_iter)``, the function is executed in a loop ``n_iter`` times. Following the function call, we print the values of the ``State`` variables ``a`` and ``b``. It can be observed that the value of ``a`` is computed as ``a.value + n_iter * (1 + b.value)``.

### Using ``brainstate.compile.while_loop``

The usage of ``brainstate.compile.while_loop`` is as follows:

In [7]:
a = brainstate.State(1.)
b = brainstate.State(20.)

def cond(x):
    return a.value < b.value

def body(x):
    a.value += x
    return x

r = brainstate.compile.while_loop(cond, body, 1.)

print(a.value, b.value, r)

20.0 20.0 1.0


In addition to the two loop statements mentioned earlier, another commonly used loop statement is brainstate.compile.scan. Its Python version can be understood as:
```python
def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)
```

Differences from the Python Version
1.	Support for ``Pytrees``: Unlike the Python version, both ``xs`` and ``ys`` in ``scan()`` can be arbitrary ``pytree`` values. This allows for simultaneously scanning multiple arrays and producing multiple output arrays. ``None`` is a special case, representing an empty ``pytree``.
2.  Lowered to WhileOp: Unlike the Python version, ``scan()`` is a JAX primitive that can be lowered to a single ``WhileOp``. This makes it particularly useful for reducing JIT compilation time. Native Python loop structures within a ``jit()`` function are typically unrolled, leading to a significant increase in XLA computations.

It is also important to note that loop statements, including ``scan()``, do not have caching functionality. Each time the function is re-executed, it needs to be recompiled. Therefore, if the function will be called multiple times, it is recommended to use ``brainstate.compile.jit`` for just-in-time compilation.

## Nested Usage of Program Compilation

A unique feature of ``BrainState`` compilation is its ability to nest calls to both JAX’s function-based compilation functions and ``BrainState``’s built-in state-aware compilation functions. Any ``State`` variables generated or utilized in intermediate steps will only exist as local variables and will be optimized out throughout the program. This characteristic results in lower memory usage and faster execution. Here’s an example:

In [8]:
b = brainstate.State(0.)


def add(i):
    c = brainstate.State(0.)

    def cond(j):
        return j <= i

    def body(j):
        c.value += 1.
        return j + 1

    brainstate.compile.while_loop(cond, body, 0.)

    b.value += c.value


brainstate.compile.for_loop(add, jnp.arange(10))

print(b.value)

55.0


In the program above, we define a function ``add``, which internally calls ``while_loop`` within a loop. Inside the ``while_loop``, we define a local variable ``c`` and perform cumulative additions to it during the loop. Throughout the program, ``c`` remains a local variable and is not cached during compilation. This nested calling approach makes the program more flexible while ensuring optimal performance.

It is worth noting that JAX transformations can also be nested with ``BrainState``‘s compilation functions, allowing for better utilization of JAX’s optimization capabilities. Since ``BrainState`` compilation functions are state-aware, they introduce a small compilation overhead. This overhead, though minimal, might make them slightly slower than JAX transformations. Therefore, in practice, choosing the appropriate compilation function depends on the specific requirements of the application.