diff --git a/sql/parser/builtins.go b/sql/parser/builtins.go index 1c66efc6f3f5..6b56a8ce831f 100644 --- a/sql/parser/builtins.go +++ b/sql/parser/builtins.go @@ -46,12 +46,13 @@ import ( ) var ( - errEmptyInputString = errors.New("the input string must not be empty") - errAbsOfMinInt64 = errors.New("abs of min integer value (-9223372036854775808) not defined") - errRoundNumberDigits = errors.New("number of digits must be greater than 0") - errSqrtOfNegNumber = errors.New("cannot take square root of a negative number") - errLogOfNegNumber = errors.New("cannot take logarithm of a negative number") - errLogOfZero = errors.New("cannot take logarithm of zero") + errEmptyInputString = errors.New("the input string must not be empty") + errAbsOfMinInt64 = errors.New("abs of min integer value (-9223372036854775808) not defined") + errRoundTooLow = errors.New("rounding would extend value by more than 2000 decimal digits") + errArgTooBig = errors.New("argument value is too large") + errSqrtOfNegNumber = errors.New("cannot take square root of a negative number") + errLogOfNegNumber = errors.New("cannot take logarithm of a negative number") + errLogOfZero = errors.New("cannot take logarithm of zero") ) const ( @@ -815,9 +816,7 @@ var Builtins = map[string][]Builtin{ return NewDFloat(DFloat(math.Exp(x))), nil }), decimalBuiltin1(func(x *inf.Dec) (Datum, error) { - dd := &DDecimal{} - decimal.Exp(&dd.Dec, x, decimal.Precision) - return dd, nil + return expDecimal(x) }), }, @@ -896,9 +895,7 @@ var Builtins = map[string][]Builtin{ return round(x, 0) }), decimalBuiltin1(func(x *inf.Dec) (Datum, error) { - dd := &DDecimal{} - dd.Round(x, 0, inf.RoundHalfUp) - return dd, nil + return roundDecimal(x, 0) }), Builtin{ Types: ArgTypes{TypeFloat, TypeInt}, @@ -911,10 +908,8 @@ var Builtins = map[string][]Builtin{ Types: ArgTypes{TypeDecimal, TypeInt}, ReturnType: TypeDecimal, fn: func(_ *EvalContext, args DTuple) (Datum, error) { - dec := &args[0].(*DDecimal).Dec - dd := &DDecimal{} - dd.Round(dec, inf.Scale(*args[1].(*DInt)), inf.RoundHalfUp) - return dd, nil + scale := int64(*args[1].(*DInt)) + return roundDecimal(&args[0].(*DDecimal).Dec, scale) }, }, }, @@ -1405,32 +1400,83 @@ func overlay(s, to string, pos, size int) (Datum, error) { } func round(x float64, n int64) (Datum, error) { - switch { - case n < 0: - return DNull, errRoundNumberDigits - case n > 323: - // When rounding to more than 323 digits after the decimal - // point, the original number is returned, because rounding has - // no effect at scales smaller than 1e-323. - // - // 323 is the sum of - // - // 15, the maximum number of significant digits in a decimal - // string that can be converted to the IEEE 754 double precision - // representation and back to a string that matches the - // original; and - // - // 308, the largest exponent. The significant digits can be - // right shifted by 308 positions at most, by setting the - // exponent to -308. + pow := math.Pow(10, float64(n)) + + if pow == 0 { + // Rounding to so many digits on the left that we're underflowing. + // Avoid a NaN below. + return NewDFloat(DFloat(0)), nil + } + if math.Abs(x*pow) > 1e17 { + // Rounding touches decimals below float precision; the operation + // is a no-op. return NewDFloat(DFloat(x)), nil } - const b = 64 - y, err := strconv.ParseFloat(strconv.FormatFloat(x, 'f', int(n), b), b) - if err != nil { - panic(fmt.Sprintf("parsing a float that was just formatted failed: %s", err)) + + v, frac := math.Modf(x * pow) + // The following computation implements unbiased rounding, also + // called bankers' rounding. It ensures that values that fall + // exactly between two integers get equal chance to be rounded up or + // down. + if x > 0.0 { + if frac > 0.5 || (frac == 0.5 && uint64(v)%2 != 0) { + v += 1.0 + } + } else { + if frac < -0.5 || (frac == -0.5 && uint64(v)%2 != 0) { + v -= 1.0 + } + } + + return NewDFloat(DFloat(v / pow)), nil +} + +const ( + scaleRatio = math.Ln2 / math.Ln10 +) + +func roundDecimal(x *inf.Dec, n int64) (Datum, error) { + curScale := int64(x.Scale()) + + if n > curScale+2000 { + // If we let the decimal value grow too many decimals, the server + // could explode (#8633). + return nil, errRoundTooLow + } + + dd := &DDecimal{} + + // We use WordLen(Bits())*8 instead of UnscaledBig().BitLen() here + // as this is faster and we do not need an exact value for the + // optimization below. + upperCurDigits := encoding.WordLen(x.UnscaledBig().Bits()) * 8 + upperDigitsLeft := float64(curScale) - float64(upperCurDigits)*scaleRatio + if n < int64(upperDigitsLeft)-1 { + // This is an optimization. When the rounding scale is definitely + // larger than the number, the result is 0, so we avoid + // spending a lot of time in the division for nothing. + return dd, nil + } + dd.Round(x, inf.Scale(n), inf.RoundHalfEven) + return dd, nil +} + +func expDecimal(x *inf.Dec) (Datum, error) { + // The computation of Exp is separated in the decimal module by + // computing the exponents on the left and right of the decimal + // separator. The computation on the right is bounded by + // decimal.Precision already; however if the value is too large on + // the left the decimal value can grow too large in memory and slow + // down / crash the entire server. So we prevent this from happening + // and limit the argument to be ~1000 or less. + curDigits := x.UnscaledBig().BitLen() + binDigitsLeft := curDigits - int(float64(x.Scale())/scaleRatio) + if binDigitsLeft > 10 /* 1024 */ { + return nil, errArgTooBig } - return NewDFloat(DFloat(y)), nil + dd := &DDecimal{} + decimal.Exp(&dd.Dec, x, decimal.Precision) + return dd, nil } // Pick the greatest (or least value) from a tuple. diff --git a/sql/testdata/builtin_function b/sql/testdata/builtin_function index 13205442c26b..8df5e4f837b5 100644 --- a/sql/testdata/builtin_function +++ b/sql/testdata/builtin_function @@ -479,6 +479,9 @@ SELECT exp(-1.0::float), exp(1.0::float), exp(2.0::decimal) ---- 0.36787944117144233 2.718281828459045 7.3890560989306502 +query error argument value is too large +SELECT exp(2000::decimal) + query RRR SELECT floor(-1.5::float), floor(1.5::float), floor(9.123456789::decimal) ---- @@ -602,14 +605,8 @@ SELECT radians(-45.0), radians(45.0) ---- -0.7853981633974483 0.7853981633974483 -# Our implementation of round is not fully Postgres-compatible because we do -# not allow negative numbers of digits. - -query error round: number of digits must be greater than 0 -SELECT round(41.2::float, -1) - query RRR -SELECT round(4.2::float, 0), round(4.2::float, 50), round(4.22222222::decimal, 3) +SELECT round(4.2::float, 0), round(4.2::float, 10), round(4.22222222::decimal, 3) ---- 4 4.2 4.222 @@ -633,12 +630,17 @@ SELECT round(-2.5::float), round(-1.5::float), round(-0.0::float), round(0.0::fl query RRRR SELECT round(-2.5::decimal, 0), round(-1.5::decimal, 0), round(1.5::decimal, 0), round(2.5::decimal, 0) ---- --3 -2 2 3 +-2 -2 2 2 + +query RRRRR +SELECT round(-2.5::decimal, 3), round(-1.5::decimal, 3), round(0.0::decimal, 3), round(1.5::decimal, 3), round(2.5::decimal, 3) +---- +-2.500 -1.500 0.000 1.500 2.500 query RRRRR SELECT round(-2.5::decimal), round(-1.5::decimal), round(0.0::decimal), round(1.5::decimal), round(2.5::decimal) ---- --3 -2 0 2 3 +-2 -2 0 2 2 # Test rounding to 14 digits, because the logic test itself # formats floats rounded to 15 digits behind the decimal point. @@ -653,6 +655,39 @@ SELECT round(-1.7976931348623157e+308::float, 1), round(1.7976931348623157e+308: ---- -1.7976931348623157e+308 1.7976931348623157e+308 +query RR +SELECT round(-1.7976931348623157e+308::float, -303), round(1.7976931348623157e+308::float, -303) +---- +-1.797690000000001e+308 1.797690000000001e+308 + +query RR +SELECT round(-1.23456789e+308::float, -308), round(1.23456789e+308::float, -308) +---- +-1.0000000000000006e+308 1.0000000000000006e+308 + +query RR +SELECT round(-1.7976931348623157e-308::float, 1), round(1.7976931348623157e-308::float, 1) +---- +-0 0 + +query RRRR +SELECT 1.234567890123456789::float, round(1.234567890123456789::float,15), round(1.234567890123456789::float,16), round(1.234567890123456789::float,17); +---- +1.2345678901234567 1.234567890123457 1.2345678901234567 1.2345678901234567 + +query RRRR +SELECT round(123.456::float, -1), round(123.456::float, -2), round(123.456::float, -3), round(123.456::float, -2438602134409251682) +---- +120 100 0 0 + +query RRRR +SELECT round(123.456::decimal, -1), round(123.456::decimal, -2), round(123.456::decimal, -3), round(123.456::decimal, -2438602134409251682) +---- +120 100 0 0 + +query error rounding would extend value by more than 2000 decimal digits +SELECT round(1::decimal, 3000) + query III SELECT sign(-2), sign(0), sign(2) ---- diff --git a/util/encoding/decimal.go b/util/encoding/decimal.go index 8fcee282fa95..b438337b02f4 100644 --- a/util/encoding/decimal.go +++ b/util/encoding/decimal.go @@ -713,8 +713,8 @@ func UpperBoundNonsortingDecimalSize(d *inf.Dec) int { // Makeup of upper bound size: // - 1 byte for the prefix // - maxVarintSize for the exponent - // - wordLen for the big.Int bytes - return 1 + maxVarintSize + wordLen(d.UnscaledBig().Bits()) + // - WordLen for the big.Int bytes + return 1 + maxVarintSize + WordLen(d.UnscaledBig().Bits()) } // upperBoundNonsortingDecimalUnscaledSize is the same as @@ -737,7 +737,8 @@ func upperBoundNonsortingDecimalUnscaledSize(unscaledLen int) int { // Taken from math/big/arith.go. const bigWordSize = int(unsafe.Sizeof(big.Word(0))) -func wordLen(nat []big.Word) int { +// WordLen returns the size in bytes of the given array of Words. +func WordLen(nat []big.Word) int { return len(nat) * bigWordSize }