
# Numba: Practical Introduction (CPU and CUDA) with one caveat...

**Updated:** 2025-11-11

This notebook is a short, hands-on introduction to [Numba](https://numba.pydata.org/) - you can follow this link
It covers:
1. What Numba is and when to use it  
2. `@njit` (nopython) basics and compilation overhead  
3. Elementwise ufuncs via `@vectorize`  
4. Parallel loops with `prange`  
5. Numerical flags like `fastmath`  
6. Common pitfalls & debugging patterns  
7. Mini-exercises **with solutions**  
8. **CUDA with Numba**: a short, practical intro (SAXPY and tiled matmul)

> Target audience: Python users comfortable with NumPy who want to speed up loops, elementwise math, and try simple GPU kernels without switching languages.



## 0) Setting the scene

We'll use a lightweight helper function to time the execution. Remember: the **first** call to a Numba-compiled function includes compilation time.


In [1]:
"""
   First calls often include a large overhead (Numba JIT compilation,
   CPU cache warmup, disk/page faults). Running the function a few times
   before the actual measururements gives you more stable performance.
   For Numba specifically it is important - similar for realistic CUDA
   kernels.
"""
from time import perf_counter
import numpy as np

def clock(fn, *args, repeat=10, warmup=1, **kwargs):
    for _ in range(max(0, warmup)):
        fn(*args, **kwargs)
    best = float('inf')
    out = None
    for _ in range(repeat):
        t0 = perf_counter()
        out = fn(*args, **kwargs)
        dt = perf_counter() - t0
        if dt < best:
            best = dt
    return best, out


## 1) How to use Numba?

Numba is a **just-in-time (JIT) compiler** that turns a well-behaved subset of Python and NumPy into fast machine code. You’ll see the biggest wins when you have **numeric loops over arrays** that may be hard to handle with pure NumPy.

### Great use cases
- **Loops with heavey computations** written in Python:
  - Example patterns: accumulate sums/products, distance computations, histogramming, simple state machines, dynamic programming, finite-difference stencils.
- **Elementwise math** and simple transforms:
  - Turn a scalar function into a ufunc via `@vectorize` (broadcasts like NumPy), note!! ufunc (universal function) is a NumPy-style function that can operate elementwise on arrays of any compatible shape, using **broadcasting** rules, without you writing loops
- **Structured loops over arrays** that don’t map cleanly to vectorization:
  For instance nested loops where each iteration may be cheap but there are lots of them.
- **Parallelizable loops** with little cross-iteration dependency:
  Use `@njit(parallel=True)` and `prange` for row/column-wise operations, reductions, or embarrassingly parallel sweeps.
- **Performance glue** around library calls:
  If 80% is already fast NumPy/BLAS, but 20% is Python loops, move that 20% into Numba.

### When Numba won’t help much
- **Code dominated by Python objects** (strings, dicts of objects, arbitrary classes).
- **I/O bound tasks** (reading files, network requests).
- **Already-vectorized NumPy/BLAS** doing the heavy lifting (e.g., large `A @ B` matrix multiply is already in optimized C/Fortran).
- **Very small inputs** where JIT overhead dwarfs execution time.

### Mental model
- **`@njit`** compiles a function at **first call** for the specific input types. Subsequent calls with the same dtypes are fast.
- **`nopython` mode** (what `@njit` enforces) keeps execution out of the Python interpreter—this is where the speedups come from.
- **Types matter.** Numba likes concrete dtypes: `float32`, `float64`, `int64`, etc. Passing different dtypes may trigger new compilations.

### What “good” Numba code looks like
- Uses **NumPy arrays** with explicit dtypes.
- Prefers **for-loops** and simple indexing inside kernels (don’t fear loops—Numba compiles them).
- Avoids Python containers that change type/size (e.g., `list.append` of mixed types). If sizes are known, **preallocate arrays**.
- Keeps control flow clear and data-dependent branches predictable.

### Parallelism with `prange`
- Replace `range` with **`prange`** in the **outermost** loop and add `@njit(parallel=True)`.
- Works best when iterations are independent or use simple reductions (`+=`, `min`, `max`).
- Start parallelizing only after the scalar kernel is correct and fast; then measure.

### Fast math & numerics
- `@njit(fastmath=True)` allows more aggressive FP optimizations (FMA, reassociation).
- Expect tiny numerical differences; **validate tolerances** for your domain.

### Benchmarking tips
- **Warm up** before timing: the first call includes JIT compilation.
- Time with a helper (e.g., `clock(fn, ...)`) and compare against a **NumPy baseline** for correctness (`np.allclose`).
- Use **representative problem sizes**; too small hides benefits, too large can hit cache/memory limits.

**Rule of thumb:** If you catch yourself writing a performance-critical `for` loop in Python over numeric arrays, try moving that loop into an `@njit` function first. If iterations are independent, try `prange`. If it’s elementwise, try `@vectorize`. Measure, then iterate.



## 2) `@njit` (nopython) Basics
If Numba can't compile a line, it errors out instead of silently falling back to Python


In [2]:

from numba import njit

def py_sum_squares(x):
    s = 0.0
    for v in x:
        s += v * v
    return s

@njit
def nb_sum_squares(x):
    s = 0.0
    for v in x:
        s += v * v
    return s

x = np.random.randn(10000000).astype(np.float64)

t_py, ref = clock(py_sum_squares, x, repeat=1, warmup=0)
t_nb_first, _ = clock(nb_sum_squares, x, repeat=1, warmup=0)
t_nb_next, out = clock(nb_sum_squares, x, repeat=3, warmup=1)

print(f"Python loop:                 {t_py:.4f} s")
print(f"Numba first call (compile):  {t_nb_first:.4f} s")
print(f"Numba subsequent best:       {t_nb_next:.4f} s")
print("Results equal", np.allclose(ref, out))


Python loop:                 1.6043 s
Numba first call (compile):  0.9891 s
Numba subsequent best:       0.0142 s
Results equal True



**What to note:** first-call overhead vs stable execution. For fair benchmarks, always warm up first.



## 3) Elementwise UFuncs with `@vectorize`

`@vectorize` turns a scalar function into a NumPy-style ufunc, supporting broadcasting and array inputs.


In [3]:

from numba import vectorize, float64

@vectorize([float64(float64, float64)], target='parallel')
def nb_hypot(a, b):
    return (a*a + b*b) ** 0.5

N = 10000000
a = np.random.rand(N)
b = np.random.rand(N)

def np_hypot(a, b):
    return np.sqrt(a*a + b*b)

t_np, y_np = clock(np_hypot, a, b, repeat=10, warmup=1)
t_nb, y_nb = clock(nb_hypot, a, b, repeat=10, warmup=1)

print(f"NumPy hypot-like:  {t_np:.4f} s, value: ", y_np)
print(f"Numba @vectorize:  {t_nb:.4f} s, value: ", y_np)
print("Results equal", np.allclose(y_np, y_nb))


NumPy hypot-like:  0.0862 s, value:  [0.83993094 0.57611054 0.58145888 ... 1.02556774 0.93929247 0.74168338]
Numba @vectorize:  0.0253 s, value:  [0.83993094 0.57611054 0.58145888 ... 1.02556774 0.93929247 0.74168338]
Results equal True



**Tip:** You can choose targets: `'cpu'` (default), `'parallel'` (multicore), or `'cuda'` (GPU).



## 4) Parallel Loops with `prange`
Use `@njit(parallel=True)` and `prange` for independent iterations (e.g., over rows).


In [4]:

from numba import prange

@njit(parallel=True)
def rowwise_dot(A, B):
    M, K = A.shape
    C = np.empty(M, dtype=A.dtype)
    for i in prange(M):
        s = 0.0
        for k in range(K):
            s += A[i, k] * B[i, k]
        C[i] = s
    return C

M, K = 20000, 256
A = np.random.randn(M, K).astype(np.float64)
B = np.random.randn(M, K).astype(np.float64)

def np_rowwise_dot(A, B):
    return np.einsum('ij,ij->i', A, B)

t_np, ref = clock(np_rowwise_dot, A, B, repeat=3, warmup=1)
t_nb, out = clock(rowwise_dot, A, B, repeat=3, warmup=1)

print(f"NumPy einsum: {t_np:.4f} s")
print(f"Numba prange: {t_nb:.4f} s")
print("Results equal", np.allclose(ref, out))


NumPy einsum: 0.0069 s
Numba prange: 0.0061 s
Results equal True



## 5) Numerical Flags: `fastmath=True`
Lets the compiler apply more aggressive floating-point optimizations (e.g., FMA). Validate numerical tolerances for your domain.


In [5]:

@njit(fastmath=True)
def poly_eval_fast(a, x):
    acc = 0.0
    for c in a:
        acc = acc * x + c
    return acc

@njit
def poly_eval_strict(a, x):
    acc = 0.0
    for c in a:
        acc = acc * x + c
    return acc

a = np.random.randn(10000000).astype(np.float64)
x = 1.2345

t_strict, y_strict = clock(poly_eval_strict, a, x, repeat=10, warmup=1)
t_fast,   y_fast   = clock(poly_eval_fast,   a, x, repeat=10, warmup=1)

print(f"Strict:   {t_strict:.6f} s")
print(f"Fastmath: {t_fast:.6f} s")
print("Results equal", np.allclose(y_strict, y_fast, rtol=1e-12, atol=1e-12))


Strict:   0.025671 s
Fastmath: 0.013887 s
Results equal True


## Fast math in Numba — what it is, why it’s faster, and when to use it

`fastmath=True` lets Numba apply **aggressive floating-point optimizations** that C/LLVM compilers commonly use for speed.  
You trade a bit of IEEE-754 strictness for performance.

### What optimizations does it unlock (conceptually)?
- **Reassociation of operations:**  
  `(a + b) + c` may become `a + (b + c)` to enable vectorization/unrolling.  
  *Effect:* tiny numerical differences because floating-point addition is not strictly associative.
- **Fused multiply-add (FMA):**  
  `a*b + c` may be computed as a single instruction (`fma`) with one rounding instead of two.  
  *Effect:* usually **more** accurate and faster, but different from strict separate operations.
- **Reciprocal/sqrt transforms & approximations:**  
  Replace `x / y` with `x * (1/y)` or use fast `rsqrt` internally.  
  *Effect:* small accuracy changes, potential speedup.
- **Assume no NaN/Inf / signed-zero is interchangeable:**  
  Allows more reordering and simplification.  
  *Effect:* code that relies on distinguishing `+0.0` vs `-0.0`, or propagating `NaN`s in specific ways, may behave differently.

### Why it can be faster
- Gives LLVM more freedom to **vectorize**, **unroll**, and **reorder** operations to use the CPU’s SIMD and FMA units efficiently.
- Reduces dependency chains (e.g., `1/y` hoisted/reused) and enables fewer, tighter loops.

### When to use it
- **Numerical kernels** where tiny last-bit differences are acceptable: filters, stencils, many ML kernels, signal/image transforms, Monte Carlo inner loops.
- **Performance-critical** paths where you’ve confirmed the numerical tolerance (e.g., `rtol=1e-8` is fine for your downstream use).

### When to avoid it (or double-check)
- **Strict reproducibility** across machines/compilers is required (e.g., bit-for-bit equality, scientific audits).
- **Edge-case semantics matter:** code that relies on precise IEEE behavior of `NaN`, `Inf`, `-0.0`, or on **exact** operation order (e.g., Kahan summation).
- **Ill-conditioned problems** where tiny rounding differences blow up.


In [6]:
"""
   A 'texbook' example (taken from the web as is)
"""
import numpy as np
from numba import njit
from time import perf_counter

@njit
def poly_strict(a, x):
    # Horner's method (strict)
    acc = 0.0
    for c in a:
        acc = acc * x + c
    return acc

@njit(fastmath=True)
def poly_fast(a, x):
    acc = 0.0
    for c in a:
        acc = acc * x + c
    return acc

# Data
a = np.random.randn(2000).astype(np.float64)
x = 1.23456789

def bench(fn, *args, repeat=5):
    best = 1e9
    out = None
    for _ in range(repeat):
        t0 = perf_counter()
        out = fn(*args)
        dt = perf_counter() - t0
        if dt < best:
            best = dt
    return best, out

t_s, y_s = bench(poly_strict, a, x)
t_f, y_f = bench(poly_fast,   a, x)

print(f"Strict:   {t_s:.8f} s -> {y_s:.16e}")
print(f"Fastmath: {t_f:.8f} s -> {y_f:.16e}")
print("allclose?", np.allclose(y_s, y_f, rtol=1e-12, atol=1e-12))
print("abs diff:", abs(y_s - y_f))


Strict:   0.00000532 s -> -1.5279206297798073e+183
Fastmath: 0.00000280 s -> -1.5279206297798073e+183
allclose? True
abs diff: 0.0



## 6) Nice, but maybe troublesome

1. **Python objects in `@njit`:** Avoid dynamic lists/dicts/sets; prefer NumPy arrays with explicit dtype.  
2. **Hidden object mode:** `@jit` without `nopython=True` can fall back to Python. Prefer `@njit`.  
3. **Unsupported NumPy corners:** If compilation fails, simplify the kernel.  
4. **Specialization:** Different dtypes trigger separate compilations. Warm up with representative inputs.  
5. **Parallel overhead:** `prange` pays off only for sufficiently large loops.



### Example: Why Python lists can break `@njit`


In [None]:
"""
   Using a plain Python dict with a string key -> TypingError
   in nopython mode
"""
from numba import njit

try:
    @njit
    def bad_dict():
        d = {}          # Python dict literal
        d['a'] = 1      # string key (object) -> unsupported in nopython mode
        return d

    # Trigger compilation
    print("bad_dict() ->", bad_dict())
except Exception as e:
    print("[Expected compile failure] bad_dict:", type(e).__name__, "-", e)


bad_dict() -> {a: 1}


In [None]:

# Fix: preallocate a NumPy array
from numba import njit
@njit
def good_array_sum(n):
    xs = np.empty(n, dtype=np.int64)
    for i in range(n):
        xs[i] = i
    s = 0
    for v in xs:
        s += v
    return s

print("Sum 0..9 =", good_array_sum(10))


Sum 0..9 = 45



## 7) Mini-exercises and solutions

1. **1D stencil:** `y[i] = 0.25*x[i-1] + 0.5*x[i] + 0.25*x[i+1]` for `i=1..N-2`.  
2. **Parallel reduction:** Column-wise sums of a 2D array with `prange`.  
3. **Ufunc practice:** Stable logistic `σ(x) = 1 / (1 + exp(-x))`.  
4. **Fastmath check:** Accumulate many small values into a large one; compare strict vs fastmath.



### Solutions



#### 7.1) 1D Stencil Solution


In [None]:

from numba import njit

@njit
def stencil1d_numba(x):
    n = x.size
    y = np.empty_like(x)
    if n == 0:
        return y
    y[0] = x[0]
    for i in range(1, n-1):
        y[i] = 0.25*x[i-1] + 0.5*x[i] + 0.25*x[i+1]
    if n > 1:
        y[n-1] = x[n-1]
    return y

def stencil1d_numpy(x):
    y = np.empty_like(x)
    if x.size == 0:
        return y
    y[0] = x[0]
    y[-1] = x[-1]
    y[1:-1] = 0.25*x[:-2] + 0.5*x[1:-1] + 0.25*x[2:]
    return y

x = np.random.randn(5000000).astype(np.float64)
t_np, y_np = clock(stencil1d_numpy, x, repeat=3, warmup=1)
t_nb, y_nb = clock(stencil1d_numba,  x, repeat=3, warmup=1)
print(f"Stencil NumPy: {t_np:.4f} s")
print(f"Stencil Numba: {t_nb:.4f} s")
print("Close?", np.allclose(y_np, y_nb))


Stencil NumPy: 0.0542 s
Stencil Numba: 0.0191 s
Close? True



#### 7.2) Column-wise Reduction Solution


In [None]:

from numba import prange, njit

@njit(parallel=True)
def colsum_parallel(A):
    m, n = A.shape
    out = np.zeros(n, dtype=A.dtype)
    for i in prange(m):
        for j in range(n):
            out[j] += A[i, j]
    return out

def colsum_numpy(A):
    return A.sum(axis=0)

A = np.random.randn(40_000, 256).astype(np.float64)
t_np, ref = clock(colsum_numpy, A, repeat=3, warmup=1)
t_nb, out = clock(colsum_parallel, A, repeat=3, warmup=1)
print(f"ColSum NumPy: {t_np:.4f} s")
print(f"ColSum Numba prange: {t_nb:.4f} s")
print('Close?', np.allclose(ref, out))


ColSum NumPy: 0.0087 s
ColSum Numba prange: 0.0067 s
Close? False



#### 7.3) Stable Logistic UFunc Solution


In [None]:

from numba import vectorize, float64

@vectorize([float64(float64)])
def logistic_ufunc(x):
    if x >= 0.0:
        z = np.exp(-x)
        return 1.0 / (1.0 + z)
    else:
        z = np.exp(x)
        return z / (1.0 + z)

def logistic_numpy(x):
    out = np.empty_like(x)
    pos = x >= 0
    z = np.empty_like(x)
    z[pos] = np.exp(-x[pos])
    out[pos] = 1.0 / (1.0 + z[pos])
    z[~pos] = np.exp(x[~pos])
    out[~pos] = z[~pos] / (1.0 + z[~pos])
    return out

x = (np.random.randn(10_000_000)*8).astype(np.float64)
t_np, y_np = clock(logistic_numpy, x, repeat=3, warmup=1)
t_nb, y_nb = clock(logistic_ufunc,  x, repeat=3, warmup=1)
print(f"Logistic NumPy: {t_np:.4f} s")
print(f"Logistic @vectorize: {t_nb:.4f} s")
print("Close?", np.allclose(y_np, y_nb, rtol=1e-12, atol=1e-12))


Logistic NumPy: 0.8163 s
Logistic @vectorize: 0.1394 s
Close? True



#### 7.4) Fastmath Accumulation Solution


In [None]:

from numba import njit

@njit
def accumulate_strict(big, small, n):
    acc = big
    for _ in range(n):
        acc += small
    return acc

@njit(fastmath=True)
def accumulate_fast(big, small, n):
    acc = big
    for _ in range(n):
        acc += small
    return acc

big  = 1e6
small = 1e-8
n = 10_000_000  # Adjust if needed for your machine
t_strict, v_strict = clock(accumulate_strict, big, small, n, repeat=3, warmup=1)
t_fast,   v_fast   = clock(accumulate_fast,   big, small, n, repeat=3, warmup=1)
print(f"Strict accumulate:   {t_strict:.4f} s -> {v_strict:.10f}")
print(f"Fastmath accumulate: {t_fast:.4f} s -> {v_fast:.10f}")
print("Absolute diff:", abs(v_strict - v_fast))


Strict accumulate:   0.0126 s -> 1000000.1001171768
Fastmath accumulate: 0.0009 s -> 1000000.1000073235
Absolute diff: 0.00010985322296619415



## 8) CUDA with Numba (Short Intro)

This section shows how to write and launch simple GPU kernels using `numba.cuda`.

**Prerequisites:** an NVIDIA GPU + CUDA driver/runtime, and `numba` built with CUDA support.  
We'll:
- Check device availability
- Implement **SAXPY** (`y = a*x + y`) as a 1D kernel
- Implement a **tiled matrix multiplication** with shared memory
- Compare to NumPy for correctness and basic timing

> If CUDA isn't available, the cells will detect this and print a helpful message.


In [7]:

import numpy as np
try:
    from numba import cuda
    cuda_available = cuda.is_available()
except Exception:
    cuda_available = False

print("CUDA available:", cuda_available)
if cuda_available:
    dev = cuda.get_current_device()
    print("Device:", dev.name.decode() if hasattr(dev.name, 'decode') else dev.name)
    print("Compute capability:", getattr(dev, 'compute_capability', 'n/a'))


CUDA available: True
Device: Tesla T4
Compute capability: (7, 5)



### 8.1) 1D Kernel: SAXPY

We launch a 1D grid of threads; each thread computes one element: `y[i] = a*x[i] + y[i]`.


In [8]:

import math
if cuda_available:
    @cuda.jit
    def saxpy_kernel(a, x, y):
        i = cuda.grid(1)  # global thread index
        if i < x.size:
            y[i] = a * x[i] + y[i]

    N = 5_000_000
    a = 2.0
    x = np.random.randn(N).astype(np.float32)
    y = np.random.randn(N).astype(np.float32)
    y_ref = a*x + y

    # Copy to device
    d_x = cuda.to_device(x)
    d_y = cuda.to_device(y)

    threads_per_block = 256
    blocks = math.ceil(N / threads_per_block)

    # Time with CUDA events for kernel-only timing
    start = cuda.event()
    stop = cuda.event()
    start.record()
    saxpy_kernel[blocks, threads_per_block](a, d_x, d_y)
    stop.record()
    stop.synchronize()
    ms = cuda.event_elapsed_time(start, stop)  # milliseconds

    y_out = d_y.copy_to_host()

    print(f"SAXPY kernel time: {ms:.3f} ms")
    print("Close to NumPy ref?", np.allclose(y_ref, y_out, rtol=1e-6, atol=1e-6))
else:
    print("CUDA not available — skipping SAXPY demo.")


ERROR:numba.cuda.cudadrv.driver:Call to cuLinkAddData results in CUDA_ERROR_UNSUPPORTED_PTX_VERSION


LinkerError: [222] Call to cuLinkAddData results in CUDA_ERROR_UNSUPPORTED_PTX_VERSION
ptxas application ptx input, line 9; fatal   : Unsupported .version 8.8; current version is '8.4'


### 8.2) Tiled Matrix Multiplication (Shared Memory)

Each block computes one tile of `C = A @ B` using shared memory to reduce global memory traffic.

- Choose `TILE = 16` or `32` depending on your GPU.
- Grid dims cover the output matrix in tiles.
- Within a block, threads cooperatively load tiles of `A` and `B`, then compute partial sums.


In [None]:

if cuda_available:
    TILE = 16

    @cuda.jit
    def matmul_tiled(A, B, C, n):
        # Shared memory tiles
        sA = cuda.shared.array(shape=(TILE, TILE), dtype=cuda.float32)
        sB = cuda.shared.array(shape=(TILE, TILE), dtype=cuda.float32)

        tx = cuda.threadIdx.x
        ty = cuda.threadIdx.y
        row = cuda.blockIdx.y * TILE + ty
        col = cuda.blockIdx.x * TILE + tx

        acc = 0.0
        # Loop over tiles
        for t in range((n + TILE - 1) // TILE):
            # Load A tile if within bounds
            a_col = t * TILE + tx
            b_row = t * TILE + ty
            if row < n and a_col < n:
                sA[ty, tx] = A[row, a_col]
            else:
                sA[ty, tx] = 0.0
            if b_row < n and col < n:
                sB[ty, tx] = B[b_row, col]
            else:
                sB[ty, tx] = 0.0
            cuda.syncthreads()

            # Compute partial product for this tile
            for k in range(TILE):
                acc += sA[ty, k] * sB[k, tx]
            cuda.syncthreads()

        if row < n and col < n:
            C[row, col] = acc

    n = 1024  # adjust for your GPU
    A = np.random.randn(n, n).astype(np.float32)
    B = np.random.randn(n, n).astype(np.float32)
    C = np.zeros((n, n), dtype=np.float32)

    d_A = cuda.to_device(A)
    d_B = cuda.to_device(B)
    d_C = cuda.to_device(C)

    threads = (TILE, TILE)
    blocks = (math.ceil(n / TILE), math.ceil(n / TILE))

    start = cuda.event(); stop = cuda.event()
    start.record()
    matmul_tiled[blocks, threads](d_A, d_B, d_C, n)
    stop.record(); stop.synchronize()
    ms = cuda.event_elapsed_time(start, stop)

    C_out = d_C.copy_to_host()
    ref = A @ B

    print(f"Tiled matmul kernel: {ms:.1f} ms for {n}x{n}")
    print("Close to NumPy ref?", np.allclose(ref, C_out, rtol=1e-3, atol=1e-3))
else:
    print("CUDA not available — skipping tiled matmul demo.")


AttributeError: module 'numba.cuda' has no attribute 'float32'


**Notes on CUDA performance:**
- Tune `TILE`, grid/block sizes, and consider `float32` vs `float64` depending on GPU throughput.  
- For very large matrices, GPU memory bandwidth and PCIe transfer time can dominate; reuse device arrays to amortize copies.  
- Use CUDA events (as shown) for kernel timing; place data transfers outside timed regions when you want kernel-only numbers.
