Skip to content

numba-mpi/numba-mpi

Repository files navigation

numba-mpi logo numba-mpi

Python 3 LLVM Linux OK macOS OK Windows OK Github Actions Status Maintenance License: GPL v3 PyPI version Anaconda-Server Badge AUR package DOI

Overview

numba-mpi provides Python wrappers to the C MPI API callable from within Numba JIT-compiled code (@njit mode).

Support is provided for a subset of MPI routines covering: size/rank, send/recv, allreduce, bcast, scatter/gather & allgather, barrier, wtime and basic asynchronous communication with isend/irecv (only for contiguous arrays); for request handling including wait/waitall/waitany and test/testall/testany.

The API uses NumPy and supports both numeric and character datatypes (e.g., broadcast). Auto-generated docstring-based API docs are published on the web: https://numba-mpi.github.io/numba-mpi

Packages can be obtained from PyPI, Conda Forge, Arch Linux or by invoking pip install git+https://github.com/numba-mpi/numba-mpi.git.

numba-mpi is a pure-Python package. The codebase includes a test suite used through the GitHub Actions workflows (thanks to mpi4py's setup-mpi!) for automated testing on: Linux (MPICH, OpenMPI & Intel MPI), macOS (MPICH & OpenMPI) and Windows (MS MPI).

Features that are not implemented yet include (help welcome!):

  • support for non-default communicators
  • support for MPI_IN_PLACE in [all]gather/scatter and allreduce
  • support for MPI_Type_create_struct (Numpy structured arrays)
  • ...

Hello world send/recv example:

import numba, numba_mpi, numpy

@numba.njit()
def hello():
    src = numpy.array([1., 2., 3., 4., 5.])
    dst_tst = numpy.empty_like(src)

    if numba_mpi.rank() == 0:
        numba_mpi.send(src, dest=1, tag=11)
    elif numba_mpi.rank() == 1:
        numba_mpi.recv(dst_tst, source=0, tag=11)

hello()

Example comparing numba-mpi vs. mpi4py performance:

The example below compares Numba + mpi4py vs. Numba + numba-mpi performance. The sample code estimates $\pi$ by integration of $4/(1+x^2)$ between 0 and 1 dividing the workload into n_intervals handled by separate MPI processes and then obtaining a sum using allreduce. The computation is carried out in a JIT-compiled function and is repeated N_TIMES, the repetitions and the MPI-handled reduction are done outside or inside of the JIT-compiled block for mpi4py and numba-mpi, respectively. Timing is repeated N_REPEAT times and the minimum time is reported. The generated plot shown below depicts the speedup obtained by replacing mpi4py with numba_mpi as a function of n_intervals - the more often communication is needed (smaller n_intervals), the larger the expected speedup.

import timeit, mpi4py, numba, numpy as np, numba_mpi

N_TIMES = 10000
N_REPEAT = 10
RTOL = 1e-3

@numba.njit
def get_pi_part(out, n_intervals, rank, size):
    h = 1 / n_intervals
    partial_sum = 0.0
    for i in range(rank + 1, n_intervals, size):
        x = h * (i - 0.5)
        partial_sum += 4 / (1 + x**2)
    out[0] = h * partial_sum

@numba.njit
def pi_numba_mpi(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        get_pi_part(part, n_intervals, numba_mpi.rank(), numba_mpi.size())
        numba_mpi.allreduce(part, pi, numba_mpi.Operator.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

def pi_mpi4py(n_intervals):
    pi = np.array([0.])
    part = np.empty_like(pi)
    for _ in range(N_TIMES):
        get_pi_part(part, n_intervals, mpi4py.MPI.COMM_WORLD.rank, mpi4py.MPI.COMM_WORLD.size)
        mpi4py.MPI.COMM_WORLD.Allreduce(part, (pi, mpi4py.MPI.DOUBLE), op=mpi4py.MPI.SUM)
        assert abs(pi[0] - np.pi) / np.pi < RTOL

plot_x = [1000 * k for k in range(1, 11)]
plot_y = {'numba_mpi': [], 'mpi4py': []}
for n_intervals in plot_x:
    for impl in plot_y:
        plot_y[impl].append(min(timeit.repeat(
            f"pi_{impl}({n_intervals})",
            globals=locals(),
            number=1,
            repeat=N_REPEAT
        )))

if numba_mpi.rank() == 0:
    from matplotlib import pyplot
    pyplot.figure(figsize=(8.3, 3.5), tight_layout=True)
    pyplot.plot(plot_x, np.array(plot_y['mpi4py'])/np.array(plot_y['numba_mpi']), marker='o')
    pyplot.xlabel('n_intervals (workload in between communication)')
    pyplot.ylabel('wall time ratio (mpi4py / numba_mpi)')
    pyplot.title(f'mpiexec -np {numba_mpi.size()}')
    pyplot.grid()
    pyplot.savefig('readme_plot.png')

plot

MPI resources on the web:

Acknowledgements:

Development of numba-mpi has been supported by the Polish National Science Centre (grant no. 2020/39/D/ST10/01220).