Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sort/sorted with key. #4977

Merged
merged 2 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 75 additions & 46 deletions numba/targets/listobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import operator

from llvmlite import ir
from numba import types, cgutils, typing
from numba import types, cgutils, typing, errors
from numba.targets.imputils import (lower_builtin, lower_cast,
iternext_impl, impl_ret_borrowed,
impl_ret_new_ref, impl_ret_untracked,
RefType)
from numba.extending import overload_method, overload
from numba.utils import cached_property
from . import quicksort, slicing

Expand Down Expand Up @@ -1031,57 +1032,85 @@ def list_reverse_impl(lst):
# -----------------------------------------------------------------------------
# Sorting

_sorting_init = False

def load_sorts():
"""
Load quicksort lazily, to avoid circular imports across the jit() global.
"""
g = globals()
if g['_sorting_init']:
return

def gt(a, b):
return a > b

default_sort = quicksort.make_jit_quicksort()
reversed_sort = quicksort.make_jit_quicksort(lt=gt)
g['run_default_sort'] = default_sort.run_quicksort
g['run_reversed_sort'] = reversed_sort.run_quicksort
g['_sorting_init'] = True


@lower_builtin("list.sort", types.List)
@lower_builtin("list.sort", types.List, types.Boolean)
def list_sort(context, builder, sig, args):
load_sorts()

if len(args) == 1:
sig = typing.signature(sig.return_type, *sig.args + (types.boolean,))
args = tuple(args) + (cgutils.false_bit,)

def list_sort_impl(lst, reverse):
if reverse:
run_reversed_sort(lst)
def gt(a, b):
return a > b

sort_forwards = quicksort.make_jit_quicksort().run_quicksort
sort_backwards = quicksort.make_jit_quicksort(lt=gt).run_quicksort

arg_sort_forwards = quicksort.make_jit_quicksort(is_argsort=True,
is_list=True).run_quicksort
arg_sort_backwards = quicksort.make_jit_quicksort(is_argsort=True, lt=gt,
is_list=True).run_quicksort


def _sort_check_reverse(reverse):
if isinstance(reverse, types.Omitted):
rty = reverse.value
elif isinstance(reverse, types.Optional):
rty = reverse.type
else:
rty = reverse
if not isinstance(rty, (types.Boolean, types.Integer, int, bool)):
msg = "an integer is required for 'reverse' (got type %s)" % reverse
raise errors.TypingError(msg)
return rty


def _sort_check_key(key):
if isinstance(key, types.Optional):
msg = ("Key must concretely be None or a Numba JIT compiled function, "
"an Optional (union of None and a value) was found")
raise errors.TypingError(msg)
if not (isinstance(key, (types.NoneType, types.Omitted)) or key is None or
isinstance(key, types.Dispatcher)):
msg = "Key must be None or a Numba JIT compiled function"
raise errors.TypingError(msg)


@overload_method(types.List, "sort")
def ol_list_sort(lst, key=None, reverse=False):

_sort_check_key(key)
_sort_check_reverse(reverse)

if (isinstance(key, (types.NoneType, types.Omitted)) or key is None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cgutils.is_nonelike() instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in cgutils? It's typing not codegen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. there's no typing util either. will need to sort that out later.

KEY = False
sort_f = sort_forwards
sort_b = sort_backwards
elif isinstance(key, types.Dispatcher):
KEY = True
sort_f = arg_sort_forwards
sort_b = arg_sort_backwards

def impl(lst, key=None, reverse=False):
if KEY is True:
_lst = [key(x) for x in lst]
else:
_lst = lst
if reverse is False or reverse == 0:
tmp = sort_f(_lst)
else:
run_default_sort(lst)
tmp = sort_b(_lst)
if KEY is True:
lst[:] = [lst[i] for i in tmp]
return impl

return context.compile_internal(builder, list_sort_impl, sig, args)

@lower_builtin(sorted, types.IterableType)
@lower_builtin(sorted, types.IterableType, types.Boolean)
def sorted_impl(context, builder, sig, args):
if len(args) == 1:
sig = typing.signature(sig.return_type, *sig.args + (types.boolean,))
args = tuple(args) + (cgutils.false_bit,)
@overload(sorted)
def ol_sorted(iterable, key=None, reverse=False):

def sorted_impl(it, reverse):
lst = list(it)
lst.sort(reverse=reverse)
return lst
if not isinstance(iterable, types.IterableType):
return False

return context.compile_internal(builder, sorted_impl, sig, args)
_sort_check_key(key)
_sort_check_reverse(reverse)

def impl(iterable, key=None, reverse=False):
lst = list(iterable)
lst.sort(key=key, reverse=reverse)
return lst
return impl

# -----------------------------------------------------------------------------
# Implicit casting
Expand Down
13 changes: 9 additions & 4 deletions numba/targets/quicksort.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
MAX_STACK = 100


def make_quicksort_impl(wrap, lt=None, is_argsort=False):
def make_quicksort_impl(wrap, lt=None, is_argsort=False, is_list=False):

intp = types.intp
zero = intp(0)
Expand All @@ -36,9 +36,14 @@ def make_quicksort_impl(wrap, lt=None, is_argsort=False):
# or normal sorting. Note the genericity may make basic sort()
# slightly slower (~5%)
if is_argsort:
@wrap
def make_res(A):
return np.arange(A.size)
if is_list:
@wrap
def make_res(A):
return [x for x in range(len(A))]
else:
@wrap
def make_res(A):
return np.arange(A.size)

@wrap
def GET(A, idx_or_val):
Expand Down
152 changes: 150 additions & 2 deletions numba/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@
import numpy as np

from numba.compiler import compile_isolated, Flags
from numba import jit, types, utils, njit
from numba import jit, types, utils, njit, errors
import numba.unittest_support as unittest
from numba import testing
from numba.six import PY2
from .support import TestCase, MemoryLeakMixin, tag

from numba.targets.quicksort import make_py_quicksort, make_jit_quicksort
from numba.targets.mergesort import make_jit_mergesort
from .timsort import make_py_timsort, make_jit_timsort, MergeRun


skip_py_27 = unittest.skipIf(PY2, "Not supported on Python 2")


def make_temp_list(keys, n):
return [keys[0]] * n

Expand Down Expand Up @@ -898,7 +902,7 @@ def test_sorted_reverse(self):
self.assertNotEqual(list(orig), got) # sanity check


class TestMergeSort(unittest.TestCase):
class TestMergeSort(TestCase):
def setUp(self):
np.random.seed(321)

Expand All @@ -923,5 +927,149 @@ def test_argsort_stable(self):
self.check_argsort_stable(sorter, *args)


nop_compiler = lambda x:x


class TestSortSlashSortedWithKey(MemoryLeakMixin, TestCase):

def test_01(self):

a = [3, 1, 4, 1, 5, 9]

@njit
def external_key(z):
return 1. / z

@njit
def foo(x, key=None):
return sorted(x[:], key=key), x[:].sort(key=key)
sklam marked this conversation as resolved.
Show resolved Hide resolved

self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))
self.assertPreciseEqual(foo(a[:], external_key),
foo.py_func(a[:], external_key))

def test_02(self):

a = [3, 1, 4, 1, 5, 9]

@njit
def foo(x):
def closure_key(z):
return 1. / z
return sorted(x[:], key=closure_key), x[:].sort(key=closure_key)
sklam marked this conversation as resolved.
Show resolved Hide resolved

self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))

def test_03(self):

a = [3, 1, 4, 1, 5, 9]

def gen(compiler):

@compiler
def bar(x, func):
return sorted(x[:], key=func), x[:].sort(key=func)
sklam marked this conversation as resolved.
Show resolved Hide resolved

@compiler
def foo(x):
def closure_escapee_key(z):
return 1. / z
return bar(x, closure_escapee_key)

return foo

self.assertPreciseEqual(gen(njit)(a[:]), gen(nop_compiler)(a[:]))

@skip_py_27
def test_04(self):

a = ['a','b','B','b','C','A']

@njit
def external_key(z):
return z.upper()

@njit
def foo(x, key=None):
return sorted(x[:], key=key), x[:].sort(key=key)
sklam marked this conversation as resolved.
Show resolved Hide resolved

self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))
self.assertPreciseEqual(foo(a[:], external_key),
foo.py_func(a[:], external_key))

@skip_py_27
def test_05(self):

a = ['a','b','B','b','C','A']

@njit
def external_key(z):
return z.upper()

@njit
def foo(x, key=None, reverse=False):
return (sorted(x[:], key=key, reverse=reverse),
x[:].sort(key=key, reverse=reverse))
sklam marked this conversation as resolved.
Show resolved Hide resolved

for key, rev in itertools.product((None, external_key),
(True, False, 1, -12, 0)):
self.assertPreciseEqual(foo(a[:], key, rev),
foo.py_func(a[:], key, rev))

def test_optional_on_key(self):
a = [3, 1, 4, 1, 5, 9]

@njit
def foo(x, predicate):
if predicate:
def closure_key(z):
return 1. / z
else:
closure_key = None

return (sorted(x[:], key=closure_key),
x[:].sort(key=closure_key))
sklam marked this conversation as resolved.
Show resolved Hide resolved

with self.assertRaises(errors.TypingError) as raises:
TF = True
foo(a[:], TF)

msg = "Key must concretely be None or a Numba JIT compiled function"
self.assertIn(msg, str(raises.exception))

def test_exceptions_sorted(self):

@njit
def foo_sorted(x, key=None, reverse=False):
return sorted(x[:], key=key, reverse=reverse)

@njit
def foo_sort(x, key=None, reverse=False):
return x[:].sort(key=key, reverse=reverse)
sklam marked this conversation as resolved.
Show resolved Hide resolved

@njit
def external_key(z):
return 1. / z

a = [3, 1, 4, 1, 5, 9]

for impl in (foo_sort, foo_sorted):

# check illegal key
with self.assertRaises(errors.TypingError) as raises:
impl(a, key="illegal")

expect = "Key must be None or a Numba JIT compiled function"
self.assertIn(expect, str(raises.exception))

# check illegal reverse
with self.assertRaises(errors.TypingError) as raises:
impl(a, key=external_key, reverse="go backwards")

expect = "an integer is required for 'reverse'"
self.assertIn(expect, str(raises.exception))



if __name__ == '__main__':
unittest.main()
27 changes: 0 additions & 27 deletions numba/typing/listdecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,6 @@ def generic(self, args, kws):
return signature(types.List(types.undefined))


@infer_global(sorted)
class SortedBuiltin(CallableTemplate):

def generic(self):
def typer(iterable, reverse=None):
if not isinstance(iterable, types.IterableType):
return
if (reverse is not None and
not isinstance(reverse, types.Boolean)):
return
return types.List(iterable.iterator_type.yield_type)

return typer


@infer_getattr
class ListAttribute(AttributeTemplate):
key = types.List
Expand Down Expand Up @@ -138,18 +123,6 @@ def resolve_reverse(self, list, args, kws):
assert not kws
return signature(types.none)

def resolve_sort(self, list):
def typer(reverse=None):
if (reverse is not None and
not isinstance(reverse, types.Boolean)):
return
return types.none

return types.BoundFunction(make_callable_template(key="list.sort",
typer=typer,
recvr=list),
list)


@infer_global(operator.add)
class AddList(AbstractTemplate):
Expand Down