# Numba Cheat Sheet

### ⚡️ Why Numba rocks for Python numerics
- 🚀 JIT-compile hot loops to native code with a single decorator.
- 🧮 Stay in pure Python syntax while getting C/Fortran-like performance.
- ♻️ Reuse the same function across NumPy arrays, scalars, and typed containers.
- 🧭 Pick the right decorator for each workload: scalar `njit`, SIMD `vectorize`, parallel `prange`, or GPU backends.

Run through the sections below whenever you need a quick refresher during the workshop.

## Installation & imports


In [None]:
!pip install numba

In [None]:
import math
import numpy as np
from numba import njit, prange, vectorize, guvectorize, float64, int32
from numba import types
from numba.typed import List, Dict

print('Numba version:', __import__('numba').__version__)


### 🔑 Core decorators at a glance
Use this as the menu of tools to reach for when a loop starts feeling slow or repetitive—the bullets remind you which decorator unlocks which hardware feature.
- `@njit` / `@jit`: compile a function the first time it runs; use `nogil=True`, `fastmath=True`, or explicit signatures for extra control.
- `prange`: drop-in replacement for `range` inside `@njit(parallel=True)` loops.
- `@vectorize` & `@guvectorize`: build universal functions (ufuncs) that broadcast like NumPy.
- `typed.List`, `typed.Dict`: Numba-friendly containers for dynamic data.


#### ✅ Scalar JIT example
A minimal Horner scheme shows how a plain Python loop instantly benefits from `@njit`—no refactor, just annotate and go.


In [None]:
@njit(cache=True)
def pairwise_distance(x, y):
    total = 0.0
    for i in range(x.shape[0]):
        diff = x[i] - y[i]
        total += diff * diff
    return math.sqrt(total)

vec_a = np.random.rand(1_000).astype(np.float64)
vec_b = np.random.rand(1_000).astype(np.float64)
print('Distance:', pairwise_distance(vec_a, vec_b))


#### 🎯 Controlling signatures & options
Lock in dtypes and math flags when you care about determinism or instruction fusion; the AXPY demo mirrors a classic BLAS call.


In [None]:
@njit('(float64, float64[:], float64[:])', fastmath=True)
def axpy(a, x, y):
    for i in range(x.shape[0]):
        y[i] = a * x[i] + y[i]

x = np.linspace(0, 1, 8, dtype=np.float64)
y = np.ones_like(x)
axpy(2.0, x, y)
print('AXPY result:', y)


#### 🧵 Parallel loops with `prange`
When iterations are independent, swap `range` for `prange` to fan work across CPU cores. This row-sum kernel mirrors a real profiling hot spot.


In [None]:
@njit(parallel=True)
def row_sums(matrix):
    out = np.empty(matrix.shape[0], dtype=np.float64)
    for i in prange(matrix.shape[0]):
        s = 0.0
        for j in range(matrix.shape[1]):
            s += matrix[i, j]
        out[i] = s
    return out

mat = np.random.rand(4_000, 512)
print('Row sums shape:', row_sums(mat).shape)


#### 🧮 Vectorized ufuncs
Need NumPy-style broadcasting but want native speed? Vectorize wraps your scalar logic into a drop-in ufunc.


In [None]:
@vectorize([float64(float64)], target='cpu', nopython=True)
def smooth_relu(x):
    return math.log1p(math.exp(x))

inp = np.linspace(-3, 3, 7)
print('Smooth ReLU:', smooth_relu(inp))


#### 🔄 Generalized ufunc (`guvectorize`)
GUfuncs let you express small batched kernels—think per-row dot products or custom reductions—while keeping NumPy broadcasting semantics.


In [None]:
@guvectorize([float64[:], float64[:], float64[:]], '(n),(n)->()', nopython=True)
def dot_prod(x, y, out):
    acc = 0.0
    for i in range(x.shape[0]):
        acc += x[i] * y[i]
    out[0] = acc

block = np.arange(12, dtype=np.float64).reshape(3, 4)
print('Dot products vs reversed:', dot_prod(block, block[:, ::-1]))


### 📦 Typed containers
Typed lists and dicts are essential when you need dynamic data structures without falling back to Python objects. The examples show both running in pure nopython mode.


In [None]:
@njit
def rolling_average(data):
    window = List()
    result = List()
    for value in data:
        window.append(value)
        if len(window) > 4:
            window.pop(0)
        result.append(sum(window) / len(window))
    return result

values = List([1.0, 2.0, 4.0, 8.0, 16.0])
print('Rolling averages:', list(rolling_average(values)))


### 🧪 Inspect typing & lowering
After a successful compile, peek at the resolved signatures and look for `object` annotations—any star `*` markers flag code that still falls back to Python.


In [None]:
sig = pairwise_distance.signatures[0]
print('Cached signature:', sig)
print(pairwise_distance.inspect_types())


### 🧭 Diagnostics & debugging
Flip these switches when something refuses to compile; they surface typing decisions, HTML reports, and runtime warnings that speed up debugging.
