Skip to content

Commit

Permalink
Merge pull request #4601 from EPronovost/epronovost/triangle-indices
Browse files Browse the repository at this point in the history
Add triangular indices functions
  • Loading branch information
stuartarchibald committed Oct 1, 2019
2 parents 7d0e3a6 + b04ba2d commit 9a49a69
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,11 @@ The following top-level functions are supported:
* :func:`numpy.trapz` (only the 3 first arguments)
* :func:`numpy.tri` (only the 3 first arguments; third argument ``k`` must be an integer)
* :func:`numpy.tril` (second argument ``k`` must be an integer)
* :func:`numpy.tril_indices` (all arguments must be integer)
* :func:`numpy.tril_indices_from` (second argument ``k`` must be an integer)
* :func:`numpy.triu` (second argument ``k`` must be an integer)
* :func:`numpy.triu_indices` (all arguments must be integer)
* :func:`numpy.triu_indices_from` (second argument ``k`` must be an integer)
* :func:`numpy.unique` (only the first argument)
* :func:`numpy.vander`
* :func:`numpy.vstack`
Expand Down
70 changes: 64 additions & 6 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,11 @@ def check_array(a):
raise ValueError('zero-size array to reduction operation not possible')


def _check_is_integer(v, name):
if not isinstance(v, (int, types.Integer)):
raise TypingError('{} must be an integer'.format(name))


def nan_min_max_factory(comparison_op, is_complex_dtype):
if is_complex_dtype:
def impl(a):
Expand Down Expand Up @@ -1422,8 +1427,7 @@ def _tri_impl(N, M, k):
def np_tri(N, M=None, k=0):

# we require k to be integer, unlike numpy
if not isinstance(k, (int, types.Integer)):
raise TypeError('k must be an integer')
_check_is_integer(k, 'k')

def tri_impl(N, M=None, k=0):
if M is None:
Expand Down Expand Up @@ -1460,8 +1464,7 @@ def np_tril_impl_2d(m, k=0):
def my_tril(m, k=0):

# we require k to be integer, unlike numpy
if not isinstance(k, (int, types.Integer)):
raise TypeError('k must be an integer')
_check_is_integer(k, 'k')

def np_tril_impl_1d(m, k=0):
m_2d = _make_square(m)
Expand All @@ -1484,6 +1487,34 @@ def np_tril_impl_multi(m, k=0):
return np_tril_impl_multi


@overload(np.tril_indices)
def np_tril_indices(n, k=0, m=None):

# we require integer arguments, unlike numpy
_check_is_integer(n, 'n')
_check_is_integer(k, 'k')
if not is_nonelike(m):
_check_is_integer(m, 'm')

def np_tril_indices_impl(n, k=0, m=None):
return np.nonzero(np.tri(n, m, k=k))
return np_tril_indices_impl


@overload(np.tril_indices_from)
def np_tril_indices_from(arr, k=0):

# we require k to be integer, unlike numpy
_check_is_integer(k, 'k')

if arr.ndim != 2:
raise TypingError("input array must be 2-d")

def np_tril_indices_from_impl(arr, k=0):
return np.tril_indices(arr.shape[0], k=k, m=arr.shape[1])
return np_tril_indices_from_impl


@register_jitable
def np_triu_impl_2d(m, k=0):
mask = np.tri(m.shape[-2], M=m.shape[-1], k=k - 1).astype(np.uint)
Expand All @@ -1493,8 +1524,7 @@ def np_triu_impl_2d(m, k=0):
@overload(np.triu)
def my_triu(m, k=0):
# we require k to be integer, unlike numpy
if not isinstance(k, (int, types.Integer)):
raise TypeError('k must be an integer')
_check_is_integer(k, 'k')

def np_triu_impl_1d(m, k=0):
m_2d = _make_square(m)
Expand All @@ -1517,6 +1547,34 @@ def np_triu_impl_multi(m, k=0):
return np_triu_impl_multi


@overload(np.triu_indices)
def np_triu_indices(n, k=0, m=None):

# we require integer arguments, unlike numpy
_check_is_integer(n, 'n')
_check_is_integer(k, 'k')
if not is_nonelike(m):
_check_is_integer(m, 'm')

def np_triu_indices_impl(n, k=0, m=None):
return np.nonzero(1 - np.tri(n, m, k=k - 1))
return np_triu_indices_impl


@overload(np.triu_indices_from)
def np_triu_indices_from(arr, k=0):

# we require k to be integer, unlike numpy
_check_is_integer(k, 'k')

if arr.ndim != 2:
raise TypingError("input array must be 2-d")

def np_triu_indices_from_impl(arr, k=0):
return np.triu_indices(arr.shape[0], k=k, m=arr.shape[1])
return np_triu_indices_from_impl


def _prepare_array(arr):
pass

Expand Down
182 changes: 181 additions & 1 deletion numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from numba.numpy_support import version as np_version
from numba.errors import TypingError
from numba.config import IS_WIN32, IS_32BITS
from numba.utils import pysignature
from .support import TestCase, CompilationCache, MemoryLeakMixin
from .matmul_usecase import needs_blas

Expand Down Expand Up @@ -133,6 +134,30 @@ def tril_m_k(m, k=0):
return np.tril(m, k)


def tril_indices_n(n):
return np.tril_indices(n)


def tril_indices_n_k(n, k=0):
return np.tril_indices(n, k)


def tril_indices_n_m(n, m=None):
return np.tril_indices(n, m=m)


def tril_indices_n_k_m(n, k=0, m=None):
return np.tril_indices(n, k, m)


def tril_indices_from_arr(arr):
return np.tril_indices_from(arr)


def tril_indices_from_arr_k(arr, k=0):
return np.tril_indices_from(arr, k)


def triu_m(m):
return np.triu(m)

Expand All @@ -141,6 +166,30 @@ def triu_m_k(m, k=0):
return np.triu(m, k)


def triu_indices_n(n):
return np.triu_indices(n)


def triu_indices_n_k(n, k=0):
return np.triu_indices(n, k)


def triu_indices_n_m(n, m=None):
return np.triu_indices(n, m=m)


def triu_indices_n_k_m(n, k=0, m=None):
return np.triu_indices(n, k, m)


def triu_indices_from_arr(arr):
return np.triu_indices_from(arr)


def triu_indices_from_arr_k(arr, k=0):
return np.triu_indices_from(arr, k)


def vander(x, N=None, increasing=False):
return np.vander(x, N, increasing)

Expand Down Expand Up @@ -1158,7 +1207,106 @@ def _triangular_matrix_exceptions(self, pyfunc):
a = np.ones((5, 6))
with self.assertTypingError() as raises:
cfunc(a, k=1.5)
assert "k must be an integer" in str(raises.exception)
self.assertIn("k must be an integer", str(raises.exception))

def _triangular_indices_tests_base(self, pyfunc, args):
cfunc = jit(nopython=True)(pyfunc)

for x in args:
expected = pyfunc(*x)
got = cfunc(*x)
self.assertEqual(type(expected), type(got))
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
np.testing.assert_array_equal(e, g)

def _triangular_indices_tests_n(self, pyfunc):
self._triangular_indices_tests_base(
pyfunc,
[[n] for n in range(10)]
)

def _triangular_indices_tests_n_k(self, pyfunc):
self._triangular_indices_tests_base(
pyfunc,
[[n, k] for n in range(10) for k in range(-n - 1, n + 2)]
)

def _triangular_indices_tests_n_m(self, pyfunc):
self._triangular_indices_tests_base(
pyfunc,
[[n, m] for n in range(10) for m in range(2 * n)]
)

def _triangular_indices_tests_n_k_m(self, pyfunc):
self._triangular_indices_tests_base(
pyfunc,
[[n, k, m] for n in range(10) for k in range(-n - 1, n + 2) for m in range(2 * n)]
)

# Check jitted version works with default values for kwargs
cfunc = jit(nopython=True)(pyfunc)
cfunc(1)

def _triangular_indices_from_tests_arr(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)

for dtype in [int, float, bool]:
for n,m in itertools.product(range(10), range(10)):
arr = np.ones((n, m), dtype)
expected = pyfunc(arr)
got = cfunc(arr)
self.assertEqual(type(expected), type(got))
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
np.testing.assert_array_equal(e, g)

def _triangular_indices_from_tests_arr_k(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)

for dtype in [int, float, bool]:
for n,m in itertools.product(range(10), range(10)):
arr = np.ones((n, m), dtype)
for k in range(-10, 10):
expected = pyfunc(arr)
got = cfunc(arr)
self.assertEqual(type(expected), type(got))
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
np.testing.assert_array_equal(e, g)

def _triangular_indices_exceptions(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
parameters = pysignature(pyfunc).parameters

with self.assertTypingError() as raises:
cfunc(1.0)
self.assertIn("n must be an integer", str(raises.exception))

if 'k' in parameters:
with self.assertTypingError() as raises:
cfunc(1, k=1.0)
self.assertIn("k must be an integer", str(raises.exception))

if 'm' in parameters:
with self.assertTypingError() as raises:
cfunc(1, m=1.0)
self.assertIn("m must be an integer", str(raises.exception))

def _triangular_indices_from_exceptions(self, pyfunc, test_k=True):
cfunc = jit(nopython=True)(pyfunc)

for ndims in [0, 1, 3]:
a = np.ones([5] * ndims)
with self.assertTypingError() as raises:
cfunc(a)
self.assertIn("input array must be 2-d", str(raises.exception))

if test_k:
a = np.ones([5, 5])
with self.assertTypingError() as raises:
cfunc(a, k=0.5)
self.assertIn("k must be an integer", str(raises.exception))

def test_tril_basic(self):
self._triangular_matrix_tests_m(tril_m)
Expand All @@ -1167,13 +1315,45 @@ def test_tril_basic(self):
def test_tril_exceptions(self):
self._triangular_matrix_exceptions(tril_m_k)

def test_tril_indices(self):
self._triangular_indices_tests_n(tril_indices_n)
self._triangular_indices_tests_n_k(tril_indices_n_k)
self._triangular_indices_tests_n_m(tril_indices_n_m)
self._triangular_indices_tests_n_k_m(tril_indices_n_k_m)
self._triangular_indices_exceptions(tril_indices_n)
self._triangular_indices_exceptions(tril_indices_n_k)
self._triangular_indices_exceptions(tril_indices_n_m)
self._triangular_indices_exceptions(tril_indices_n_k_m)

def test_tril_indices_from(self):
self._triangular_indices_from_tests_arr(tril_indices_from_arr)
self._triangular_indices_from_tests_arr_k(tril_indices_from_arr_k)
self._triangular_indices_from_exceptions(tril_indices_from_arr, False)
self._triangular_indices_from_exceptions(tril_indices_from_arr_k, True)

def test_triu_basic(self):
self._triangular_matrix_tests_m(triu_m)
self._triangular_matrix_tests_m_k(triu_m_k)

def test_triu_exceptions(self):
self._triangular_matrix_exceptions(triu_m_k)

def test_triu_indices(self):
self._triangular_indices_tests_n(triu_indices_n)
self._triangular_indices_tests_n_k(triu_indices_n_k)
self._triangular_indices_tests_n_m(triu_indices_n_m)
self._triangular_indices_tests_n_k_m(triu_indices_n_k_m)
self._triangular_indices_exceptions(triu_indices_n)
self._triangular_indices_exceptions(triu_indices_n_k)
self._triangular_indices_exceptions(triu_indices_n_m)
self._triangular_indices_exceptions(triu_indices_n_k_m)

def test_triu_indices_from(self):
self._triangular_indices_from_tests_arr(triu_indices_from_arr)
self._triangular_indices_from_tests_arr_k(triu_indices_from_arr_k)
self._triangular_indices_from_exceptions(triu_indices_from_arr, False)
self._triangular_indices_from_exceptions(triu_indices_from_arr_k, True)

def partition_sanity_check(self, pyfunc, cfunc, a, kth):
# as NumPy uses a different algorithm, we do not expect to match outputs exactly...
expected = pyfunc(a, kth)
Expand Down

0 comments on commit 9a49a69

Please sign in to comment.