Skip to content

Commit

Permalink
Switch to use strict math by default
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Feb 25, 2018
1 parent eb10b3e commit 2950115
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
Expand Up @@ -19,6 +19,8 @@
*/ */
package org.neo4j.values.utils; package org.neo4j.values.utils;


import org.neo4j.values.storable.DoubleValue;
import org.neo4j.values.storable.LongValue;
import org.neo4j.values.storable.NumberValue; import org.neo4j.values.storable.NumberValue;
import org.neo4j.values.storable.Values; import org.neo4j.values.storable.Values;


Expand All @@ -35,6 +37,18 @@ private ValueMath()
throw new UnsupportedOperationException( "Do not instantiate" ); throw new UnsupportedOperationException( "Do not instantiate" );
} }


/**
* Overflow safe addition of two longs
*
* @param a left-hand operand
* @param b right-hand operand
* @return a + b
*/
public static LongValue add( long a, long b )
{
return longValue( Math.addExact( a, b ) );
}

/** /**
* Overflow safe addition of two longs * Overflow safe addition of two longs
* <p> * <p>
Expand All @@ -44,7 +58,7 @@ private ValueMath()
* @param b right-hand operand * @param b right-hand operand
* @return a + b * @return a + b
*/ */
public static NumberValue add( long a, long b ) public static NumberValue overflowSafeAdd( long a, long b )
{ {
long r = a + b; long r = a + b;
//Check if result overflows //Check if result overflows
Expand All @@ -62,11 +76,23 @@ public static NumberValue add( long a, long b )
* @param b right-hand operand * @param b right-hand operand
* @return a + b * @return a + b
*/ */
public static NumberValue add( double a, double b ) public static DoubleValue add( double a, double b )
{ {
return Values.doubleValue( a + b ); return Values.doubleValue( a + b );
} }


/**
* Overflow safe subtraction of two longs
*
* @param a left-hand operand
* @param b right-hand operand
* @return a - b
*/
public static LongValue subtract( long a, long b )
{
return longValue( Math.subtractExact( a, b ) );
}

/** /**
* Overflow safe subtraction of two longs * Overflow safe subtraction of two longs
* <p> * <p>
Expand All @@ -76,7 +102,7 @@ public static NumberValue add( double a, double b )
* @param b right-hand operand * @param b right-hand operand
* @return a + b * @return a + b
*/ */
public static NumberValue subtract( long a, long b ) public static NumberValue overflowSafeSubtract( long a, long b )
{ {
long r = a - b; long r = a - b;
//Check if result overflows //Check if result overflows
Expand All @@ -94,11 +120,23 @@ public static NumberValue subtract( long a, long b )
* @param b right-hand operand * @param b right-hand operand
* @return a - b * @return a - b
*/ */
public static NumberValue subtract( double a, double b ) public static DoubleValue subtract( double a, double b )
{ {
return Values.doubleValue( a - b ); return Values.doubleValue( a - b );
} }


/**
* Overflow safe multiplication of two longs
*
* @param a left-hand operand
* @param b right-hand operand
* @return a * b
*/
public static LongValue multiply( long a, long b )
{
return longValue( Math.multiplyExact( a, b ) );
}

/** /**
* Overflow safe multiplication of two longs * Overflow safe multiplication of two longs
* <p> * <p>
Expand All @@ -108,7 +146,7 @@ public static NumberValue subtract( double a, double b )
* @param b right-hand operand * @param b right-hand operand
* @return a * b * @return a * b
*/ */
public static NumberValue multiply( long a, long b ) public static NumberValue overflowSafeMultiply( long a, long b )
{ {
long r = a * b; long r = a * b;
//Check if result overflows //Check if result overflows
Expand All @@ -131,7 +169,7 @@ public static NumberValue multiply( long a, long b )
* @param b right-hand operand * @param b right-hand operand
* @return a * b * @return a * b
*/ */
public static NumberValue multiply( double a, double b ) public static DoubleValue multiply( double a, double b )
{ {
return Values.doubleValue( a * b ); return Values.doubleValue( a * b );
} }
Expand Down
Expand Up @@ -19,7 +19,9 @@
*/ */
package org.neo4j.values.storable; package org.neo4j.values.storable;


import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;


import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -29,9 +31,15 @@
import static org.neo4j.values.storable.Values.intValue; import static org.neo4j.values.storable.Values.intValue;
import static org.neo4j.values.storable.Values.longValue; import static org.neo4j.values.storable.Values.longValue;
import static org.neo4j.values.storable.Values.shortValue; import static org.neo4j.values.storable.Values.shortValue;
import static org.neo4j.values.utils.ValueMath.overflowSafeAdd;
import static org.neo4j.values.utils.ValueMath.overflowSafeMultiply;
import static org.neo4j.values.utils.ValueMath.overflowSafeSubtract;


public class NumberValueMathTest public class NumberValueMathTest
{ {
@Rule
public ExpectedException exception = ExpectedException.none();

@Test @Test
public void shouldAddSimpleIntegers() public void shouldAddSimpleIntegers()
{ {
Expand Down Expand Up @@ -86,7 +94,7 @@ public void shouldAddSimpleFloats()
NumberValue[] integers = NumberValue[] integers =
new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )}; new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )};
NumberValue[] floats = NumberValue[] floats =
new NumberValue[]{floatValue( 42 ), doubleValue( 42 ) }; new NumberValue[]{floatValue( 42 ), doubleValue( 42 )};


for ( NumberValue a : integers ) for ( NumberValue a : integers )
{ {
Expand All @@ -104,7 +112,7 @@ public void shouldSubtractSimpleFloats()
NumberValue[] integers = NumberValue[] integers =
new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )}; new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )};
NumberValue[] floats = NumberValue[] floats =
new NumberValue[]{floatValue( 42 ), doubleValue( 42 ) }; new NumberValue[]{floatValue( 42 ), doubleValue( 42 )};


for ( NumberValue a : integers ) for ( NumberValue a : integers )
{ {
Expand All @@ -122,33 +130,63 @@ public void shouldMultiplySimpleFloats()
NumberValue[] integers = NumberValue[] integers =
new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )}; new NumberValue[]{byteValue( (byte) 42 ), shortValue( (short) 42 ), intValue( 42 ), longValue( 42 )};
NumberValue[] floats = NumberValue[] floats =
new NumberValue[]{floatValue( 42 ), doubleValue( 42 ) }; new NumberValue[]{floatValue( 42 ), doubleValue( 42 )};


for ( NumberValue a : integers ) for ( NumberValue a : integers )
{ {
for ( NumberValue b : floats ) for ( NumberValue b : floats )
{ {
assertThat( a.times( b ), equalTo( doubleValue( 42 * 42) ) ); assertThat( a.times( b ), equalTo( doubleValue( 42 * 42 ) ) );
assertThat( b.times( a ), equalTo( doubleValue( 42 * 42 ) ) ); assertThat( b.times( a ), equalTo( doubleValue( 42 * 42 ) ) );
} }
} }
} }


@Test @Test
public void shouldNotOverflowOnAddition() public void shouldFailOnOverflowingAdd()
{
//Expect
exception.expect( ArithmeticException.class );

//WHEN
longValue( Long.MAX_VALUE ).plus( longValue( 1 ) );
}

@Test
public void shouldFailOnOverflowingSubtraction()
{
//Expect
exception.expect( ArithmeticException.class );

//WHEN
longValue( Long.MAX_VALUE ).minus( longValue( -1 ) );
}

@Test
public void shouldFailOnOverflowingMultiplication()
{
//Expect
exception.expect( ArithmeticException.class );

//When
longValue( Long.MAX_VALUE ).times( 2 );
}

@Test
public void shouldNotOverflowOnSafeAddition()
{ {
assertThat( longValue( Long.MAX_VALUE ).plus( longValue( 1 ) ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) ); assertThat( overflowSafeAdd( Long.MAX_VALUE, 1 ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) );
} }


@Test @Test
public void shouldNotOverflowOnSubtraction() public void shouldNotOverflowOnSafeSubtraction()
{ {
assertThat( longValue( Long.MAX_VALUE ).minus( longValue( -1 ) ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) ); assertThat( overflowSafeSubtract( Long.MAX_VALUE, -1 ), equalTo( doubleValue( (double) Long.MAX_VALUE + 1 ) ) );
} }


@Test @Test
public void shouldNotOverflowOnMultiplication() public void shouldNotOverflowOnMultiplication()
{ {
assertThat( longValue( Long.MAX_VALUE ).times( 2 ), equalTo( doubleValue( (double) Long.MAX_VALUE * 2 ) ) ); assertThat( overflowSafeMultiply( Long.MAX_VALUE, 2 ), equalTo( doubleValue( (double) Long.MAX_VALUE * 2 ) ) );
} }
} }

0 comments on commit 2950115

Please sign in to comment.