<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Check-numba-version" data-toc-modified-id="Check-numba-version-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Check <code>numba</code> version</a></span></li><li><span><a href="#First-numba-example" data-toc-modified-id="First-numba-example-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>First <code>numba</code> example</a></span></li><li><span><a href="#Second-example" data-toc-modified-id="Second-example-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Second example</a></span></li></ul></div>

# Check `numba` version


In [1]:
import numba as nb
print(nb.__version__)

0.46.0


# First `numba` example

In [2]:
import numpy as np
import numba as nb

def trace_python(a):
    trace = 0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i]) 
    return trace

@nb.jit(nopython=True)
def trace_numba(a):
    '''
    a[N,N]
    '''
    trace = 0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i]) 
    return trace

def trace_numpy(a):
    trace = np.tanh(np.diag(a)).sum()
    return trace



for n in [100,1000, 5000]:
    print("n=",n)
    a=np.random.rand(n,n)
    %timeit trace_python(a)
    %timeit trace_numpy(a)
    %timeit trace_numba(a)



n= 100
131 µs ± 891 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.28 µs ± 21.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
947 ns ± 7.24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
n= 1000
1.35 ms ± 27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
18.7 µs ± 186 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
8.62 µs ± 264 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
n= 5000
6.95 ms ± 84.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
81.5 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68.6 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


# Second example

In [4]:
import numpy as np
import numba as nb

def sum_py(v,idx):
    sm=0
    for i in range(idx.shape[0]):
        sm += v[i]
    return sm

sum_jit=nb.jit(nopython=True)(sum_py)

def sum_numpy(v,idx):
    return np.sum(v[idx])

a=np.random.randn(1000)
idx=np.random.randint(1000,size=100)

print("pure python:")
%timeit sum(a,idx)
print("numpy:")
%timeit sum_numpy(a,idx)
print("python jit:")
%timeit sum_jit(a,idx)


pure python:
1.13 ms ± 6.94 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
numpy:
4.93 µs ± 60.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
python jit:
485 ns ± 1.67 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
