Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
244 lines (205 sloc) 7.08 KB
import collections
import numpy as np
from numba.core import types
QuicksortImplementation = collections.namedtuple(
'QuicksortImplementation',
(# The compile function itself
'compile',
# All subroutines exercised by test_sort
'partition', 'partition3', 'insertion_sort',
# The top-level function
'run_quicksort',
))
Partition = collections.namedtuple('Partition', ('start', 'stop'))
# Under this size, switch to a simple insertion sort
SMALL_QUICKSORT = 15
MAX_STACK = 100
def make_quicksort_impl(wrap, lt=None, is_argsort=False, is_list=False):
intp = types.intp
zero = intp(0)
# Two subroutines to make the core algorithm generic wrt. argsort
# or normal sorting. Note the genericity may make basic sort()
# slightly slower (~5%)
if is_argsort:
if is_list:
@wrap
def make_res(A):
return [x for x in range(len(A))]
else:
@wrap
def make_res(A):
return np.arange(A.size)
@wrap
def GET(A, idx_or_val):
return A[idx_or_val]
else:
@wrap
def make_res(A):
return A
@wrap
def GET(A, idx_or_val):
return idx_or_val
def default_lt(a, b):
"""
Trivial comparison function between two keys.
"""
return a < b
LT = wrap(lt if lt is not None else default_lt)
@wrap
def insertion_sort(A, R, low, high):
"""
Insertion sort A[low:high + 1]. Note the inclusive bounds.
"""
assert low >= 0
if high <= low:
return
for i in range(low + 1, high + 1):
k = R[i]
v = GET(A, k)
# Insert v into A[low:i]
j = i
while j > low and LT(v, GET(A, R[j - 1])):
# Make place for moving A[i] downwards
R[j] = R[j - 1]
j -= 1
R[j] = k
@wrap
def partition(A, R, low, high):
"""
Partition A[low:high + 1] around a chosen pivot. The pivot's index
is returned.
"""
assert low >= 0
assert high > low
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.
# median of three {low, middle, high}
if LT(GET(A, R[mid]), GET(A, R[low])):
R[low], R[mid] = R[mid], R[low]
if LT(GET(A, R[high]), GET(A, R[mid])):
R[high], R[mid] = R[mid], R[high]
if LT(GET(A, R[mid]), GET(A, R[low])):
R[low], R[mid] = R[mid], R[low]
pivot = GET(A, R[mid])
# Temporarily stash the pivot at the end
R[high], R[mid] = R[mid], R[high]
i = low
j = high - 1
while True:
while i < high and LT(GET(A, R[i]), pivot):
i += 1
while j >= low and LT(pivot, GET(A, R[j])):
j -= 1
if i >= j:
break
R[i], R[j] = R[j], R[i]
i += 1
j -= 1
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
R[i], R[high] = R[high], R[i]
return i
@wrap
def partition3(A, low, high):
"""
Three-way partition [low, high) around a chosen pivot.
A tuple (lt, gt) is returned such that:
- all elements in [low, lt) are < pivot
- all elements in [lt, gt] are == pivot
- all elements in (gt, high] are > pivot
"""
mid = (low + high) >> 1
# median of three {low, middle, high}
if LT(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
if LT(A[high], A[mid]):
A[high], A[mid] = A[mid], A[high]
if LT(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
A[low], A[mid] = A[mid], A[low]
lt = low
gt = high
i = low + 1
while i <= gt:
if LT(A[i], pivot):
A[lt], A[i] = A[i], A[lt]
lt += 1
i += 1
elif LT(pivot, A[i]):
A[gt], A[i] = A[i], A[gt]
gt -= 1
else:
i += 1
return lt, gt
@wrap
def run_quicksort(A):
R = make_res(A)
if len(A) < 2:
return R
stack = [Partition(zero, zero)] * MAX_STACK
stack[0] = Partition(zero, len(A) - 1)
n = 1
while n > 0:
n -= 1
low, high = stack[n]
# Partition until it becomes more efficient to do an insertion sort
while high - low >= SMALL_QUICKSORT:
assert n < MAX_STACK
i = partition(A, R, low, high)
# Push largest partition on the stack
if high - i > i - low:
# Right is larger
if high > i:
stack[n] = Partition(i + 1, high)
n += 1
high = i - 1
else:
if i > low:
stack[n] = Partition(low, i - 1)
n += 1
low = i + 1
insertion_sort(A, R, low, high)
return R
# Unused quicksort implementation based on 3-way partitioning; the
# partitioning scheme turns out exhibiting bad behaviour on sorted arrays.
@wrap
def _run_quicksort(A):
stack = [Partition(zero, zero)] * 100
stack[0] = Partition(zero, len(A) - 1)
n = 1
while n > 0:
n -= 1
low, high = stack[n]
# Partition until it becomes more efficient to do an insertion sort
while high - low >= SMALL_QUICKSORT:
assert n < MAX_STACK
l, r = partition3(A, low, high)
# One trivial (empty) partition => iterate on the other
if r == high:
high = l - 1
elif l == low:
low = r + 1
# Push largest partition on the stack
elif high - r > l - low:
# Right is larger
stack[n] = Partition(r + 1, high)
n += 1
high = l - 1
else:
stack[n] = Partition(low, l - 1)
n += 1
low = r + 1
insertion_sort(A, low, high)
return QuicksortImplementation(wrap,
partition, partition3, insertion_sort,
run_quicksort)
def make_py_quicksort(*args, **kwargs):
return make_quicksort_impl((lambda f: f), *args, **kwargs)
def make_jit_quicksort(*args, **kwargs):
from numba.core.extending import register_jitable
return make_quicksort_impl((lambda f: register_jitable(f)),
*args, **kwargs)
You can’t perform that action at this time.