diff --git a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java index cea04457a4de3..5c7e3458a4674 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java @@ -161,38 +161,20 @@ public void accept( ExpressionVisitor visitor ) }; } - public static Expression subtractInts( final Expression lhs, final Expression rhs ) + public static Expression subtract( final Expression lhs, final Expression rhs ) { - return new Expression( INT ) - { - @Override - public void accept( ExpressionVisitor visitor ) - { - visitor.subtractInts( lhs, rhs ); - } - }; - } - - public static Expression subtractLongs( final Expression lhs, final Expression rhs ) - { - return new Expression( LONG ) + if ( !lhs.type.equals( rhs.type ) ) { - @Override - public void accept( ExpressionVisitor visitor ) - { - visitor.subtractLongs( lhs, rhs ); - } - }; - } - - public static Expression subtractDoubles( final Expression lhs, final Expression rhs ) - { - return new Expression( DOUBLE ) + throw new IllegalArgumentException( + String.format( "Cannot subtract variable with different types. LHS %s, RHS %s", lhs.type.simpleName(), + rhs.type.simpleName() )); + } + return new Expression( lhs.type ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.subtractDoubles( lhs, rhs ); + visitor.subtract( lhs, rhs ); } }; } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java index 4167f340d96f7..819afffad7943 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java @@ -228,24 +228,7 @@ public void lte( Expression lhs, Expression rhs ) } @Override - public void subtractInts( Expression lhs, Expression rhs ) - { - sub( lhs, rhs ); - } - - @Override - public void subtractLongs( Expression lhs, Expression rhs ) - { - sub( lhs, rhs ); - } - - @Override - public void subtractDoubles( Expression lhs, Expression rhs ) - { - sub( lhs, rhs ); - } - - private void sub( Expression lhs, Expression rhs ) + public void subtract( Expression lhs, Expression rhs ) { result.append( "sub(" ); lhs.accept( this ); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java index 3353fe817c96b..811064f166773 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java @@ -61,11 +61,7 @@ public interface ExpressionVisitor void lte( Expression lhs, Expression rhs ); - void subtractInts( Expression lhs, Expression rhs ); - - void subtractLongs( Expression lhs, Expression rhs ); - - void subtractDoubles( Expression lhs, Expression rhs ); + void subtract( Expression lhs, Expression rhs ); void multiplyLongs( Expression lhs, Expression rhs ); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java index e978aa24bc641..fe77bef0152a6 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java @@ -55,6 +55,7 @@ import static org.objectweb.asm.Opcodes.FCMPG; import static org.objectweb.asm.Opcodes.FCMPL; import static org.objectweb.asm.Opcodes.FLOAD; +import static org.objectweb.asm.Opcodes.FSUB; import static org.objectweb.asm.Opcodes.GETFIELD; import static org.objectweb.asm.Opcodes.GETSTATIC; import static org.objectweb.asm.Opcodes.GOTO; @@ -424,7 +425,7 @@ public void lt( Expression lhs, Expression rhs ) } @Override - public void lte( Expression lhs, Expression rhs) + public void lte( Expression lhs, Expression rhs ) { assertSameType( lhs, rhs, "compare" ); numberOperation( lhs.type(), @@ -436,27 +437,15 @@ public void lte( Expression lhs, Expression rhs) } @Override - public void subtractInts( Expression lhs, Expression rhs ) + public void subtract( Expression lhs, Expression rhs ) { lhs.accept( this ); rhs.accept( this ); - methodVisitor.visitInsn( ISUB ); - } - - @Override - public void subtractLongs( Expression lhs, Expression rhs ) - { - lhs.accept( this ); - rhs.accept( this ); - methodVisitor.visitInsn( LSUB ); - } - - @Override - public void subtractDoubles( Expression lhs, Expression rhs ) - { - lhs.accept( this ); - rhs.accept( this ); - methodVisitor.visitInsn( DSUB ); + numberOperation( lhs.type(), + () -> methodVisitor.visitInsn( ISUB ), + () -> methodVisitor.visitInsn( LSUB ), + () -> methodVisitor.visitInsn( FSUB ), + () -> methodVisitor.visitInsn( DSUB ) ); } @Override @@ -680,7 +669,7 @@ private void assertSameType( Expression lhs, Expression rhs, String operation ) { if ( !lhs.type().equals( rhs.type() ) ) { - throw new IllegalArgumentException( String.format( "Can only %s values of the same type", operation )); + throw new IllegalArgumentException( String.format( "Can only %s values of the same type", operation ) ); } } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java index 9a15df25be675..901e0c9bd04df 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java @@ -418,24 +418,7 @@ public void lte( Expression lhs, Expression rhs ) } @Override - public void subtractInts( Expression lhs, Expression rhs ) - { - sub( lhs, rhs); - } - - @Override - public void subtractLongs( Expression lhs, Expression rhs ) - { - sub( lhs, rhs); - } - - @Override - public void subtractDoubles( Expression lhs, Expression rhs ) - { - sub( lhs, rhs); - } - - private void sub( Expression lhs, Expression rhs ) + public void subtract( Expression lhs, Expression rhs ) { lhs.accept( this ); append( " - " ); diff --git a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java index eea097d3fc59e..15f9e9c9fde1c 100644 --- a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java +++ b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java @@ -64,8 +64,7 @@ import static org.neo4j.codegen.Expression.newInstance; import static org.neo4j.codegen.Expression.not; import static org.neo4j.codegen.Expression.or; -import static org.neo4j.codegen.Expression.subtractDoubles; -import static org.neo4j.codegen.Expression.subtractLongs; +import static org.neo4j.codegen.Expression.subtract; import static org.neo4j.codegen.Expression.ternary; import static org.neo4j.codegen.ExpressionTemplate.cast; import static org.neo4j.codegen.ExpressionTemplate.load; @@ -1387,18 +1386,7 @@ private T subtractForType( Class clazz, T lhs, T rhs ) throws Throwable try ( CodeBlock block = simple.generateMethod( clazz, "sub", param( clazz, "a" ), param( clazz, "b" ) ) ) { - if ( clazz == long.class ) - { - block.returns( subtractLongs( block.load( "a" ), block.load( "b" ) ) ); - } - else if ( clazz == double.class ) - { - block.returns( subtractDoubles( block.load( "a" ), block.load( "b" ) ) ); - } - else - { - fail( "adding " + clazz.getSimpleName() + " is not supported" ); - } + block.returns( subtract( block.load( "a" ), block.load( "b" ) ) ); } handle = simple.handle(); diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/CodeStructure.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/CodeStructure.scala index c26f3af489d39..1aed0efde7ef3 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/CodeStructure.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/CodeStructure.scala @@ -91,11 +91,11 @@ trait MethodStructure[E] { def loadVariable(varName: String): E // arithmetic - def add(lhs: E, rhs: E): E - def subtract(lhs: E, rhs: E): E - def multiply(lhs: E, rhs: E): E - def divide(lhs: E, rhs: E): E - def modulus(lhs: E, rhs: E): E + def addExpression(lhs: E, rhs: E): E + def subtractExpression(lhs: E, rhs: E): E + def multiplyExpression(lhs: E, rhs: E): E + def divideExpression(lhs: E, rhs: E): E + def modulusExpression(lhs: E, rhs: E): E // predicates def threeValuedNotExpression(value: E): E diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Addition.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Addition.scala index 6b856fa767e85..7872978313f66 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Addition.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Addition.scala @@ -24,7 +24,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols._ case class Addition(lhs: CodeGenExpression, rhs: CodeGenExpression) extends CodeGenExpression with BinaryOperator { - override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.add + override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.addExpression override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Division.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Division.scala index cb55e5702e125..be8aa20096c22 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Division.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Division.scala @@ -25,7 +25,7 @@ case class Division(lhs: CodeGenExpression, rhs: CodeGenExpression) extends CodeGenExpression with BinaryOperator{ override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = - structure.divide + structure.divideExpression override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Modulo.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Modulo.scala index 51d9628a5ab31..6812a4ffddadf 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Modulo.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Modulo.scala @@ -24,7 +24,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols._ case class Modulo(lhs: CodeGenExpression, rhs: CodeGenExpression) extends CodeGenExpression with BinaryOperator { - override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.modulus + override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.modulusExpression override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTFloat, ReferenceType) diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Multiplication.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Multiplication.scala index ddbce1198e5be..7e9a880b1c29d 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Multiplication.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Multiplication.scala @@ -26,7 +26,7 @@ case class Multiplication(lhs: CodeGenExpression, rhs: CodeGenExpression) override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable - override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.multiply + override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.multiplyExpression override def name: String = "multiply" } diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Subtraction.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Subtraction.scala index 0692ecc0d5631..6a9a5508acb95 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Subtraction.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/Subtraction.scala @@ -25,7 +25,7 @@ case class Subtraction(lhs: CodeGenExpression, rhs: CodeGenExpression) extends CodeGenExpression with BinaryOperator { override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = - structure.subtract + structure.subtractExpression override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala index 940d9309310a1..416afb7a2c154 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala @@ -189,8 +189,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def decreaseCounterAndCheckForZero(name: String): Expression = { val local = locals(name) - generator.assign(local, subtractInts(local, constant(1))) - + generator.assign(local, subtract(local, constant(1))) equal(constant(0), local) } @@ -468,8 +467,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A invoke(generator.load(tableVar), countingTablePut, generator.load(keyVar), ternary( equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL"))), - constant(1), - Expression.add(generator.load(countName), constant(1)))))) + constant(1), add(generator.load(countName), constant(1)))))) case LongsToCountTable => val countName = context.namer.newVarName() @@ -489,7 +487,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A ternaryOnNull(generator.load(countName), invoke(boxInteger, constant(1)), invoke(boxInteger, - Expression.add( + add( invoke(generator.load(countName), unboxInteger), constant(1))))))) @@ -504,7 +502,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A generator.assign(times, invoke(generator.load(tableVar), countingTableGet, generator.load(keyVar))) using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) - body.assign(times, subtractInts(times, constant(1))) + body.assign(times, subtract(times, constant(1))) } case LongsToCountTable => val times = generator.declare(typeRef[Int], context.namer.newVarName()) @@ -524,7 +522,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) - body.assign(times, subtractInts(times, constant(1))) + body.assign(times, subtract(times, constant(1))) } case tableType@LongToListTable(structure, localVars) => diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala index 2feebc1b6cc22..6be8f5ad28383 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala @@ -189,7 +189,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def decrementCounter(name: String) = { val local = locals(name) - generator.assign(local, subtractInts(local, constant(1))) + generator.assign(local, subtract(local, constant(1))) } override def checkCounter(name: String, comparator: Comparator, value: Int): Expression = { @@ -374,15 +374,15 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def loadVariable(varName: String) = generator.load(varName) - override def add(lhs: Expression, rhs: Expression) = math(Methods.mathAdd, lhs, rhs) + override def addExpression(lhs: Expression, rhs: Expression) = math(Methods.mathAdd, lhs, rhs) - override def subtract(lhs: Expression, rhs: Expression) = math(Methods.mathSub, lhs, rhs) + override def subtractExpression(lhs: Expression, rhs: Expression) = math(Methods.mathSub, lhs, rhs) - override def multiply(lhs: Expression, rhs: Expression) = math(Methods.mathMul, lhs, rhs) + override def multiplyExpression(lhs: Expression, rhs: Expression) = math(Methods.mathMul, lhs, rhs) - override def divide(lhs: Expression, rhs: Expression) = math(Methods.mathDiv, lhs, rhs) + override def divideExpression(lhs: Expression, rhs: Expression) = math(Methods.mathDiv, lhs, rhs) - override def modulus(lhs: Expression, rhs: Expression) = math(Methods.mathMod, lhs, rhs) + override def modulusExpression(lhs: Expression, rhs: Expression) = math(Methods.mathMod, lhs, rhs) private def math(method: MethodReference, lhs: Expression, rhs: Expression): Expression = invoke(method, lhs, rhs) @@ -472,8 +472,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A invoke(generator.load(tableVar), countingTablePut, generator.load(keyVar), ternary( equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL"))), - constant(1), - addInts(generator.load(countName), constant(1)))))) + constant(1), add(generator.load(countName), constant(1)))))) case LongsToCountTable => val countName = context.namer.newVarName() @@ -493,7 +492,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A ternaryOnNull(generator.load(countName), invoke(boxInteger, constant(1)), invoke(boxInteger, - addInts( + add( invoke(generator.load(countName), unboxInteger), constant(1))))))) @@ -508,7 +507,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A generator.assign(times, invoke(generator.load(tableVar), countingTableGet, generator.load(keyVar))) using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) - body.assign(times, subtractInts(times, constant(1))) + body.assign(times, subtract(times, constant(1))) } case LongsToCountTable => val times = generator.declare(typeRef[Int], context.namer.newVarName()) @@ -528,7 +527,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) - body.assign(times, subtractInts(times, constant(1))) + body.assign(times, subtract(times, constant(1))) } case tableType@LongToListTable(structure, localVars) =>