# Overview

@[Chaoming Wang](https://github.com/chaoming0625)

The core idea behind BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code "just-in-time" for execution. Subsequently, such transformed code can run at native machine code speed!

Excellent JIT compilers such as [JAX](https://github.com/google/jax) and [Numba](https://github.com/numba/numba) are provided in Python. However, they are designed to work only on pure Python functions. While, in computational neuroscience, most models have too many parameters and variables, it's hard to manage and control model logic by only using functions. On the contrary, object-oriented programming (OOP) based on ``class`` in Python will make your coding more readable, controlable, flexible and modular. Therefore, it is necessary to support JIT compilation on class objects for programming in brain modeling. 

In order to provide **a platform can satisfy the need of brain dynamics programming**, we provide [brainpy.math](../apis/math.rst) module. 

In [1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

In [2]:
import numpy as np

## Why do you need ``brainpy.math`` module?

Specifically, ``brainpy.math`` make the following contributions:

- **Numpy-like ndarray**. Python users are familiar with [NumPy](https://numpy.org/), especially its [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). JAX has a similar ``ndarray`` and similar operations on JAX ndarray. However, several basic features are fundamentally different from numpy ndarray. For example, JAX ndarray does not support in-place mutating updates, like ``x[i] += y``. In order to overcome these gotchas, ``brainpy.math`` provides ``JaxArray``, which can be used as the same with numpy ndarray. 

In [3]:
# ndarray in "numpy"

a = np.arange(5)
a

array([0, 1, 2, 3, 4])

In [4]:
a[0] += 10
a

array([10,  1,  2,  3,  4])

In [5]:
# ndarray in "brainpy.math"

b = bm.arange(5)
b

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

In [6]:
b[0] += 10
b

JaxArray(DeviceArray([10,  1,  2,  3,  4], dtype=int32))

More details please see [Tensors tutorial](./tensors.ipynb).

- **Numpy-like random sampling**. JAX has its own style to make random numbers, which is very different from the original NumPy. In order to provide a consistent experience, ``brainpy.math`` provides the same programming style for random sampling. There are minimal gaps between ``brainpy.math.random`` and ``numpy.random`` module. For example:

In [7]:
# random sampling in "numpy"

np.random.seed(12345)

In [8]:
np.random.random(10)

array([0.92961609, 0.31637555, 0.18391881, 0.20456028, 0.56772503,
       0.5955447 , 0.96451452, 0.6531771 , 0.74890664, 0.65356987])

In [9]:
np.random.normal(0., 2., 10)

array([ 2.01437872, -2.59244222,  0.54998327,  0.45782576,  2.70583367,
        1.77285868, -4.00327462, -0.74368507,  3.33805062, -0.87713947])

In [10]:
# random sampling in "brainpy.math.random"

bm.random.seed(12345)

In [11]:
bm.random.random(10)

JaxArray(DeviceArray([0.69164526, 0.21911383, 0.36975408, 0.3115406 , 0.7697315 ,
                      0.01930189, 0.43002808, 0.7519586 , 0.01569903, 0.30688643],            dtype=float32))

In [12]:
bm.random.normal(0., 2., 10)

JaxArray(DeviceArray([ 2.9769602 , -3.2109535 ,  1.465871  ,  0.9005953 ,
                       1.7705722 , -2.477971  , -1.8175219 ,  2.1754556 ,
                      -0.25283262,  2.0693657 ], dtype=float32))

More details please see [Tensors tutorial](./tensors.ipynb).

- **Transformation on class objects**. OOP is the essence of Python. However, JAX's excellent tranformations (like JIT compilatio) only support [pure functions](https://en.wikipedia.org/wiki/Pure_function). In order to make them work on object-oriented coding used in brain dynamics programming, ``brainpy.math`` extends JAX's transformations to be capable of performing on Python classess.

Example 1: JIT compilation performed on class objects.

In [13]:
class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w[:] = self.w - u
        
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)

In [14]:
lr1 = LogisticRegression(num_dim)

%timeit lr1(points, labels)

230 ms ± 25.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
lr2 = bm.jit(LogisticRegression(num_dim))

%timeit lr2(points, labels)

142 ms ± 9.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Example 2: Autograd performed on variables of a class object.

In [16]:
class Linear(bp.Base):
  def __init__(self, num_hidden, num_input, **kwargs):
    super(Linear, self).__init__(**kwargs)

    # parameters
    self.num_input = num_input
    self.num_hidden = num_hidden

    # variables
    self.w = bm.random.random((num_input, num_hidden))
    self.b = bm.zeros((num_hidden,))

  def __call__(self, x):
    r = x @ self.w + self.b
    return r.mean()

In [17]:
l = Linear(num_hidden=3, num_input=2)

In [18]:
bm.grad(l, grad_vars=(l.w, l.b))(bm.random.random([5, 2]))

(JaxArray(DeviceArray([[0.14844148, 0.14844148, 0.14844148],
                       [0.2177031 , 0.2177031 , 0.2177031 ]], dtype=float32)),
 JaxArray(DeviceArray([0.33333334, 0.33333334, 0.33333334], dtype=float32)))

## What's the difference between ``brainpy.math`` and other frameworks?

``brainpy.math`` is not intended to be a reimplementation of the API of any other frameworks. All we are trying to do is to make **a better brain dynamics programming framework for Python users**. 

However, there are important differences between ``brainpy.math`` and other frameworks. 

Essentially, JAX itself and many other JAX frameworks follow a functional programming paradigm. When appling this kind of coding style on the brain dynamics models, it will be a difficult problem because many variables and parameter in a model will make the coding out of control. 

On the contrary, ``brainpy.math`` allows an object-oriented programming paradigm (which is much more Pythonic). The most similar framework is called [Objax](https://github.com/google/objax), which also supports OOP based on JAX. However, it is more suitable for deep learning domain, and can not be directly used in brain dynamics programming. 

## How to interoperate with other JAX frameworks?

BrainPy can be easily interoperated with other JAX frameworks. 

- First, **data can be exchangeable**. This is because ``JaxArray`` can direactly convert to JAX ndarray or NumPy ndarray.  

In [19]:
b

JaxArray(DeviceArray([10,  1,  2,  3,  4], dtype=int32))

Convert ``JaxArray`` into a JAX ndarray. 

In [20]:
# JaxArray.value is a JAX ndarray

b.value

DeviceArray([10,  1,  2,  3,  4], dtype=int32)

Convert a ``JaxArray`` into a numpy ndarray.

In [21]:
# JaxArray can be easily converted to a numpy ndarray

np.asarray(b)

array([10,  1,  2,  3,  4])

Convert a numpy ndarray into a ``JaxArray``. 

In [22]:
bm.asarray(np.arange(5))

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

Convert a JAX ndarray into a ``JaxArray``.

In [23]:
import jax.numpy as jnp

In [24]:
bm.asarray(jnp.arange(5))

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

In [25]:
bm.JaxArray(jnp.arange(5))

JaxArray(DeviceArray([0, 1, 2, 3, 4], dtype=int32))

- Second, **transformations in ``brainpy.math`` can also works on functions**. This means APIs in other JAX frameworks can be naturally integrated in BrainPy. Let's take the gradient-based optimisation library [Optax](https://github.com/deepmind/optax) as an example to illustrate how to use other JAX framework as our own. 

In [26]:
import optax

In [27]:
# First create several useful functions. 

network = bm.vmap(lambda params, x: bm.dot(params, x), in_axes=(None, 0))

def compute_loss(params, x, y):
  y_pred = network(params, x)
  loss = bm.mean(optax.l2_loss(y_pred, y))
  return loss

@bm.jit
def train(params, opt_state, xs, ys):
  grads = bm.grad(compute_loss)(params, xs.value, ys)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state

In [28]:
# Generate some data

bm.random.seed(42)
target_params = 0.5
xs = bm.random.normal(size=(16, 2))
ys = bm.sum(xs * target_params, axis=-1)

In [29]:
# Initialize parameters of the model + optimizer

params = bm.array([0.0, 0.0])
optimizer = optax.adam(learning_rate=1e-1)
opt_state = optimizer.init(params)

In [30]:
# A simple update loop

for _ in range(1000):
  params, opt_state = train(params, opt_state, xs, ys)

assert bm.allclose(params, target_params), \
  'Optimization should retrieve the target params used to generate the data.'