# Numba Typing Exercises

This notebook provides some exposition of Numba's typing mechanisms, and how to deal with some of the issues you may encounter with typing. It covers:

* How to display the typing of functions and understand the output,
* Examination of different typings of the same function,
* How to understand and fix typing errors,
* Some CUDA-specific issues related to performance and occupancy.

This notebook as published in the Git repository includes all the output from a previous run of the notebook - this is because some of the output (e.g. temporary variable names, register counts, etc.) may vary slightly with different versions of Numba, or different CUDA toolkits. It is suggested to clear all the output and work through the notebook, and keep the version with output provided for a reference in case there appears to be a discrepancy between the output of Numba and the description given in the text.

We'll begin with importing some required packages. We use the `@njit` decorator for CPU-targeted examples, which is shorthand for `@jit(nopython=True)` - the nopython mode typing has more requirements than the object mode typing, which leads to better performance, so it is a better choice for learning about typing. The `@cuda.jit` decorator is used for the CUDA-targeted examples.

In [None]:
from numba import njit, cuda
import numpy as np

## Inspecting the typing

Throughout this notebook we will use `inspect_types()` extensively to inspect the results of the typing algorithm. We'll start with a very simple example:

In [None]:
@njit
def f(a, b):
    return a + b

Let's call this function with a pair of `float32`s to force a typing:

In [None]:
f(np.float32(1), np.float32(2))

Now we'll inspect the typing for this call:

In [None]:
f.inspect_types()

The output of `inspect_types()` is a printout of the function's source code annotated with the Numba IR for each line, and the type of each IR node. Note that `del` nodes have no type, as they simply delete an existing variable.

Types are separated from IR with a double colon. In one example from above:

```
$8return_value.3 = cast(value=$6binary_add.2)  :: float32
```

the type of `$return_value.3` (which is also the type returned by `cast(value=$6binary_add.2)` is `float32`.

## An example with branching

When a variable takes a value from multiple different control flow paths (i.e. *branches*), a unification is needed to determine a type that is suitable for representing the types across all the different control flow path. We can explore unification using a simple function with a branch in it:

In [None]:
@njit
def select(a, b, c):
    if c:
        ret = a
    else:
        ret = b
    return ret

We'll start by calling the function with `a` and `b` both as `float32` for a first example:

In [None]:
select(np.float32(1), np.float32(2), True)

If we inspect the typing we get:

In [None]:
select.inspect_types()

We see that where the value of a variable can come from two separate branches, there is a *phi node*: `ret.2 = phi(incoming_values=...)`. The `incoming_values` track the different sources of this variable - in this example, `ret` from the `if` side of the branch, and `ret.1` from the `else` side of the branch.

The type of the phi node (`float32` in this case) is the type resulting from unification of the types of all the incoming values.

### Another typing of the branching function

If we call the function with a `float32` and a `float64`, we get another typing:

In [None]:
select(np.float32(1), np.float64(2), True)

Now let's inspect types again. This time, there will be two sets of typings - one for the `(float32, float32, boolean)` call earlier, and another for the `(float32, float64, boolean)` call in the previous cell. If we call `inspect_types()` with no arguments, it will print out the typings for all sets of argument types that have been seen so far. In order to focus on just the case we are interested in, we can pass the `signature` keyword argument with a tuple of Numba types to get the typing for a specific set of argument types. Numba types are imported from `numba` - for a comprehensive list of them, see [Types and signatures](http://numba.pydata.org/numba-doc/latest/reference/types.html#types-and-signatures) in the Numba documentation.

In [None]:
from numba import float32, float64, boolean
select.inspect_types(signature=(float32, float64, boolean))

Here we see the types from each branch (`ret = a  :: float32` and `ret.1 = b  :: float64`) have unified to `float64` at the phi node.

### Failing unification

Sometimes unification can fail. If we try to choose between a tuple and a scalar:

In [None]:
select((1, 2), 3.0, False)

The typing fails at unification: `Cannot unify int64 and UniTuple(int64 x 2) for 'ret.2'`.

When a typing error occurs, we can debug the propagation of type information by setting the environment variable `NUMBA_DEBUG_TYPEINFER` to `1`, or setting `numba.config.DEBUG_TYPEINFER` to `True`. It helps to also dump the Numba IR to understand the results of propagation better, so we should also set `numba.config.DUMP_IR` to `True` (or use the corresponding environment variable `NUMBA_DUMP_IR`). The debug output won't appear in the Jupyter notebook, but we can get the output by re-running this example as an external script:

In [None]:
%%script python
from numba import njit
from numba import config


@njit
def select(a, b, c):
    if c:
        ret = a
    else:
        ret = b
    return ret

config.DEBUG_TYPEINFER = True
config.DUMP_IR=True
select((1, 2), 3.0, False)

Numba has printed a dump of the type of all variables after each propagation step. The type inference happens on the IR in [Static Single Assignment (SSA) form](https://en.wikipedia.org/wiki/Static_single_assignment_form) so the names of variables after propagation carry a "version number" - e.g. `ret`, `ret.1`, `ret.2`, etc - however, the IR dump does not presently print the version numbers of each variable, so it can be a little tricky to work out which variable each versioned variable refers to.

The different versions of the variable make up the set that is being unified, so we can see that the variable `ret` has a set of `{UniTuple(int64 x 2), float64, float64}` from its versions `ret`, `ret.1` and `ret.2`.

A general strategy for debugging typing issues is to examine the changes in the types of variables at each propagate step, to determine how a typing error is occurring.

### Exercises

Execute the code in the following cell, and try to locate the typing of `x` in the output. Try to understand the message accompanying the `TypingError` (which begins with `Invalid use of Function(...`). You may find it easier to run this example on the terminal to avoid a lot of scrolling through a frame in the IPython notebook.

In [None]:
%%script python
from numba import njit
from numba import config
import numpy as np
config.DEBUG_TYPEINFER = True
config.DUMP_IR = True

@njit
def array_vs_scalar():
    x = np.zeros(20)
    x[0] = 10
    x[0, 1] = 20

array_vs_scalar()

This is an example of a function unsupported on the CUDA target. Numba tries to implement this function using the `array_sum_impl` internal function, which you will see in the output. Try to determine which function is unsupported (in the message beginning with `Use of unsupported NumPy function...`) and locate the call to it in the IR for `array_sum_impl`.

In [None]:
%%script python
from numba import cuda
from numba import config
import numpy as np
config.DEBUG_TYPEINFER = True
config.DUMP_IR = True

@cuda.jit
def sum_reduce(x):
    x[0] = x.sum()

x = np.ones(10)
sum_reduce(x)

## Branch Elimination

Sometimes Numba can eliminate code from dead branches, if it can determine that the branch will never run for a given set of argument types - this can avoid a unification error that would otherwise have occurred if Numba could not eliminate these dead branches. The next example demonstrates this capability when it does work, and also when it doesn't.

In [None]:
@njit
def branch_elim_example(a, b, cond):
    if cond is None:
        return a
    else:
        return b

This call, where `cond` is `None`, succeeds due to the elision of the `else` branch:

In [None]:
branch_elim_example(1, (1, 2), None)

In the following call branch elimination fails, forcing an attempt to unify two things that cannot be unified:

In [None]:
branch_elim_example(1, (1, 2), True)

The following cell contains the same function and call, but run with `%%script` so you can inspect the IR and typing if you wish.

In [None]:
%%script python
from numba import njit
from numba import config
import numpy as np
config.DEBUG_TYPEINFER = True
config.DUMP_IR = True

@njit
def branch_elim_example(a, b, cond):
    if cond is None:
        return a
    else:
        return b
    
branch_elim_example(1, (1, 2), True)

### General summary of Branch Elimination

* Branch elimination can sometimes remove dead code and prevent unification errors.
* In practive if you find that some calls fail to unify, then branch elimination may be involved.

# CUDA-specific issues

This section looks at a few issues where performance on CUDA can be impacted due to the typing. These are:

* Widening unification
* Widening arithmetic, and its propagation
* The typing of integer arithmetic
* Register usage control

## Widening unification

Unification of types can result in a type that is larger than any of the types from the set that was unified. This first example uses the CPU target because it makes for a simpler example, but the general idea of widening unification applies to the CUDA target as well.

In [None]:
@njit
def select(a, b, threshold, value):
    if threshold < value:
        r = a
    else:
        r = b
    return r

a = np.float32(1)
b = np.int32(2)
select(a, b, 10, 11)  # Call with (float32, int32, int64, int64)

After the call, we can inspect the typing:

In [None]:
select.inspect_types()

### Exercises

Try to determine from the typing:

* Try to determine the return type from the typing output. What was it? 
* Why was this type chosen instead of one of the types in the set?
* Fix the above code so that the return type is no wider than any of the input types.

## Width of constants

The default width of constants and the propagation of their width can have an effect on the typing that results in slower code due to the use of double precision units, and increased register usage. We will build up an example step-by-step to see the impact on the propagated types and the knock-on effects on the LLVM IR and PTX code. 

We begin with a very simple example, where we assign a constant to an array element:

In [None]:
from numba import void

@cuda.jit(void(float32[:]))
def assign_constant(x):
    x[0] = 2.0

Now let's see the typing:

In [None]:
assign_constant.inspect_types()

The constant has a type of `float64`. Now let's look at what LLVM does with that, by viewing the LLVM IR after LLVM optimizations:

In [None]:
print(assign_constant.inspect_llvm())

It turns out that the LLVM optimizer was able to convert this back to a 32-bit constant: `store float 2.000000e+00, float* %arg.x.4, align 4`.

We see a similar width in the PTX:

In [None]:
print(assign_constant.inspect_asm())

Correspondingly, we have `mov.u32 	%r1, 1073741824;`. So far, so good.

### Increasing complexity slightly - in-place addition

Now let's build up the example a little - instead of assigning a constant, we add a constant to the array element:

In [None]:
@cuda.jit(void(float32[:]))
def add_constant(x):
    x[0] += 2.0

If we inspect the types, we see:

In [None]:
add_constant.inspect_types()

Again the typing of the constant is `float64`, and also the addition of the `float32` and `float64` (`$8binary_subscr.4` plus `$const10.5` stored in `$12inplace_add.6`) results in a `float64`.

But does the addition result in a 64-bit operation in the LLVM IR?

In [None]:
print(add_constant.inspect_llvm())

No! Again the LLVM optimizer has managed to reduce this to a 32-bit operation: `fadd float %.4957, 2.000000e+00`.

The PTX corresponds:

In [None]:
print(add_constant.inspect_asm())

As expected, we see `add.f32 	%f2, %f1, 0f40000000;`.

### Bringing in another addition

As well as adding a constant, we'll now add another array element:

In [None]:
@cuda.jit(void(float32[:], float32[:]))
def add_constant_2(x, y):
    x[0] += y[0] + 2.0

We would expect the IR to contain more `float64` operations:

In [None]:
add_constant_2.inspect_types()

What happens this time in the LLVM IR? Let's see:

In [None]:
print(add_constant_2.inspect_llvm())

Instead of operations on 32-bit floats, we now see casts (`fpext` / `fptrunc`) between 32- and 64-bit values, and operations on 64-bit values (`fadd double`). This time, the optimizer couldn't save us!

NVVM doesn't help us in this case either:

In [None]:
print(add_constant_2.inspect_asm())

Similarly we see casts (e.g. `cvt.f64.f32`) and operations on 64-bit values (e.g. `add.f64`).

### Exercise:

* Fix the typing of the `add_constant_2` function with an appropriate cast.
* Re-run the inspection of the typing, LLVM, and PTX to verify that the width of operations is reduced.

## Register usage

We can find out the register usage of the kernel from its `regs` attribute:

In [None]:
add_constant_2._func.get().attrs.regs

With the original typing, this gives 8 registers on my setup. With the "corrected" typing, fewer registers are needed - 6 in my case. In general, reducing the width of operations reduces register usage and can increase occupancy.

## Controlling register usage by parameter

The `max_registers` keyword argument of the `@cuda.jit` decorator can also be used to limit register usage, which can be helpful if the limit reducing register usage via code changes has been hit.

This only has an effect for kernels of a minimum level of complexity - the following is about the size of the simplest example for which it can be seen to take effect:

In [None]:
@cuda.jit
def busy_arithmetic(x, y, a):
    a = y[0]
    b = 2.0
    c = y[1] / 6
    d = y[2] % 8
    e = y[3] * y[4]
    for i in range(a):
        a += 2
        b -= c
        e *= d
        x[0] += a * b + c * d - e

x = np.empty(32, dtype=np.float32)
y = np.empty(32, dtype=np.float32)
kernel = busy_arithmetic.specialize(x, y, 5)

Note here we used the `specialize()` function of the CUDA-jitted kernel - this can be used to give us a compiled kernel with a typing for a particular set of arguments without launching a kernel. This is convenient when we only want to experiment with a particular typing of a function.

Let's examine the register usage of the kernel:

In [None]:
kernel._func.get().attrs.regs

Now if we redefine the kernel with the `max_registers` keyword argument and inspect the register usage:

In [None]:
@cuda.jit(max_registers=24)
def busy_arithmetic_maxreg_24(x, y, a):
    a = y[0]
    b = 2.0
    c = y[1] / 6
    d = y[2] % 8
    e = y[3] * y[4]
    for i in range(a):
        a += 2
        b -= c
        e *= d
        x[0] += a * b + c * d - e
        
kernel_maxreg_24 = busy_arithmetic_maxreg_24.specialize(x, y, 5)
kernel_maxreg_24._func.get().attrs.regs

We see that the register usage is reduced to the level we requested. However, the `max_registers` kwarg places no commitment on the optimizer, so it may not be honored. For example:

In [None]:
@cuda.jit(max_registers=20)
def busy_arithmetic_maxreg_20(x, y, a):
    a = y[0]
    b = 2.0
    c = y[1] / 6
    d = y[2] % 8
    e = y[3] * y[4]
    for i in range(a):
        a += 2
        b -= c
        e *= d
        x[0] += a * b + c * d - e
        
kernel_maxreg_20 = busy_arithmetic_maxreg_20.specialize(x, y, 5)
kernel_maxreg_20._func.get().attrs.regs

The register usage was reduced, but only to 24, which was the minimum achievable.

## Integer arithmetic width

Numba strongly prefers using `int64` values for all integer arithmetic. Let's consider an example:

In [None]:
from numba import int32

@cuda.jit
def index_computation(x):
    i = cuda.grid(1)                     # int32

    if i < x.shape[0]:                   # x.shape[0] will be int64
        for j in range(3):               # range_iter_int64
            x[i, j] = (i * 2) + (j * 3)  # int64 computations

x = np.zeros((1024, 3), dtype=np.int32)
kernel = index_computation.specialize(x)

Now if we inspect the typing:

In [None]:
kernel.inspect_types()

We see that most of the arithmetic happens using `int64` values, and the range iterates over `int64` (the `range_iter_int64` type).

We can attempt to reduce the width of arithmetic operations using casts, but it requires a lot of casts:

In [None]:
@cuda.jit
def index_computation_int32(x):
    i = cuda.grid(1)                     # int32

    if i < int32(x.shape[0]):            # Attempt to compare using int32 arithmetic
        for j in range(int32(3)):        # Force iteration over int32 - a range_iter_int32
            x[i, j] = int32(int32(int32(i) * int32(2))
                            + int32(int32(j) * int32(3)))
                                         # Attempt to make all constants and operations int32

kernel_int32 = index_computation_int32.specialize(x)

If we have been successful, we should see a reduced register usage for the `index_computation_int32` kernel:

In [None]:
kernel._func.get().attrs.regs

In [None]:
kernel_int32._func.get().attrs.regs

We have actually made things worse! Often it is better to try not to reduce the width of `int64` operations, because it results in a mix of `int32` and `int64` values, which ends up requiring more registers.

Exercises:

* Inspect the IR, LLVM, and PTX to see where `int64` computations remain in `index_computation_int32`.

# Summary

Throughout the course of this notebook, we have:

* Seen how to use `inspect_types()` to view the typing of jitted functions
* Examined *phi nodes* and looked at the unification of types at phi nodes.
* Seen how calls with different argument types result in different specialisations of a function, that have different typings.
* Examined typing errors:
  * Unification failures, and how to determine what failed to unify
  * Use of a variable with inconsistent typing throughout the function (e.g. 1D array vs. 2D array)
  * Use of unsupported functions, or functions implemented using unsupported functions in the CUDA target.
* Seen an example of branch elimination, and how it sometimes succeeds in allowing typings with arguments that could otherwise have resulted in unification errors.
* Looked at CUDA-specific issues, mainly related to register usage:
  * When widening unification occurs, and how to prevent it.
  * When widening arithmetic occurs, and how to avoid it for floating point types.
  * How integer arithmetic strongly prefers `int64`, and how it can be counterproductive to try to reduce it to `int32` and narrower types.
* Seen how to control register usage using the `max_registers` keyword argument.