# Just-in-time compilation with Numba

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import numba

In [None]:
from time import time
from contextlib import contextmanager

@contextmanager
def cpu_timer():
    start = time()
    yield
    end = time()
    print(f'Elapsed time: {end - start} s')

## Using njit

In [None]:
X, Y = np.meshgrid(np.linspace(-2.0, 1, 10000), np.linspace(-1.25, 1.25, 10000))

def mandelbrot(X, Y, itermax):
    mandel = np.empty(shape=X.shape, dtype=np.int32)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            it = 0
            cx = X[i, j]
            cy = Y[i, j]
            x = 0.0
            y = 0.0
            while x * x + y * y < 4.0 and it < itermax:
                x, y = x * x - y * y + cx, 2.0 * x * y + cy
                it += 1
            mandel[i, j] = it
            
    return mandel

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)

with cpu_timer():
    m = mandelbrot(X, Y, 100)
    
ax.imshow(np.log(1 + m), extent=[-2.0, 1, -1.25, 1.25]);

In [None]:
from numba import prange
@numba.njit(parallel=True)
def mandelbrot_jitted(X, Y, itermax):
    mandel = np.empty(shape=X.shape, dtype=np.int32)
    for i in prange(X.shape[0]):
        for j in range(X.shape[1]):
            it = 0
            cx = X[i, j]
            cy = Y[i, j]
            x = cx
            y = cy
            while x * x + y * y < 4.0 and it < itermax:
                x, y = x * x - y * y + cx, 2.0 * x * y + cy
                it += 1
            mandel[i, j] = it
            
    return mandel

In [None]:
fig = plt.figure(figsize=(20, 20))
ax = fig.add_subplot(111)

with cpu_timer():
    m = mandelbrot_jitted(X, Y, 100)
    
ax.imshow(np.log(1 + m), extent=[-2.0, 1, -1.25, 1.25]);
print(m.sum())

## Using vectorize

In [None]:
from math import sin
from numba import float64, int64

def my_numpy_sin(a, b):
    return np.sin(a) + np.sin(b)

@np.vectorize
def my_sin(a, b):
    return sin(a) + sin(b)

@numba.vectorize([float64(float64, float64), int64(int64, int64)], target='parallel')
def my_sin_numba(a, b):
    return np.sin(a) + np.sin(b)

In [None]:
x = np.random.randint(0, 100, size=9000000)
y = np.random.randint(0, 100, size=9000000)
print(y.dtype)

%time _ = my_numpy_sin(x, y)
%time _ = my_sin(x, y)
%time _ = my_sin_numba(x, y)