-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add triangular indices functions #4601
Changes from 5 commits
5614074
36df641
2993966
8130b1b
b2843d2
e7d6a09
2b6ced5
7a8524b
82ed8d0
277aa55
b04ba2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -614,6 +614,12 @@ def check_array(a): | |
raise ValueError('zero-size array to reduction operation not possible') | ||
|
||
|
||
@register_jitable | ||
def _check_is_integer(v, name): | ||
if not isinstance(v, (int, types.Integer)): | ||
raise TypingError('{} must be an integer'.format(name)) | ||
|
||
|
||
def nan_min_max_factory(comparison_op, is_complex_dtype): | ||
if is_complex_dtype: | ||
def impl(a): | ||
|
@@ -1352,8 +1358,7 @@ def _tri_impl(N, M, k): | |
def np_tri(N, M=None, k=0): | ||
|
||
# we require k to be integer, unlike numpy | ||
if not isinstance(k, (int, types.Integer)): | ||
raise TypeError('k must be an integer') | ||
_check_is_integer(k, 'k') | ||
|
||
def tri_impl(N, M=None, k=0): | ||
if M is None: | ||
|
@@ -1390,8 +1395,7 @@ def np_tril_impl_2d(m, k=0): | |
def my_tril(m, k=0): | ||
|
||
# we require k to be integer, unlike numpy | ||
if not isinstance(k, (int, types.Integer)): | ||
raise TypeError('k must be an integer') | ||
_check_is_integer(k, 'k') | ||
|
||
def np_tril_impl_1d(m, k=0): | ||
m_2d = _make_square(m) | ||
|
@@ -1414,6 +1418,34 @@ def np_tril_impl_multi(m, k=0): | |
return np_tril_impl_multi | ||
|
||
|
||
@overload(np.tril_indices) | ||
def np_tril_indices(n, k=0, m=None): | ||
|
||
# we require integer arguments, unlike numpy | ||
esc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_check_is_integer(n, 'n') | ||
_check_is_integer(k, 'k') | ||
if m is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a This ought to work: from numba import njit
import numpy as np
@njit
def triu_indices_n(n, k=0, m=None):
return np.triu_indices(n, k, m)
print(triu_indices_n(2))
print(triu_indices_n.py_func(2)) something like (but with the import at the top of file) will likely sort it: from numba.numpy_support import is_nonelike
if not is_nonelike(m):
_check_is_integer(m, 'm') same argument applies to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, the tests don't actually capture this case.. |
||
_check_is_integer(m, 'm') | ||
|
||
def np_tril_indices_impl(n, k=0, m=None): | ||
return np.nonzero(np.tri(n, m, k=k)) | ||
EPronovost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return np_tril_indices_impl | ||
|
||
|
||
@overload(np.tril_indices_from) | ||
def np_tril_indices_from(arr, k=0): | ||
|
||
# we require k to be integer, unlike numpy | ||
_check_is_integer(k, 'k') | ||
|
||
if arr.ndim != 2: | ||
raise TypingError("input array must be 2-d") | ||
|
||
def np_tril_indices_from_impl(arr, k=0): | ||
return np.tril_indices(arr.shape[0], k=k, m=arr.shape[1]) | ||
return np_tril_indices_from_impl | ||
|
||
|
||
@register_jitable | ||
def np_triu_impl_2d(m, k=0): | ||
mask = np.tri(m.shape[-2], M=m.shape[-1], k=k - 1).astype(np.uint) | ||
|
@@ -1423,8 +1455,7 @@ def np_triu_impl_2d(m, k=0): | |
@overload(np.triu) | ||
def my_triu(m, k=0): | ||
# we require k to be integer, unlike numpy | ||
if not isinstance(k, (int, types.Integer)): | ||
raise TypeError('k must be an integer') | ||
_check_is_integer(k, 'k') | ||
|
||
def np_triu_impl_1d(m, k=0): | ||
m_2d = _make_square(m) | ||
|
@@ -1447,6 +1478,34 @@ def np_triu_impl_multi(m, k=0): | |
return np_triu_impl_multi | ||
|
||
|
||
@overload(np.triu_indices) | ||
def np_triu_indices(n, k=0, m=None): | ||
|
||
# we require integer arguments, unlike numpy | ||
_check_is_integer(n, 'n') | ||
_check_is_integer(k, 'k') | ||
if m is not None: | ||
_check_is_integer(m, 'm') | ||
|
||
def np_triu_indices_impl(n, k=0, m=None): | ||
return np.nonzero(1 - np.tri(n, m, k=k - 1)) | ||
return np_triu_indices_impl | ||
|
||
|
||
@overload(np.triu_indices_from) | ||
def np_triu_indices_from(arr, k=0): | ||
|
||
# we require k to be integer, unlike numpy | ||
_check_is_integer(k, 'k') | ||
|
||
if arr.ndim != 2: | ||
raise TypingError("input array must be 2-d") | ||
|
||
def np_triu_indices_from_impl(arr, k=0): | ||
return np.triu_indices(arr.shape[0], k=k, m=arr.shape[1]) | ||
return np_triu_indices_from_impl | ||
|
||
|
||
def _prepare_array(arr): | ||
pass | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,6 +125,30 @@ def tril_m_k(m, k=0): | |
return np.tril(m, k) | ||
|
||
|
||
def tril_indices_n(n): | ||
return np.tril_indices(n) | ||
|
||
|
||
def tril_indices_n_k(n, k=0): | ||
return np.tril_indices(n, k) | ||
|
||
|
||
def tril_indices_n_m(n, m=None): | ||
return np.tril_indices(n, m=m) | ||
|
||
|
||
def tril_indices_n_k_m(n, k=0, m=None): | ||
return np.tril_indices(n, k, m) | ||
|
||
esc marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is one last corner case that you should address:
I.e. supplying the arguments as keyword arguments. I recently discovered that this might lead to slightly different path and should be checked to be safe. |
||
|
||
def tril_indices_from_arr(arr): | ||
return np.tril_indices_from(arr) | ||
|
||
|
||
def tril_indices_from_arr_k(arr, k=0): | ||
return np.tril_indices_from(arr, k) | ||
|
||
|
||
def triu_m(m): | ||
return np.triu(m) | ||
|
||
|
@@ -133,6 +157,30 @@ def triu_m_k(m, k=0): | |
return np.triu(m, k) | ||
|
||
|
||
def triu_indices_n(n): | ||
return np.triu_indices(n) | ||
|
||
|
||
def triu_indices_n_k(n, k=0): | ||
return np.triu_indices(n, k) | ||
|
||
|
||
def triu_indices_n_m(n, m=None): | ||
return np.triu_indices(n, m=m) | ||
|
||
|
||
def triu_indices_n_k_m(n, k=0, m=None): | ||
return np.triu_indices(n, k, m) | ||
|
||
esc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def triu_indices_from_arr(arr): | ||
return np.triu_indices_from(arr) | ||
|
||
|
||
def triu_indices_from_arr_k(arr, k=0): | ||
return np.triu_indices_from(arr, k) | ||
|
||
|
||
def vander(x, N=None, increasing=False): | ||
return np.vander(x, N, increasing) | ||
|
||
|
@@ -1064,7 +1112,99 @@ def _triangular_matrix_exceptions(self, pyfunc): | |
a = np.ones((5, 6)) | ||
with self.assertTypingError() as raises: | ||
cfunc(a, k=1.5) | ||
assert "k must be an integer" in str(raises.exception) | ||
self.assertIn("k must be an integer", str(raises.exception)) | ||
|
||
def _triangular_indices_tests_base(self, pyfunc, args): | ||
cfunc = jit(nopython=True)(pyfunc) | ||
|
||
for x in args: | ||
expected = pyfunc(*x) | ||
got = cfunc(*x) | ||
self.assertEqual(type(expected), type(got)) | ||
self.assertEqual(len(expected), len(got)) | ||
for e, g in zip(expected, got): | ||
np.testing.assert_array_equal(e, g) | ||
|
||
def _triangular_indices_tests_n(self, pyfunc): | ||
self._triangular_indices_tests_base( | ||
pyfunc, | ||
[[n] for n in range(10)] | ||
) | ||
|
||
def _triangular_indices_tests_n_k(self, pyfunc): | ||
self._triangular_indices_tests_base( | ||
pyfunc, | ||
[[n, k] for n in range(10) for k in range(-n - 1, n + 2)] | ||
) | ||
|
||
def _triangular_indices_tests_n_m(self, pyfunc): | ||
self._triangular_indices_tests_base( | ||
pyfunc, | ||
[[n, m] for n in range(10) for m in range(2 * n)] | ||
) | ||
|
||
def _triangular_indices_tests_n_k_m(self, pyfunc): | ||
self._triangular_indices_tests_base( | ||
pyfunc, | ||
[[n, k, m] for n in range(10) for k in range(-n - 1, n + 2) for m in range(2 * n)] | ||
) | ||
|
||
esc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _triangular_indices_from_tests_arr(self, pyfunc): | ||
cfunc = jit(nopython=True)(pyfunc) | ||
|
||
for dtype in [int, float, bool]: | ||
for n in range(10): | ||
for m in range(10): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
arr = np.ones((n, m), dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is just testing matrices that are full? given the alg, perhaps put in some zeros via a slice with step or similar?! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean randomly or with structure? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't think it matters, whatever is convenient. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, ignore this, am remembering a different function! |
||
expected = pyfunc(arr) | ||
got = cfunc(arr) | ||
self.assertEqual(type(expected), type(got)) | ||
self.assertEqual(len(expected), len(got)) | ||
for e, g in zip(expected, got): | ||
np.testing.assert_array_equal(e, g) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd assume that the arrays returned would be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried this, and the strides turned out to be different:
|
||
|
||
def _triangular_indices_from_tests_arr_k(self, pyfunc): | ||
cfunc = jit(nopython=True)(pyfunc) | ||
|
||
for dtype in [int, float, bool]: | ||
for n in range(10): | ||
for m in range(10): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
arr = np.ones((n, m), dtype) | ||
for k in range(-10, 10): | ||
expected = pyfunc(arr) | ||
got = cfunc(arr) | ||
self.assertEqual(type(expected), type(got)) | ||
self.assertEqual(len(expected), len(got)) | ||
for e, g in zip(expected, got): | ||
np.testing.assert_array_equal(e, g) | ||
|
||
def _triangular_indices_exceptions(self, pyfunc): | ||
cfunc = jit(nopython=True)(pyfunc) | ||
|
||
# Exceptions leak references | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be resolvable at typing time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed |
||
self.disable_leak_check() | ||
|
||
with self.assertTypingError() as raises: | ||
cfunc(4, k=1.5) | ||
self.assertIn("k must be an integer", str(raises.exception)) | ||
|
||
def _triangular_indices_from_exceptions(self, pyfunc, test_k=True): | ||
cfunc = jit(nopython=True)(pyfunc) | ||
|
||
# Exceptions leak references | ||
self.disable_leak_check() | ||
|
||
for ndims in [0, 1, 3]: | ||
a = np.ones([5] * ndims) | ||
with self.assertTypingError() as raises: | ||
cfunc(a) | ||
self.assertIn("input array must be 2-d", str(raises.exception)) | ||
|
||
if test_k: | ||
a = np.ones([5, 5]) | ||
with self.assertTypingError() as raises: | ||
cfunc(a, k=0.5) | ||
self.assertIn("k must be an integer", str(raises.exception)) | ||
|
||
def test_tril_basic(self): | ||
self._triangular_matrix_tests_m(tril_m) | ||
|
@@ -1073,13 +1213,43 @@ def test_tril_basic(self): | |
def test_tril_exceptions(self): | ||
self._triangular_matrix_exceptions(tril_m_k) | ||
|
||
def test_tril_indices(self): | ||
self._triangular_indices_tests_n(tril_indices_n) | ||
self._triangular_indices_tests_n_k(tril_indices_n_k) | ||
self._triangular_indices_tests_n_m(tril_indices_n_m) | ||
self._triangular_indices_tests_n_k_m(tril_indices_n_k_m) | ||
self._triangular_indices_exceptions(tril_indices_n_k) | ||
self._triangular_indices_exceptions(tril_indices_n_m) | ||
self._triangular_indices_exceptions(tril_indices_n_k_m) | ||
|
||
def test_tril_indices_from(self): | ||
self._triangular_indices_from_tests_arr(tril_indices_from_arr) | ||
self._triangular_indices_from_tests_arr_k(tril_indices_from_arr_k) | ||
self._triangular_indices_from_exceptions(tril_indices_from_arr, False) | ||
self._triangular_indices_from_exceptions(tril_indices_from_arr_k, True) | ||
|
||
def test_triu_basic(self): | ||
self._triangular_matrix_tests_m(triu_m) | ||
self._triangular_matrix_tests_m_k(triu_m_k) | ||
|
||
def test_triu_exceptions(self): | ||
self._triangular_matrix_exceptions(triu_m_k) | ||
|
||
def test_triu_indices(self): | ||
self._triangular_indices_tests_n(triu_indices_n) | ||
self._triangular_indices_tests_n_k(triu_indices_n_k) | ||
self._triangular_indices_tests_n_m(triu_indices_n_m) | ||
self._triangular_indices_tests_n_k_m(triu_indices_n_k_m) | ||
self._triangular_indices_exceptions(triu_indices_n_k) | ||
self._triangular_indices_exceptions(triu_indices_n_m) | ||
self._triangular_indices_exceptions(triu_indices_n_k_m) | ||
|
||
def test_triu_indices_from(self): | ||
self._triangular_indices_from_tests_arr(triu_indices_from_arr) | ||
self._triangular_indices_from_tests_arr_k(triu_indices_from_arr_k) | ||
self._triangular_indices_from_exceptions(triu_indices_from_arr, False) | ||
self._triangular_indices_from_exceptions(triu_indices_from_arr_k, True) | ||
|
||
def partition_sanity_check(self, pyfunc, cfunc, a, kth): | ||
# as NumPy uses a different algorithm, we do not expect to match outputs exactly... | ||
expected = pyfunc(a, kth) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this being compiled? the type is known at typing time?