Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dtype input argument in np.sum #4472

Merged
merged 29 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9e5f6b7
added dtype kwarg (doesn't do anything) + tests
luk-f-a Aug 21, 2019
cf899bf
fixed impl - basic implementation passes tests
luk-f-a Aug 21, 2019
de3d1c0
format fixes and test removal
luk-f-a Aug 21, 2019
ee11c3d
added dtype-axis combination
luk-f-a Aug 27, 2019
4d674fa
added dtype-axis tests
luk-f-a Aug 27, 2019
bbde75f
adjusting arraydecl (1 of 2)
luk-f-a Aug 27, 2019
66d525c
adjusting arraydecl (2 of 2)
luk-f-a Aug 27, 2019
bd5014a
flake8 fixes
luk-f-a Aug 28, 2019
7c3d9bb
factoring out common parts of sum implementations
luk-f-a Aug 28, 2019
3abb4a8
added tests
luk-f-a Aug 31, 2019
496eeca
fixed issue with np.int32 arrays
luk-f-a Sep 6, 2019
affa1ec
disabled timedelta test
luk-f-a Sep 6, 2019
b63d8d6
flake8 fix
luk-f-a Sep 7, 2019
5a3e219
updated documentation
luk-f-a Sep 17, 2019
e0370f7
split out tests for int32 and uint32 due to special numpy behaviour
luk-f-a Sep 17, 2019
bebddf2
Merge branch 'master' into sum_with_dtype
luk-f-a Sep 17, 2019
26ab11f
Merge branch 'master' into sum_with_dtype
luk-f-a Sep 17, 2019
d236a98
test fix for 32-bit machines
luk-f-a Sep 18, 2019
0ffe381
fix for scalar results
luk-f-a Sep 18, 2019
18be7fb
fixes to address review
luk-f-a Sep 25, 2019
8959801
fix docs char limit
luk-f-a Sep 25, 2019
c25d0b5
changes after second review
luk-f-a Sep 25, 2019
0726ed1
removing redundant tests
luk-f-a Sep 26, 2019
50ac590
Word wrap docs
stuartarchibald Sep 26, 2019
f3a333b
Merge pull request #2 from stuartarchibald/pr_4472
luk-f-a Sep 26, 2019
6bc30ed
removing redundant tests
luk-f-a Sep 26, 2019
0d642a9
Merge remote-tracking branch 'origin/sum_with_dtype' into sum_with_dtype
luk-f-a Sep 26, 2019
80feaa6
reverting addition of mean tests
luk-f-a Sep 26, 2019
a3074b4
reverting addition of mean tests2
luk-f-a Sep 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,21 @@ The following methods of Numpy arrays are supported:
* :meth:`~numpy.ndarray.repeat` (no axis argument)
* :meth:`~numpy.ndarray.reshape` (only the 1-argument form)
* :meth:`~numpy.ndarray.sort` (without arguments)
* :meth:`~numpy.ndarray.sum` (with or without the ``axis`` argument.
``axis`` only supports ``integer`` values)
* :meth:`~numpy.ndarray.sum` (with or without the ``axis`` and/or ``dtype`` arguments.)

* ``axis`` only supports ``integer`` values.
* If the ``axis`` argument is a compile-time constant, all valid values are supported.
An out-of-range value will result in a ``LoweringError`` at compile-time.
* If the ``axis`` argument is not a compile-time constant, only values from 0 to 3 are supported.
An out-of-range value will result in a runtime exception.
* All numeric ``dtypes`` are supported as ``dtype`` parameter. ``timedelta`` arrays can
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this section all needs line wrapping to 80 chars

Copy link
Contributor Author

@luk-f-a luk-f-a Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, there's one left with 81 characters, is that ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

err no :) I can go fix it though so don't worry!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, I didn't know it had to be strictly 80. I thought it was like a pep8 thing, where there's a bit of flexibility.

be used as input arrays but ``timedelta`` is not supported as ``dtype`` parameter.
* When ``dtype`` is given, it determines the type of the internal accumulator. When it is not,
the selection is made automatically based on the input array's ``dtype``, mostly following the same
rules as NumPy. However, on 64-bit Windows, Numba uses a 64-bit accumulator for integer inputs
(``int64`` for ``int32`` inputs and ``uint64`` for ``uint32`` inputs), while NumPy would use a 32-bit
accumulator in those cases.


* :meth:`~numpy.ndarray.transpose`
* :meth:`~numpy.ndarray.view` (only the 1-argument form)
Expand Down
155 changes: 112 additions & 43 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,50 +177,20 @@ def _array_sum_axis_nop(arr, v):
return arr


@lower_builtin(np.sum, types.Array, types.intp)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral)
@lower_builtin("array.sum", types.Array, types.intp)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral)
def array_sum_axis(context, builder, sig, args):
"""
The third parameter to gen_index_tuple that generates the indexing
tuples has to be a const so we can't just pass "axis" through since
that isn't const. We can check for specific values and have
different instances that do take consts. Supporting axis summation
only up to the fourth dimension for now.
"""
# typing/arraydecl.py:sum_expand defines the return type for sum with axis.
# It is one dimension less than the input array.
def gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero):
def inner(arr, axis):
"""
function that performs sums over one specific axis

retty = sig.return_type
zero = getattr(retty, 'dtype', retty)(0)
# if the return is scalar in type then "take" the 0th element of the
# 0d array accumulator as the return value
if getattr(retty, 'ndim', None) is None:
op = np.take
else:
op = _array_sum_axis_nop
[ty_array, ty_axis] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
if const_axis_val < 0 or const_axis_val > ty_array.ndim:
raise ValueError("'axis' entry is out of bounds")
The third parameter to gen_index_tuple that generates the indexing
tuples has to be a const so we can't just pass "axis" through since
that isn't const. We can check for specific values and have
different instances that do take consts. Supporting axis summation
only up to the fourth dimension for now.

ty_axis = context.typing_context.resolve_value_type(const_axis_val)
axis_val = context.get_constant(ty_axis, const_axis_val)
# rewrite arguments
args = args[0], axis_val
# rewrite sig
sig = sig.replace(args=[ty_array, ty_axis])
is_axis_const = True

def array_sum_impl_axis(arr, axis):
typing/arraydecl.py:sum_expand defines the return type for sum with axis.
It is one dimension less than the input array.
"""
ndim = arr.ndim

if not is_axis_const:
Expand Down Expand Up @@ -268,8 +238,107 @@ def array_sum_impl_axis(arr, axis):
elif axis == 3:
index_tuple4 = _gen_index_tuple(arr.shape, axis_index, 3)
result += arr[index_tuple4]

return op(result, 0)
return inner


@lower_builtin(np.sum, types.Array, types.intp, types.DTypeSpec)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral, types.DTypeSpec)
@lower_builtin("array.sum", types.Array, types.intp, types.DTypeSpec)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral, types.DTypeSpec)
def array_sum_axis_dtype(context, builder, sig, args):
retty = sig.return_type
zero = getattr(retty, 'dtype', retty)(0)
# if the return is scalar in type then "take" the 0th element of the
# 0d array accumulator as the return value
if getattr(retty, 'ndim', None) is None:
op = np.take
else:
op = _array_sum_axis_nop
[ty_array, ty_axis, ty_dtype] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
if const_axis_val < 0 or const_axis_val > ty_array.ndim:
raise ValueError("'axis' entry is out of bounds")

ty_axis = context.typing_context.resolve_value_type(const_axis_val)
axis_val = context.get_constant(ty_axis, const_axis_val)
# rewrite arguments
args = args[0], axis_val, args[2]
# rewrite sig
sig = sig.replace(args=[ty_array, ty_axis, ty_dtype])
is_axis_const = True

gen_impl = gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero)
compiled = register_jitable(gen_impl)

def array_sum_impl_axis(arr, axis, dtype):
return compiled(arr, axis)

res = context.compile_internal(builder, array_sum_impl_axis, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin(np.sum, types.Array, types.DTypeSpec)
luk-f-a marked this conversation as resolved.
Show resolved Hide resolved
@lower_builtin("array.sum", types.Array, types.DTypeSpec)
def array_sum_dtype(context, builder, sig, args):
zero = sig.return_type(0)

def array_sum_impl(arr, dtype):
c = zero
for v in np.nditer(arr):
c += v.item()
return c

res = context.compile_internal(builder, array_sum_impl, sig, args,
locals=dict(c=sig.return_type))
return impl_ret_borrowed(context, builder, sig.return_type, res)


@lower_builtin(np.sum, types.Array, types.intp)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral)
@lower_builtin("array.sum", types.Array, types.intp)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral)
def array_sum_axis(context, builder, sig, args):
retty = sig.return_type
zero = getattr(retty, 'dtype', retty)(0)
# if the return is scalar in type then "take" the 0th element of the
# 0d array accumulator as the return value
if getattr(retty, 'ndim', None) is None:
op = np.take
else:
op = _array_sum_axis_nop
[ty_array, ty_axis] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
if const_axis_val < 0 or const_axis_val > ty_array.ndim:
raise ValueError("'axis' entry is out of bounds")

ty_axis = context.typing_context.resolve_value_type(const_axis_val)
axis_val = context.get_constant(ty_axis, const_axis_val)
# rewrite arguments
args = args[0], axis_val
# rewrite sig
sig = sig.replace(args=[ty_array, ty_axis])
is_axis_const = True

gen_impl = gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero)
compiled = register_jitable(gen_impl)

def array_sum_impl_axis(arr, axis):
return compiled(arr, axis)

res = context.compile_internal(builder, array_sum_impl_axis, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)
Expand Down