Skip to content

Commit

Permalink
Fix Issue 20147 - Enable comparison (==, >, >=, <=, <) between std.bi…
Browse files Browse the repository at this point in the history
…gint.BigInt and floating point numbers
  • Loading branch information
n8sh committed Sep 4, 2019
1 parent aee204b commit f1e4a0e
Showing 1 changed file with 222 additions and 31 deletions.
253 changes: 222 additions & 31 deletions std/bigint.d
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import std.format : FormatSpec, FormatException;
import std.internal.math.biguintcore;
import std.range.primitives;
import std.traits;
import std.math : FloatingPointControl;

/** A struct representing an arbitrary precision integer.
*
Expand Down Expand Up @@ -644,22 +643,30 @@ public:

/**
Implements `BigInt` equality test with other `BigInt`'s and built-in
integer types.
numeric types.
*/
bool opEquals()(auto ref const BigInt y) const pure @nogc
{
return sign == y.sign && y.data == data;
}

/// ditto
bool opEquals(T)(T y) const pure nothrow @nogc if (isIntegral!T)
bool opEquals(T)(const T y) const pure nothrow @nogc if (isIntegral!T)
{
if (sign != (y<0))
return 0;
return data.opEquals(cast(ulong) absUnsign(y));
}

///
/// ditto
bool opEquals(T)(const T y) const nothrow @nogc if (isFloatingPoint!T)
{
// This is a separate function from the isIntegral!T case
// due to the impurity of std.math.scalbn which is used
// for 80 bit floats.
return 0 == opCmp(y);
}

@system unittest
{
auto x = BigInt("12345");
Expand All @@ -674,6 +681,31 @@ public:
assert(x != w);
}

@system unittest
{
import std.math : nextDown, nextUp;

const x = BigInt("0x1abc_de80_0000_0000_0000_0000_0000_0000");
BigInt x1 = x + 1;
BigInt x2 = x - 1;

const d = 0x1.abcde8p124;
assert(x == d);
assert(x1 != d);
assert(x2 != d);
assert(x != nextUp(d));
assert(x != nextDown(d));
assert(x != double.nan);

const dL = 0x1.abcde8p124L;
assert(x == dL);
assert(x1 != dL);
assert(x2 != dL);
assert(x != nextUp(dL));
assert(x != nextDown(dL));
assert(x != real.nan);
}

/**
Implements casting to `bool`.
*/
Expand Down Expand Up @@ -780,18 +812,32 @@ public:
/**
Implements casting to floating point types.
*/
T opCast(T, FloatingPointControl.RoundingMode roundingMode = FloatingPointControl.roundToNearest)()
@safe nothrow @nogc const if (isFloatingPoint!T && (roundingMode == FloatingPointControl.roundToZero ||
roundingMode == FloatingPointControl.roundToNearest))
T opCast(T)() @safe nothrow @nogc const if (isFloatingPoint!T)
{
return _toFloat!(T, "defaultRounding");
}

// Common code use in opCast!Floating and opCmp!Floating.
private T _toFloat(T, string roundingMode)() @safe nothrow @nogc const
if (__traits(isFloating, T) && (roundingMode == "defaultRounding" || roundingMode == "truncate"))
{
import core.bitop : bsr;
enum bool performRounding = (roundingMode != FloatingPointControl.roundToZero);
enum performRounding = roundingMode == "defaultRounding";
enum performTruncation = roundingMode == "truncate";
static assert(performRounding || performTruncation);
enum int totalNeededBits = T.mant_dig + int(performRounding);

static if (totalNeededBits <= 64)
version (linux)
enum forceSlowPath = false;
else version (Android)
enum forceSlowPath = false;
else
// Win_32 and FreeBSD_32 on the test suite but
// maybe others as well.
enum forceSlowPath = size_t.sizeof == 4;
static if (totalNeededBits <= 64 && !forceSlowPath)
{
// We need to examine the top two 64-bit words, not just the top one,
// since the top word could have just a single bit set.
// since the top word could have just a single significant bit.
const ulongLength = data.ulongLength;
const ulong w1 = data.peekUlong(ulongLength - 1);
const ulong w2 = ulongLength < 2 ? 0 : data.peekUlong(ulongLength - 2);
Expand All @@ -800,9 +846,20 @@ public:
size_t exponent = (ulongLength - 1) * 64 + w1BitCount + 1;
static if (performRounding)
{
sansExponent += 1UL << (64 - totalNeededBits);
// We know the high bit is 1 so we can detect overflow by testing
// if it is zero post addition.
// If we wanted to round ties away from 0 we could just add
// roundUpInc to sansExponent. The more complicated logic is
// so ties are rounded to even, which matches the default
// rounding behavior when casting int/long to float/double.
enum roundUpInc = 1UL << (64 - totalNeededBits);
if (sansExponent & (roundUpInc - 1))
sansExponent += roundUpInc;
else if (0 != (sansExponent & (roundUpInc << 1)))
foreach (i; 0 .. ulongLength - 2)
if (data.peekUlong(i))
{
sansExponent += roundUpInc;
break;
}
if (0 <= cast(long) sansExponent)
{
// Don't bother filling in the high bit with 1.
Expand Down Expand Up @@ -836,11 +893,23 @@ public:
cast(int) exponent - 65);
}
}
else static if (performRounding)
{
// FIXME: this is a naive implementation.
import std.math : scalbn;
const ulongLength = data.ulongLength;
if ((ulongLength - 1) * 64L > int.max)
return isNegative ? -T.infinity : T.infinity;
real acc = data.peekUlong(0);
foreach (int i; 1 .. cast(int) ulongLength)
acc += scalbn(data.peekUlong(i), i * 64);
if (isNegative)
acc = -acc;
return acc;
}
else
{
// Quadruple-precision or greater floating point.
import std.math : scalbn;

const ulongLength = data.ulongLength;
if ((ulongLength - 1) * 64L > int.max)
return isNegative ? -T.infinity : T.infinity;
Expand All @@ -851,26 +920,17 @@ public:
for (ptrdiff_t i = ulongLength - 2; i >= 0 && bitsStillNeeded > 0; i--)
{
ulong w = data.peekUlong(i);
static if (performRounding)
// To round towards zero we must make sure not to use too many bits.
if (bitsStillNeeded >= 64)
{
acc += scalbn(w, scale -= 64);
bitsStillNeeded -= 64;
}
else
{
// To round towards zero we must make sure not to use too many bits.
if (bitsStillNeeded >= 64)
{
acc += scalbn(w, scale -= 64);
bitsStillNeeded -= 64;
}
else
{

w = (w >>> (64 - bitsStillNeeded)) << (64 - bitsStillNeeded);
acc += scalbn(w, scale -= 64);
break;
}
w = (w >>> (64 - bitsStillNeeded)) << (64 - bitsStillNeeded);
acc += scalbn(w, scale -= 64);
break;
}
}
if (isNegative)
Expand Down Expand Up @@ -906,6 +966,67 @@ public:
assert(-123.0f == cast(float) BigInt(-123));
}

@system unittest
{
// Test that casts from BigInt to float/double are rounded the
// same as casts from uint/ulong to float/double.
foreach (x; [123456789, 123456432])
assert(cast(float) BigInt(x) == cast(float) uint(x));
foreach (x; [0x1111_1111_1111_117UL, 0x1111_1111_1111_118UL])
assert(cast(double) BigInt(x) == cast(double) ulong(x));
// Test that casting to float/double/real rounds towards the
// nearest value.
import std.meta : AliasSeq;
bool[3] expectedRounding;
bool first = true;
static foreach(F; AliasSeq!(float, double, real))
{{
BigInt a = BigInt(1UL) << (F.mant_dig + 1);
BigInt b = a + 1;
BigInt c = a + 2;
BigInt d = a + 3;
assert(cast(F) a <= cast(F) b);
assert(cast(F) b <= cast(F) c);
assert(cast(F) c <= cast(F) d);
assert(cast(F) a < cast(F) d);
// Verify that ties are rounded consistently
// for float/double/real.
if (first)
{
expectedRounding[0] = (cast(F) a < cast(F) b);
expectedRounding[1] = (cast(F) b < cast(F) c);
expectedRounding[2] = (cast(F) c < cast(F) d);
first = false;
}
else
{
assert(expectedRounding[0] == (cast(F) a < cast(F) b));
assert(expectedRounding[1] == (cast(F) b < cast(F) c));
assert(expectedRounding[2] == (cast(F) c < cast(F) d));
}
// Verify that casting from BigInt produces the
// same result as casting from ulong if the BigInt
// is exactly representable as a ulong.
if (d <= ulong.max)
{
assert(cast(F) a == cast(F) (cast(ulong) a));
assert(cast(F) b == cast(F) (cast(ulong) b));
assert(cast(F) c == cast(F) (cast(ulong) c));
assert(cast(F) d == cast(F) (cast(ulong) d));
}
// An earlier version of the code special-cased values of
// of magnitude less than 2^^64. Check that values of
// larger magnitude have the same rounding rules.
BigInt a2 = a << 64;
BigInt b2 = b << 64;
BigInt c2 = c << 64;
BigInt d2 = d << 64;
assert(expectedRounding[0] == (cast(F) a2 < cast(F) b2));
assert(expectedRounding[1] == (cast(F) b2 < cast(F) c2));
assert(expectedRounding[2] == (cast(F) c2 < cast(F) d2));
}}
}

/**
Implements casting to/from qualified `BigInt`'s.
Expand Down Expand Up @@ -940,14 +1061,40 @@ public:
}

/// ditto
int opCmp(T)(T y) pure nothrow @nogc const if (isIntegral!T)
int opCmp(T)(const T y) pure nothrow @nogc const if (isIntegral!T)
{
if (sign != (y<0) )
return sign ? -1 : 1;
int cmp = data.opCmp(cast(ulong) absUnsign(y));
return sign? -cmp: cmp;
}
/// ditto
int opCmp(T)(const T y) nothrow @nogc const if (isFloatingPoint!T)
{
import core.bitop : bsr;
import std.math : cmp;

const asFloat = _toFloat!(T, "truncate");
if (const c = cmp(asFloat, y)) // Handles +/- infinity and NaN.
return c;
const ulongLength = data.ulongLength;
const w1 = data.peekUlong(ulongLength - 1);
const numSignificantBits = (ulongLength - 1) * 64 + bsr(w1) + 1;
for (ptrdiff_t bitsRemainingToCheck = numSignificantBits - T.mant_dig, i = 0;
bitsRemainingToCheck > 0; i++, bitsRemainingToCheck -= 64)
{
auto word = data.peekUlong(i);
if (word == 0)
continue;
// Make sure we're only checking digits that are beyond
// the precision of `y`.
if (bitsRemainingToCheck < 64 && (word << (64 - bitsRemainingToCheck)) == 0)
break; // This can only happen on the last loop iteration.
return isNegative ? -1 : 1;
}
return 0;
}
/// ditto
int opCmp(T:BigInt)(const T y) pure nothrow @nogc const
{
if (sign != y.sign)
Expand All @@ -970,6 +1117,50 @@ public:
assert(x < w);
}

///
@system unittest
{
auto x = BigInt("0x1abc_de80_0000_0000_0000_0000_0000_0000");
BigInt y = x - 1;
BigInt z = x + 1;

double d = 0x1.abcde8p124;
assert(y < d);
assert(z > d);
assert(x >= d && x <= d);
}

@system unittest
{
// Test that rounding of opCast!real is as opCmp!real expects.
auto x = BigInt("0x1abc_de80_0000_0000_0000_0000_0000_0000");
BigInt y = x - 1;
BigInt z = x + 1;

real d = 0x1.abcde8p124;
assert(y < d);
assert(z > d);
assert(x >= d && x <= d);

// Test comparison for numbers of 64 bits or fewer.
auto w1 = BigInt(0x1abc_de80_0000_0000);
auto w2 = w1 - 1;
auto w3 = w1 + 1;
assert(w1.ulongLength == 1);
assert(w2.ulongLength == 1);
assert(w3.ulongLength == 1);

double e = 0x1.abcde8p+60;
assert(w1 >= e && w1 <= e);
assert(w2 < e);
assert(w3 > e);

real eL = 0x1.abcde8p+60;
assert(w1 >= eL && w1 <= eL);
assert(w2 < eL);
assert(w3 > eL);
}

/**
Returns: The value of this `BigInt` as a `long`, or `long.max`/`long.min`
if outside the representable range.
Expand Down

0 comments on commit f1e4a0e

Please sign in to comment.