# `Numba` Tutorial

In this tutorial, we will learn how to use numba to speed up python loops.

To intall conda, type `conda install numba`.

In [1]:
import numpy as np
import numba as na
from numba import jit, njit, prange, set_num_threads

### 1. The Numba's JIT decorators, `@jit`.

First, let's consider a nested loop in python.\

Nested loops are very common in any computational physics problems (i.e. the acceleration calculations in the n-body problem).

In [2]:
def native_python(N):
    value = 0
    for _ in range(N):
        for _ in range(N):
            # some physical calculations, such as acceleration. 
            value += np.tanh(123)
    return value

In [3]:
test_size = 3000

In [4]:
%timeit ans = native_python(N=test_size)

17.3 s ± 763 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
ans1 = native_python(N=test_size)
print(ans1)

9000000.0


the above function takes ~6.43 s with `N=3000` (measured by Kuo-Chuan's desktop computer).

In the above example, the calculation is simply adding np.tanh(123) N times. This is equivalent to

In [6]:
ans2 = np.sum(np.tanh(123)*np.ones(test_size*test_size))

In [7]:
print(ans1==ans2)

True


In [8]:
%timeit np.sum(np.tanh(123)*np.ones(test_size**2))

12.3 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


the same calculation takes only 19.5 ms with `np.sum()` (x330 speedup).

In eariler lecutres, we have learned that we should use `numpy` and `scipy` to avoid using loops in native python.\
However, it is possible that the calculations inside the for loops cannot find counter part calculations in `numpy` and `scipy` (or not straightforward). 

Numba's Just-in-time (JIT) decoraators is one good solution.


In [9]:
@jit(nopython=True) # new syntex
def numba_jit(N):
    value = 0
    for _ in range(N):
        for _ in range(N):
            value += np.tanh(123)
    return value

In [10]:
ans3 = numba_jit(N=test_size)
print(ans3)
print(ans1==ans3)

9000000.0
True


In [11]:
%timeit ans = numba_jit(N=test_size)

18.5 ms ± 64.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


With `jit`, it takes 6.47 ms now by just adding one line of code!
Note that the performance could be still a bottle neck when `test_size` is big.

In [12]:
%timeit ans = numba_jit(N=(test_size*10))

1.87 s ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The calculation times increased with N^2.

We could actually further improve it with `njit` and `prange`.

In [13]:
@njit(parallel=True)
def numba_njit_parallel(N):
    value = 0
    for i in prange(N): # prange is parallel range.
        for j in prange(N):
            value += np.tanh(123)
    return value

note that in the above example, we could not use `for _ in prange(N)`, becasue `_` is not recognitzed by numba in parallel computing.

(Use in the heavy calculations, but be aware that the communication between one core and the other also needs time! Thus, if the data is small, do not use the parallel calculations.)

In [14]:
ans4 = numba_njit_parallel(N=test_size)
print(ans1==ans4)

True


In [19]:
set_num_threads(8) # use how many cores to calculate

In [20]:
%timeit ans = numba_njit_parallel(N=(test_size*10))

295 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


It took 161 ms with 4 threads (x4 speedup).

# Exercise

## Exercise 1: Use numba `jit` and `njit` to speedup the Pi calculation. 

Compare your solutions with `numpy`.

In [24]:
def cal_norm(N: int):
  dx = 1/N
  area = 0

  for i in range (0, N + 1, 1):
    h = np.sqrt(1 - (i/N)**2)
    area += dx * h
  Area = area * 4
  
  return Area
%timeit cal_norm(N = int(10e4))
print(cal_norm(N = int(10e4)))

167 ms ± 4.86 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.141612616401957


In [27]:
@jit(nopython=True)
def cal_jit(N: int):
  dx = 1/N
  area = 0

  for i in range (0, N + 1, 1):
    h = np.sqrt(1 - (i/N)**2)
    area += dx * h
  Area = area * 4
  
  return Area

%timeit cal_jit(N = int(10e4))
print(cal_jit(N = int(10e4)))

659 µs ± 143 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
3.141612616401957


In [28]:
@njit(parallel=True)
def cal_njit(N: int):
  dx = 1/N
  area = 0

  for i in prange (0, N + 1, 1): # needs to use prange
    h = np.sqrt(1 - (i/N)**2)
    area += dx * h
  Area = area * 4
  
  return Area

%timeit cal_njit(N = int(10e4))
print(cal_njit(N = int(10e4)))

101 µs ± 13.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.141612616401991


## Exercise 2: Speedup your N-body solver.

Now, move back to `2_nbody.ipynb`. Let's speed up our `nbody.py` solver with numba.