Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min max on iterables #3820

Merged
merged 12 commits into from
Mar 14, 2019
31 changes: 30 additions & 1 deletion numba/targets/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def lower_get_type_max_value(context, builder, sig, args):
# -----------------------------------------------------------------------------

from numba.typing.builtins import IndexValue, IndexValueType
from numba.extending import overload
from numba.extending import overload, register_jitable

@lower_builtin(IndexValue, types.intp, types.Type)
@lower_builtin(IndexValue, types.uintp, types.Type)
Expand All @@ -447,6 +447,7 @@ def impl_index_value(context, builder, sig, args):
index_value.value = value
return index_value._getvalue()


@overload(min)
def indval_min(indval1, indval2):
if isinstance(indval1, IndexValueType) and \
Expand All @@ -457,6 +458,7 @@ def min_impl(indval1, indval2):
return indval1
return min_impl


@overload(max)
def indval_max(indval1, indval2):
if isinstance(indval1, IndexValueType) and \
Expand All @@ -466,3 +468,30 @@ def max_impl(indval1, indval2):
return indval2
return indval1
return max_impl


greater_than = register_jitable(lambda a, b: a > b)
less_than = register_jitable(lambda a, b: a < b)


@register_jitable
def min_max_impl(iterable, op):
if isinstance(iterable, types.IterableType):
def impl(iterable):
it = iter(iterable)
return_val = next(it)
for val in it:
if op(val, return_val):
return_val = val
return return_val
return impl


@overload(min)
def iterable_min(iterable):
return min_max_impl(iterable, less_than)


@overload(max)
def iterable_max(iterable):
return min_max_impl(iterable, greater_than)
5 changes: 2 additions & 3 deletions numba/targets/heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function, absolute_import, division

import heapq as hq
import numpy as np

from numba import types
from numba.errors import TypingError
Expand Down Expand Up @@ -209,7 +208,7 @@ def hq_nsmallest_impl(n, iterable):
if n == 0:
return [iterable[0] for _ in range(0)]
elif n == 1:
out = np.min(np.asarray(iterable))
out = min(iterable)
return [out]

size = len(iterable)
Expand Down Expand Up @@ -243,7 +242,7 @@ def hq_nlargest_impl(n, iterable):
if n == 0:
return [iterable[0] for _ in range(0)]
elif n == 1:
out = np.max(np.asarray(iterable))
out = max(iterable)
return [out]

size = len(iterable)
Expand Down
33 changes: 32 additions & 1 deletion numba/tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numba.unittest_support as unittest
from numba.compiler import compile_isolated, Flags
from numba import jit, typeof, errors, types, utils, config
from numba import jit, typeof, errors, types, utils, config, njit
from .support import TestCase, tag


Expand Down Expand Up @@ -982,6 +982,37 @@ def test_pow_usecase(self):
r = cres.entry_point(x, y)
self.assertPreciseEqual(r, pow_usecase(x, y))

def _check_min_max(self, pyfunc):
cfunc = njit()(pyfunc)
expected = pyfunc()
got = cfunc()
self.assertPreciseEqual(expected, got)

def test_min_max_iterable_input(self):

@njit
def frange(start, stop, step):
i = start
while i < stop:
yield i
i += step

def sample_functions(op):
yield lambda: op(range(10))
yield lambda: op(range(4, 12))
yield lambda: op(range(-4, -15, -1))
yield lambda: op([6.6, 5.5, 7.7])
yield lambda: op([(3, 4), (1, 2)])
yield lambda: op(frange(1.1, 3.3, 0.1))
yield lambda: op([np.nan, -np.inf, np.inf, np.nan])
yield lambda: op([(3,), (1,), (2,)])

for fn in sample_functions(op=min):
self._check_min_max(fn)

for fn in sample_functions(op=max):
self._check_min_max(fn)


if __name__ == '__main__':
unittest.main()