Skip to content

Commit

Permalink
changes after second review
Browse files Browse the repository at this point in the history
  • Loading branch information
luk-f-a committed Sep 25, 2019
1 parent 8959801 commit c25d0b5
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions numba/tests/test_array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,9 @@ def check_err(a):
check_err(np.array([]))

def test_sum(self):
"""
test sum - basic
"""
pyfunc = array_sum
cfunc = jit(nopython=True)(pyfunc)
# OK
Expand All @@ -788,10 +791,12 @@ def test_sum(self):
self.assertPreciseEqual(pyfunc(a, 0), cfunc(a, 0))

def test_sum2(self):
""" test sum over a whole range of dtypes, no axis or dtype parameter
"""
pyfunc = array_sum
cfunc = jit(nopython=True)(pyfunc)
all_dtypes = [np.float64, np.float32, np.int64, np.int32,
np.complex64, np.uint32, np.uint64, np.timedelta64]
np.complex64, np.complex128, np.uint32, np.uint64, np.timedelta64]
all_test_arrays = [
[np.ones((7, 6, 5, 4, 3), arr_dtype),
np.ones(1, arr_dtype),
Expand All @@ -804,6 +809,7 @@ def test_sum2(self):
self.assertPreciseEqual(pyfunc(arr), cfunc(arr))

def test_sum_axis_kws(self):
""" test sum with axis parameter - basic """
pyfunc = array_sum_axis_kws
cfunc = jit(nopython=True)(pyfunc)
# OK
Expand All @@ -813,6 +819,7 @@ def test_sum_axis_kws(self):
self.assertPreciseEqual(pyfunc(a, axis=2), cfunc(a, axis=2))

def test_sum_axis_kws2(self):
""" test sum with axis parameter over a whole range of dtypes """
pyfunc = array_sum_axis_kws
cfunc = jit(nopython=True)(pyfunc)
all_dtypes = [np.float64, np.float32, np.int64, np.uint64, np.complex64,
Expand All @@ -836,13 +843,13 @@ def test_sum_axis_kws2(self):
cfunc(arr, axis=axis))

def test_sum_axis_kws3(self):
""" uint32 and int32 must be tested separately because Numpy's current
""" testing uint32 and int32 separately
uint32 and int32 must be tested separately because Numpy's current
behaviour is different in 64bits Windows (accumulates as int32)
and 64bits Linux (accumulates as int64), while Numba has decided to always
accumulate as int64, when the OS is 64bits. No testing has been done
for behaviours in 32 bits platforms.
:return:
"""
pyfunc = array_sum_axis_kws
cfunc = jit(nopython=True)(pyfunc)
Expand Down Expand Up @@ -876,6 +883,7 @@ def test_sum_axis_kws3(self):
self.assertEqual(npy_res, numba_res)

def test_sum_dtype_kws(self):
""" test sum with dtype parameter over a whole range of dtypes """
pyfunc = array_sum_dtype_kws
cfunc = jit(nopython=True)(pyfunc)
all_dtypes = [np.float64, np.float32, np.int64, np.int32, np.uint32,
Expand All @@ -894,7 +902,7 @@ def test_sum_dtype_kws(self):
np.dtype('int32'): [np.float64, np.int64, np.float32, np.int32],
np.dtype('uint32'): [np.float64, np.int64, np.float32],
np.dtype('uint64'): [np.float64, np.int64],
np.dtype('complex64'): [np.complex64],
np.dtype('complex64'): [np.complex64, np.complex128],
np.dtype('complex128'): [np.complex128],
np.dtype('timedelta64'): [np.timedelta64]}

Expand All @@ -907,26 +915,8 @@ def test_sum_dtype_kws(self):
self.assertPreciseEqual(pyfunc(arr, dtype=out_dtype),
cfunc(arr, dtype=out_dtype))

def test_sum_dtype_kws_negative(self):
pyfunc = array_sum_dtype_kws
cfunc = jit(nopython=True)(pyfunc)
dtype = np.float64
# OK
a = np.ones((7, 6, 5, 4, 3))
self.assertFalse(type(pyfunc(a, dtype=np.int32)) == cfunc(a, dtype=dtype))

def test_sum_axis_dtype_kws(self):
pyfunc = array_sum_axis_dtype_kws
cfunc = jit(nopython=True)(pyfunc)
dtype = np.float64
# OK
a = np.ones((7, 6, 5, 4, 3))
self.assertPreciseEqual(pyfunc(a, axis=1, dtype=dtype),
cfunc(a, axis=1, dtype=dtype))
self.assertPreciseEqual(pyfunc(a, axis=2, dtype=dtype),
cfunc(a, axis=2, dtype=dtype))

def test_sum_axis_dtype_kws2(self):
""" test sum with axis and dtype parameters over a whole range of dtypes """
pyfunc = array_sum_axis_dtype_kws
cfunc = jit(nopython=True)(pyfunc)
all_dtypes = [np.float64, np.float32, np.int64, np.int32, np.uint32,
Expand All @@ -943,7 +933,7 @@ def test_sum_axis_dtype_kws2(self):
np.dtype('int32'): [np.float64, np.int64, np.float32, np.int32],
np.dtype('uint32'): [np.float64, np.int64, np.float32],
np.dtype('uint64'): [np.float64, np.uint64],
np.dtype('complex64'): [np.complex64],
np.dtype('complex64'): [np.complex64, np.complex128],
np.dtype('complex128'): [np.complex128],
np.dtype('timedelta64'): [np.timedelta64]}

Expand All @@ -961,10 +951,7 @@ def test_sum_axis_dtype_kws2(self):
self.assertPreciseEqual(py_res, nb_res)

def test_sum_axis_dtype_pos_arg(self):
"""
testing that axis and dtype inputs work when passed as positional
:return:
"""
""" testing that axis and dtype inputs work when passed as positional """
pyfunc = array_sum_axis_dtype_pos
cfunc = jit(nopython=True)(pyfunc)
dtype = np.float64
Expand Down

0 comments on commit c25d0b5

Please sign in to comment.