Skip to content

Commit

Permalink
Fix Issue 18280 - std.algorithm.comparison.cmp for non-strings should…
Browse files Browse the repository at this point in the history
… call opCmp only once per item pair

split cmp into two overloads per @andralex

#6056 (review)

Minor adjustments, again

cmp should return auto and let opCmp drive

#6056 (comment)

Fix Issue 18285 - std.algorithm.comparison.cmp for strings with custom predicate compares lengths wrong

Test std.algorithm.comparison.cmp when opCmp returns float

Promotions should not use cast

Optimize cmp's endgame

There are some redundant tests when the end of the ranges is reached. Eliminated that, and improved threeWayByPred.

Fix Issue 18286 - std.algorithm.comparison.cmp for string with custom predicate fails if distinct chars can compare equal

Fix Issue 18288 - std.algorithm.comparison.cmp for wide strings should be @safe

re-apply remove cast in promotions
  • Loading branch information
n8sh committed Jan 24, 2018
1 parent d11a3d0 commit 7f59e5a
Showing 1 changed file with 206 additions and 42 deletions.
248 changes: 206 additions & 42 deletions std/algorithm/comparison.d
Expand Up @@ -580,79 +580,93 @@ do

// cmp
/**********************************
Performs three-way lexicographical comparison on two
$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives)
according to predicate `pred`. Iterating `r1` and `r2` in
lockstep, `cmp` compares each element `e1` of `r1` with the
corresponding element `e2` in `r2`. If one of the ranges has been
finished, `cmp` returns a negative value if `r1` has fewer
elements than `r2`, a positive value if `r1` has more elements
than `r2`, and `0` if the ranges have the same number of
elements.
Performs a lexicographical comparison on two
$(REF_ALTTEXT input ranges, isInputRange, std,range,primitives).
Iterating `r1` and `r2` in lockstep, `cmp` compares each element
`e1` of `r1` with the corresponding element `e2` in `r2`. If one
of the ranges has been finished, `cmp` returns a negative value
if `r1` has fewer elements than `r2`, a positive value if `r1`
has more elements than `r2`, and `0` if the ranges have the same
number of elements.
If the ranges are strings, `cmp` performs UTF decoding
appropriately and compares the ranges one code point at a time.
A custom predicate may be specified, in which case `cmp` performs
a three-way lexicographical comparison using `pred`. Otherwise
the elements are compared using `opCmp`.
Params:
pred = The predicate used for comparison.
pred = Predicate used for comparison. Without a predicate
specified the ordering implied by `opCmp` is used.
r1 = The first range.
r2 = The second range.
Returns:
0 if both ranges compare equal. -1 if the first differing element of $(D
r1) is less than the corresponding element of `r2` according to $(D
pred). 1 if the first differing element of `r2` is less than the
corresponding element of `r1` according to `pred`.
`0` if the ranges compare equal. A negative value if `r1` is a prefix of `r2` or
the first differing element of `r1` is less than the corresponding element of `r2`
according to `pred`. A positive value if `r2` is a prefix of `r1` or the first
differing element of `r2` is less than the corresponding element of `r1`
according to `pred`.
Note:
An earlier version of the documentation incorrectly stated that `-1` is the
only negative value returned and `1` is the only positive value returned.
Whether that is true depends on the types being compared.
*/
int cmp(alias pred = "a < b", R1, R2)(R1 r1, R2 r2)
auto cmp(R1, R2)(R1 r1, R2 r2)
if (isInputRange!R1 && isInputRange!R2)
{
static if (!(isSomeString!R1 && isSomeString!R2))
{
for (;; r1.popFront(), r2.popFront())
{
if (r1.empty) return -cast(int)!r2.empty;
if (r2.empty) return !r1.empty;
auto a = r1.front, b = r2.front;
if (binaryFun!pred(a, b)) return -1;
if (binaryFun!pred(b, a)) return 1;
static if (is(typeof(r1.front.opCmp(r2.front)) R))
alias Result = R;
else
alias Result = int;
if (r2.empty) return Result(!r1.empty);
if (r1.empty) return Result(-1);
static if (is(typeof(r1.front.opCmp(r2.front))))
{
auto c = r1.front.opCmp(r2.front);
if (c != 0) return c;
}
else
{
auto a = r1.front, b = r2.front;
if (a < b) return -1;
if (b < a) return 1;
}
}
}
else
{
import core.stdc.string : memcmp;
import std.utf : decode;

static if (is(typeof(pred) : string))
enum isLessThan = pred == "a < b";
else
enum isLessThan = false;

// For speed only
static int threeWay(size_t a, size_t b)
{
static if (size_t.sizeof == int.sizeof && isLessThan)
static if (size_t.sizeof == int.sizeof)
return a - b;
else
return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0;
// Faster than return b < a ? 1 : a < b ? -1 : 0;
return (a > b) - (a < b);
}
// For speed only
// @@@BUG@@@ overloading should be allowed for nested functions
static int threeWayInt(int a, int b)
{
static if (isLessThan)
return a - b;
else
return binaryFun!pred(b, a) ? 1 : binaryFun!pred(a, b) ? -1 : 0;
return a - b;
}

static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof && isLessThan)
static if (typeof(r1[0]).sizeof == typeof(r2[0]).sizeof)
{
static if (typeof(r1[0]).sizeof == 1)
{
immutable len = min(r1.length, r2.length);
immutable result = __ctfe ?
int result = __ctfe ?
{
foreach (i; 0 .. len)
{
Expand All @@ -663,17 +677,21 @@ if (isInputRange!R1 && isInputRange!R2)
}()
: () @trusted { return memcmp(r1.ptr, r2.ptr, len); }();
if (result) return result;
return threeWay(r1.length, r2.length);
}
else
{
auto p1 = r1.ptr, p2 = r2.ptr,
pEnd = p1 + min(r1.length, r2.length);
for (; p1 != pEnd; ++p1, ++p2)
return () @trusted
{
if (*p1 != *p2) return threeWayInt(cast(int) *p1, cast(int) *p2);
}
auto p1 = r1.ptr, p2 = r2.ptr,
pEnd = p1 + min(r1.length, r2.length);
for (; p1 != pEnd; ++p1, ++p2)
{
if (*p1 != *p2) return threeWayInt(int(*p1), int(*p2));
}
return threeWay(r1.length, r2.length);
}();
}
return threeWay(r1.length, r2.length);
}
else
{
Expand All @@ -683,14 +701,58 @@ if (isInputRange!R1 && isInputRange!R2)
if (i2 == r2.length) return threeWay(r1.length, i1);
immutable c1 = decode(r1, i1),
c2 = decode(r2, i2);
if (c1 != c2) return threeWayInt(cast(int) c1, cast(int) c2);
if (c1 != c2) return threeWayInt(int(c1), int(c2));
}
}
}
}

/// ditto
int cmp(alias pred, R1, R2)(R1 r1, R2 r2)
if (isInputRange!R1 && isInputRange!R2)
{
static if (!(isSomeString!R1 && isSomeString!R2))
{
for (;; r1.popFront(), r2.popFront())
{
if (r2.empty) return !r1.empty;
if (r1.empty) return -1;
auto a = r1.front, b = r2.front;
if (binaryFun!pred(a, b)) return -1;
if (binaryFun!pred(b, a)) return 1;
}
}
else
{
import std.utf : decode;

// For speed only
static int threeWayCompareLength(size_t a, size_t b)
{
static if (size_t.sizeof == int.sizeof)
return a - b;
else
// Faster than return b < a ? 1 : a < b ? -1 : 0;
return (a > b) - (a < b);
}

for (size_t i1, i2;;)
{
if (i1 == r1.length) return threeWayCompareLength(i2, r2.length);
if (i2 == r2.length) return threeWayCompareLength(r1.length, i1);
immutable c1 = decode(r1, i1),
c2 = decode(r2, i2);
if (c1 != c2)
{
if (binaryFun!pred(c2, c1)) return 1;
if (binaryFun!pred(c1, c2)) return -1;
}
}
}
}

///
@safe unittest
pure @safe unittest
{
int result;

Expand All @@ -712,6 +774,8 @@ if (isInputRange!R1 && isInputRange!R2)
assert(result > 0);
result = cmp("aaa", "aaa"d);
assert(result == 0);
result = cmp("aaa"d, "aaa"d);
assert(result == 0);
result = cmp(cast(int[])[], cast(int[])[]);
assert(result == 0);
result = cmp([1, 2, 3], [1, 2, 3]);
Expand All @@ -724,6 +788,106 @@ if (isInputRange!R1 && isInputRange!R2)
assert(result > 0);
}

/// Example predicate that compares individual elements in reverse lexical order
pure @safe unittest
{
int result;

result = cmp!"a > b"("abc", "abc");
assert(result == 0);
result = cmp!"a > b"("", "");
assert(result == 0);
result = cmp!"a > b"("abc", "abcd");
assert(result < 0);
result = cmp!"a > b"("abcd", "abc");
assert(result > 0);
result = cmp!"a > b"("abc"d, "abd");
assert(result > 0);
result = cmp!"a > b"("bbc", "abc"w);
assert(result < 0);
result = cmp!"a > b"("aaa", "aaaa"d);
assert(result < 0);
result = cmp!"a > b"("aaaa", "aaa"d);
assert(result > 0);
result = cmp!"a > b"("aaa", "aaa"d);
assert(result == 0);
result = cmp("aaa"d, "aaa"d);
assert(result == 0);
result = cmp!"a > b"(cast(int[])[], cast(int[])[]);
assert(result == 0);
result = cmp!"a > b"([1, 2, 3], [1, 2, 3]);
assert(result == 0);
result = cmp!"a > b"([1, 3, 2], [1, 2, 3]);
assert(result < 0);
result = cmp!"a > b"([1, 2, 3], [1L, 2, 3, 4]);
assert(result < 0);
result = cmp!"a > b"([1L, 2, 3], [1, 2]);
assert(result > 0);
}

@nogc nothrow pure @safe unittest
{
// Issue 18286: cmp for string with custom predicate fails if distinct chars can compare equal
static bool ltCi(dchar a, dchar b)// less than, case insensitive
{
import std.ascii : toUpper;
return toUpper(a) < toUpper(b);
}
static assert(cmp!ltCi("apple2", "APPLE1") > 0);
static assert(cmp!ltCi("apple1", "APPLE2") < 0);
static assert(cmp!ltCi("apple", "APPLE1") < 0);
static assert(cmp!ltCi("APPLE", "apple1") < 0);
static assert(cmp!ltCi("apple", "APPLE") == 0);
}

@nogc nothrow @safe unittest
{
// Issue 18280: for non-string ranges check that opCmp is evaluated only once per pair.
static int ctr = 0;
struct S
{
int opCmp(ref const S rhs) const
{
++ctr;
return 0;
}
}
immutable S[4] a;
immutable S[4] b;
immutable result = cmp(a[], b[]);
assert(result == 0, "neither should compare greater than the other!");
assert(ctr == a.length, "opCmp should be called exactly once per pair of items!");
}

nothrow pure @safe unittest
{
// Test cmp when opCmp returns float.
struct F
{
float value;
float opCmp(const ref F rhs) const
{
return value - rhs.value;
}
}
auto result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3)]);
assert(result == 0);
assert(is(typeof(result) == float));
result = cmp([F(1), F(3), F(2)], [F(1), F(2), F(3)]);
assert(result > 0);
result = cmp([F(1), F(2), F(3)], [F(1), F(2), F(3), F(4)]);
assert(result < 0);
result = cmp([F(1), F(2), F(3)], [F(1), F(2)]);
assert(result > 0);
}

nothrow pure @safe unittest
{
// Parallelism (was broken by inferred return type "immutable int")
import std.parallelism : task;
auto t = task!cmp("foo", "bar");
}

// equal
/**
Compares two ranges for equality, as defined by predicate `pred`
Expand Down

0 comments on commit 7f59e5a

Please sign in to comment.