# Just-In-Time Compilation

@[Chaoming Wang](https://github.com/chaoming0625)

One of the core ideas of BrainPy is the Just-In-Time (JIT) compilation. JIT compilation enables your Python code to be compiled into machine code "just-in-time" for execution. Subsequently, such transformed code can run at native machine code speed! Therefore, it is necessary to understand how to write codes that are compatible with the JIT environment. 

For more details, please see the tutorials in "Math Foundation". 

In [5]:
import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')

## JIT Compilation for Functions

To take advantage of the JIT compilation, users just need to wrap their customized *functions* or *objects* into [bm.jit()](../apis/math/generated/brainpy.math.jit.jit.rst) to instruct BrainPy to transform your codes into machine code. 


Take the **pure functions** as an example. Here we try to implement a function of Gaussian Error Linear Unit:

In [6]:
def gelu(x):
  sqrt = bm.sqrt(2 / bm.pi)
  cdf = 0.5 * (1.0 + bm.tanh(sqrt * (x + 0.044715 * (x ** 3))))
  y = x * cdf
  return y

Let's first try to run the function without JIT.

In [7]:
x = bm.random.random(100000)
%timeit gelu(x)

298 µs ± 4.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


After JIT compilation, the function significantly speeds up. 

In [8]:
gelu_jit = bm.jit(gelu)
%timeit gelu_jit(x)

65.6 µs ± 80.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## JIT Compilation for Objects

Moreover, in BrainPy, JIT compilation can be performed on **class objects** with the following requirements:

1. The class object must be a subclass of [brainpy.Base](../tutorial_math/base.ipynb).

2. Dynamically changed variables must be labeled as [brainpy.math.Variable](../tutorial_math/variables.ipynb).

3. Variable updating  must be accomplished by [in-place operations](../tutorial_math/variables.ipynb).


Below is a simple example of a Logistic regression classifier. 

In [9]:
class LogisticRegression(bp.Base):
    def __init__(self, dimension):
        super(LogisticRegression, self).__init__()

        # parameters    
        self.dimension = dimension
    
        # variables
        self.w = bm.Variable(2.0 * bm.ones(dimension) - 1.3)

    def __call__(self, X, Y):
        u = bm.dot(((1.0 / (1.0 + bm.exp(-Y * bm.dot(X, self.w))) - 1.0) * Y), X)
        self.w.value = self.w - u

In this example, the model weights (``self.w``) will be modified during training, so it is marked as ``bm.Variable``. 

In [10]:
import time

def benckmark(model, points, labels, num_iter=30, name=''):
    t0 = time.time()
    for i in range(num_iter):
        model(points, labels)

    print(f'{name} used time {time.time() - t0} s')

In [11]:
num_dim, num_points = 10, 20000000
points = bm.random.random((num_points, num_dim))
labels = bm.random.random(num_points)

In [12]:
# without JIT

lr1 = LogisticRegression(num_dim)

benckmark(lr1, points, labels, name='Logistic Regression (without jit)')

Logistic Regression (without jit) used time 10.001078605651855 s


In [13]:
# with JIT

lr2 = LogisticRegression(num_dim)
lr2 = bm.jit(lr2)

benckmark(lr2, points, labels, name='Logistic Regression (with jit)')

Logistic Regression (with jit) used time 4.872019052505493 s


Note that in the above ``LogisticRegression`` model, the dynamically changed variable (``self.w``) is marked as ``bm.Variable``. If not, in the compilation phase, all ``self.`` accessed variables which are not the instances of ``bm.Variable`` will be compiled as static constants. 

### In-place operators

The updating of variables should be made in-place. There are several commonly used in-place operations. 

In [10]:
v = bm.Variable(bm.arange(10))

v

Variable(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

1. **Indexing and slicing**, which includes:
  - Indexing: ``v[i] = a``
  - Slicing: ``v[i:j] = b``
  - Slicing the specific values: ``v[[1, 3]] = c``
  - Slicing all values, ``v[:] = d``, ``v[...] = e``

for more details, please refer to [Array Objects Indexing](https://numpy.org/doc/stable/reference/arrays.indexing.html).

In [11]:
v[0] = 2.

v

Variable(DeviceArray([2, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

2. **Augmented assignment**. All augmented assignment are in-place operations:
  - ``+=`` (add)
  - ``-=`` (subtract)
  - ``/=`` (divide)
  - ``*=`` (multiply)
  - ``//=`` (floor divide)
  - ``%=`` (modulo)
  - ``**=`` (power)
  - ``&=`` (and)
  - ``|=`` (or)
  - ``^=`` (xor)
  - ``<<=`` (left shift)
  - ``>>=`` (right shift) 

In [12]:
v += 1

v

Variable(DeviceArray([ 3,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32))

3. `` .value`` **assignment**, which directly accesses the data stored in the JaxArray.

In [13]:
v.value = bm.arange(10)

v

Variable(DeviceArray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

4. ``.update()`` **method**.

In [14]:
v.update(bm.ones_like(v))

v

Variable(DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32))