High-performance and parallel computing for AI - Practical 2: JIT & Wrapping
============================================================================

IMPORTANT
=========

For these practicals we will be using a different `conda environment`. When opening a notebook or a terminal make sure you are using the **CuPy Kernel**!!!

Tutorial 1 - Numba and Just-in-time compiling
---------------------------------------------

For this part we will be using [numba](https://numba.pydata.org/). Numba is an open source just-in-time compiler that translates a subset of Python and NumPy code into fast machine code. In this question we investigate the use of JIT. Note: there may be some overlaps with other parts of the course, but we will be using numba later on in the GPU part.

**Warning:** For reasons that will be obvious soon before running the scripts in this question you should restart the Jupyter kernel (click on the circular arrow or the fast-forward sign above).

The easiest way to start working with numba is to import the numba JIT decorator:

In [1]:
from numba import jit

Consider the script 

In [2]:
import numpy as np
from time import time

N = 10000
a = np.random.randn(N)
b = np.random.randn(N)

def f(a,b):
    out = a.copy()
    for i in range(N):
        out[i] = a[i] + b[i]

    return out

tic = time()
c = f(a,b)
toc = time()-tic
print(toc)

0.002525806427001953


The above is not the best way of working in Python: Python for loops are very slow. What we should do is use vectorization (see later).
However, another option is to use numba to compile the function:

In [3]:
@jit
def fjit(a,b):
    out = a.copy()
    for i in range(N):
        out[i] = a[i] + b[i]

    return out

# The first time the function is called, numba will use JIT to compile the function.
tic = time()
c = fjit(a,b)
toc = time()-tic
print("With compilation:", toc)

# The second time the function is called no compilation will occur so it will be faster!
tic = time()
c = fjit(a,b)
toc = time()-tic
print("Without compilation:", toc)

With compilation: 0.37271833419799805
Without compilation: 6.103515625e-05


The compiled code is now much much faster! However, you can see the big downside of JIT: compilation took a very long time, much longer than the pure Python for loop!

To make things worse, a change of variable type will trigger another JIT:

In [4]:
# No JIT here, this was compiled before so it's fast
tic = time()
c = fjit(a,b)
toc = time()-tic
print("Already JIT-ed:", toc)

# If you change variable types: re-JIT
aa = a.astype(np.int64)
bb = b.astype(np.int64)
tic = time()
cc = fjit(aa,bb)
toc = time()-tic
print("Argh! Re-JIT:", toc)

# This was JIT-ed before so now we are fine!
aa = a.astype(np.int64)
bb = b.astype(np.int64)
tic = time()
cc = fjit(aa,bb)
toc = time()-tic
print("Already JIT-ed:", toc)

# Changing array size won't trigger recompilation though
tic = time()
cc = fjit(aa[:-1],bb[:-1])
toc = time()-tic
print("Already JIT-ed:", toc)

Already JIT-ed: 0.00021314620971679688
Argh! Re-JIT: 0.11827611923217773
Already JIT-ed: 6.246566772460938e-05
Already JIT-ed: 9.012222290039062e-05


You can prescribe input and output types if you want:
1- It may make compilation slightly faster.
2- It will throw an error if you try using the function with the wrong input/output type

To prescribe types you need to provide a *function signature* to the decorator:
```python
    import numba as nb
    @jit(nb.float32(nb.float32, nb.float32))
```

The above means that the function being JIT-ed takes two single-precision variables as inputs and returns one single precision variable. These are for scalars. If you want to use vectors you will have to use, e.g., `nb.float32[:]`.

In [5]:
import numba as nb

@jit(nb.float64[:](nb.float64[:], nb.float64[:])) # This is called the function signature
def fjit(a,b):
    out = a.copy()
    for i in range(N):
        out[i] = a[i] + b[i]

    return out

c = fjit(a,b) # no problem

## Try uncommenting the lines below for fun
#fjit(aa,bb) # aa and bb are arrays of integers so numba will complain
#fjit(3.0,2.0) # these are scalars and not arrays o numba will complain

Important! Use `nb.void` for functions that do not return anything, e.g.,

In [6]:
@jit(nb.void(nb.float64))
def test(a):
    print(a)

test(4.0)

4.0


If you use nested functions, please **jit them all**! Note: this is numba-specific. In other libraries (notably Jax) you only need to JIT the outermost function.

In [7]:
@jit
def anincrediblefunction(a):
    return fjit(a,a) # fjit was already decorated with jit. This is good practice when using numba.

c = anincrediblefunction(a)

Note that numba has two compilation modes: a `nopython` mode which will avoid using interpeted Python calls (which are slow) and an `object` mode which allows interpreted Python calls and is compatible with Python objects that cannot be compiled. `object` mode is always very slow: good for debugging, but in general it should be avoided.

By default, the `jit` decorator will use `nopython` mode (which is good!). In old versions of numba it didn't and you had to use the `njit` decorator instead. If you want to use `object` mode use instead `jit` with `nopython=False`:

In [8]:
@jit(nopython=False)
def fjit(a,b):
    out = a.copy()
    for i in range(N):
        out[i] = a[i] + b[i]

    return out

# Note: I suspect that for such a simple function numba will be fast anyways.

**Warning** Numba does not like a lot of things, including:
* **External libraries**. If you try to use anything other than numpy (or a few numba-supported libraries) with numba it won't work.
* **Python classes**. Classes won't work inside JIT-ed code *unless* you compile them with the experimental feature `@jitclass`.

This makes numba annoying at times. However, it offers a very easy way of compiling Python code and making it much much faster. Additionally, it allows to do the same using GPUs, which is a great feature.

**A few curiosities:**
* numba uses the LLVM compiler under the hood.
* JAX also uses JIT. However it uses the XLA compiler (my take is that LLVM is faster and XLA has a few quirks such as it is slow for loops and does not like a few operations such as editing vector entries, but I am not an expert) and it is designed with automatic differentiation and machine learning applications in mind rather than JIT. Jax is not compatible with numpy (it uses its own numpy library instead, jax.numpy).
* The other AI libraries (PyTorch, TensorFlow) also use JIT.
* Julia is a programming language that tries to bridge between compiled and interpreted languages. It is as easy as python, but it JITs everything so as to achieve compiled language speeds. However big catch: JIT compiling time. 

In [9]:
# Here is a short jax example
import jax

# This is being unfair to JAX: you could just do a + b and Jax does not like for loops
# so you need this convoluted way.
@jax.jit
def jax_fjit(a, b):
    out = jax.numpy.zeros_like(a)
    
    def body_fun(i, result):
        return result.at[i].set(a[i] + b[i])  # Functional update
    
    return jax.lax.fori_loop(0, N, body_fun, out)

@jit
def numba_fjit(a,b):
    out = a.copy()
    for i in range(N):
        out[i] = a[i] + b[i]

    return out

# Numba again
print("Using numba\n")
tic = time()
c = numba_fjit(a,b)
toc = time()-tic
print("With compilation:", toc)

tic = time()
c = numba_fjit(a,b)
toc = time()-tic
print("Without compilation:", toc)

# Using JAX
print("\n\nUsing JAX\n")

# need to convert to jax.numpy arrays first
aa = jax.numpy.array(a)
bb = jax.numpy.array(b)

tic = time()
cc = jax_fjit(aa,bb)
toc = time()-tic
print("With compilation:", toc)

tic = time()
cc = jax_fjit(aa,bb)
toc = time()-tic
print("Without compilation:", toc)



Using numba

With compilation: 0.09668779373168945
Without compilation: 7.152557373046875e-05


Using JAX

With compilation: 0.04004979133605957
Without compilation: 0.0004780292510986328


In [10]:
# Just so you know, numpy code IS COMPILED so using pure numpy will already be quite fast:

tic = time()
c = a+b
toc = time()-tic
print("Pure numpy:", toc)

Pure numpy: 0.00028896331787109375


Note that this example is unfair to JAX: you could just do `a + b` as in numpy and it would work and it would be faster, plus for loops are especially slow in JAX.
Nevertheless it shows like each alternative has its own pros and cons. numba is designed specifically for JIT-ing so I expect it to be fast in general. For things that you cannot use numpy (or jax.numpy) for, numba is likely to be a great option.

**Conclusions**
This short tutorial was just to introduce you to the basic of JIT and to prepare you for the GPU tutorial in which we may be using numba. If you want to read more about numba, please checkout its official [documentation](https://numba.readthedocs.io/en/stable/index.html). 

Problem 1
---------

Write a Python function that uses for loops to compute $\lVert \sin(Ab) \rVert_2$, where $A$ is a random square matrix of size $N=1024$ and $b$ is a random vector of compatible sizes. Compile it with numba and compute the timings of: 1) The pure Python code. 2) The first time the numba-jitted code is run. 3) The second time the numba-jitted code is run. Check that they match with what we observed in the practical so far.

Now, try using numpy.linalg.norm to compute the Euclidean norm and try to compile the code with numba. What happens? Try to understand the error message and note that numba does not give very clear messages when it fails to compile. Now, circumvent the problem by compiling it by passing nopython=False to the jit decorator so that numba uses object mode. Does it work now? Compare the timings of using nopython=False VS your previous code. What do you observe? Think about why using object mode is always a bad idea and why we have been lucky here. 


Hint: You may have to restart the jupyterhub kernel to make sure JIT compiling happens.
Hint: Always write the pure Python code first and check that it runs before jitting it.

In [11]:
import numpy as np
from numba import jit

N = 1024
A = np.random.randn(N,N)
b = np.random.randn(N)

def myfun(A,b):
    c = np.zeros_like(b)
    for i in range(N):
        for j in range(N):
            c[i] += A[i,j]*b[j]

    for i in range(N):
        c[i] = np.sin(c[i])**2

    out = 0
    for i in range(N):
        out += c[i]

    out = np.sqrt(out)
    return out

# Can also jit this way
jit_myfun = jit(myfun)

tic = time()
cc = myfun(A,b)
toc = time()-tic
print("Pure Python:", toc)

tic = time()
cc = jit_myfun(A,b)
toc = time()-tic
print("With compilation:", toc)

tic = time()
cc = jit_myfun(A,b)
toc = time()-tic
print("Without compilation:", toc)

# Note: Using np.linalg.norm throws a compilation error in which it asks you to use object mode.
#       This happens since numba does not support numpy.linalg.

@jit(nopython=False)
def jit_mybadfun(A,b):
    c = np.zeros_like(b)
    for i in range(N):
        for j in range(N):
            c[i] += A[i,j]*b[j]

    for i in range(N):
        c[i] = np.sin(c[i])

    out = np.linalg.norm(c)
    return out

tic = time()
cc = jit_mybadfun(A,b)
toc = time()-tic
print("With compilation:", toc)

tic = time()
cc = jit_mybadfun(A,b)
toc = time()-tic
print("Without compilation:", toc)

# Using object mode is a bad idea since it will use pure Python to evaluate code it is not able to compile.
# However here we are lucky since np.linalg.norm invokes compiled code under the hood.

Pure Python: 0.32219743728637695
With compilation: 0.2579164505004883
Without compilation: 0.0009684562683105469
With compilation: 0.22426271438598633
Without compilation: 0.0009686946868896484


Tutorial 2 - Quick wrapper example
-----------------------------------

An alternative to JIT compiling is to write code in a *compiled* language (e.g., C/C++, Fortran, Rust, etc.) and then compiling it and calling it from Python.

However, I cannot assume you know anything other than Python so I am only going to show you a small example using CFFI - The Common Foreign Function Interface.

Frankly, I do not advise you learn CFFI, there are better packages such as (for C/C++):
* [Pybind11](https://github.com/pybind/pybind11).
* [Nanobind](https://github.com/wjakob/nanobind).

For other languages there are likely other very good options.

Nevertheless, CFFI allows me to show you an example without ever leaving Python. Using anything else would require writing C/C++ code, compiling it on the side etc. which is a bit tricky with Jupyter notebooks so here it is.

In [12]:
import numpy as np
import cffi

ffi = cffi.FFI()
# Tell CFFI to expect the add_arrays function
ffi.cdef("void add_arrays(const double* a, const double* b, double* out);")
# define source code: a simple function to add 2 arrays. The output array out must be passed as input. It will be modified.
C_SOURCE = "#define N %d\n" % N + '''
void add_arrays(const double* a, const double* b, double* out) {
    for (int i = 0; i < N; ++i) {
        out[i] = a[i] + b[i];
    }
}
'''

# Compiles the code. This is very similar to JIT, but normally it is done offline and separately and only once.
C = ffi.verify(
    C_SOURCE,
    extra_compile_args=["-O3", "-ffast-math"]
)

# Create input arrays and output buffer
a = np.random.randn(N)
b = np.random.randn(N)
out = np.empty_like(a)

# Execute C function with array buffers
C.add_arrays(
    ffi.cast("const double*", a.ctypes.data),
    ffi.cast("const double*", b.ctypes.data),
    ffi.cast("double*", out.ctypes.data)
)

print(np.linalg.norm(out))

out = np.empty_like(a)
tic = time()
C.add_arrays(
    ffi.cast("const double*", a.ctypes.data),
    ffi.cast("const double*", b.ctypes.data),
    ffi.cast("double*", out.ctypes.data)
)
toc = time()-tic
print("With wrapper:", toc)

44.58124424024471
With wrapper: 8.058547973632812e-05


The above is fast, but not as fast as numba.I never used CFFI, but in my experience the other wrappers are faster.

Problem 2 (bonus)
-----------------

This is a bonus problem. Only work on it if you are done for the day. You can go back on it on a later day if you have time.

Read the numba [jitclass documentation](https://numba.readthedocs.io/en/stable/user/jitclass.html). Create a new "CrazyNumberArray" class which stores a numpy array and overloads its `__mul__` and `__add__` operations so that the result of each operation is perturbed by an independent uniform random variable in $[-10^{-6}, 10^{-6}]$. Use the `numba.jitclass` decorator to JIT-compile the class. Check that it becomes faster.

Hint: you can generate uniform random variables by using `numpy.random.rand`. However, these are standard uniforms and they lie in $[0,1]$. Think about how you can map them to the desired interval.
