In [None]:
%load_ext wurlitzer

In [None]:
import python_example

In [None]:
import numpy as np
from numba import jit

In [None]:
def _fma_no_jit(
    out: 'np.ndarray[np.float64]',
    weights: 'np.ndarray[np.float64]',
    *arrays: 'np.ndarray[np.float64]',
):
    """Simple fused multiply–add, compiled to avoid Python memory implications.

    :param out: must be zero array in the same shape of each in `arrays`

    If not compiled, a lot of Python objects will be created,
    and as the Python garbage collector is inefficient,
    it would have larger memory footprints.
    """
    for weight, array in zip(weights, arrays):
        out += weight * array

In [None]:
@jit(nopython=True, nogil=True, cache=False)
def _fma(
    out: 'np.ndarray[np.float64]',
    weights: 'np.ndarray[np.float64]',
    *arrays: 'np.ndarray[np.float64]',
):
    """Simple fused multiply–add, compiled to avoid Python memory implications.

    :param out: must be zero array in the same shape of each in `arrays`

    If not compiled, a lot of Python objects will be created,
    and as the Python garbage collector is inefficient,
    it would have larger memory footprints.
    """
    for weight, array in zip(weights, arrays):
        out += weight * array

In [None]:
@jit(nopython=True, nogil=True, parallel=True, cache=False)
def _fma_parallel(
    out: 'np.ndarray[np.float64]',
    weights: 'np.ndarray[np.float64]',
    *arrays: 'np.ndarray[np.float64]',
):
    """Simple fused multiply–add, compiled to avoid Python memory implications.

    :param out: must be zero array in the same shape of each in `arrays`

    If not compiled, a lot of Python objects will be created,
    and as the Python garbage collector is inefficient,
    it would have larger memory footprints.
    """
    for weight, array in zip(weights, arrays):
        out += weight * array

# Test

In [None]:
shape = (10,)
array = np.random.randn(*shape)
weights = np.random.randn(*shape)
out_original = np.random.randn(*shape)

In [None]:
out = out_original.copy()
python_example.fma(out, weights, array)
np.testing.assert_array_equal(out, out_original + weights * array)

In [None]:
out = out_original.copy()
weight = weights[0]
python_example.fma_scalar_weight(out, weight, array)
np.testing.assert_array_equal(out, out_original + weight * array)

In [None]:
out = out_original.copy()
weight = weights[0]
python_example.fma_scalar_weight(out, weight, array)
out2 = out_original.copy()
_fma(out2, np.array([weight]), array)
np.testing.assert_array_equal(out, out2)

In [None]:
shape = (10,)
n = 3
arrays = [np.random.randn(*shape) for _ in range(n)]
weights = np.random.randn(n)
out_original = np.random.randn(*shape)

In [None]:
out = out_original.copy()
python_example.fma_vector_weights(out, weights, *arrays)
out2 = out_original.copy()
_fma(out2, weights, *arrays)
out3 = out_original.copy()
_fma_parallel(out3, weights, *arrays)
out4 = out_original.copy()
python_example.fma_vector_weights_arrays(out4, weights, *arrays)
out5 = out_original.copy()
_fma_no_jit(out5, weights, *arrays)
np.testing.assert_array_equal(out, out2)
np.testing.assert_array_equal(out, out3)
np.testing.assert_array_equal(out, out4)
np.testing.assert_array_equal(out, out5)

# Benchmark

In [None]:
%%time
n = 68
shape = (10000000,)
arrays = [np.random.randn(*shape) for _ in range(n)]
weights = np.random.randn(n)

In [None]:
out = np.zeros(*shape)

In [None]:
%timeit python_example.fma_vector_weights(out, weights, *arrays)

In [None]:
out = np.zeros(*shape)

In [None]:
%timeit python_example.fma_vector_weights_arrays(out, weights, *arrays)