Skip to content

Commit

Permalink
Merge pull request #4472 from luk-f-a/sum_with_dtype
Browse files Browse the repository at this point in the history
Allow dtype input argument in np.sum
  • Loading branch information
stuartarchibald committed Sep 30, 2019
2 parents 4587a70 + a3074b4 commit 6359d3b
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 74 deletions.
21 changes: 17 additions & 4 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,26 @@ The following methods of Numpy arrays are supported:
* :meth:`~numpy.ndarray.repeat` (no axis argument)
* :meth:`~numpy.ndarray.reshape` (only the 1-argument form)
* :meth:`~numpy.ndarray.sort` (without arguments)
* :meth:`~numpy.ndarray.sum` (with or without the ``axis`` argument.
``axis`` only supports ``integer`` values)
* :meth:`~numpy.ndarray.sum` (with or without the ``axis`` and/or ``dtype``
arguments.)

* If the ``axis`` argument is a compile-time constant, all valid values are supported.
* ``axis`` only supports ``integer`` values.
* If the ``axis`` argument is a compile-time constant, all valid values
are supported.
An out-of-range value will result in a ``LoweringError`` at compile-time.
* If the ``axis`` argument is not a compile-time constant, only values from 0 to 3 are supported.
* If the ``axis`` argument is not a compile-time constant, only values
from 0 to 3 are supported.
An out-of-range value will result in a runtime exception.
* All numeric ``dtypes`` are supported in the ``dtype`` parameter.
``timedelta`` arrays can be used as input arrays but ``timedelta`` is not
supported as ``dtype`` parameter.
* When a ``dtype`` is given, it determines the type of the internal
accumulator. When it is not, the selection is made automatically based on
the input array's ``dtype``, mostly following the same rules as NumPy.
However, on 64-bit Windows, Numba uses a 64-bit accumulator for integer
inputs (``int64`` for ``int32`` inputs and ``uint64`` for ``uint32``
inputs), while NumPy would use a 32-bit accumulator in those cases.


* :meth:`~numpy.ndarray.transpose`
* :meth:`~numpy.ndarray.view` (only the 1-argument form)
Expand Down
155 changes: 112 additions & 43 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,50 +178,20 @@ def _array_sum_axis_nop(arr, v):
return arr


@lower_builtin(np.sum, types.Array, types.intp)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral)
@lower_builtin("array.sum", types.Array, types.intp)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral)
def array_sum_axis(context, builder, sig, args):
"""
The third parameter to gen_index_tuple that generates the indexing
tuples has to be a const so we can't just pass "axis" through since
that isn't const. We can check for specific values and have
different instances that do take consts. Supporting axis summation
only up to the fourth dimension for now.
"""
# typing/arraydecl.py:sum_expand defines the return type for sum with axis.
# It is one dimension less than the input array.
def gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero):
def inner(arr, axis):
"""
function that performs sums over one specific axis
retty = sig.return_type
zero = getattr(retty, 'dtype', retty)(0)
# if the return is scalar in type then "take" the 0th element of the
# 0d array accumulator as the return value
if getattr(retty, 'ndim', None) is None:
op = np.take
else:
op = _array_sum_axis_nop
[ty_array, ty_axis] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
if const_axis_val < 0 or const_axis_val > ty_array.ndim:
raise ValueError("'axis' entry is out of bounds")
The third parameter to gen_index_tuple that generates the indexing
tuples has to be a const so we can't just pass "axis" through since
that isn't const. We can check for specific values and have
different instances that do take consts. Supporting axis summation
only up to the fourth dimension for now.
ty_axis = context.typing_context.resolve_value_type(const_axis_val)
axis_val = context.get_constant(ty_axis, const_axis_val)
# rewrite arguments
args = args[0], axis_val
# rewrite sig
sig = sig.replace(args=[ty_array, ty_axis])
is_axis_const = True

def array_sum_impl_axis(arr, axis):
typing/arraydecl.py:sum_expand defines the return type for sum with axis.
It is one dimension less than the input array.
"""
ndim = arr.ndim

if not is_axis_const:
Expand Down Expand Up @@ -269,8 +239,107 @@ def array_sum_impl_axis(arr, axis):
elif axis == 3:
index_tuple4 = _gen_index_tuple(arr.shape, axis_index, 3)
result += arr[index_tuple4]

return op(result, 0)
return inner


@lower_builtin(np.sum, types.Array, types.intp, types.DTypeSpec)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral, types.DTypeSpec)
@lower_builtin("array.sum", types.Array, types.intp, types.DTypeSpec)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral, types.DTypeSpec)
def array_sum_axis_dtype(context, builder, sig, args):
retty = sig.return_type
zero = getattr(retty, 'dtype', retty)(0)
# if the return is scalar in type then "take" the 0th element of the
# 0d array accumulator as the return value
if getattr(retty, 'ndim', None) is None:
op = np.take
else:
op = _array_sum_axis_nop
[ty_array, ty_axis, ty_dtype] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
if const_axis_val < 0 or const_axis_val > ty_array.ndim:
raise ValueError("'axis' entry is out of bounds")

ty_axis = context.typing_context.resolve_value_type(const_axis_val)
axis_val = context.get_constant(ty_axis, const_axis_val)
# rewrite arguments
args = args[0], axis_val, args[2]
# rewrite sig
sig = sig.replace(args=[ty_array, ty_axis, ty_dtype])
is_axis_const = True

gen_impl = gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero)
compiled = register_jitable(gen_impl)

def array_sum_impl_axis(arr, axis, dtype):
return compiled(arr, axis)

res = context.compile_internal(builder, array_sum_impl_axis, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)


@lower_builtin(np.sum, types.Array, types.DTypeSpec)
@lower_builtin("array.sum", types.Array, types.DTypeSpec)
def array_sum_dtype(context, builder, sig, args):
zero = sig.return_type(0)

def array_sum_impl(arr, dtype):
c = zero
for v in np.nditer(arr):
c += v.item()
return c

res = context.compile_internal(builder, array_sum_impl, sig, args,
locals=dict(c=sig.return_type))
return impl_ret_borrowed(context, builder, sig.return_type, res)


@lower_builtin(np.sum, types.Array, types.intp)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral)
@lower_builtin("array.sum", types.Array, types.intp)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral)
def array_sum_axis(context, builder, sig, args):
retty = sig.return_type
zero = getattr(retty, 'dtype', retty)(0)
# if the return is scalar in type then "take" the 0th element of the
# 0d array accumulator as the return value
if getattr(retty, 'ndim', None) is None:
op = np.take
else:
op = _array_sum_axis_nop
[ty_array, ty_axis] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
if const_axis_val < 0 or const_axis_val > ty_array.ndim:
raise ValueError("'axis' entry is out of bounds")

ty_axis = context.typing_context.resolve_value_type(const_axis_val)
axis_val = context.get_constant(ty_axis, const_axis_val)
# rewrite arguments
args = args[0], axis_val
# rewrite sig
sig = sig.replace(args=[ty_array, ty_axis])
is_axis_const = True

gen_impl = gen_sum_axis_impl(is_axis_const, const_axis_val, op, zero)
compiled = register_jitable(gen_impl)

def array_sum_impl_axis(arr, axis):
return compiled(arr, axis)

res = context.compile_internal(builder, array_sum_impl_axis, sig, args)
return impl_ret_new_ref(context, builder, sig.return_type, res)
Expand Down

0 comments on commit 6359d3b

Please sign in to comment.