Skip to content

Commit

Permalink
Add support for np.correlate and np.convolve
Browse files Browse the repository at this point in the history
As title. Lack of FFT support prevents the implementation of
efficient convolution alg.

Closes #2500
  • Loading branch information
stuartarchibald committed Mar 1, 2018
1 parent cd68b54 commit 6ac35c0
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ The following top-level functions are supported:
* :func:`numpy.bincount` (only the 2 first arguments)
* :func:`numpy.column_stack`
* :func:`numpy.concatenate`
* :func:`numpy.convolve` (only the 2 first arguments)
* :func:`numpy.copy` (only the first argument)
* :func:`numpy.correlate` (only the 2 first arguments)
* :func:`numpy.diag`
* :func:`numpy.digitize`
* :func:`numpy.dstack`
Expand Down
165 changes: 165 additions & 0 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@
impl_ret_new_ref, impl_ret_untracked)
from numba.typing import signature
from .arrayobj import make_array, load_item, store_item, _empty_nd_impl
from .linalg import ensure_blas

from numba.extending import intrinsic
from numba.errors import RequireConstValue, TypingError

def _check_blas():
# Checks if a BLAS is available so e.g. dot will work
try:
ensure_blas()
except ImportError:
return False
return True

_HAVE_BLAS = _check_blas()

@intrinsic
def _create_tuple_result_shape(tyctx, shape_list, shape_tuple):
Expand Down Expand Up @@ -1450,3 +1460,158 @@ def impl(arg):

generate_xinfo(np.finfo, finfo, _finfo_supported)
generate_xinfo(np.iinfo, iinfo, _iinfo_supported)

def _get_inner_prod(dta, dtb):
# gets an inner product implementation, if both types are float then
# BLAS is used else a local function

@register_jitable
def _innerprod(a, b):
acc = 0
for i in range(len(a)):
acc = acc + a[i] * b[i]
return acc

# no BLAS... use local function regardless
if not _HAVE_BLAS:
return _innerprod

flty = types.real_domain | types.complex_domain
floats = dta in flty and dtb in flty
if not floats:
return _innerprod
else:
a_dt = as_dtype(dta)
b_dt = as_dtype(dtb)
dt = np.promote_types(a_dt, b_dt)

@register_jitable
def _dot_wrap(a, b):
return np.dot(a.astype(dt), b.astype(dt))
return _dot_wrap

def _assert_1d(a, func_name):
if isinstance(a, types.Array):
if not a.ndim <= 1:
raise TypingError("%s() only supported on 1D arrays " % func_name)

def _np_correlate_core(ap1, ap2, mode, direction):
pass

@overload(_np_correlate_core)
def _np_correlate_core_impl(ap1, ap2, mode, direction):
a_dt = as_dtype(ap1.dtype)
b_dt = as_dtype(ap2.dtype)
dt = np.promote_types(a_dt, b_dt)
innerprod = _get_inner_prod(ap1.dtype, ap2.dtype)

def impl(ap1, ap2, mode, direction):
# Implementation loosely based on `_pyarray_correlate` from
# https://github.com/numpy/numpy/blob/3bce2be74f228684ca2895ad02b63953f37e2a9d/numpy/core/src/multiarray/multiarraymodule.c#L1191
# For "Mode":
# Convolve uses 'full' by default, this is denoted by the number 2
# Correlate uses 'valid' by default, this is denoted by the number 0
# For "direction", +1 to write the return values out in order 0->N
# -1 to write them out N->0.

if not (mode == 0 or mode == 2):
raise ValueError("Invalid mode")

n1 = len(ap1)
n2 = len(ap2)
length = n1
n = n2
if mode == 0: # mode == 0, correlate default
length = length - n + 1
n_left = 0
n_right = 0
elif mode == 2: # mode == 2, convolve default
n_right = n - 1
n_left = n - 1
length = length + n - 1
else:
raise ValueError("Invalid mode")

ret = np.zeros(length, dt)
n = n - n_left

if direction == 1:
idx = 0
inc = 1
elif direction == -1:
idx = length - 1
inc = -1
else:
raise ValueError("Invalid direction")

for i in range(n_left):
ret[idx] = innerprod(ap1[:idx + 1], ap2[-(idx + 1):])
idx = idx + inc

for i in range(n1 - n2 + 1):
ret[idx] = innerprod(ap1[i : i + n2], ap2)
idx = idx + inc

for i in range(n_right, 0, -1):
ret[idx] = innerprod(ap1[-i:], ap2[:i])
idx = idx + inc
return ret

return impl

@overload(np.correlate)
def _np_correlate(a, v):
_assert_1d(a, 'np.correlate')
_assert_1d(v, 'np.correlate')

@register_jitable
def op_conj(x):
return np.conj(x)

@register_jitable
def op_nop(x):
return x

if a.dtype in types.complex_domain:
if v.dtype in types.complex_domain:
a_op = op_nop
b_op = op_conj
else:
a_op = op_nop
b_op = op_nop
else:
if v.dtype in types.complex_domain:
a_op = op_nop
b_op = op_conj
else:
a_op = op_conj
b_op = op_nop

def impl(a, v):
if len(a) < len(v):
return _np_correlate_core(b_op(v), a_op(a), 0, -1)
else:
return _np_correlate_core(a_op(a), b_op(v), 0, 1)

return impl

@overload(np.convolve)
def np_convolve(a, v):
_assert_1d(a, 'np.convolve')
_assert_1d(v, 'np.convolve')

def impl(a, v):
la = len(a)
lv = len(v)

if la == 0:
raise ValueError("'a' cannot be empty")
if lv == 0:
raise ValueError("'v' cannot be empty")

if la < lv:
return _np_correlate_core(v, a[::-1], 2, 1)
else:
return _np_correlate_core(a, v[::-1], 2, 1)

return impl
63 changes: 62 additions & 1 deletion numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numba.compiler import compile_isolated, Flags, utils
from numba import jit, typeof, types
from numba.numpy_support import version as np_version
from numba.errors import UntypedAttributeError
from numba.errors import UntypedAttributeError, TypingError
from .support import TestCase, CompilationCache

no_pyobj_flags = Flags()
Expand Down Expand Up @@ -60,6 +60,12 @@ def finfo(*args):
def finfo_machar(*args):
return np.finfo(*args).machar

def correlate(a, v):
return np.correlate(a, v)

def convolve(a, v):
return np.convolve(a, v)


class TestNPFunctions(TestCase):
"""
Expand Down Expand Up @@ -424,6 +430,61 @@ def check_values(values):

check_values(values)

def _test_correlate_convolve(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
# only 1d arrays are accepted, test varying lengths
# and varying dtype
lengths = (1, 2, 3, 7)
dts = [np.int8, np.int32, np.int64, np.float32, np.float64,
np.complex64, np.complex128]

for dt1, dt2, n, m in itertools.product(dts, dts, lengths, lengths):
a = np.arange(n, dtype=dt1)
v = np.arange(m, dtype=dt2)

if np.issubdtype(dt1, np.complexfloating):
a = (a + 1j * a).astype(dt1)
if np.issubdtype(dt2, np.complexfloating):
v = (v + 1j * v).astype(dt2)

expected = pyfunc(a, v)
got = cfunc(a, v)
self.assertPreciseEqual(expected, got)

_a = np.arange(12).reshape(4, 3)
_b = np.arange(12)
for x, y in [(_a, _b), (_b, _a)]:
with self.assertRaises(TypingError) as raises:
cfunc(x, y)
msg = 'only supported on 1D arrays'
self.assertIn(msg, str(raises.exception))

def test_correlate(self):
self._test_correlate_convolve(correlate)
# correlate supports 0 dimension arrays
_a = np.ones(shape=(0,))
_b = np.arange(5)
cfunc = jit(nopython=True)(correlate)
for x, y in [(_a, _b), (_b, _a), (_a, _a)]:
expected = correlate(x, y)
got = cfunc(x, y)
self.assertPreciseEqual(expected, got)

def test_convolve(self):
self._test_correlate_convolve(convolve)
# convolve raises if either array has a 0 dimension
_a = np.ones(shape=(0,))
_b = np.arange(5)
cfunc = jit(nopython=True)(convolve)
for x, y in [(_a, _b), (_b, _a)]:
with self.assertRaises(ValueError) as raises:
cfunc(x, y)
if len(x) == 0:
self.assertIn("'a' cannot be empty", str(raises.exception))
else:
self.assertIn("'v' cannot be empty", str(raises.exception))


class TestNPMachineParameters(TestCase):
# tests np.finfo, np.iinfo, np.MachAr

Expand Down

0 comments on commit 6ac35c0

Please sign in to comment.