Skip to content

Commit

Permalink
Improved test coverage for triu_indices
Browse files Browse the repository at this point in the history
 * Add explicit test for `n, m` arguments
 * Factored out test code for `triu_indices`
 * Fixed spacing for flake8
  • Loading branch information
EPronovost committed Sep 26, 2019
1 parent 36df641 commit 2993966
Showing 1 changed file with 49 additions and 25 deletions.
74 changes: 49 additions & 25 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,42 +124,63 @@ def tril_m(m):
def tril_m_k(m, k=0):
return np.tril(m, k)


def tril_indices_n(n):
return np.tril_indices(n)


def tril_indices_n_k(n, k=0):
return np.tril_indices(n, k)


def tril_indices_n_m(n, m=None):
return np.tril_indices(n, m=m)


def tril_indices_n_k_m(n, k=0, m=None):
return np.tril_indices(n, k, m)


def tril_indices_from_arr(arr):
return np.tril_indices_from(arr)


def tril_indices_from_arr_k(arr, k=0):
return np.tril_indices_from(arr, k)


def triu_m(m):
return np.triu(m)


def triu_m_k(m, k=0):
return np.triu(m, k)


def triu_indices_n(n):
return np.triu_indices(n)


def triu_indices_n_k(n, k=0):
return np.triu_indices(n, k)


def triu_indices_n_m(n, m=None):
return np.triu_indices(n, m=m)


def triu_indices_n_k_m(n, k=0, m=None):
return np.triu_indices(n, k, m)


def triu_indices_from_arr(arr):
return np.triu_indices_from(arr)


def triu_indices_from_arr_k(arr, k=0):
return np.triu_indices_from(arr, k)


def vander(x, N=None, increasing=False):
return np.vander(x, N, increasing)

Expand Down Expand Up @@ -1093,41 +1114,40 @@ def _triangular_matrix_exceptions(self, pyfunc):
cfunc(a, k=1.5)
assert "k must be an integer" in str(raises.exception)

def _triangular_indices_tests_n(self, pyfunc):
def _triangular_indices_tests_base(self, pyfunc, args):
cfunc = jit(nopython=True)(pyfunc)

for n in range(10):
expected = pyfunc(n)
got = cfunc(n)
for x in args:
expected = pyfunc(*x)
got = cfunc(*x)
self.assertEqual(type(expected), type(got))
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
np.testing.assert_array_equal(e, g)

def _triangular_indices_tests_n(self, pyfunc):
self._triangular_indices_tests_base(
pyfunc,
[[n] for n in range(10)]
)

def _triangular_indices_tests_n_k(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
self._triangular_indices_tests_base(
pyfunc,
[[n, k] for n in range(10) for k in range(-n - 1, n + 2)]
)

for n in range(10):
for k in range(-n - 1, n + 2):
expected = pyfunc(n, k)
got = cfunc(n, k)
self.assertEqual(type(expected), type(got))
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
np.testing.assert_array_equal(e, g)
def _triangular_indices_tests_n_m(self, pyfunc):
self._triangular_indices_tests_base(
pyfunc,
[[n, m] for n in range(10) for m in range(2 * n)]
)

def _triangular_indices_tests_n_k_m(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)

for n in range(10):
for m in range(2 * n):
for k in range(-n - 1, n + 2):
expected = pyfunc(n, k, m)
got = cfunc(n, k, m)
self.assertEqual(type(expected), type(got))
self.assertEqual(len(expected), len(got))
for e, g in zip(expected, got):
np.testing.assert_array_equal(e, g)
self._triangular_indices_tests_base(
pyfunc,
[[n, k, m] for n in range(10) for k in range(-n - 1, n + 2) for m in range(2 * n)]
)

def _triangular_indices_from_tests_arr(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
Expand Down Expand Up @@ -1196,10 +1216,12 @@ def test_tril_exceptions(self):
def test_tril_indices(self):
self._triangular_indices_tests_n(tril_indices_n)
self._triangular_indices_tests_n_k(tril_indices_n_k)
self._triangular_indices_tests_n_m(tril_indices_n_m)
self._triangular_indices_tests_n_k_m(tril_indices_n_k_m)
self._triangular_indices_exceptions(tril_indices_n_k)
self._triangular_indices_exceptions(tril_indices_n_m)
self._triangular_indices_exceptions(tril_indices_n_k_m)

def test_tril_indices_from(self):
self._triangular_indices_from_tests_arr(tril_indices_from_arr)
self._triangular_indices_from_tests_arr_k(tril_indices_from_arr_k)
Expand All @@ -1216,8 +1238,10 @@ def test_triu_exceptions(self):
def test_triu_indices(self):
self._triangular_indices_tests_n(triu_indices_n)
self._triangular_indices_tests_n_k(triu_indices_n_k)
self._triangular_indices_tests_n_m(triu_indices_n_m)
self._triangular_indices_tests_n_k_m(triu_indices_n_k_m)
self._triangular_indices_exceptions(triu_indices_n_k)
self._triangular_indices_exceptions(triu_indices_n_m)
self._triangular_indices_exceptions(triu_indices_n_k_m)

def test_triu_indices_from(self):
Expand Down

0 comments on commit 2993966

Please sign in to comment.