# Using `jit`

We know how to find hotspots now, how do we improve their performance?

We `jit` them!

We'll start with a trivial example but get to some more realistic applications shortly.

### Array sum

The function below is a naive `sum` function that sums all the elements of a given array.

In [1]:
def sum_array(inp):
    J, I = inp.shape
    
    mysum = 0
    for i in range(J):
        for j in range(I):
            mysum += inp[j, i]
            
    return mysum

In [2]:
import numpy

In [3]:
arr = numpy.random.random((300, 300))

In [4]:
sum_array(arr)

44842.32891326613

In [5]:
%timeit sum_array(arr)

10 loops, best of 3: 74.1 ms per loop


In [6]:
from numba import jit

## As a function call

In [7]:
sum_array_numba = jit()(sum_array)

In [8]:
sum_array_numba(arr)

44842.32891326613

In [10]:
%timeit sum_array_numba(arr)

1000 loops, best of 3: 230 µs per loop


## (more commonly) As a decorator

In [11]:
@jit
def sum_array(inp):
    J, I = inp.shape
    
    mysum = 0
    for i in range(J):
        for j in range(I):
            mysum += inp[j, i]
            
    return mysum

In [12]:
sum_array(arr)

44842.32891326613

In [13]:
%timeit sum_array(arr)

1000 loops, best of 3: 230 µs per loop


## How does this compare to NumPy?

In [16]:
%timeit arr.sum()

The slowest run took 4.19 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 102 µs per loop


## Your turn!

Everyone likes fractals! (right...?)

Use `jit` (either in function or decorator form) to speed up the Mandelbrot code below.

**Note**: the call to run the `create_fractal` function is commented out because it takes around ~15s to run on a new-ish i7. 

In [None]:
def mandel(x, y, max_iters):
    """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x,y)
    z = 0.0j
    for i in range(max_iters):
        z = z*z + c
        if (z.real*z.real + z.imag*z.imag) >= 4:
            return i

    return 255

def create_fractal(min_x, max_x, min_y, max_y, image, iters):
    height = image.shape[0]
    width = image.shape[1]

    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height
    for x in range(width):
        real = min_x + x * pixel_size_x
        for y in range(height):
            imag = min_y + y * pixel_size_y
            color = mandel(real, imag, iters)
            image[y, x] = color

    return image

#Uncomment these to run
#image = numpy.zeros((500 * 2, 750 * 2), dtype=numpy.uint8)
#image = create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20)

In [None]:
from matplotlib import pyplot, cm
%matplotlib inline

In [None]:
pyplot.figure(figsize=(10,8))
pyplot.imshow(image, cmap=cm.viridis);