Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for np.correlate and np.convolve #2777

Merged
merged 2 commits into from
Apr 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ The following top-level functions are supported:
* :func:`numpy.bincount` (only the 2 first arguments)
* :func:`numpy.column_stack`
* :func:`numpy.concatenate`
* :func:`numpy.convolve` (only the 2 first arguments)
* :func:`numpy.copy` (only the first argument)
* :func:`numpy.correlate` (only the 2 first arguments)
* :func:`numpy.diag`
* :func:`numpy.digitize`
* :func:`numpy.dstack`
Expand Down
183 changes: 183 additions & 0 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import math
from collections import namedtuple
from enum import IntEnum

import numpy as np

Expand All @@ -21,10 +22,20 @@
impl_ret_new_ref, impl_ret_untracked)
from numba.typing import signature
from .arrayobj import make_array, load_item, store_item, _empty_nd_impl
from .linalg import ensure_blas

from numba.extending import intrinsic
from numba.errors import RequireConstValue, TypingError

def _check_blas():
# Checks if a BLAS is available so e.g. dot will work
try:
ensure_blas()
except ImportError:
return False
return True

_HAVE_BLAS = _check_blas()

@intrinsic
def _create_tuple_result_shape(tyctx, shape_list, shape_tuple):
Expand Down Expand Up @@ -1450,3 +1461,175 @@ def impl(arg):

generate_xinfo(np.finfo, finfo, _finfo_supported)
generate_xinfo(np.iinfo, iinfo, _iinfo_supported)

def _get_inner_prod(dta, dtb):
# gets an inner product implementation, if both types are float then
# BLAS is used else a local function

@register_jitable
def _innerprod(a, b):
acc = 0
for i in range(len(a)):
acc = acc + a[i] * b[i]
return acc

# no BLAS... use local function regardless
if not _HAVE_BLAS:
return _innerprod

flty = types.real_domain | types.complex_domain
floats = dta in flty and dtb in flty
if not floats:
return _innerprod
else:
a_dt = as_dtype(dta)
b_dt = as_dtype(dtb)
dt = np.promote_types(a_dt, b_dt)

@register_jitable
def _dot_wrap(a, b):
return np.dot(a.astype(dt), b.astype(dt))
return _dot_wrap

def _assert_1d(a, func_name):
if isinstance(a, types.Array):
if not a.ndim <= 1:
raise TypingError("%s() only supported on 1D arrays " % func_name)

def _np_correlate_core(ap1, ap2, mode, direction):
pass


class _corr_conv_Mode(IntEnum):
"""
Enumerated modes for correlate/convolve as per:
https://github.com/numpy/numpy/blob/ac6b1a902b99e340cf7eeeeb7392c91e38db9dd8/numpy/core/numeric.py#L862-L870
"""
VALID = 0
SAME = 1
FULL = 2


@overload(_np_correlate_core)
def _np_correlate_core_impl(ap1, ap2, mode, direction):
a_dt = as_dtype(ap1.dtype)
b_dt = as_dtype(ap2.dtype)
dt = np.promote_types(a_dt, b_dt)
innerprod = _get_inner_prod(ap1.dtype, ap2.dtype)

Mode = _corr_conv_Mode

def impl(ap1, ap2, mode, direction):
# Implementation loosely based on `_pyarray_correlate` from
# https://github.com/numpy/numpy/blob/3bce2be74f228684ca2895ad02b63953f37e2a9d/numpy/core/src/multiarray/multiarraymodule.c#L1191
# For "Mode":
# Convolve uses 'full' by default, this is denoted by the number 2
# Correlate uses 'valid' by default, this is denoted by the number 0
# For "direction", +1 to write the return values out in order 0->N
# -1 to write them out N->0.

if not (mode == Mode.VALID or mode == Mode.FULL):
raise ValueError("Invalid mode")

n1 = len(ap1)
n2 = len(ap2)
length = n1
n = n2
if mode == Mode.VALID: # mode == valid == 0, correlate default
length = length - n + 1
n_left = 0
n_right = 0
elif mode == Mode.FULL: # mode == full == 2, convolve default
n_right = n - 1
n_left = n - 1
length = length + n - 1
else:
raise ValueError("Invalid mode")

ret = np.zeros(length, dt)
n = n - n_left

if direction == 1:
idx = 0
inc = 1
elif direction == -1:
idx = length - 1
inc = -1
else:
raise ValueError("Invalid direction")

for i in range(n_left):
ret[idx] = innerprod(ap1[:idx + 1], ap2[-(idx + 1):])
idx = idx + inc

for i in range(n1 - n2 + 1):
ret[idx] = innerprod(ap1[i : i + n2], ap2)
idx = idx + inc

for i in range(n_right, 0, -1):
ret[idx] = innerprod(ap1[-i:], ap2[:i])
idx = idx + inc
return ret

return impl

@overload(np.correlate)
def _np_correlate(a, v):
_assert_1d(a, 'np.correlate')
_assert_1d(v, 'np.correlate')

@register_jitable
def op_conj(x):
return np.conj(x)

@register_jitable
def op_nop(x):
return x

Mode = _corr_conv_Mode

if a.dtype in types.complex_domain:
if v.dtype in types.complex_domain:
a_op = op_nop
b_op = op_conj
else:
a_op = op_nop
b_op = op_nop
else:
if v.dtype in types.complex_domain:
a_op = op_nop
b_op = op_conj
else:
a_op = op_conj
b_op = op_nop

def impl(a, v):
if len(a) < len(v):
return _np_correlate_core(b_op(v), a_op(a), Mode.VALID, -1)
else:
return _np_correlate_core(a_op(a), b_op(v), Mode.VALID, 1)

return impl

@overload(np.convolve)
def np_convolve(a, v):
_assert_1d(a, 'np.convolve')
_assert_1d(v, 'np.convolve')

Mode = _corr_conv_Mode

def impl(a, v):
la = len(a)
lv = len(v)

if la == 0:
raise ValueError("'a' cannot be empty")
if lv == 0:
raise ValueError("'v' cannot be empty")

if la < lv:
return _np_correlate_core(v, a[::-1], Mode.FULL, 1)
else:
return _np_correlate_core(a, v[::-1], Mode.FULL, 1)

return impl
63 changes: 62 additions & 1 deletion numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numba.compiler import compile_isolated, Flags, utils
from numba import jit, typeof, types
from numba.numpy_support import version as np_version
from numba.errors import UntypedAttributeError
from numba.errors import UntypedAttributeError, TypingError
from .support import TestCase, CompilationCache

no_pyobj_flags = Flags()
Expand Down Expand Up @@ -60,6 +60,12 @@ def finfo(*args):
def finfo_machar(*args):
return np.finfo(*args).machar

def correlate(a, v):
return np.correlate(a, v)

def convolve(a, v):
return np.convolve(a, v)


class TestNPFunctions(TestCase):
"""
Expand Down Expand Up @@ -424,6 +430,61 @@ def check_values(values):

check_values(values)

def _test_correlate_convolve(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
# only 1d arrays are accepted, test varying lengths
# and varying dtype
lengths = (1, 2, 3, 7)
dts = [np.int8, np.int32, np.int64, np.float32, np.float64,
np.complex64, np.complex128]

for dt1, dt2, n, m in itertools.product(dts, dts, lengths, lengths):
a = np.arange(n, dtype=dt1)
v = np.arange(m, dtype=dt2)

if np.issubdtype(dt1, np.complexfloating):
a = (a + 1j * a).astype(dt1)
if np.issubdtype(dt2, np.complexfloating):
v = (v + 1j * v).astype(dt2)

expected = pyfunc(a, v)
got = cfunc(a, v)
self.assertPreciseEqual(expected, got)

_a = np.arange(12).reshape(4, 3)
_b = np.arange(12)
for x, y in [(_a, _b), (_b, _a)]:
with self.assertRaises(TypingError) as raises:
cfunc(x, y)
msg = 'only supported on 1D arrays'
self.assertIn(msg, str(raises.exception))

def test_correlate(self):
self._test_correlate_convolve(correlate)
# correlate supports 0 dimension arrays
_a = np.ones(shape=(0,))
_b = np.arange(5)
cfunc = jit(nopython=True)(correlate)
for x, y in [(_a, _b), (_b, _a), (_a, _a)]:
expected = correlate(x, y)
got = cfunc(x, y)
self.assertPreciseEqual(expected, got)

def test_convolve(self):
self._test_correlate_convolve(convolve)
# convolve raises if either array has a 0 dimension
_a = np.ones(shape=(0,))
_b = np.arange(5)
cfunc = jit(nopython=True)(convolve)
for x, y in [(_a, _b), (_b, _a)]:
with self.assertRaises(ValueError) as raises:
cfunc(x, y)
if len(x) == 0:
self.assertIn("'a' cannot be empty", str(raises.exception))
else:
self.assertIn("'v' cannot be empty", str(raises.exception))


class TestNPMachineParameters(TestCase):
# tests np.finfo, np.iinfo, np.MachAr

Expand Down