Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min max on iterables #3820

Merged
merged 12 commits into from
Mar 14, 2019
38 changes: 32 additions & 6 deletions numba/targets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,22 +446,48 @@ def impl_index_value(context, builder, sig, args):
index_value.value = value
return index_value._getvalue()


@overload(min)
def indval_min(indval1, indval2):
def indval_min(indval1, indval2=None):

if isinstance(indval1, IndexValueType) and \
isinstance(indval2, IndexValueType):
def min_impl(indval1, indval2):
def impl(indval1, indval2=None):
if indval1.value > indval2.value:
return indval2
return indval1
return min_impl
return impl

if indval2 is None:
if isinstance(indval1, types.IterableType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if it's possible to pull these out of this scope which has a non-intuitive signature? Would defining another overload like:

@overload(min)
def iterable_min(iterable):
    if isinstance(iterable, types.IterableType):
        def impl(iterable):
            <details>

work? Could also close over the <, > operator to save duplication too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def impl(indval1, indval2=None):
it = iter(indval1)
return_val = next(it)
for val in it:
if val < return_val:
return_val = val
return return_val
return impl


@overload(max)
def indval_max(indval1, indval2):
def indval_max(indval1, indval2=None):

if isinstance(indval1, IndexValueType) and \
isinstance(indval2, IndexValueType):
def max_impl(indval1, indval2):
def impl(indval1, indval2=None):
if indval2.value > indval1.value:
return indval2
return indval1
return max_impl
return impl

if indval2 is None:
if isinstance(indval1, types.IterableType):
def impl(indval1, indval2=None):
it = iter(indval1)
return_val = next(it)
for val in it:
if val > return_val:
return_val = val
return return_val
return impl
56 changes: 55 additions & 1 deletion numba/tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numba.unittest_support as unittest
from numba.compiler import compile_isolated, Flags
from numba import jit, typeof, errors, types, utils, config
from numba import jit, typeof, errors, types, utils, config, njit
from .support import TestCase, tag


Expand Down Expand Up @@ -982,6 +982,60 @@ def test_pow_usecase(self):
r = cres.entry_point(x, y)
self.assertPreciseEqual(r, pow_usecase(x, y))

def _check_min_max(self, pyfunc):
cfunc = njit()(pyfunc)
expected = pyfunc()
got = cfunc()
self.assertPreciseEqual(expected, got)

def test_min_max_iterable_input(self):

@njit
def frange(start, stop, step):
i = start
while i < stop:
yield i
i += step

def sample_functions(op):
yield lambda: op(range(10))
yield lambda: op(range(4, 12))
yield lambda: op(range(-4, -15, -1))
yield lambda: op([6.6, 5.5, 7.7])
yield lambda: op([(3, 4), (1, 2)])
yield lambda: op(frange(1.1, 3.3, 0.1))
yield lambda: op([np.nan, -np.inf, np.inf, np.nan])
yield lambda: op([(1,), (1,), (1,)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps vary the tuple value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


for fn in sample_functions(op=min):
self._check_min_max(fn)

for fn in sample_functions(op=max):
self._check_min_max(fn)

def test_min_max_supplemental(self):
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
# check the parallel implementations of
# argmin and argmax, which depend on
# min and max builtins, respectively

def argmin(arr):
return np.argmin(arr)

def argmax(arr):
return np.argmax(arr)

def arrays():
yield np.arange(10, 2, -1)
yield np.ones(365)
yield self.random.randn(10)

for pyfunc in argmin, argmax:
cfunc = njit(parallel=True)(pyfunc)
for arr in arrays():
expected = pyfunc(arr)
got = cfunc(arr)
self.assertPreciseEqual(expected, got)


if __name__ == '__main__':
unittest.main()