370 changes: 333 additions & 37 deletions std/algorithm/sorting.d
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ if (isRandomAccessRange!Range && hasLength!Range && hasSlicing!Range)
// Pivot at the front
r.swapAt(pivot, 0);

// Fork implemnentation depending on nothrow copy, assignment, and
// Fork implementation depending on nothrow copy, assignment, and
// comparison. If all of these are nothrow, use the specialized
// implementation discussed at https://youtube.com/watch?v=AxnotgLql0k.
static if (is(typeof(
Expand All @@ -640,8 +640,8 @@ if (isRandomAccessRange!Range && hasLength!Range && hasSlicing!Range)
version(unittest)
{
import std.algorithm.searching;
assert(r[0 .. lo].all!(x => x <= p));
assert(r[hi + 1 .. $].all!(x => x >= p));
assert(r[0 .. lo].all!(x => !lt(p, x)));
assert(r[hi + 1 .. $].all!(x => !lt(x, p)));
}
do ++lo; while (lt(r[lo], p));
r[hi] = r[lo];
Expand Down Expand Up @@ -742,9 +742,15 @@ if (isRandomAccessRange!Range && hasLength!Range && hasSlicing!Range)
assert(pivot == 0 || pivot == 1);
assert(a == [ 42, 42 ]);

import std.random : uniform;
import std.random;
import std.algorithm.iteration : map;
a = iota(0, uniform(1, 1000)).map!(_ => uniform(-1000, 1000)).array;
import std.stdio;
auto s = unpredictableSeed;
auto g = Random(s);
a = iota(0, uniform(1, 1000, g))
.map!(_ => uniform(-1000, 1000, g))
.array;
scope(failure) writeln("RNG seed was ", s);
pivot = pivotPartition!less(a, a.length / 2);
assert(a[0 .. pivot].all!(x => x <= a[pivot]));
assert(a[pivot .. $].all!(x => x >= a[pivot]));
Expand Down Expand Up @@ -2934,52 +2940,125 @@ auto topN(alias less = "a < b",
{
static assert(ss == SwapStrategy.unstable,
"Stable topN not yet implemented");

if (nth >= r.length) return r[0 .. r.length];

auto ret = r[0 .. nth];
if (false)
{
// Workaround for https://issues.dlang.org/show_bug.cgi?id=16528
// Safety checks: enumerate all potentially unsafe generic primitives
// then use a @trusted implementation.
r = r[0 .. $];
r = r[0 .. $ - 1];
auto b = binaryFun!less(r[0], r[$ - 1]);
import std.algorithm.mutation : swapAt;
r.swapAt(size_t(0), size_t(0));
auto len = r.length;
static assert(is(typeof(len) == size_t));
pivotPartition!less(r, 0);
}
bool useSampling = true;
topNImpl!(binaryFun!less)(r, nth, useSampling);
return ret;
}

private @trusted
void topNImpl(alias less, R)(R r, size_t n, ref bool useSampling)
{
for (;;)
{
assert(nth < r.length);
import std.algorithm.mutation : swap;
import std.algorithm.searching : minPos;
if (nth == 0)
import std.algorithm.mutation : swapAt;
assert(n < r.length);
size_t pivot = void;

// Decide strategy for partitioning
if (n == 0)
{
// Special-case "min"
swap(r.front, r.minPos!less.front);
break;
pivot = 0;
foreach (i; 1 .. r.length)
if (less(r[i], r[pivot])) pivot = i;
r.swapAt(n, pivot);
return;
}
if (nth + 1 == r.length)
if (n + 1 == r.length)
{
// Special-case "max"
swap(r.back, r.minPos!((a, b) => binaryFun!less(b, a)).front);
break;
pivot = 0;
foreach (i; 1 .. r.length)
if (less(r[pivot], r[i])) pivot = i;
r.swapAt(n, pivot);
return;
}
if (r.length <= 12)
{
pivot = pivotPartition!less(r, r.length / 2);
}
else if (n * 16 <= (r.length - 1) * 7)
{
pivot = topNPartitionOffMedian!(less, No.leanRight)
(r, n, useSampling);
// Quality check
if (useSampling)
{
if (pivot < n)
{
if (pivot * 4 < r.length)
{
useSampling = false;
}
}
else if ((r.length - pivot) * 8 < r.length * 3)
{
useSampling = false;
}
}
}
else if (n * 16 >= (r.length - 1) * 9)
{
pivot = topNPartitionOffMedian!(less, Yes.leanRight)
(r, n, useSampling);
// Quality check
if (useSampling)
{
if (pivot < n)
{
if (pivot * 8 < r.length * 3)
{
useSampling = false;
}
}
else if ((r.length - pivot) * 4 < r.length)
{
useSampling = false;
}
}
}
else
{
pivot = topNPartition!less(r, n, useSampling);
// Quality check
if (useSampling &&
(pivot * 9 < r.length * 2 || pivot * 9 > r.length * 7))
{
// Failed - abort sampling going forward
useSampling = false;
}
}

assert(pivot != size_t.max);
// See how the pivot fares
if (pivot == n)
{
return;
}
auto pivot = r.getPivot!less;
assert(!binaryFun!less(r[pivot], r[pivot]));
swap(r[pivot], r.back);
auto right = r.partition!(a => binaryFun!less(a, r.back), ss);
assert(right.length >= 1);
pivot = r.length - right.length;
if (pivot > nth)
if (pivot > n)
{
// We don't care to swap the pivot back, won't be visited anymore
assert(pivot < r.length);
r = r[0 .. pivot];
continue;
}
// Swap the pivot to where it should be
swap(right.front, r.back);
if (pivot == nth)
else
{
// Found Waldo!
break;
n -= pivot + 1;
r = r[pivot + 1 .. $];
}
++pivot; // skip the pivot
r = r[pivot .. r.length];
nth -= pivot;
}
return ret;
}

///
Expand All @@ -2993,6 +3072,223 @@ auto topN(alias less = "a < b",
assert(v[n] == 9);
}

private size_t topNPartition(alias lp, R)(R r, size_t n, bool useSampling)
{
assert(r.length >= 9 && n < r.length);
immutable ninth = r.length / 9;
auto pivot = ninth / 2;
// Position subrange r[lo .. hi] to have length equal to ninth and its upper
// median r[lo .. hi][$ / 2] in exactly the same place as the upper median
// of the entire range r[$ / 2]. This is to improve behavior for searching
// the median in already sorted ranges.
immutable lo = r.length / 2 - pivot, hi = lo + ninth;
// We have either one straggler on the left, one on the right, or none.
assert(lo - (r.length - hi) <= 1 || (r.length - hi) - lo <= 1);
assert(lo >= ninth * 4);
assert(r.length - hi >= ninth * 4);

// Partition in groups of 3, and the mid tertile again in groups of 3
if (!useSampling)
p3!lp(r, lo - ninth, hi + ninth);
p3!lp(r, lo, hi);

// Get the median of medians of medians
// Map the full interval of n to the full interval of the ninth
pivot = (n * (ninth - 1)) / (r.length - 1);
topNImpl!lp(r[lo .. hi], pivot, useSampling);
return expandPartition!lp(r, lo, pivot + lo, hi);
}

private void p3(alias less, Range)(Range r, size_t lo, immutable size_t hi)
{
assert(lo <= hi && hi < r.length);
immutable ln = hi - lo;
for (; lo < hi; ++lo)
{
assert(lo >= ln);
assert(lo + ln < r.length);
medianOf!less(r, lo - ln, lo, lo + ln);
}
}

private void p4(alias less, Flag!"leanRight" f, Range)
(Range r, size_t lo, immutable size_t hi)
{
assert(lo <= hi && hi < r.length);
immutable ln = hi - lo, _2ln = ln * 2;
for (; lo < hi; ++lo)
{
assert(lo >= ln);
assert(lo + ln < r.length);
static if (f == Yes.leanRight)
medianOf!(less, f)(r, lo - _2ln, lo - ln, lo, lo + ln);
else
medianOf!(less, f)(r, lo - ln, lo, lo + ln, lo + _2ln);
}
}

private size_t topNPartitionOffMedian(alias lp, Flag!"leanRight" f, R)
(R r, size_t n, bool useSampling)
{
assert(r.length >= 12);
assert(n < r.length);
immutable _4 = r.length / 4;
static if (f == Yes.leanRight)
immutable leftLimit = 2 * _4;
else
immutable leftLimit = _4;
// Partition in groups of 4, and the left quartile again in groups of 3
if (!useSampling)
{
p4!(lp, f)(r, leftLimit, leftLimit + _4);
}
immutable _12 = _4 / 3;
immutable lo = leftLimit + _12, hi = lo + _12;
p3!lp(r, lo, hi);

// Get the median of medians of medians
// Map the full interval of n to the full interval of the ninth
immutable pivot = (n * (_12 - 1)) / (r.length - 1);
topNImpl!lp(r[lo .. hi], pivot, useSampling);
return expandPartition!lp(r, lo, pivot + lo, hi);
}

/*
Params:
less = predicate
r = range to partition
pivot = pivot to partition around
lo = value such that r[lo .. pivot] already less than r[pivot]
hi = value such that r[pivot .. hi] already greater than r[pivot]
Returns: new position of pivot
*/
private
size_t expandPartition(alias lp, R)(R r, size_t lo, size_t pivot, size_t hi)
in
{
import std.algorithm.searching : all;
assert(lo <= pivot);
assert(pivot < hi);
assert(hi <= r.length);
assert(r[lo .. pivot + 1].all!(x => !lp(r[pivot], x)));
assert(r[pivot + 1 .. hi].all!(x => !lp(x, r[pivot])));
}
out
{
import std.algorithm.searching : all;
assert(r[0 .. pivot + 1].all!(x => !lp(r[pivot], x)));
assert(r[pivot + 1 .. $].all!(x => !lp(x, r[pivot])));
}
body
{
import std.algorithm.mutation : swapAt;
import std.algorithm.searching : all;
// We work with closed intervals!
--hi;

size_t left = 0, rite = r.length - 1;
loop: for (;; ++left, --rite)
{
for (;; ++left)
{
if (left == lo) break loop;
if (!lp(r[left], r[pivot])) break;
}
for (;; --rite)
{
if (rite == hi) break loop;
if (!lp(r[pivot], r[rite])) break;
}
r.swapAt(left, rite);
}

assert(r[lo .. pivot + 1].all!(x => !lp(r[pivot], x)));
assert(r[pivot + 1 .. hi + 1].all!(x => !lp(x, r[pivot])));
assert(r[0 .. left].all!(x => !lp(r[pivot], x)));
assert(r[rite + 1 .. $].all!(x => !lp(x, r[pivot])));

immutable oldPivot = pivot;

if (left < lo)
{
// First loop: spend r[lo .. pivot]
for (; lo < pivot; ++left)
{
if (left == lo) goto done;
if (!lp(r[oldPivot], r[left])) continue;
--pivot;
assert(!lp(r[oldPivot], r[pivot]));
r.swapAt(left, pivot);
}
// Second loop: make left and pivot meet
for (;; ++left)
{
if (left == pivot) goto done;
if (!lp(r[oldPivot], r[left])) continue;
for (;;)
{
if (left == pivot) goto done;
--pivot;
if (lp(r[pivot], r[oldPivot]))
{
r.swapAt(left, pivot);
break;
}
}
}
}

// First loop: spend r[lo .. pivot]
for (; hi != pivot; --rite)
{
if (rite == hi) goto done;
if (!lp(r[rite], r[oldPivot])) continue;
++pivot;
assert(!lp(r[pivot], r[oldPivot]));
r.swapAt(rite, pivot);
}
// Second loop: make left and pivot meet
outer: for (; rite > pivot; --rite)
{
if (!lp(r[rite], r[oldPivot])) continue;
while (rite > pivot)
{
++pivot;
if (lp(r[oldPivot], r[pivot]))
{
r.swapAt(rite, pivot);
break;
}
}
}

done:
r.swapAt(oldPivot, pivot);
return pivot;
}

unittest
{
auto a = [ 10, 5, 3, 4, 8, 11, 13, 3, 9, 4, 10 ];
assert(expandPartition!((a, b) => a < b)(a, 4, 5, 6) == 9);
a = randomArray;
if (a.length == 0) return;
expandPartition!((a, b) => a < b)(a, a.length / 2, a.length / 2,
a.length / 2 + 1);
}

version(unittest)
private T[] randomArray(Flag!"exactSize" flag = No.exactSize, T = int)(
size_t maxSize = 1000,
T minValue = 0, T maxValue = 255)
{
import std.random : unpredictableSeed, Random, uniform;
import std.algorithm.iteration : map;
auto size = flag == Yes.exactSize ? maxSize : uniform(1, maxSize);
return iota(0, size).map!(_ => uniform(minValue, maxValue)).array;
}

@safe unittest
{
import std.algorithm.comparison : max, min;
Expand Down