# Exercise 1.1 - Using the jit decorator

## Objectives

- Use the jit decorator and observe its behaviour with respect to input types
- Understand how to specify the expected types ahead-of-time
- See how to call other compiled functions
- Understand the overheads in dispatching a jitted function

## A first use of the jit decorator

Define a compiled addition function like:

In [22]:
from numba import jit

@jit
def add(x, y):
    # A somewhat trivial example
    return x + y

Now try calling the function with:

In [23]:
add(1, 2)

3

and then:

In [24]:
add(1j, 2)

(2+1j)

Notice what happens on each invocation - the function behaves differently depending on the types of the arguments. Although this is unsurprising for Python code, Numba has generated two separate implementations of the add function for the different argument types.

You can also explicitly specify the function signature that you are expecting, which will cause Numba to compile the function only for the given signature, at the time of declaration:

In [25]:
from numba import int32

@jit(int32(int32, int32))
def add_int32(x, y):
    return x + y

Now try making these calls to the new function:

In [26]:
add_int32(1, 2)

3

and then:

In [27]:
add_int32(1j, 2)

TypeError: No matching definition for argument type(s) complex128, int64

This time the outcome is different - Numba will only permit the function to execute with the specified types.

## Calling other compiled functions

Execute the following:

In [28]:
import math

def square(x):
    return x * x

@jit
def hypot(x, y):
    return math.sqrt(square(x) + square(y))

Time the execution of `hypot(3, 4)`:

In [29]:
%timeit hypot(3, 4)

The slowest run took 74162.06 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 601 ns per loop


Now let's add a @jit decorator to the square function:

In [30]:
@jit
def square(x):
    return x * x

@jit
def hypot(x, y):
    return math.sqrt(square(x) + square(y))

 and time the execution of hypot(3, 4) again:

In [31]:
%timeit hypot(3, 4)

The slowest run took 273622.44 times longer than the fastest. This could mean that an intermediate result is being cached 
10000000 loops, best of 3: 188 ns per loop


There are two things to note here:
- First, the execution time is reduced when we jitted both functions. Calling jitted functions from other jitted functions is possible, and brings a speed improvement over calling normal Python functions.
- Secondly, even though we did not change `hypot`, we needed to redefine it after changing the `square` function. This is because Numba resolved the call to `square` when it compiled `hypot`, not at the time `square` is called.

## Numba overheads

Let’s define a Python function that adds two numbers (to complement our add function from above):

In [32]:
def add_python(x, y):
    return x + y

Now try benchmarking it against the Numba-compiled function:

In [33]:
%timeit add(1, 2)

The slowest run took 20.72 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 209 ns per loop


In [34]:
%timeit add_python(1, 2)

The slowest run took 12.14 times longer than the fastest. This could mean that an intermediate result is being cached 
10000000 loops, best of 3: 98 ns per loop


The Numba-compiled code takes longer than the Python code! This illustrates that there is some overhead for calling a Numba-compiled function. The work done by a function must be large enough that the speedup from compiling amortises the overhead of the Numba function call cost.

Let's try a function which performs more computation. Define the normal and jitted versions:

In [35]:
def clip(x, lim):
    for i in range(len(x)):
        if x[i] > lim:
            x[i] = lim


@jit
def clip_jit(x, lim):
    for i in range(len(x)):
        if x[i] > lim:
            x[i] = lim

Now let's set up some input data:

In [36]:
import numpy as np

a1 = np.arange(1000)
a2 = np.arange(1000)

And let's benchmark these two implementations:

In [37]:
%timeit clip(a1, 100)

10000 loops, best of 3: 191 µs per loop


In [38]:
%timeit clip_jit(a2, 100)

The slowest run took 71949.10 times longer than the fastest. This could mean that an intermediate result is being cached 
1000000 loops, best of 3: 818 ns per loop


You should see a significant speedup from the use of clip_jit compared to clip. (On my laptop, a speedup of about 200 times over the Python implementation is obtained).

# Summary

- You use the jit decorator to instruct Numba to compile a function.
- Numba infers the types of arguments, and specialises the compiled function to the argument types.
- You can also specify the types of the arguments ahead of time, but this will prevent specialisations for other types being compiled.
- Jitted functions can call other jitted functions (and it is faster to do so).
- Jitted function calls have overhead - you must make sure they do enough computation to get an overall speedup.