# Understanding Numba

* https://numba.pydata.org/

In [48]:
from numba import njit
from numba import prange
import numpy as np

x = np.arange(100).reshape(10, 10)

In [2]:
def go(a):
    trace = 0
    for i in range(a.shape[0]):   
        trace += np.tanh(a[i, i]) 
    return a + trace              

In [3]:
@njit
def go_fast(a): # Function is compiled to machine code when called the first time
    trace = 0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

In [4]:
%timeit go(x)

go_fast(x) # force compilation...
%timeit go_fast(x)

28.4 µs ± 535 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
948 ns ± 54.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [5]:
def is_prime(n):
    for i in range(2,n):
        if n%i == 0 :
            return False
    return True
[i for i in range(50) if is_prime(i)]

[0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]

In [6]:
@njit
def is_prime_fast(n):
    for i in range(2,n):
        if n%i == 0 :
            return False
    return True
[i for i in range(50) if is_prime_fast(i)]

[0, 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]

In [7]:
%timeit is_prime(199999)
%timeit is_prime_fast(199999)

14.6 ms ± 443 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.05 ms ± 9.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
from numba import prange
#@njit
@njit(parallel=True)
def count_prime_fast(N):
    c = 0
    #for n in range(N) :
    for n in prange(N) :
        for i in range(2,n):
            if n%i == 0 :
                break
        else:
            c += 1
    return c
count_prime_fast(100000)

9594

In [98]:
def incr(n):
    c = 0.0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                c += 1.0
    return c
incr(300)

27000000.0

In [99]:
@njit
def incr_fast(n):
    c = 0.0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                c += 1.0
    return c
incr_fast(300)

27000000.0

In [100]:
@njit(parallel=True)
def incr_fast_par(n):
    c = 0
    for i in prange(n):
        for j in range(n):
            for k in range(n):
                c += 1.0
    return c
incr_fast_par(300)

27000000.0

In [101]:
n = 300
%timeit incr(n)
%timeit incr_fast(n)
%timeit incr_fast_par(n)

880 ms ± 5.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
28.5 ms ± 403 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.71 ms ± 41.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [102]:
n = 10**3
print(incr_fast(n))
print(incr_fast_par(n))

%timeit incr_fast(n)
%timeit incr_fast_par(n)

1000000000.0
1000000000.0
1.06 s ± 7.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
81.6 ms ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
