diff --git a/community/values/src/main/java/org/neo4j/values/utils/ValueMath.java b/community/values/src/main/java/org/neo4j/values/utils/ValueMath.java index 87eb732e49bec..441700a1314ec 100644 --- a/community/values/src/main/java/org/neo4j/values/utils/ValueMath.java +++ b/community/values/src/main/java/org/neo4j/values/utils/ValueMath.java @@ -19,6 +19,8 @@ */ package org.neo4j.values.utils; +import org.neo4j.values.storable.DoubleValue; +import org.neo4j.values.storable.LongValue; import org.neo4j.values.storable.NumberValue; import org.neo4j.values.storable.Values; @@ -35,6 +37,18 @@ private ValueMath() throw new UnsupportedOperationException( "Do not instantiate" ); } + /** + * Overflow safe addition of two longs + * + * @param a left-hand operand + * @param b right-hand operand + * @return a + b + */ + public static LongValue add( long a, long b ) + { + return longValue( Math.addExact( a, b ) ); + } + /** * Overflow safe addition of two longs *

@@ -44,7 +58,7 @@ private ValueMath() * @param b right-hand operand * @return a + b */ - public static NumberValue add( long a, long b ) + public static NumberValue overflowSafeAdd( long a, long b ) { long r = a + b; //Check if result overflows @@ -62,11 +76,23 @@ public static NumberValue add( long a, long b ) * @param b right-hand operand * @return a + b */ - public static NumberValue add( double a, double b ) + public static DoubleValue add( double a, double b ) { return Values.doubleValue( a + b ); } + /** + * Overflow safe subtraction of two longs + * + * @param a left-hand operand + * @param b right-hand operand + * @return a - b + */ + public static LongValue subtract( long a, long b ) + { + return longValue( Math.subtractExact( a, b ) ); + } + /** * Overflow safe subtraction of two longs *

@@ -76,7 +102,7 @@ public static NumberValue add( double a, double b ) * @param b right-hand operand * @return a + b */ - public static NumberValue subtract( long a, long b ) + public static NumberValue overflowSafeSubtract( long a, long b ) { long r = a - b; //Check if result overflows @@ -94,11 +120,23 @@ public static NumberValue subtract( long a, long b ) * @param b right-hand operand * @return a - b */ - public static NumberValue subtract( double a, double b ) + public static DoubleValue subtract( double a, double b ) { return Values.doubleValue( a - b ); } + /** + * Overflow safe multiplication of two longs + * + * @param a left-hand operand + * @param b right-hand operand + * @return a * b + */ + public static LongValue multiply( long a, long b ) + { + return longValue( Math.multiplyExact( a, b ) ); + } + /** * Overflow safe multiplication of two longs *

@@ -108,7 +146,7 @@ public static NumberValue subtract( double a, double b ) * @param b right-hand operand * @return a * b */ - public static NumberValue multiply( long a, long b ) + public static NumberValue overflowSafeMultiply( long a, long b ) { long r = a * b; //Check if result overflows @@ -131,7 +169,7 @@ public static NumberValue multiply( long a, long b ) * @param b right-hand operand * @return a * b */ - public static NumberValue multiply( double a, double b ) + public static DoubleValue multiply( double a, double b ) { return Values.doubleValue( a * b ); } diff --git a/community/values/src/test/java/org/neo4j/values/storable/NumberValueMathTest.java b/community/values/src/test/java/org/neo4j/values/storable/NumberValueMathTest.java index 7908bb27f23e3..9b81d2c195baf 100644 --- a/community/values/src/test/java/org/neo4j/values/storable/NumberValueMathTest.java +++ b/community/values/src/test/java/org/neo4j/values/storable/NumberValueMathTest.java @@ -19,7 +19,9 @@ */ package org.neo4j.values.storable; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; @@ -29,9 +31,15 @@ import static org.neo4j.values.storable.Values.intValue; import static org.neo4j.values.storable.Values.longValue; import static org.neo4j.values.storable.Values.shortValue; +import static org.neo4j.values.utils.ValueMath.overflowSafeAdd; +import static org.neo4j.values.utils.ValueMath.overflowSafeMultiply; +import static org.neo4j.values.utils.ValueMath.overflowSafeSubtract; public class NumberValueMathTest { + @Rule + public ExpectedException exception = ExpectedException.none(); + @Test public void shouldAddSimpleIntegers() { @@ -86,7 +94,7 @@ public void shouldAddSimpleFloats() NumberValue[] integers = new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )}; NumberValue[] floats = - new NumberValue[]{floatValue( 42 ), doubleValue( 42 ) }; + new NumberValue[]{floatValue( 42 ), doubleValue( 42 )}; for ( NumberValue a : integers ) { @@ -104,7 +112,7 @@ public void shouldSubtractSimpleFloats() NumberValue[] integers = new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )}; NumberValue[] floats = - new NumberValue[]{floatValue( 42 ), doubleValue( 42 ) }; + new NumberValue[]{floatValue( 42 ), doubleValue( 42 )}; for ( NumberValue a : integers ) { @@ -122,33 +130,63 @@ public void shouldMultiplySimpleFloats() NumberValue[] integers = new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )}; NumberValue[] floats = - new NumberValue[]{floatValue( 42 ), doubleValue( 42 ) }; + new NumberValue[]{floatValue( 42 ), doubleValue( 42 )}; for ( NumberValue a : integers ) { for ( NumberValue b : floats ) { - assertThat( a.times( b ), equalTo( doubleValue( 42 * 42) ) ); + assertThat( a.times( b ), equalTo( doubleValue( 42 * 42 ) ) ); assertThat( b.times( a ), equalTo( doubleValue( 42 * 42 ) ) ); } } } @Test - public void shouldNotOverflowOnAddition() + public void shouldFailOnOverflowingAdd() + { + //Expect + exception.expect( ArithmeticException.class ); + + //WHEN + longValue( Long.MAX_VALUE ).plus( longValue( 1 ) ); + } + + @Test + public void shouldFailOnOverflowingSubtraction() + { + //Expect + exception.expect( ArithmeticException.class ); + + //WHEN + longValue( Long.MAX_VALUE ).minus( longValue( -1 ) ); + } + + @Test + public void shouldFailOnOverflowingMultiplication() + { + //Expect + exception.expect( ArithmeticException.class ); + + //When + longValue( Long.MAX_VALUE ).times( 2 ); + } + + @Test + public void shouldNotOverflowOnSafeAddition() { - assertThat( longValue( Long.MAX_VALUE ).plus( longValue( 1 ) ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) ); + assertThat( overflowSafeAdd( Long.MAX_VALUE, 1 ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) ); } @Test - public void shouldNotOverflowOnSubtraction() + public void shouldNotOverflowOnSafeSubtraction() { - assertThat( longValue( Long.MAX_VALUE ).minus( longValue( -1 ) ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) ); + assertThat( overflowSafeSubtract( Long.MAX_VALUE, -1 ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) ); } @Test public void shouldNotOverflowOnMultiplication() { - assertThat( longValue( Long.MAX_VALUE ).times( 2 ), equalTo( doubleValue( (double) Long.MAX_VALUE * 2 ) ) ); + assertThat( overflowSafeMultiply( Long.MAX_VALUE, 2 ), equalTo( doubleValue( (double) Long.MAX_VALUE * 2 ) ) ); } }