# Overview

@[Chaoming Wang](https://github.com/chaoming0625)
@[Xiaoyu Chen](mailto:c-xy17@tsinghua.org.cn)

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code "just-in-time" for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as [JAX](https://github.com/google/jax) and [Numba](https://github.com/numba/numba) are provided in Python. While they are designed to work only on [pure Python functions](https://en.wikipedia.org/wiki/Pure_function), most computational neuroscience models have too many parameters and variables to manage using functions only. On the contrary, object-oriented programming (OOP) based on ``class`` in Python makes coding more readable, controlable, flexible, and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling. 

In order to provide **a platform can satisfy the need for brain dynamics programming**, we provide the [brainpy.math](../apis/math.rst) module. 

In [2]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

In [3]:
import numpy as np

## Why use ``brainpy.math``?

Specifically, ``brainpy.math`` makes the following contributions:

### 1. **Numpy-like ndarray**.
Python users are familiar with [NumPy](https://numpy.org/), especially its [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). JAX has similar ``ndarray`` structures and operations. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like ``x[i] += y``. To overcome these drawbacks, ``brainpy.math`` provides ``JaxArray`` that can be used in the same way as numpy ndarray. 

In [7]:
# ndarray in "numpy"

a = np.arange(5)
a

array([0, 1, 2, 3, 4])

In [8]:
a[0] += 5
a

array([5, 1, 2, 3, 4])

In [9]:
# ndarray in "brainpy.math"

b = bm.arange(5)
b

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

In [10]:
b[0] += 5
b

JaxArray(DeviceArray([5, 1, 2, 3, 4], dtype=int32))

For more details, please see the [Tensors](./tensors.ipynb) tutorial.

### 2. **Numpy-like random sampling**. 
JAX has its own style to make random numbers, which is very different from the original NumPy. To provide a consistent experience, ``brainpy.math`` provides ``brainpy.math.random`` for random sampling just like the ``numpy.random`` module. For example:

In [11]:
# random sampling in "numpy"

np.random.seed(12345)

In [12]:
np.random.random(5)

array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503])

In [13]:
np.random.normal(0., 2., 5)

array([0.90110884, 0.18534658, 2.49626568, 1.53620142, 2.4976073 ])

In [14]:
# random sampling in "brainpy.math.random"

bm.random.seed(12345)

In [15]:
bm.random.random(5)

JaxArray(DeviceArray([0.47887695, 0.5548092 , 0.8850775 , 0.30382073, 0.6007602 ],            dtype=float32))

In [16]:
bm.random.normal(0., 2., 5)

JaxArray(DeviceArray([-1.5375284 , -0.59702027, -2.2728395 ,  3.2330806 ,
                      -0.27385947], dtype=float32))

For more details, please see the [Tensors](./tensors.ipynb) tutorial.

### 3. **JAX transformations on class objects**. 
OOP is the essence of Python. However, JAX's excellent tranformations (like JIT compilation) only support [pure functions](https://en.wikipedia.org/wiki/Pure_function). To make them work on object-oriented coding in brain dynamics programming, ``brainpy.math`` extends JAX transformations to Python classess.

Example 1: JIT compilation performed on class objects.

In [13]:
class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u
        
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)

In [14]:
lr1 = LogisticRegression(num_dim)

%timeit lr1(points, labels)

230 ms ± 25.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
lr2 = bm.jit(LogisticRegression(num_dim))

%timeit lr2(points, labels)

142 ms ± 9.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Example 2: Autograd performed on variables of a class object.

In [16]:
class Linear(bp.Base):
  def __init__(self, num_hidden, num_input, **kwargs):
    super(Linear, self).__init__(**kwargs)

    # parameters
    self.num_input = num_input
    self.num_hidden = num_hidden

    # variables
    self.w = bm.random.random((num_input, num_hidden))
    self.b = bm.zeros((num_hidden,))

  def __call__(self, x):
    r = x @ self.w + self.b
    return r.mean()

In [17]:
l = Linear(num_hidden=3, num_input=2)

In [18]:
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))

(JaxArray(DeviceArray([[0.14844148, 0.14844148, 0.14844148],
                       [0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32)),
 JaxArray(DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32)))

## What is the difference between ``brainpy.math`` and other frameworks?

``brainpy.math`` is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make **a better brain dynamics programming framework for Python users**. 

However, there are important differences between ``brainpy.math`` and other frameworks. As is stated above, JAX and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on brain dynamics models, it will become a huge problem due to the overwhelmingly large number of parameters and variables. On the contrary, ``brainpy.math`` allows an object-oriented programming paradigm, which is much more Pythonic. The most similar framework is called [Objax](https://github.com/google/objax) which also supports OOP based on JAX, but it is more suitable for the deep learning domain and not able to be used directly in brain dynamics programming. 

## How to interoperate `brainpy.math` with other JAX frameworks?

`brainpy.math` can be easily interoperated with other JAX frameworks. 

### 1. data are exchangeable in different frameworks. 
This can be realized because ``JaxArray`` can be direactly converted to JAX ndarray or NumPy ndarray.  

In [17]:
b

JaxArray(DeviceArray([5, 1, 2, 3, 4], dtype=int32))

Convert a ``JaxArray`` into a JAX ndarray. 

In [18]:
# JaxArray.value is a JAX ndarray
b.value

DeviceArray([5, 1, 2, 3, 4], dtype=int32)

Convert a ``JaxArray`` into a numpy ndarray.

In [19]:
# JaxArray can be easily converted to a numpy ndarray
np.asarray(b)

array([5, 1, 2, 3, 4])

Convert a numpy ndarray into a ``JaxArray``. 

In [20]:
bm.asarray(np.arange(5))

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

Convert a JAX ndarray into a ``JaxArray``.

In [21]:
import jax.numpy as jnp
bm.asarray(jnp.arange(5))

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

In [22]:
bm.JaxArray(jnp.arange(5))

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

### 2. transformations in ``brainpy.math`` also work on functions. 
APIs in other JAX frameworks can be naturally integrated in BrainPy. Let's take the gradient-based optimization library [Optax](https://github.com/deepmind/optax) as an example to illustrate how to use other JAX frameworks in BrainPy. 

In [26]:
import optax

In [27]:
# First create several useful functions. 

network = bm.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = bm.mean(optax.l2_loss(y_pred, y))
  return loss

@bm.jit
def train(params, opt_state, xs, ys):
  grads = bm.grad(compute_loss)(params, xs.value, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

In [28]:
# Generate some data

bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)

In [29]:
# Initialize parameters of the model + optimizer

params = bm.array([0.0, 0.0])
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(params)

In [30]:
# A simple update loop

for _ in range(1000):
  params, opt_state = train(params, opt_state, xs, ys)

assert bm.allclose(params, target_params), \
  'Optimization should retrieve the target params used to generate the data.'