# Numba Advanced Features

### 🧰 What we highlight
- 🏗️ Defining `jitclass` structures for object-like workflows.
- 📚 Working with `typed.List` and `typed.Dict` to keep dynamic containers JIT-friendly.
- 🔂 Creating callbacks with `cfunc` and wiring them into SciPy-style integrators.
- 🧵 Combining parallelism, fastmath, and random states for stochastic simulations.

## Setup


In [None]:
!pip install numba

In [None]:
import numpy as np
from numba import njit, prange
from numba.typed import List, Dict
from numba import types
from numba.experimental import jitclass
from numba import cfunc
import math


### 🧱 `jitclass`: struct-like objects
Use `jitclass` when you need stateful structures without sacrificing nopython performance—this particle example mirrors a tiny integrator.


In [None]:
spec = [
    ('position', types.float64[:]),
    ('velocity', types.float64[:]),
    ('mass', types.float64),
]

@jitclass(spec)
class Particle:
    def __init__(self, position, velocity, mass):
        self.position = position
        self.velocity = velocity
        self.mass = mass

    def kinetic_energy(self):
        vx, vy = self.velocity
        return 0.5 * self.mass * (vx * vx + vy * vy)

    def integrate(self, force, dt):
        ax = force[0] / self.mass
        ay = force[1] / self.mass
        self.velocity[0] += ax * dt
        self.velocity[1] += ay * dt
        self.position[0] += self.velocity[0] * dt
        self.position[1] += self.velocity[1] * dt

p = Particle(np.array([0.0, 0.0], dtype=np.float64),
             np.array([1.0, 0.5], dtype=np.float64),
             2.0)
print('Initial KE:', p.kinetic_energy())
p.integrate(np.array([0.0, -9.81 * p.mass], dtype=np.float64), 0.1)
print('Updated position:', p.position)


### 📦 Typed containers in action
Typed dictionaries and lists unlock dynamic data structures in compiled code; the examples show counting and geometry utilities that stay JIT-friendly.


In [None]:
@njit
def histogram(values):
    counts = Dict.empty(key_type=types.int32, value_type=types.int32)
    for v in values:
        key = int(v)
        if key in counts:
            counts[key] += 1
        else:
            counts[key] = 1
    return counts

@njit
def polyline_length(points):
    total = 0.0
    for i in range(1, len(points)):
        dx = points[i][0] - points[i - 1][0]
        dy = points[i][1] - points[i - 1][1]
        total += math.hypot(dx, dy)
    return total

values = np.array([0, 1, 1, 2, 0, 4, 4, 4, 2])
print('Histogram:', histogram(values))

poly = List()
poly.append(np.array([0.0, 0.0]))
poly.append(np.array([1.0, 0.0]))
poly.append(np.array([1.0, 1.0]))
print('Polyline length:', polyline_length(poly))


### 🔁 Callbacks with `cfunc`
Export JIT-compiled functions to C by grabbing their address. Perfect for SciPy ODE solvers or custom extension hooks.


In [None]:
@cfunc(types.double(types.double))
def gaussian(x):
    return math.exp(-0.5 * x * x)

print('Callable address:', hex(gaussian.address))


### 🧪 Parallel + RNG example
Combine `prange`, `fastmath`, and NumPy RNG to spin up thousands of random walks—handy for Monte Carlo style workloads.


In [None]:
@njit(parallel=True, fastmath=True)
def random_walks(steps, walkers):
    positions = np.zeros(walkers, dtype=np.float64)
    for i in prange(walkers):
        pos = 0.0
        for _ in range(steps):
            pos += 1 if np.random.rand() > 0.5 else -1
        positions[i] = pos
    return positions

trails = random_walks(2_000, 50_000)
print('Mean displacement:', trails.mean())
print('Std displacement:', trails.std())


### 🧮 Reusable solver class example
Wrap repeated physics updates inside a `jitclass` so you can reuse the integrator in larger simulations without Python overhead.


In [None]:
spec_solver = [
    ('dt', types.float64),
    ('drag', types.float64),
    ('gravity', types.float64),
]

@jitclass(spec_solver)
class ProjectileSolver:
    def __init__(self, dt, drag, gravity):
        self.dt = dt
        self.drag = drag
        self.gravity = gravity

    def step(self, position, velocity):
        ax = -self.drag * velocity[0]
        ay = -self.gravity - self.drag * velocity[1]
        velocity[0] += ax * self.dt
        velocity[1] += ay * self.dt
        position[0] += velocity[0] * self.dt
        position[1] += velocity[1] * self.dt
        if position[1] < 0.0:
            position[1] = 0.0
            velocity[1] = 0.0
        return position, velocity

solver = ProjectileSolver(0.01, 0.02, 9.81)
pos = np.array([0.0, 0.0])
vel = np.array([30.0, 30.0])
for _ in range(200):
    pos, vel = solver.step(pos, vel)
print('Landing position approximated:', pos)
