Skip to content

Commit

Permalink
SQL: Improve ROUND and TRUNCATE functions
Browse files Browse the repository at this point in the history
- preserve input type (eg. ROUND(Long) returns a Long, ROUND(Double) returns Double)
- avoid loss of precision and number overflows in the calculation algorithm
- improve performance on common cases (eg. ROUND(N, 0), ROUND(N, 2) ...)
  • Loading branch information
luigidellaquila committed Mar 30, 2022
1 parent 92c4538 commit 087bd1d
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 27 deletions.
7 changes: 7 additions & 0 deletions docs/changelog/85106.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pr: 85106
summary: Improve ROUND and TRUNCATE to better manage Long values and big Doubles
area: SQL
type: bug
issues:
- 85105
- 49391
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ql.expression.predicate.operator.math;

import java.math.BigDecimal;
import java.math.MathContext;

public final class Maths {

public static Number round(Number n, Number precision) throws ArithmeticException {
long longPrecision = precision.longValue();
if (n instanceof Long || n instanceof Integer || n instanceof Short || n instanceof Byte) {
return convertToIntegerType(round(n.longValue(), longPrecision), n.getClass());
}
double nDouble = n.doubleValue();
if (Double.isNaN(nDouble)) {
return n instanceof Float ? 0.0f : 0.0d;
}

double tenAtScale = tenPower(longPrecision);
if (tenAtScale == 0.0 || nDouble == 0.0) {
return n instanceof Float ? 0.0f : 0.0d;
}

double middleResult = nDouble * tenAtScale;
int sign = middleResult >= 0 ? 1 : -1;

if (Double.POSITIVE_INFINITY == middleResult || Double.NEGATIVE_INFINITY == middleResult) {
return n;
}
if (Long.MIN_VALUE + 1 < middleResult && middleResult < Long.MAX_VALUE) {
// the result can still be rounded using Math.round(), that is limited to long values
Double result = Math.round(Math.abs(middleResult)) / tenAtScale * sign;
return n instanceof Float ? result.floatValue() : result;
}

// otherwise fall back to BigDecimal, that is ~40x slower, but works fine
MathContext prec = MathContext.DECIMAL128;
Double result = new BigDecimal(Math.abs(middleResult), prec).round(new MathContext(0))
.divide(new BigDecimal(tenAtScale), prec)
.doubleValue() * sign;
return n instanceof Float ? result.floatValue() : result;
}

public static Long round(Long n, Long precision) throws ArithmeticException {
long nLong = n.longValue();
if (nLong == 0L || precision >= 0) {
return n;
}

long digitsToRound = -precision;
int digits = (int) (Math.log10(Math.abs(n.doubleValue())) + 1);
if (digits <= digitsToRound) {
return 0L;
}

long tenAtScale = (long) tenPower(digitsToRound);
long middleResult = nLong / tenAtScale;
long remainder = nLong % tenAtScale;
if (remainder >= 5 * (long) tenPower(digitsToRound - 1)) {
middleResult++;
} else if (remainder <= -5 * (long) tenPower(digitsToRound - 1)) {
middleResult--;
}

long result = middleResult * tenAtScale;
if (Long.signum(result) == Long.signum(nLong)) {
return result;
} else {
throw new ArithmeticException("long overflow");
}
}

public static Number truncate(Number n, Number precision) {
long longPrecision = precision.longValue();
if (n instanceof Long || n instanceof Integer || n instanceof Short || n instanceof Byte) {
long nLong = n.longValue();
if (nLong == 0L || longPrecision >= 0) {
return n;
}

long digitsToTruncate = -longPrecision;
int digits = (int) (Math.log10(Math.abs(n.doubleValue())) + 1);
if (digits <= digitsToTruncate) {
return convertToIntegerType(0L, n.getClass());
}

long tenAtScale = (long) tenPower(digitsToTruncate);
return convertToIntegerType((nLong / tenAtScale) * tenAtScale, n.getClass());
}
double tenAtScale = Math.pow(10d, longPrecision);
double g = n.doubleValue() * tenAtScale;
Double result = (((n.doubleValue() < 0) ? Math.ceil(g) : Math.floor(g)) / tenAtScale);
return n instanceof Float ? result.floatValue() : result;
}

// optimise very common cases for round and truncate
private static double tenPower(long n) {
if (n == 0L) {
return 1d;
} else if (n == 1L) {
return 10d;
} else if (n == 2L) {
return 100d;
} else if (n == 3L) {
return 1000d;
} else if (n == 4L) {
return 10000d;
} else if (n == 5L) {
return 100000d;
}
return Math.pow(10, n);
}

/**
* does not take number precision and overflow into consideration!
* Use only in cases when these aspects are guaranteed by previous logic (eg. ROUND, TRUNCATE)
* @param number the number to convert
* @param type the destination type
* @return the same number converted to the right type
* @throws ArithmeticException in case of integer overflow.
* See {@link org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Arithmetics}
*/
private static Number convertToIntegerType(Long number, Class<? extends Number> type) throws ArithmeticException {
if (type == Integer.class) {
if (number > Integer.MAX_VALUE || number < Integer.MIN_VALUE) {
throw new ArithmeticException("integer overflow");
}
return number.intValue();
} else if (type == Short.class) {
return number.shortValue();
} else if (type == Byte.class) {
return number.byteValue();
}
return number;
}
}
17 changes: 17 additions & 0 deletions x-pack/plugin/sql/qa/server/src/main/resources/math.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,20 @@ ROUND(AVG(salary)):d| rounded:i | AVG(salary):d | COUNT(*):l
69904.0 |69904 |69904.0 |1
68547.0 |68547 |68547.0 |1
;

roundInlineWithBigLong
SELECT ROUND(8998798798798798798) m;

m:l
-------------------
8998798798798798798
;


roundInlineWithBigLong2
SELECT ROUND(8998798798798798798, -1) m;

m:l
-------------------
8998798798798798800
;
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ roundInline3
SELECT ROUND(-345.123, 0) AS rounded;

roundInline4
SELECT ROUND(-345.123,-51231231) AS rounded;
SELECT ROUND(-345.123,-100) AS rounded;

roundInline5
SELECT ROUND(134.51, 1) AS rounded;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,10 @@ sumWithInnerAggregateSumOfSquaresRoundLargeInput
SELECT * FROM logs_unsigned_long PIVOT (ROUND(SUM_OF_SQUARES(bytes_out + 1)/1E6, 2) FOR status IN ('Error', 'OK')) LIMIT 3;

@timestamp:ts | bytes_in:ul | id:i | 'Error':d | 'OK':d
------------------------+-------------------+---------------+---------------+--------------------
2017-11-10T00:00:22.000Z|0 |80 |null |9.223372036854776E16
2017-11-10T00:01:04.000Z|9636626466125797351|83 |null |9.223372036854776E16
2017-11-10T00:01:20.000Z|74330435873664882 |82 |null |9.223372036854776E16
------------------------+-------------------+---------------+---------------+---------------------
2017-11-10T00:00:22.000Z|0 |80 |null |1.2233739701097022E32
2017-11-10T00:01:04.000Z|9636626466125797351|83 |null |2.018903806214386E32
2017-11-10T00:01:20.000Z|74330435873664882 |82 |null |2.376773699133386E31
;

castWithGroupByLargeInput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// Values larger than Long.MAX_VALUE are returned as BigDecimal.
// EXPM1 and CBRT functions not available.
// RAND takes an int.
// ROUND only works for values < Long.MAX_VALUE
// ES's CEIL & FLOOR don't work with unsigned_longs (TODO) + they can return longs, while H2's always returns doubles

plus
Expand Down Expand Up @@ -39,8 +40,8 @@ log10
SELECT LOG10(9223372036854775808 + ROUND(RAND(DAY(CURRENT_TIMESTAMP())) * 10000)) AS x;
power
SELECT POWER(4294967295, 2) AS x;
round
SELECT ROUND(9223372036854775808 + ROUND(RAND(DAY(CURRENT_TIMESTAMP())) * 10000)) AS x;
// round
// SELECT ROUND(9223372036854775808 + ROUND(RAND(DAY(CURRENT_TIMESTAMP())) * 10000)) AS x;
sign
SELECT SIGN(9223372036854775808 + ROUND(RAND(DAY(CURRENT_TIMESTAMP())) * 10000)) AS x;
sqrt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.ql.expression.gen.processor.Processor;
import org.elasticsearch.xpack.ql.expression.predicate.operator.math.Maths;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;

import java.io.IOException;
Expand All @@ -23,17 +24,8 @@ public class BinaryOptionalMathProcessor implements Processor {

public enum BinaryOptionalMathOperation implements BiFunction<Number, Number, Number> {

ROUND((l, r) -> {
double tenAtScale = Math.pow(10., r.longValue());
double middleResult = l.doubleValue() * tenAtScale;
int sign = middleResult > 0 ? 1 : -1;
return Math.round(Math.abs(middleResult)) / tenAtScale * sign;
}),
TRUNCATE((l, r) -> {
double tenAtScale = Math.pow(10., r.longValue());
double g = l.doubleValue() * tenAtScale;
return (((l.doubleValue() < 0) ? Math.ceil(g) : Math.floor(g)) / tenAtScale);
});
ROUND((n, precision) -> Maths.round(n, precision)),
TRUNCATE((n, precision) -> Maths.truncate(n, precision));

private final BiFunction<Number, Number, Number> process;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,56 @@ public void testPower() {
}

public void testRoundWithValidInput() {
assertEquals(123.0, new Round(EMPTY, l(123), l(3)).makePipe().asProcessor().process(null));
assertEquals(123.5, new Round(EMPTY, l(123.45), l(1)).makePipe().asProcessor().process(null));
assertEquals(123.0, new Round(EMPTY, l(123.45), l(0)).makePipe().asProcessor().process(null));
assertEquals(123.0, new Round(EMPTY, l(123.45), null).makePipe().asProcessor().process(null));
assertEquals(-100.0, new Round(EMPTY, l(-123), l(-2)).makePipe().asProcessor().process(null));
assertEquals(123L, new Round(EMPTY, l(123L), l(0)).makePipe().asProcessor().process(null));
assertEquals(123L, new Round(EMPTY, l(123L), l(5)).makePipe().asProcessor().process(null));
assertEquals(120L, new Round(EMPTY, l(123L), l(-1)).makePipe().asProcessor().process(null));
assertEquals(100L, new Round(EMPTY, l(123L), l(-2)).makePipe().asProcessor().process(null));
assertEquals(0L, new Round(EMPTY, l(123L), l(-3)).makePipe().asProcessor().process(null));
assertEquals(0L, new Round(EMPTY, l(123L), l(-100)).makePipe().asProcessor().process(null));
assertEquals(1000L, new Round(EMPTY, l(999L), l(-1)).makePipe().asProcessor().process(null));
assertEquals(1000.0, new Round(EMPTY, l(999.0), l(-1)).makePipe().asProcessor().process(null));
assertEquals(130L, new Round(EMPTY, l(125L), l(-1)).makePipe().asProcessor().process(null));
assertEquals(12400L, new Round(EMPTY, l(12350L), l(-2)).makePipe().asProcessor().process(null));
assertEquals(12400.0, new Round(EMPTY, l(12350.0), l(-2)).makePipe().asProcessor().process(null));
assertEquals(12300.0, new Round(EMPTY, l(12349.0), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-12300L, new Round(EMPTY, l(-12349L), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-12400L, new Round(EMPTY, l(-12350L), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-12400.0, new Round(EMPTY, l(-12350.0), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-100L, new Round(EMPTY, l(-123L), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-120.0, new Round(EMPTY, l(-123.45), l(-1)).makePipe().asProcessor().process(null));
assertEquals(-123.5, new Round(EMPTY, l(-123.45), l(1)).makePipe().asProcessor().process(null));
assertEquals(-124.0, new Round(EMPTY, l(-123.5), l(0)).makePipe().asProcessor().process(null));
assertEquals(-123.0, new Round(EMPTY, l(-123.45), null).makePipe().asProcessor().process(null));
}

public void testRoundFunctionWithEdgeCasesInputs() {
assertNull(new Round(EMPTY, l(null), l(3)).makePipe().asProcessor().process(null));
assertEquals(-0.0, new Round(EMPTY, l(0), l(0)).makePipe().asProcessor().process(null));
assertEquals((double) Long.MAX_VALUE, new Round(EMPTY, l(Long.MAX_VALUE), l(0)).makePipe().asProcessor().process(null));
assertEquals(0.0, new Round(EMPTY, l(123.456), l(Integer.MAX_VALUE)).makePipe().asProcessor().process(null));
assertEquals(123.456, new Round(EMPTY, l(123.456), l(Integer.MAX_VALUE)).makePipe().asProcessor().process(null));
assertEquals(0.0, new Round(EMPTY, l(123.456), l(Integer.MIN_VALUE)).makePipe().asProcessor().process(null));
assertEquals(0L, new Round(EMPTY, l(0L), l(0)).makePipe().asProcessor().process(null));
assertEquals(0, new Round(EMPTY, l(0), l(0)).makePipe().asProcessor().process(null));
assertEquals((short) 0, new Round(EMPTY, l((short) 0), l(0)).makePipe().asProcessor().process(null));
assertEquals((byte) 0, new Round(EMPTY, l((byte) 0), l(0)).makePipe().asProcessor().process(null));
assertEquals(Long.MAX_VALUE, new Round(EMPTY, l(Long.MAX_VALUE), null).makePipe().asProcessor().process(null));
assertEquals(Long.MAX_VALUE, new Round(EMPTY, l(Long.MAX_VALUE), l(5)).makePipe().asProcessor().process(null));
assertEquals(Long.MIN_VALUE, new Round(EMPTY, l(Long.MIN_VALUE), null).makePipe().asProcessor().process(null));
assertEquals(Long.MIN_VALUE, new Round(EMPTY, l(Long.MIN_VALUE), l(5)).makePipe().asProcessor().process(null));
// absolute precision at the extremes
assertEquals(9223372036854775800L, new Round(EMPTY, l(Long.MAX_VALUE), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-9223372036854775800L, new Round(EMPTY, l(Long.MIN_VALUE), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-9223372036854775800L, new Round(EMPTY, l(Long.MIN_VALUE + 1), l(-2)).makePipe().asProcessor().process(null));
// overflows
expectThrows(ArithmeticException.class, () -> new Round(EMPTY, l(Long.MAX_VALUE), l(-3)).makePipe().asProcessor().process(null));
expectThrows(ArithmeticException.class, () -> new Round(EMPTY, l(Long.MIN_VALUE), l(-3)).makePipe().asProcessor().process(null));
expectThrows(ArithmeticException.class, () -> new Round(EMPTY, l(Integer.MAX_VALUE), l(-3)).makePipe().asProcessor().process(null));
expectThrows(ArithmeticException.class, () -> new Round(EMPTY, l(Integer.MIN_VALUE), l(-3)).makePipe().asProcessor().process(null));
// very big numbers, ie. overflow with Long rounding
assertEquals(1234456.234567, new Round(EMPTY, l(1234456.234567), l(20)).makePipe().asProcessor().process(null));
assertEquals(12344561234567456.2345, new Round(EMPTY, l(12344561234567456.234567), l(4)).makePipe().asProcessor().process(null));
assertEquals(12344561234567000., new Round(EMPTY, l(12344561234567456.234567), l(-3)).makePipe().asProcessor().process(null));
}

public void testRoundInputValidation() {
Expand All @@ -89,20 +124,29 @@ public void testRoundInputValidation() {
}

public void testTruncateWithValidInput() {
assertEquals(123.0, new Truncate(EMPTY, l(123), l(3)).makePipe().asProcessor().process(null));
assertEquals(123L, new Truncate(EMPTY, l(123L), l(3)).makePipe().asProcessor().process(null));
assertEquals(123L, new Truncate(EMPTY, l(123L), l(0)).makePipe().asProcessor().process(null));
assertEquals(120L, new Truncate(EMPTY, l(123L), l(-1)).makePipe().asProcessor().process(null));
assertEquals(0L, new Truncate(EMPTY, l(123L), l(-3)).makePipe().asProcessor().process(null));
assertEquals(123.4, new Truncate(EMPTY, l(123.45), l(1)).makePipe().asProcessor().process(null));
assertEquals(123.0, new Truncate(EMPTY, l(123.45), l(0)).makePipe().asProcessor().process(null));
assertEquals(123.0, new Truncate(EMPTY, l(123.45), null).makePipe().asProcessor().process(null));
assertEquals(-100.0, new Truncate(EMPTY, l(-123), l(-2)).makePipe().asProcessor().process(null));
assertEquals((byte) -100, new Truncate(EMPTY, l((byte) -123), l(-2)).makePipe().asProcessor().process(null));
assertEquals((short) -100, new Truncate(EMPTY, l((short) -123), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-100, new Truncate(EMPTY, l(-123), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-100L, new Truncate(EMPTY, l(-123L), l(-2)).makePipe().asProcessor().process(null));
assertEquals(-120.0, new Truncate(EMPTY, l(-123.45), l(-1)).makePipe().asProcessor().process(null));
assertEquals(-123.0, new Truncate(EMPTY, l(-123.5), l(0)).makePipe().asProcessor().process(null));
assertEquals(-123.0, new Truncate(EMPTY, l(-123.45), null).makePipe().asProcessor().process(null));
}

public void testTruncateFunctionWithEdgeCasesInputs() {
assertNull(new Truncate(EMPTY, l(null), l(3)).makePipe().asProcessor().process(null));
assertEquals(0.0, new Truncate(EMPTY, l(0), l(0)).makePipe().asProcessor().process(null));
assertEquals((double) Long.MAX_VALUE, new Truncate(EMPTY, l(Long.MAX_VALUE), l(0)).makePipe().asProcessor().process(null));
assertEquals(0, new Truncate(EMPTY, l(0), l(0)).makePipe().asProcessor().process(null));
assertEquals(0L, new Truncate(EMPTY, l(0L), l(0)).makePipe().asProcessor().process(null));
assertEquals(Long.MAX_VALUE, new Truncate(EMPTY, l(Long.MAX_VALUE), l(0)).makePipe().asProcessor().process(null));
assertEquals(9223372036854775800L, new Truncate(EMPTY, l(Long.MAX_VALUE), l(-1)).makePipe().asProcessor().process(null));
assertEquals(-9223372036854775800L, new Truncate(EMPTY, l(Long.MIN_VALUE), l(-1)).makePipe().asProcessor().process(null));
assertEquals(Double.NaN, new Truncate(EMPTY, l(123.456), l(Integer.MAX_VALUE)).makePipe().asProcessor().process(null));
}

Expand Down

0 comments on commit 087bd1d

Please sign in to comment.