# Just-In-Time Compilation

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code "just-in-time" for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as [JAX](https://github.com/google/jax) and [Numba](https://github.com/numba/numba) are provided in Python. However, they are designed to work only on pure Python functions: all the input data is passed through the function parameters, all the results are output through the function results. While, the essence of Python is object-oriented programming (OOP) based on ``class``. OOP makes the programming more flexible, modular, and re-usable. In BrainPy, we relieve these constraints and enables users to JIT a class object without the preformance loss. 

In this section, we will talk about the JIT compilation in BrainPy.  

In [1]:
import sys
sys.path.append('../../')

import brainpy as bp



## JIT in Numba and JAX

[Numba](https://github.com/numba/numba) is specialized to optimize your native Python codes such like functions, NumPy arrays, loops and condition controls. It is a cross-platform library which can run on Windows, Linux, macOS, etc. The most wonderful thing is that numba can JIT compile your native Python loops (``for`` or ``while`` syntaxs) and condition controls (``if ... else ...``). This means that it supports your intutive Python programming. 

However, Numba is a lightweight JIT compiler, and is just suitable for small network models. For large networks, the parallel performance is poor. Futhermore, numba doesn't support `one code runs on multiple devices`. Same code cannot run on GPU targets.

[JAX](https://github.com/google/jax) is a rising-star JIT compiler in Python scientific computing. It uses [XLA](https://www.tensorflow.org/xla) to JIT compile and run your NumPy programs. Same code can be deployed onto CPUs, GPUs and TPUs. Moreover, JAX supports automatic differentiation, which means you can train models through back-propagation. JAX prefers large network models, and has excellent parallel performance. 

However, JAX has intrinsic overhead, and is not suitable to run small networks. Moreover, JAX only supports Linux and macOS platforms. Windows users must install JAX on [WSL](https://docs.microsoft.com/en-us/windows/wsl/about) or compile JAX from source. Further, the coding in JAX is not very intutive. For example, 

- Doesn't support in-place mutating updates of arrays, like ``x[i] += y``, instead you should use `x = jax.ops.index_update(x, i, y)`
- Doesn't support JIT compilation of your native loops and conditions, like
```python
arr = np.zeros(5)
for i in range(arr.shape[0]):
    arr[i] += 2.
    if i % 2 == 0:
        arr[i] += 1.
```
instead you should use 
```python
arr = np.zeros(5)
def loop_body(i, acc_arr):
    arr1 = ops.index_update(acc_arr, i, acc_arr[i] + 2.)
    return jax.lax.cond(i % 2 == 0, 
                        arr1,
                        lambda arr1: ops.index_update(arr1, i, arr1[i] + 1),
                        arr1,
                        lambda arr1: arr1)
arr = jax.lax.fori_loop(0, arr.shape[0], loop_body, arr)
```

What's more, both frameworks have poor support on class objects. 

## BrainPy `math` module

In order to obtain an *intutive*, *flexible* and *high-performance* framework for brain modeling, in [BrainPy](https://github.com/PKU-NIP-Lab/BrainPy), we want to combine the advantages of both compilers together, and try to overcome the gotchas of each framework as much as possible (although we have not finished it). Specifically, we provide [BrainPy math module](../apis/math.rst) for:

- flexible switch between NumPy and JAX backends
- unified numpy-like array operations 
- unified ``ndarray`` data structure which supports in-place update
- unified ``random`` APIs
- powerful ``jit()`` compilation which supports functions and class objects both

Users can switch to different backends by using ``brainpy.math.use_backend``:

In [2]:
# switch to NumPy backend
bp.math.use_backend('numpy')

bp.math.get_backend_name()

'numpy'

In [3]:
# switch to JAX backend
bp.math.use_backend('jax')

bp.math.get_backend_name()

'jax'

After the backend switch, the APIs in ``brainpy.math`` is much similar to APIs in original ``numpy``. The detailed comparison please see the [Comparison Table](../apis/math/comparison.rst). 

For example, the **array creation** APIs,

In [4]:
bp.math.zeros((10, 3))

JaxArray(DeviceArray([[0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.],
                      [0., 0., 0.]], dtype=float32))

In [5]:
bp.math.arange(10)

JaxArray(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [6]:
x = bp.math.array([[1,2], [3,4]])

x

JaxArray(DeviceArray([[1, 2],
                      [3, 4]], dtype=int32))

The **array manipulation** APIs:

In [7]:
bp.math.max(x)

DeviceArray(4, dtype=int32)

In [8]:
bp.math.repeat(x, 2)

JaxArray(DeviceArray([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32))

In [9]:
bp.math.repeat(x, 2, axis=1)

JaxArray(DeviceArray([[1, 1, 2, 2],
                      [3, 3, 4, 4]], dtype=int32))

The **random numbers** generation functions:

In [10]:
bp.math.random.random((3, 5))

JaxArray(DeviceArray([[0.15193117, 0.7006848 , 0.75320196, 0.29045963, 0.7425157 ],
                      [0.18510342, 0.9365095 , 0.02459204, 0.2899201 , 0.1062901 ],
                      [0.31897604, 0.14713216, 0.0075345 , 0.60187805, 0.293056  ]],            dtype=float32))

In [11]:
y = bp.math.random.normal(loc=0.0, scale=2.0, size=(2, 5))

y

JaxArray(DeviceArray([[-2.5946703 , -0.44657612,  1.4826825 , -3.1162384 ,
                        0.60915095],
                      [-0.6821795 , -0.7344547 ,  0.24855301, -1.627654  ,
                        1.7101754 ]], dtype=float32))

The **linear algebra** functions:

In [12]:
bp.math.dot(x, y)

JaxArray(DeviceArray([[ -3.9590292,  -1.9154855,   1.9797885,  -6.3715463,
                         4.029502 ],
                      [-10.512729 ,  -4.277547 ,   5.44226  , -15.859331 ,
                         8.668155 ]], dtype=float32))

In [13]:
bp.math.linalg.eig(x)

(JaxArray(DeviceArray([-0.37228107+0.j,  5.3722816 +0.j], dtype=complex64)),
 JaxArray(DeviceArray([[-0.8245648 +0.j, -0.41597357+0.j],
                       [ 0.56576747+0.j, -0.9093767 +0.j]], dtype=complex64)))

The **Discrete Fourier Transform** functions:

In [14]:
bp.math.fft.fft(bp.math.exp(2j * bp.math.pi * bp.math.arange(8) / 8))

JaxArray(DeviceArray([ 3.2584137e-07+3.1786513e-08j,  8.0000000e+00+4.8023384e-07j,
                      -3.2584137e-07+3.1786513e-08j, -1.6858739e-07+3.1786506e-08j,
                      -3.8941437e-07-2.0663207e-07j,  2.3841858e-07-1.9411573e-07j,
                       3.8941437e-07-2.0663207e-07j,  1.6858739e-07+3.1786506e-08j],            dtype=complex64))

In [15]:
bp.math.fft.ifft(bp.math.array([0, 4, 0, 0]))

JaxArray(DeviceArray([ 1.+0.j,  0.+1.j, -1.+0.j,  0.-1.j], dtype=complex64))

The full list of API implementation please see the [Comparison Table](../apis/math/comparison.rst).

## JIT compilation in BrainPy

Same with Numba and JAX, BrainPy supports JIT compilation for **functions**. 

For example, in JAX backend, we implementat a ``selu`` function: 

In [16]:
bp.math.use_backend('jax')

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha)

What need you to do is to simply pass the function into the [bp.math.jit()](../apis/math/generated/brainpy.math.jax.jit.rst):

In [17]:
selu_jit = bp.math.jit(selu)

Then, let's compare them:

In [19]:
x = bp.math.random.random((1000000,))

In [20]:
%timeit selu(x)

2.69 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%timeit selu_jit(x)

335 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Similarly, in NumPy backend, we can also compare the non-jitted ``selu`` function and the jitted one which wrapped by [bp.math.jit()](../apis/math/generated/brainpy.math.numpy.jit.rst):

In [36]:
bp.math.use_backend('numpy')

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha)

selu_jit = bp.math.jit(selu)

x = bp.math.random.random((1000000,))

In [37]:
%timeit selu(x)

9.26 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
selu_jit(x)

%timeit selu_jit(x)

8.64 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


However, in BrainPy, JIT compilation can be carried on the objects. 