From 4e4b518fc96116123577bc0039393822f7776539 Mon Sep 17 00:00:00 2001 From: Tobias Lindaaker Date: Tue, 21 Feb 2017 19:08:32 +0100 Subject: [PATCH] Simplify the code generation API Remove specialized methods for different types of comparisons, instead make the implementation perform the specialization. --- .../java/org/neo4j/codegen/CodeBlock.java | 33 +- .../java/org/neo4j/codegen/Expression.java | 326 +++++++++++++++--- .../org/neo4j/codegen/ExpressionToString.java | 58 ++-- .../org/neo4j/codegen/ExpressionVisitor.java | 12 +- .../java/org/neo4j/codegen/InvalidState.java | 22 +- .../java/org/neo4j/codegen/MethodEmitter.java | 24 +- .../java/org/neo4j/codegen/TypeReference.java | 13 + .../bytecode/ByteCodeExpressionVisitor.java | 92 +++-- .../java/org/neo4j/codegen/bytecode/If.java | 8 +- .../neo4j/codegen/bytecode/JumpVisitor.java | 244 +++++++++++++ .../bytecode/MethodByteCodeEmitter.java | 84 ++--- .../org/neo4j/codegen/bytecode/While.java | 14 +- .../codegen/source/MethodSourceWriter.java | 94 ++--- .../org/neo4j/codegen/CodeGenerationTest.java | 74 ++-- .../org/neo4j/codegen/ExpressionTest.java | 178 ++++++++++ .../codegen/GeneratedMethodStructure.scala | 30 +- .../spi/v3_2/codegen/AuxGenerator.scala | 6 +- .../codegen/GeneratedMethodStructure.scala | 72 ++-- .../internal/spi/v3_2/codegen/Templates.scala | 2 +- 19 files changed, 983 insertions(+), 403 deletions(-) create mode 100644 community/codegen/src/main/java/org/neo4j/codegen/bytecode/JumpVisitor.java create mode 100644 community/codegen/src/test/java/org/neo4j/codegen/ExpressionTest.java diff --git a/community/codegen/src/main/java/org/neo4j/codegen/CodeBlock.java b/community/codegen/src/main/java/org/neo4j/codegen/CodeBlock.java index 7daa0b37e698..5e95f8a4697b 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/CodeBlock.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/CodeBlock.java @@ -20,11 +20,7 @@ package org.neo4j.codegen; import java.util.Iterator; -import java.util.PrimitiveIterator; import java.util.function.Consumer; -import java.util.stream.DoubleStream; -import java.util.stream.IntStream; -import java.util.stream.LongStream; import static org.neo4j.codegen.LocalVariables.copy; import static org.neo4j.codegen.MethodReference.methodReference; @@ -163,33 +159,15 @@ public CodeBlock forEach( Parameter local, Expression iterable ) return block; } - public CodeBlock whileLoop( Expression...tests ) + public CodeBlock whileLoop( Expression test ) { - emitter.beginWhile( tests ); + emitter.beginWhile( test ); return new CodeBlock( this ); } - public CodeBlock ifStatement( Expression...tests ) + public CodeBlock ifStatement( Expression test ) { - emitter.beginIf( tests ); - return new CodeBlock( this ); - } - - public CodeBlock ifNotStatement( Expression...tests ) - { - emitter.beginIfNot( tests ); - return new CodeBlock( this ); - } - - public CodeBlock ifNullStatement( Expression...tests ) - { - emitter.beginIfNull( tests ); - return new CodeBlock( this ); - } - - public CodeBlock ifNonNullStatement( Expression...tests ) - { - emitter.beginIfNonNull( tests ); + emitter.beginIf( test ); return new CodeBlock( this ); } @@ -201,8 +179,7 @@ public CodeBlock block() public void tryCatch( Consumer body, Consumer onError, Parameter exception ) { - emitter.tryCatchBlock( body, onError, localVariables.createNew( exception.type(), exception.name() ), - this ); + emitter.tryCatchBlock( body, onError, localVariables.createNew( exception.type(), exception.name() ), this ); } public void returns() diff --git a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java index a4fd53078ced..2c3daee91769 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java @@ -19,6 +19,8 @@ */ package org.neo4j.codegen; +import java.util.Arrays; + import static org.neo4j.codegen.TypeReference.BOOLEAN; import static org.neo4j.codegen.TypeReference.DOUBLE; import static org.neo4j.codegen.TypeReference.INT; @@ -29,6 +31,22 @@ public abstract class Expression extends ExpressionTemplate { + public static final Expression TRUE = new Constant( BOOLEAN, Boolean.TRUE ) + { + @Override + Expression not() + { + return FALSE; + } + }, FALSE = new Constant( BOOLEAN, Boolean.FALSE ) + { + @Override + Expression not() + { + return TRUE; + } + }, NULL = new Constant( OBJECT, null ); + protected Expression( TypeReference type ) { super( type ); @@ -54,6 +72,12 @@ public void accept( ExpressionVisitor visitor ) { visitor.gt( lhs, rhs ); } + + @Override + Expression not() + { + return lte( lhs, rhs ); + } }; } @@ -66,6 +90,12 @@ public void accept( ExpressionVisitor visitor ) { visitor.gte( lhs, rhs ); } + + @Override + Expression not() + { + return lt( lhs, rhs ); + } }; } @@ -78,6 +108,12 @@ public void accept( ExpressionVisitor visitor ) { visitor.lt( lhs, rhs ); } + + @Override + Expression not() + { + return gte( lhs, rhs ); + } }; } @@ -90,41 +126,228 @@ public void accept( ExpressionVisitor visitor ) { visitor.lte( lhs, rhs ); } + + @Override + Expression not() + { + return gt( lhs, rhs ); + } }; } public static Expression and( final Expression lhs, final Expression rhs ) { + if ( lhs == FALSE || rhs == FALSE ) + { + return FALSE; + } + if ( lhs == TRUE ) + { + return rhs; + } + if ( rhs == TRUE ) + { + return lhs; + } + Expression[] expressions; + if ( lhs instanceof And ) + { + if ( rhs instanceof And ) + { + expressions = expressions( ((And) lhs).expressions, ((And) rhs).expressions ); + } + else + { + expressions = expressions( ((And) lhs).expressions, rhs ); + } + } + else if ( rhs instanceof And ) + { + expressions = expressions( lhs, ((And) rhs).expressions ); + } + else + { + expressions = new Expression[] {lhs, rhs}; + } + return new And( expressions ); + } + + public static Expression or( final Expression lhs, final Expression rhs ) + { + if ( lhs == TRUE || rhs == TRUE ) + { + return TRUE; + } + if ( lhs == FALSE ) + { + return rhs; + } + if ( rhs == FALSE ) + { + return lhs; + } + Expression[] expressions; + if ( lhs instanceof Or ) + { + if ( rhs instanceof Or ) + { + expressions = expressions( ((Or) lhs).expressions, ((Or) rhs).expressions ); + } + else + { + expressions = expressions( ((Or) lhs).expressions, rhs ); + } + } + else if ( rhs instanceof Or ) + { + expressions = expressions( lhs, ((Or) rhs).expressions ); + } + else + { + expressions = new Expression[] {lhs, rhs}; + } + return new Or( expressions ); + } + + private static class And extends Expression + { + private final Expression[] expressions; + + And( Expression[] expressions ) + { + super( BOOLEAN ); + this.expressions = expressions; + } + + @Override + public void accept( ExpressionVisitor visitor ) + { + visitor.and( expressions ); + } + } + + private static class Or extends Expression + { + private final Expression[] expressions; + + Or( Expression[] expressions ) + { + super( BOOLEAN ); + this.expressions = expressions; + } + + @Override + public void accept( ExpressionVisitor visitor ) + { + visitor.or( expressions ); + } + } + + private static Expression[] expressions( Expression[] some, Expression[] more ) + { + Expression[] result = Arrays.copyOf( some, some.length + more.length ); + System.arraycopy( more, 0, result, some.length, more.length ); + return result; + } + + private static Expression[] expressions( Expression[] some, Expression last ) + { + Expression[] result = Arrays.copyOf( some, some.length + 1 ); + result[some.length] = last; + return result; + } + + private static Expression[] expressions( Expression first, Expression[] more ) + { + Expression[] result = new Expression[more.length]; + result[0] = first; + System.arraycopy( more, 0, result, 1, more.length ); + return result; + } + + public static Expression equal( final Expression lhs, final Expression rhs ) + { + if ( lhs == NULL ) + { + if ( rhs == NULL ) + { + return constant( true ); + } + else + { + return isNull( rhs ); + } + } + else if ( rhs == NULL ) + { + return isNull( lhs ); + } return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.and( lhs, rhs ); + visitor.equal( lhs, rhs ); + } + + @Override + Expression not() + { + return notEqual( lhs, rhs ); } }; } - public static Expression or( final Expression lhs, final Expression rhs ) + public static Expression isNull( final Expression expression ) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.or( lhs, rhs ); + visitor.isNull( expression ); + } + + @Override + Expression not() + { + return notNull( expression ); } }; } - public static Expression equal( final Expression lhs, final Expression rhs ) + public static Expression notNull( final Expression expression ) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.equal( lhs, rhs ); + visitor.notNull( expression ); + } + + @Override + Expression not() + { + return isNull( expression ); + } + }; + } + + public static Expression notEqual( final Expression lhs, final Expression rhs ) + { + return new Expression( BOOLEAN ) + { + @Override + public void accept( ExpressionVisitor visitor ) + { + visitor.notEqual( lhs, rhs ); + } + + @Override + Expression not() + { + return equal( lhs, rhs ); } }; } @@ -147,7 +370,7 @@ public static Expression add( final Expression lhs, final Expression rhs ) { throw new IllegalArgumentException( String.format( "Cannot add variables with different types. LHS %s, RHS %s", lhs.type.simpleName(), - rhs.type.simpleName() )); + rhs.type.simpleName() ) ); } return new Expression( lhs.type ) @@ -165,8 +388,10 @@ public static Expression subtract( final Expression lhs, final Expression rhs ) if ( !lhs.type.equals( rhs.type ) ) { throw new IllegalArgumentException( - String.format( "Cannot subtract variables with different types. LHS %s, RHS %s", lhs.type.simpleName(), - rhs.type.simpleName() )); + String.format( + "Cannot subtract variables with different types. LHS %s, RHS %s", + lhs.type.simpleName(), + rhs.type.simpleName() ) ); } return new Expression( lhs.type ) { @@ -183,8 +408,10 @@ public static Expression multiply( final Expression lhs, final Expression rhs ) if ( !lhs.type.equals( rhs.type ) ) { throw new IllegalArgumentException( - String.format( "Cannot multiply variables with different types. LHS %s, RHS %s", lhs.type.simpleName(), - rhs.type.simpleName() )); + String.format( + "Cannot multiply variables with different types. LHS %s, RHS %s", + lhs.type.simpleName(), + rhs.type.simpleName() ) ); } return new Expression( lhs.type ) { @@ -198,11 +425,10 @@ public void accept( ExpressionVisitor visitor ) public static Expression constant( final Object value ) { - TypeReference reference; if ( value == null ) { - reference = OBJECT; + return NULL; } else if ( value instanceof String ) { @@ -222,7 +448,7 @@ else if ( value instanceof Double ) } else if ( value instanceof Boolean ) { - reference = BOOLEAN; + return (Boolean) value ? TRUE : FALSE; } else { @@ -239,6 +465,23 @@ public void accept( ExpressionVisitor visitor ) }; } + private static class Constant extends Expression + { + private final Object value; + + Constant( TypeReference type, Object value ) + { + super( type ); + this.value = value; + } + + @Override + public void accept( ExpressionVisitor visitor ) + { + visitor.constant( value ); + } + } + //TODO deduce type from constants public static Expression newArray( TypeReference baseType, Expression... constants ) { @@ -266,7 +509,7 @@ public void accept( ExpressionVisitor visitor ) } /** box expression */ - public static Expression box( final Expression expression) + public static Expression box( final Expression expression ) { TypeReference type; switch ( expression.type.simpleName() ) @@ -303,13 +546,13 @@ public static Expression box( final Expression expression) @Override public void accept( ExpressionVisitor visitor ) { - visitor.box(expression); + visitor.box( expression ); } }; } /** unbox expression */ - public static Expression unbox( final Expression expression) + public static Expression unbox( final Expression expression ) { TypeReference type; switch ( expression.type.name() ) @@ -346,13 +589,13 @@ public static Expression unbox( final Expression expression) @Override public void accept( ExpressionVisitor visitor ) { - visitor.unbox(expression); + visitor.unbox( expression ); } }; } /** get static field */ - public static Expression getStatic(final FieldReference field ) + public static Expression getStatic( final FieldReference field ) { return new Expression( field.type() ) { @@ -364,33 +607,6 @@ public void accept( ExpressionVisitor visitor ) }; } - public static Expression ternaryOnNull( final Expression test, final Expression onTrue, final Expression onFalse ) - { - TypeReference reference = onTrue.type.equals( onFalse.type ) ? onTrue.type : OBJECT; - return new Expression( reference ) - { - @Override - public void accept( ExpressionVisitor visitor ) - { - visitor.ternaryOnNull( test, onTrue, onFalse ); - } - }; - } - - public static Expression ternaryOnNonNull( final Expression test, final Expression onTrue, - final Expression onFalse ) - { - TypeReference reference = onTrue.type.equals( onFalse.type ) ? onTrue.type : OBJECT; - return new Expression( reference ) - { - @Override - public void accept( ExpressionVisitor visitor ) - { - visitor.ternaryOnNonNull( test, onTrue, onFalse ); - } - }; - } - public static Expression ternary( final Expression test, final Expression onTrue, final Expression onFalse ) { TypeReference reference = onTrue.type.equals( onFalse.type ) ? onTrue.type : OBJECT; @@ -404,7 +620,8 @@ public void accept( ExpressionVisitor visitor ) }; } - public static Expression invoke( final Expression target, final MethodReference method, + public static Expression invoke( + final Expression target, final MethodReference method, final Expression... arguments ) { return new Expression( method.returns() ) @@ -465,6 +682,17 @@ public void accept( ExpressionVisitor visitor ) public static Expression not( final Expression expression ) { + return expression.not(); + } + + Expression not() + { + return notExpr( this ); + } + + private static Expression notExpr( final Expression expression ) + { + assert expression.type == BOOLEAN : "Can only apply not() to boolean expressions"; return new Expression( BOOLEAN ) { @Override @@ -472,6 +700,12 @@ public void accept( ExpressionVisitor visitor ) { visitor.not( expression ); } + + @Override + Expression not() + { + return expression; + } }; } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java index e381fda97327..7daed5752830 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java @@ -132,21 +132,6 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) result.append( "}" ); } - @Override - public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) - { - ternary( Expression.equal( test, Expression.constant( null ) ), - onTrue, onFalse ); - } - - @Override - public void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ) - { - ternary( Expression.not( - Expression.equal( test, Expression.constant( null ) ) ), - onTrue, onFalse ); - } - @Override public void equal( Expression lhs, Expression rhs ) { @@ -158,9 +143,9 @@ public void equal( Expression lhs, Expression rhs ) } @Override - public void or( Expression lhs, Expression rhs ) + public void notEqual( Expression lhs, Expression rhs ) { - result.append( "or(" ); + result.append( "notEqual(" ); lhs.accept( this ); result.append( ", " ); rhs.accept( this ); @@ -168,12 +153,41 @@ public void or( Expression lhs, Expression rhs ) } @Override - public void and( Expression lhs, Expression rhs ) + public void isNull( Expression expression ) { - result.append( "and(" ); - lhs.accept( this ); - result.append( ", " ); - rhs.accept( this ); + result.append( "isNull(" ); + expression.accept( this ); + result.append( ")" ); + } + + @Override + public void notNull( Expression expression ) + { + result.append( "notNull(" ); + expression.accept( this ); + result.append( ")" ); + } + + @Override + public void or( Expression... expressions ) + { + boolOp( "or(", expressions ); + } + + @Override + public void and( Expression... expressions ) + { + boolOp( "and(", expressions ); + } + + private void boolOp( String sep, Expression[] expressions ) + { + for ( Expression expression : expressions ) + { + result.append( sep ); + expression.accept( this ); + sep = ", "; + } result.append( ")" ); } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java index d59e0da35a9c..1e5757809468 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java @@ -41,15 +41,17 @@ public interface ExpressionVisitor void ternary( Expression test, Expression onTrue, Expression onFalse ); - void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ); + void equal( Expression lhs, Expression rhs); - void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ); + void notEqual( Expression lhs, Expression rhs ); - void equal( Expression lhs, Expression rhs); + void isNull( Expression expression ); + + void notNull( Expression expression ); - void or( Expression lhs, Expression rhs ); + void or( Expression... expressions ); - void and( Expression lhs, Expression rhs ); + void and( Expression... expressions ); void add( Expression lhs, Expression rhs ); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/InvalidState.java b/community/codegen/src/main/java/org/neo4j/codegen/InvalidState.java index 3e75d2078c09..bd26ccdae1af 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/InvalidState.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/InvalidState.java @@ -89,31 +89,13 @@ public void assign( LocalVariable variable, Expression value ) } @Override - public void beginWhile( Expression...tests ) + public void beginWhile( Expression test ) { throw new IllegalStateException( reason ); } @Override - public void beginIf( Expression...tests ) - { - throw new IllegalStateException( reason ); - } - - @Override - public void beginIfNot( Expression...tests ) - { - throw new IllegalStateException( reason ); - } - - @Override - public void beginIfNull( Expression...tests ) - { - throw new IllegalStateException( reason ); - } - - @Override - public void beginIfNonNull( Expression...tests ) + public void beginIf( Expression test ) { throw new IllegalStateException( reason ); } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/MethodEmitter.java b/community/codegen/src/main/java/org/neo4j/codegen/MethodEmitter.java index 458a61352ff9..ea001b707d78 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/MethodEmitter.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/MethodEmitter.java @@ -35,29 +35,9 @@ public interface MethodEmitter void assign( LocalVariable local, Expression value ); - /** - * Begin a while block, - * - * while (tests[0] && tests[1]...) - * ... - * - */ - void beginWhile( Expression...tests ); + void beginWhile( Expression test ); - /** - * Begin an if block, - * - * if (tests[0] && tests[1]...) - * ... - * - */ - void beginIf( Expression... tests ); - - void beginIfNot( Expression...tests ); - - void beginIfNull( Expression...tests ); - - void beginIfNonNull( Expression...tests ); + void beginIf( Expression test ); void beginBlock(); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/TypeReference.java b/community/codegen/src/main/java/org/neo4j/codegen/TypeReference.java index d9ba4e1d33f3..62365e985bb2 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/TypeReference.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/TypeReference.java @@ -69,6 +69,19 @@ public static TypeReference typeReference( Class type ) else if (type.isPrimitive()) { simpleName = type.getName(); + switch ( simpleName ) + { + case "boolean": + return BOOLEAN; + case "int": + return INT; + case "long": + return LONG; + case "double": + return DOUBLE; + default: + // continue through the normal path + } } else { diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java index 3e3082dc197d..d6e244e8ed14 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java @@ -19,11 +19,11 @@ */ package org.neo4j.codegen.bytecode; +import java.lang.reflect.Modifier; + import org.objectweb.asm.Label; import org.objectweb.asm.MethodVisitor; -import java.lang.reflect.Modifier; - import org.neo4j.codegen.Expression; import org.neo4j.codegen.ExpressionVisitor; import org.neo4j.codegen.FieldReference; @@ -72,7 +72,9 @@ import static org.objectweb.asm.Opcodes.IFNE; import static org.objectweb.asm.Opcodes.IFNONNULL; import static org.objectweb.asm.Opcodes.IFNULL; +import static org.objectweb.asm.Opcodes.IF_ACMPEQ; import static org.objectweb.asm.Opcodes.IF_ACMPNE; +import static org.objectweb.asm.Opcodes.IF_ICMPEQ; import static org.objectweb.asm.Opcodes.IF_ICMPGE; import static org.objectweb.asm.Opcodes.IF_ICMPGT; import static org.objectweb.asm.Opcodes.IF_ICMPLE; @@ -258,37 +260,63 @@ public void newInstance( TypeReference type ) @Override public void not( Expression expression ) { - expression.accept( this ); - Label l0 = new Label(); - methodVisitor.visitJumpInsn( IFNE, l0 ); - methodVisitor.visitInsn( ICONST_1 ); - Label l1 = new Label(); - methodVisitor.visitJumpInsn( GOTO, l1 ); - methodVisitor.visitLabel( l0 ); - methodVisitor.visitInsn( ICONST_0 ); - methodVisitor.visitLabel( l1 ); + test( IFNE, expression, Expression.TRUE, Expression.FALSE ); } @Override - public void ternary( Expression test, Expression onTrue, Expression onFalse ) + public void isNull( Expression expression ) { - ternaryExpression( IFEQ, test, onTrue, onFalse ); + test( IFNONNULL, expression, Expression.TRUE, Expression.FALSE ); } @Override - public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) + public void notNull( Expression expression ) { - ternaryExpression( IFNONNULL, test, onTrue, onFalse ); + test( IFNULL, expression, Expression.TRUE, Expression.FALSE ); } @Override + public void ternary( Expression test, Expression onTrue, Expression onFalse ) + { + test( IFEQ, test, onTrue, onFalse ); + } + + public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) + { + test( IFNONNULL, test, onTrue, onFalse ); + } + public void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ) { - ternaryExpression( IFNULL, test, onTrue, onFalse ); + test( IFNULL, test, onTrue, onFalse ); + } + + private void test( int test, Expression predicate, Expression onTrue, Expression onFalse ) + { + predicate.accept( this ); + Label isFalse = new Label(); + methodVisitor.visitJumpInsn( test, isFalse ); + onTrue.accept( this ); + Label after = new Label(); + methodVisitor.visitJumpInsn( GOTO, after ); + methodVisitor.visitLabel( isFalse ); + onFalse.accept( this ); + methodVisitor.visitLabel( after ); } @Override public void equal( Expression lhs, Expression rhs ) + { + equal( lhs, rhs, true ); + } + + @Override + public void notEqual( Expression lhs, Expression rhs ) + { + equal( lhs, rhs, false ); + } + + private void equal( Expression lhs, Expression rhs, boolean equal ) { assertSameType( lhs, rhs, "compare" ); switch ( lhs.type().simpleName() ) @@ -298,25 +326,27 @@ public void equal( Expression lhs, Expression rhs ) case "short": case "char": case "boolean": - compareIntOrReferenceType( lhs, rhs, IF_ICMPNE ); + compareIntOrReferenceType( lhs, rhs, equal ? IF_ICMPNE : IF_ICMPEQ ); break; case "long": - compareLongOrFloatType( lhs, rhs, LCMP, IFNE ); + compareLongOrFloatType( lhs, rhs, LCMP, equal ? IFNE : IFEQ ); break; case "float": - compareLongOrFloatType( lhs, rhs, FCMPL, IFNE ); + compareLongOrFloatType( lhs, rhs, FCMPL, equal ? IFNE : IFEQ ); break; case "double": - compareLongOrFloatType( lhs, rhs, DCMPL, IFNE ); + compareLongOrFloatType( lhs, rhs, DCMPL, equal ? IFNE : IFEQ ); break; default: - compareIntOrReferenceType( lhs, rhs, IF_ACMPNE ); + compareIntOrReferenceType( lhs, rhs, equal ? IF_ACMPNE : IF_ACMPEQ ); } } @Override - public void or( Expression lhs, Expression rhs ) + public void or( Expression... expressions ) { + assert expressions.length == 2 : "only supports or(lhs, rhs)"; + Expression lhs = expressions[0], rhs = expressions[1]; /* * something like: * @@ -346,12 +376,13 @@ public void or( Expression lhs, Expression rhs ) methodVisitor.visitLabel( l1 ); methodVisitor.visitInsn( ICONST_0 ); methodVisitor.visitLabel( l2 ); - } @Override - public void and( Expression lhs, Expression rhs ) + public void and( Expression... expressions ) { + assert expressions.length == 2 : "only supports and(lhs, rhs)"; + Expression lhs = expressions[0], rhs = expressions[1]; /* * something like: * @@ -723,19 +754,6 @@ private void arrayStore( TypeReference reference ) } } - private void ternaryExpression( int op, Expression test, Expression onTrue, Expression onFalse ) - { - test.accept( this ); - Label l0 = new Label(); - methodVisitor.visitJumpInsn( op, l0 ); - onTrue.accept( this ); - Label l1 = new Label(); - methodVisitor.visitJumpInsn( GOTO, l1 ); - methodVisitor.visitLabel( l0 ); - onFalse.accept( this ); - methodVisitor.visitLabel( l1 ); - } - private void numberOperation( TypeReference type, Runnable onInt, Runnable onLong, Runnable onFloat, Runnable onDouble ) { diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/If.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/If.java index e4eab7b25a16..4885fdb4286a 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/If.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/If.java @@ -26,17 +26,17 @@ public class If implements Block { private final MethodVisitor methodVisitor; - private final Label l0; + private final Label after; - public If( MethodVisitor methodVisitor, Label l0) + public If( MethodVisitor methodVisitor, Label after ) { this.methodVisitor = methodVisitor; - this.l0 = l0; + this.after = after; } @Override public void endBlock() { - methodVisitor.visitLabel( l0 ); + methodVisitor.visitLabel( after ); } } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/JumpVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/JumpVisitor.java new file mode 100644 index 000000000000..613179eac24c --- /dev/null +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/JumpVisitor.java @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.codegen.bytecode; + +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; + +import org.neo4j.codegen.Expression; +import org.neo4j.codegen.ExpressionVisitor; +import org.neo4j.codegen.FieldReference; +import org.neo4j.codegen.LocalVariable; +import org.neo4j.codegen.MethodReference; +import org.neo4j.codegen.TypeReference; + +import static org.objectweb.asm.Opcodes.IFEQ; +import static org.objectweb.asm.Opcodes.IFNE; +import static org.objectweb.asm.Opcodes.IFNONNULL; +import static org.objectweb.asm.Opcodes.IFNULL; + +class JumpVisitor implements ExpressionVisitor +{ + private final ExpressionVisitor eval; + private final MethodVisitor methodVisitor; + private final Label target; + + JumpVisitor( ExpressionVisitor eval, MethodVisitor methodVisitor, Label target ) + { + this.eval = eval; + this.methodVisitor = methodVisitor; + this.target = target; + } + + @Override + public void invoke( Expression target, MethodReference method, Expression[] arguments ) + { + eval.invoke( target, method, arguments ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void invoke( MethodReference method, Expression[] arguments ) + { + eval.invoke( method, arguments ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void load( LocalVariable variable ) + { + eval.load( variable ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void getField( Expression target, FieldReference field ) + { + eval.getField( target, field ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void constant( Object value ) + { + eval.constant( value ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void getStatic( FieldReference field ) + { + eval.getStatic( field ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void not( Expression expression ) + { + expression.accept( eval ); + methodVisitor.visitJumpInsn( IFNE, this.target ); + } + + @Override + public void ternary( Expression test, Expression onTrue, Expression onFalse ) + { + eval.ternary( test, onTrue, onFalse ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void equal( Expression lhs, Expression rhs ) + { + eval.equal( lhs, rhs ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void notEqual( Expression lhs, Expression rhs ) + { + eval.equal( lhs, rhs ); + methodVisitor.visitJumpInsn( IFNE, this.target ); + } + + @Override + public void isNull( Expression expression ) + { + expression.accept( eval ); + methodVisitor.visitJumpInsn( IFNONNULL, this.target ); + } + + @Override + public void notNull( Expression expression ) + { + expression.accept( eval ); + methodVisitor.visitJumpInsn( IFNULL, this.target ); + } + + @Override + public void or( Expression... expressions ) + { + eval.or( expressions ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void and( Expression... expressions ) + { + for ( Expression expression : expressions ) + { + expression.accept( this ); + } + } + + @Override + public void gt( Expression lhs, Expression rhs ) + { + eval.gt( lhs, rhs ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void gte( Expression lhs, Expression rhs ) + { + eval.gte( lhs, rhs ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void lt( Expression lhs, Expression rhs ) + { + eval.lt( lhs, rhs ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void lte( Expression lhs, Expression rhs ) + { + eval.lte( lhs, rhs ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void unbox( Expression expression ) + { + eval.unbox( expression ); + methodVisitor.visitJumpInsn( IFEQ, this.target ); + } + + @Override + public void loadThis( String sourceName ) + { + throw new IllegalArgumentException( "'" + sourceName + "' is not a boolean expression" ); + } + + @Override + public void newInstance( TypeReference type ) + { + throw new IllegalArgumentException( "'new' is not a boolean expression" ); + } + + @Override + public void add( Expression lhs, Expression rhs ) + { + throw new IllegalArgumentException( "'+' is not a boolean expression" ); + } + + @Override + public void subtract( Expression lhs, Expression rhs ) + { + throw new IllegalArgumentException( "'-' is not a boolean expression" ); + } + + @Override + public void multiply( Expression lhs, Expression rhs ) + { + throw new IllegalArgumentException( "'*' is not a boolean expression" ); + } + + @Override + public void cast( TypeReference type, Expression expression ) + { + throw new IllegalArgumentException( "cast is not a boolean expression" ); + } + + @Override + public void newArray( TypeReference type, Expression... constants ) + { + throw new IllegalArgumentException( "'new' (array) is not a boolean expression" ); + } + + @Override + public void longToDouble( Expression expression ) + { + throw new IllegalArgumentException( "cast is not a boolean expression" ); + } + + @Override + public void pop( Expression expression ) + { + throw new IllegalArgumentException( "pop is not a boolean expression" ); + } + + @Override + public void box( Expression expression ) + { + throw new IllegalArgumentException( "box is not a boolean expression" ); + } +} diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/MethodByteCodeEmitter.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/MethodByteCodeEmitter.java index 968defd2e9fa..1e73bfb959af 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/MethodByteCodeEmitter.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/MethodByteCodeEmitter.java @@ -19,15 +19,16 @@ */ package org.neo4j.codegen.bytecode; -import org.objectweb.asm.ClassWriter; -import org.objectweb.asm.Label; -import org.objectweb.asm.MethodVisitor; - import java.util.Deque; import java.util.LinkedList; import java.util.function.Consumer; +import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; + import org.neo4j.codegen.Expression; +import org.neo4j.codegen.ExpressionVisitor; import org.neo4j.codegen.FieldReference; import org.neo4j.codegen.LocalVariable; import org.neo4j.codegen.MethodDeclaration; @@ -51,9 +52,6 @@ import static org.objectweb.asm.Opcodes.FSTORE; import static org.objectweb.asm.Opcodes.GOTO; import static org.objectweb.asm.Opcodes.IFEQ; -import static org.objectweb.asm.Opcodes.IFNE; -import static org.objectweb.asm.Opcodes.IFNONNULL; -import static org.objectweb.asm.Opcodes.IFNULL; import static org.objectweb.asm.Opcodes.IRETURN; import static org.objectweb.asm.Opcodes.ISTORE; import static org.objectweb.asm.Opcodes.LRETURN; @@ -65,7 +63,7 @@ class MethodByteCodeEmitter implements MethodEmitter { private final MethodVisitor methodVisitor; private final MethodDeclaration declaration; - private final ByteCodeExpressionVisitor expressionVisitor; + private final ExpressionVisitor expressionVisitor; private final TypeReference base; private Deque stateStack = new LinkedList<>(); @@ -171,42 +169,21 @@ public void assign( LocalVariable variable, Expression value ) } @Override - public void beginWhile( Expression...tests ) + public void beginWhile( Expression test ) { - Label l0 = new Label(); - methodVisitor.visitLabel( l0 ); - Label l1 = new Label(); - for ( Expression test : tests ) - { - test.accept( expressionVisitor ); - methodVisitor.visitJumpInsn( IFEQ, l1 ); - } + Label repeat = new Label(), done = new Label(); + methodVisitor.visitLabel( repeat ); + test.accept( new JumpVisitor( expressionVisitor, methodVisitor, done ) ); - stateStack.push( new While( methodVisitor, l0, l1 ) ); - } - - @Override - public void beginIf( Expression... tests ) - { - beginConditional( IFEQ, tests ); + stateStack.push( new While( methodVisitor, repeat, done ) ); } @Override - public void beginIfNot( Expression... tests ) + public void beginIf( Expression test ) { - beginConditional( IFNE, tests ); - } - - @Override - public void beginIfNull( Expression...tests ) - { - beginConditional( IFNONNULL, tests ); - } - - @Override - public void beginIfNonNull( Expression...tests ) - { - beginConditional( IFNULL, tests ); + Label after = new Label(); + test.accept( new JumpVisitor( expressionVisitor, methodVisitor, after ) ); + stateStack.push( new If( methodVisitor, after ) ); } @Override @@ -228,22 +205,22 @@ public void endBlock() @Override public void tryCatchBlock( Consumer body, Consumer handler, LocalVariable exception, T block ) { - Label l0 = new Label(); - Label l1 = new Label(); - Label l2 = new Label(); - methodVisitor.visitTryCatchBlock( l0, l1, l2, + Label start = new Label(); + Label end = new Label(); + Label handle = new Label(); + Label after = new Label(); + methodVisitor.visitTryCatchBlock( start, end, handle, byteCodeName( exception.type() ) ); - methodVisitor.visitLabel( l0 ); + methodVisitor.visitLabel( start ); body.accept( block ); - methodVisitor.visitLabel( l1 ); - Label l3 = new Label(); - methodVisitor.visitJumpInsn( GOTO, l3 ); + methodVisitor.visitLabel( end ); + methodVisitor.visitJumpInsn( GOTO, after ); //handle catch - methodVisitor.visitLabel( l2 ); + methodVisitor.visitLabel( handle ); methodVisitor.visitVarInsn( ASTORE, exception.index() ); handler.accept( block ); - methodVisitor.visitLabel( l3 ); + methodVisitor.visitLabel( after ); } @Override @@ -265,15 +242,4 @@ public void assignVariableInScope( LocalVariable local, Expression value ) //these are equivalent when it comes to bytecode assign( local, value ); } - - private void beginConditional(int op, Expression[] tests) - { - Label l0 = new Label(); - for ( Expression test : tests ) - { - test.accept( expressionVisitor ); - methodVisitor.visitJumpInsn( op, l0 ); - } - stateStack.push(new If(methodVisitor, l0)); - } } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/While.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/While.java index b1d123611a31..cebd2d6c7229 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/While.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/While.java @@ -28,20 +28,20 @@ public class While implements Block { private final MethodVisitor methodVisitor; - private final Label l0; - private final Label l1; + private final Label repeat; + private final Label done; - public While( MethodVisitor methodVisitor, Label l0, Label l1 ) + public While( MethodVisitor methodVisitor, Label repeat, Label done ) { this.methodVisitor = methodVisitor; - this.l0 = l0; - this.l1 = l1; + this.repeat = repeat; + this.done = done; } @Override public void endBlock() { - methodVisitor.visitJumpInsn( GOTO, l0 ); - methodVisitor.visitLabel( l1 ); + methodVisitor.visitJumpInsn( GOTO, repeat ); + methodVisitor.visitLabel( done ); } } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java index ab229c0afa4c..7c6c3ddbc326 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java @@ -133,70 +133,25 @@ public void assign( LocalVariable variable, Expression value ) } @Override - public void beginWhile( Expression...tests ) + public void beginWhile( Expression test ) { indent().append( "while( " ); - String sep = ""; - for (Expression test: tests) - { - append( sep ); - test.accept( this ); - sep = " && "; - } + test.accept( this ); append( " )\n" ); indent().append( "{\n" ); level.push( LEVEL ); } @Override - public void beginIf( Expression...tests ) + public void beginIf( Expression test ) { indent().append( "if ( " ); - String sep = ""; - for (Expression test: tests) - { - append( sep ); - test.accept( this ); - sep = " && "; - } + test.accept( this ); append( " )\n" ); indent().append( "{\n" ); level.push( LEVEL ); } - @Override - public void beginIfNot( Expression...tests ) - { - Expression[] nots = new Expression[tests.length]; - for ( int i = 0; i < tests.length; i++ ) - { - nots[i] = Expression.not( tests[i] ); - } - beginIf( nots ); - } - - @Override - public void beginIfNull( Expression...tests ) - { - Expression[] nulls = new Expression[tests.length]; - for ( int i = 0; i < tests.length; i++ ) - { - nulls[i] = Expression.equal(tests[i], Expression.constant( null ) ); - } - beginIf(nulls); - } - - @Override - public void beginIfNonNull( Expression...tests ) - { - Expression[] notNulls = new Expression[tests.length]; - for ( int i = 0; i < tests.length; i++ ) - { - notNulls[i] = Expression.not(Expression.equal(tests[i], Expression.constant( null ) )); - } - beginIf(notNulls); - } - @Override public void beginBlock() { @@ -361,37 +316,52 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) } @Override - public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) + public void equal( Expression lhs, Expression rhs ) { - ternary( Expression.equal( test, Expression.constant( null ) ), - onTrue, onFalse ); + binaryOperation( lhs, rhs, " == " ); } @Override - public void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ) + public void notEqual( Expression lhs, Expression rhs ) { - ternary( Expression.not( - Expression.equal( test, Expression.constant( null ) )), - onTrue, onFalse ); + binaryOperation( lhs, rhs, " != " ); } @Override - public void equal( Expression lhs, Expression rhs ) + public void isNull( Expression expression ) { - binaryOperation( lhs, rhs, " == " ); + expression.accept( this ); + append( " == null" ); } @Override - public void or( Expression lhs, Expression rhs ) + public void notNull( Expression expression ) { - binaryOperation( lhs, rhs, " || " ); + expression.accept( this ); + append( " != null" ); + } + + @Override + public void or( Expression... expressions ) + { + boolOp( expressions, " || "); } @Override - public void and( Expression lhs, Expression rhs ) + public void and( Expression... expressions ) { - binaryOperation( lhs, rhs, " && " ); + boolOp( expressions, " && "); + } + private void boolOp( Expression[] expressions, String op ) + { + String sep = ""; + for ( Expression expression : expressions ) + { + append( sep ); + expression.accept( this ); + sep = op; + } } @Override diff --git a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java index 49fa39c8fc41..9d0ff012b352 100644 --- a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java +++ b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java @@ -19,12 +19,6 @@ */ package org.neo4j.codegen; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.mockito.InOrder; - import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -34,6 +28,13 @@ import java.util.Iterator; import java.util.List; import java.util.function.BiFunction; +import java.util.function.Function; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.InOrder; import org.neo4j.codegen.source.Configuration; import org.neo4j.codegen.source.SourceCode; @@ -59,11 +60,14 @@ import static org.neo4j.codegen.Expression.add; import static org.neo4j.codegen.Expression.and; import static org.neo4j.codegen.Expression.constant; +import static org.neo4j.codegen.Expression.equal; import static org.neo4j.codegen.Expression.invoke; +import static org.neo4j.codegen.Expression.isNull; import static org.neo4j.codegen.Expression.multiply; import static org.neo4j.codegen.Expression.newArray; import static org.neo4j.codegen.Expression.newInstance; import static org.neo4j.codegen.Expression.not; +import static org.neo4j.codegen.Expression.notNull; import static org.neo4j.codegen.Expression.or; import static org.neo4j.codegen.Expression.subtract; import static org.neo4j.codegen.Expression.ternary; @@ -491,7 +495,7 @@ public void shouldGenerateWhileLoopWithMultipleTestExpressions() throws Throwabl try ( CodeBlock callEach = simple.generateMethod( void.class, "check", param( boolean.class, "a" ), param( boolean.class, "b" ), param( Runnable.class, "runner" ) ) ) { - try ( CodeBlock loop = callEach.whileLoop( callEach.load( "a" ), callEach.load( "b" ) ) ) + try ( CodeBlock loop = callEach.whileLoop( and( callEach.load( "a" ), callEach.load( "b" ) ) ) ) { loop.expression( invoke( loop.load( "runner" ), @@ -639,18 +643,16 @@ public void shouldGenerateIfStatement() throws Throwable } @Test - public void shouldGenerateIfWithMultipleTestsStatement() throws Throwable + public void shouldGenerateIfEqualsStatement() throws Throwable { // given ClassHandle handle; try ( ClassGenerator simple = generateClass( "SimpleClass" ) ) { try ( CodeBlock conditional = simple.generateMethod( void.class, "conditional", - param( boolean.class, "test1" ), param( boolean.class, "test2" ), - param( Runnable.class, "runner" ) ) ) + param( Object.class, "lhs" ), param( Object.class, "rhs" ), param( Runnable.class, "runner" ) ) ) { - try ( CodeBlock doStuff = conditional - .ifStatement( conditional.load( "test1" ), conditional.load( "test2" ) ) ) + try ( CodeBlock doStuff = conditional.ifStatement( equal( conditional.load( "lhs" ), conditional.load( "rhs" ) ) ) ) { doStuff.expression( invoke( doStuff.load( "runner" ), RUN ) ); @@ -662,35 +664,29 @@ public void shouldGenerateIfWithMultipleTestsStatement() throws Throwable Runnable runner1 = mock( Runnable.class ); Runnable runner2 = mock( Runnable.class ); - Runnable runner3 = mock( Runnable.class ); - Runnable runner4 = mock( Runnable.class ); + Object a = "a", b = "b"; // when - MethodHandle conditional = - instanceMethod( handle.newInstance(), "conditional", boolean.class, boolean.class, Runnable.class ); - conditional.invoke( true, true, runner1 ); - conditional.invoke( false, true, runner2 ); - conditional.invoke( true, false, runner3 ); - conditional.invoke( false, false, runner4 ); + MethodHandle conditional = instanceMethod( handle.newInstance(), "conditional", Object.class, Object.class, Runnable.class ); + conditional.invoke( a, b, runner1 ); + conditional.invoke( a, a, runner2 ); // then - verify( runner1 ).run(); - verifyZeroInteractions( runner2 ); - verifyZeroInteractions( runner3 ); - verifyZeroInteractions( runner4 ); + verify( runner2 ).run(); + verifyZeroInteractions( runner1 ); } @Test - public void shouldGenerateIfNotExpressionStatement() throws Throwable + public void shouldGenerateIfNotEqualsStatement() throws Throwable { // given ClassHandle handle; try ( ClassGenerator simple = generateClass( "SimpleClass" ) ) { try ( CodeBlock conditional = simple.generateMethod( void.class, "conditional", - param( boolean.class, "test" ), param( Runnable.class, "runner" ) ) ) + param( Object.class, "lhs" ), param( Object.class, "rhs" ), param( Runnable.class, "runner" ) ) ) { - try ( CodeBlock doStuff = conditional.ifStatement( not( conditional.load( "test" ) ) ) ) + try ( CodeBlock doStuff = conditional.ifStatement( not( equal( conditional.load( "lhs" ), conditional.load( "rhs" ) ) ) ) ) { doStuff.expression( invoke( doStuff.load( "runner" ), RUN ) ); @@ -702,11 +698,12 @@ public void shouldGenerateIfNotExpressionStatement() throws Throwable Runnable runner1 = mock( Runnable.class ); Runnable runner2 = mock( Runnable.class ); + Object a = "a", b = "b"; // when - MethodHandle conditional = instanceMethod( handle.newInstance(), "conditional", boolean.class, Runnable.class ); - conditional.invoke( true, runner1 ); - conditional.invoke( false, runner2 ); + MethodHandle conditional = instanceMethod( handle.newInstance(), "conditional", Object.class, Object.class, Runnable.class ); + conditional.invoke( a, a, runner1 ); + conditional.invoke( a, b, runner2 ); // then verify( runner2 ).run(); @@ -714,7 +711,7 @@ public void shouldGenerateIfNotExpressionStatement() throws Throwable } @Test - public void shouldGenerateIfNotStatement() throws Throwable + public void shouldGenerateIfNotExpressionStatement() throws Throwable { // given ClassHandle handle; @@ -723,7 +720,7 @@ public void shouldGenerateIfNotStatement() throws Throwable try ( CodeBlock conditional = simple.generateMethod( void.class, "conditional", param( boolean.class, "test" ), param( Runnable.class, "runner" ) ) ) { - try ( CodeBlock doStuff = conditional.ifNotStatement( conditional.load( "test" ) ) ) + try ( CodeBlock doStuff = conditional.ifStatement( not( conditional.load( "test" ) ) ) ) { doStuff.expression( invoke( doStuff.load( "runner" ), RUN ) ); @@ -756,7 +753,7 @@ public void shouldGenerateIfNullStatement() throws Throwable try ( CodeBlock conditional = simple.generateMethod( void.class, "conditional", param( Object.class, "test" ), param( Runnable.class, "runner" ) ) ) { - try ( CodeBlock doStuff = conditional.ifNullStatement( conditional.load( "test" ) ) ) + try ( CodeBlock doStuff = conditional.ifStatement( isNull(conditional.load( "test" ) ) ) ) { doStuff.expression( invoke( doStuff.load( "runner" ), RUN ) ); @@ -789,7 +786,7 @@ public void shouldGenerateIfNonNullStatement() throws Throwable try ( CodeBlock conditional = simple.generateMethod( void.class, "conditional", param( Object.class, "test" ), param( Runnable.class, "runner" ) ) ) { - try ( CodeBlock doStuff = conditional.ifNonNullStatement( conditional.load( "test" ) ) ) + try ( CodeBlock doStuff = conditional.ifStatement( notNull( conditional.load( "test" ) ) ) ) { doStuff.expression( invoke( doStuff.load( "runner" ), RUN ) ); @@ -1128,7 +1125,7 @@ public void shouldHandleTernaryOnNullOperator() throws Throwable param( Object.class, "test" ), param( TernaryChecker.class, "check" ) ) ) { ternaryBlock.returns( - Expression.ternaryOnNull( ternaryBlock.load( "test" ), + ternary( isNull( ternaryBlock.load( "test" ) ), invoke( ternaryBlock.load( "check" ), methodReference( TernaryChecker.class, String.class, "onTrue" ) ), invoke( ternaryBlock.load( "check" ), @@ -1165,7 +1162,7 @@ public void shouldHandleTernaryOnNonNullOperator() throws Throwable param( Object.class, "test" ), param( TernaryChecker.class, "check" ) ) ) { ternaryBlock.returns( - Expression.ternaryOnNonNull( ternaryBlock.load( "test" ), + ternary( notNull( ternaryBlock.load( "test" ) ), invoke( ternaryBlock.load( "check" ), methodReference( TernaryChecker.class, String.class, "onTrue" ) ), invoke( ternaryBlock.load( "check" ), @@ -1743,6 +1740,11 @@ private Object boxTest(Class unboxedType, T value) return instanceMethod( handle.newInstance(), "box", unboxedType ).invoke( value ); } + private MethodHandle conditional(Function test, Parameter... params) + { + throw new UnsupportedOperationException( "not implemented" ); + } + static MethodHandle method( Class target, String name, Class... parameters ) throws Exception { return MethodHandles.lookup().unreflect( target.getMethod( name, parameters ) ); diff --git a/community/codegen/src/test/java/org/neo4j/codegen/ExpressionTest.java b/community/codegen/src/test/java/org/neo4j/codegen/ExpressionTest.java new file mode 100644 index 000000000000..3ce8f37e267a --- /dev/null +++ b/community/codegen/src/test/java/org/neo4j/codegen/ExpressionTest.java @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.codegen; + +import org.junit.Test; + +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.neo4j.codegen.Expression.FALSE; +import static org.neo4j.codegen.Expression.NULL; +import static org.neo4j.codegen.Expression.TRUE; +import static org.neo4j.codegen.Expression.and; +import static org.neo4j.codegen.Expression.equal; +import static org.neo4j.codegen.Expression.gt; +import static org.neo4j.codegen.Expression.gte; +import static org.neo4j.codegen.Expression.invoke; +import static org.neo4j.codegen.Expression.lt; +import static org.neo4j.codegen.Expression.lte; +import static org.neo4j.codegen.Expression.not; +import static org.neo4j.codegen.Expression.notEqual; +import static org.neo4j.codegen.Expression.or; +import static org.neo4j.codegen.MethodReference.methodReference; + +public class ExpressionTest +{ + @Test + public void shouldNegateTrueToFalse() throws Exception + { + assertSame( FALSE, not( TRUE ) ); + assertSame( TRUE, not( FALSE ) ); + } + + @Test + public void shouldRemoveDoubleNegation() throws Exception + { + Expression expression = invoke( methodReference( getClass(), boolean.class, "TRUE" ) ); + assertSame( expression, not( not( expression ) ) ); + } + + @Test + public void shouldOptimizeNullChecks() throws Exception + { + // given + ExpressionVisitor visitor = mock( ExpressionVisitor.class ); + Expression expression = invoke( methodReference( getClass(), Object.class, "value" ) ); + + // when + equal( expression, NULL ).accept( visitor ); + + // then + verify( visitor ).isNull( expression ); + + reset( visitor ); // next + + // when + equal( NULL, expression ).accept( visitor ); + + // then + verify( visitor ).isNull( expression ); + + reset( visitor ); // next + + // when + not( equal( expression, NULL ) ).accept( visitor ); + + // then + verify( visitor ).notNull( expression ); + + reset( visitor ); // next + + // when + not( equal( NULL, expression ) ).accept( visitor ); + + // then + verify( visitor ).notNull( expression ); + } + + @Test + public void shouldOptimizeNegatedInequalities() throws Exception + { + // given + ExpressionVisitor visitor = mock( ExpressionVisitor.class ); + Expression expression = invoke( methodReference( getClass(), Object.class, "value" ) ); + + // when + not( gt( expression, expression ) ).accept( visitor ); + + // then + verify( visitor ).lte( expression, expression ); + + reset( visitor ); // next + + // when + not( gte( expression, expression ) ).accept( visitor ); + + // then + verify( visitor ).lt( expression, expression ); + + reset( visitor ); // next + + // when + not( lt( expression, expression ) ).accept( visitor ); + + // then + verify( visitor ).gte( expression, expression ); + + reset( visitor ); // next + + // when + not( lte( expression, expression ) ).accept( visitor ); + + // then + verify( visitor ).gt( expression, expression ); + + reset( visitor ); // next + + // when + not( equal( expression, expression ) ).accept( visitor ); + + // then + verify( visitor ).notEqual( expression, expression ); + + reset( visitor ); // next + + // when + not( notEqual( expression, expression ) ).accept( visitor ); + + // then + verify( visitor ).equal( expression, expression ); + } + + @Test + public void shouldOptimizeBooleanCombinationsWithConstants() throws Exception + { + // given + Expression expression = invoke( methodReference( getClass(), boolean.class, "TRUE" ) ); + + // then + assertSame( expression, and( expression, TRUE ) ); + assertSame( expression, and( TRUE, expression ) ); + assertSame( FALSE, and( expression, FALSE ) ); + assertSame( FALSE, and( FALSE, expression ) ); + + assertSame( expression, or( expression, FALSE ) ); + assertSame( expression, or( FALSE, expression ) ); + assertSame( TRUE, or( expression, TRUE ) ); + assertSame( TRUE, or( TRUE, expression ) ); + } + + public static boolean TRUE() + { + return true; + } + + public static Object value() + { + return null; + } +} diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala index 82698d2b474b..51a512a42ddc 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala @@ -161,7 +161,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A } override def ifNotStatement(test: Expression)(block: (MethodStructure[Expression]) => Unit) = { - using(generator.ifNotStatement(test)) { body => + using(generator.ifStatement(not(test))) { body => block(copy(generator = body)) } } @@ -208,9 +208,9 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def visitorAccept() = tryCatch(generator) { onSuccess => using( - onSuccess.ifNotStatement( + onSuccess.ifStatement(not( invoke(onSuccess.load("visitor"), - visit, onSuccess.load("row")))) { body => + visit, onSuccess.load("row"))))) { body => // NOTE: we are in this if-block if the visitor decided to terminate early (by returning false) //close all outstanding events for (event <- events) { @@ -238,7 +238,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A equal(nullValue(codeGenType), generator.load(varName)), nullValue(codeGenType), onSuccess) - case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) + case _ => ternary(Expression.isNull(generator.load(varName)), constant(null), onSuccess) } override def nullableReference(varName: String, codeGenType: CodeGenType, onSuccess: Expression) = codeGenType match { @@ -247,7 +247,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A equal(nullValue(codeGenType), generator.load(varName)), constant(null), onSuccess) - case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) + case _ => ternary(Expression.isNull(generator.load(varName)), constant(null), onSuccess) } override def materializeRelationship(relIdVar: String) = @@ -277,7 +277,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def expectParameter(key: String, variableName: String) = { using( - generator.ifNotStatement(invoke(params, mapContains, constant(key)))) { block => + generator.ifStatement(not(invoke(params, mapContains, constant(key))))) { block => block.throwException(parameterNotFoundException(key)) } generator.assign(typeRef[Object], variableName, invoke(loadParameter, @@ -489,7 +489,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A pop( invoke(generator.load(tableVar), countingTableCompositeKeyPut, generator.load(keyName), - ternaryOnNull(generator.load(countName), + ternary(Expression.isNull(generator.load(countName)), invoke(boxInteger, constant(1)), invoke(boxInteger, add( @@ -520,8 +520,8 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A newArray(typeRef[Long], keyVars.map(generator.load): _*))))) generator.assign(times, - ternaryOnNull( - intermediate, + ternary( + Expression.isNull(intermediate), constant(-1), invoke(intermediate, unboxInteger))) @@ -539,7 +539,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A val list = generator.declare(hashTable.listType, context.namer.newVarName()) val elementName = context.namer.newVarName() generator.assign(list, invoke(generator.load(tableVar), hashTable.get, generator.load(keyVar))) - using(generator.ifNonNullStatement(list)) { onTrue => + using(generator.ifStatement(Expression.notNull(list))) { onTrue => using(onTrue.forEach(Parameter.param(hashTable.valueType, elementName), list)) { forEach => localVars.foreach { case (l, f) => @@ -563,7 +563,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A newArray(typeRef[Long], keyVars.map(generator.load): _*)) ))) - using(generator.ifNonNullStatement(list)) { onTrue => + using(generator.ifStatement(Expression.notNull(list))) { onTrue => using(onTrue.forEach(Parameter.param(hashTable.valueType, elementName), list)) { forEach => localVars.foreach { case (l, f) => @@ -598,7 +598,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A invoke( generator.load(tableVar), hashTable.get, generator.load(keyVar)))) - using(generator.ifNullStatement(list)) { onTrue => // if (null == list) + using(generator.ifStatement(Expression.isNull(list))) { onTrue => // if (null == list) // list = new ListType(); onTrue.assign(list, createNewInstance(hashTable.listType)) onTrue.expression( @@ -626,7 +626,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A generator.assign(list, cast(hashTable.listType, invoke(generator.load(tableVar), hashTable.get, generator.load(keyName)))) - using(generator.ifNullStatement(generator.load(listName))) { onTrue => // if (null == list) + using(generator.ifStatement(Expression.isNull(generator.load(listName)))) { onTrue => // if (null == list) // list = new ListType(); onTrue.assign(list, createNewInstance(hashTable.listType)) // tableVar.put(keyVar, list); @@ -710,10 +710,10 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nodeIdSeek(nodeIdVar: String, expression: Expression)(block: MethodStructure[Expression] => Unit) = { generator.assign(typeRef[Long], nodeIdVar, invoke(Methods.mathCastToLong, expression)) - using(generator.ifStatement( + using(generator.ifStatement(and( gt(generator.load(nodeIdVar), constant(-1L)), invoke(readOperations, nodeExists, generator.load(nodeIdVar)) - )) { ifBody => + ))) { ifBody => block(copy(generator = ifBody)) } } diff --git a/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala b/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala index accd745d36ce..747746e6185b 100644 --- a/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala +++ b/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala @@ -189,7 +189,7 @@ class AuxGenerator(val packageName: String, val generator: CodeGenerator) { l2.assign(compareResult, Expression.invoke(method[java.lang.Double, Int]("compare", typeRef[Double], typeRef[Double]), lhs, rhs)) - using(l2.ifNotStatement(Expression.equal(compareResult, Expression.constant(0)))) { l3 => + using(l2.ifStatement(Expression.notEqual(compareResult, Expression.constant(0)))) { l3 => l3.returns(compareResult) } } @@ -206,7 +206,7 @@ class AuxGenerator(val packageName: String, val generator: CodeGenerator) { */ val (thisValueVariable, otherValueVariable) = assignComparatorVariablesFor(l2, reprType) - using(l2.ifNotStatement(Expression.equal(thisValueVariable, otherValueVariable))) { l3 => + using(l2.ifStatement(Expression.notEqual(thisValueVariable, otherValueVariable))) { l3 => l3.returns(Expression.ternary(thisValueVariable, Expression.constant(greaterThanSortResult(sortOrder)), Expression.constant(lessThanSortResult(sortOrder)))) @@ -225,7 +225,7 @@ class AuxGenerator(val packageName: String, val generator: CodeGenerator) { l2.assign(compareResult, Expression.invoke(method[CompiledOrderabilityUtils, Int]("compare", typeRef[Object], typeRef[Object]), lhs, rhs)) - using(l2.ifNotStatement(Expression.equal(compareResult, Expression.constant(0)))) { l3 => + using(l2.ifStatement(Expression.notEqual(compareResult, Expression.constant(0)))) { l3 => l3.returns(compareResult) } } diff --git a/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala b/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala index b502321befca..ee4d7aa4b07c 100644 --- a/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala +++ b/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala @@ -182,13 +182,13 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux } override def ifNotStatement(test: Expression)(block: (MethodStructure[Expression]) => Unit) = { - using(generator.ifNotStatement(test)) { body => + using(generator.ifStatement(not(test))) { body => block(copy(generator = body)) } } override def ifNonNullStatement(test: Expression)(block: (MethodStructure[Expression]) => Unit) = { - using(generator.ifNonNullStatement(test)) { body => + using(generator.ifStatement(Expression.notNull(test))) { body => block(copy(generator = body)) } } @@ -239,9 +239,9 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux override def visitorAccept() = tryCatch(generator) { onSuccess => using( - onSuccess.ifNotStatement( - invoke(onSuccess.load("visitor"), - visit, onSuccess.load("row")))) { body => + onSuccess.ifStatement( + not(invoke(onSuccess.load("visitor"), + visit, onSuccess.load("row"))))) { body => // NOTE: we are in this if-block if the visitor decided to terminate early (by returning false) //close all outstanding events for (event <- events) { @@ -278,7 +278,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux equal(nullValue(codeGenType), generator.load(varName)), nullValue(codeGenType), onSuccess) - case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) + case _ => ternary(Expression.isNull(generator.load(varName)), constant(null), onSuccess) } override def nullableReference(varName: String, codeGenType: CodeGenType, onSuccess: Expression) = codeGenType match { @@ -287,7 +287,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux equal(nullValue(codeGenType), generator.load(varName)), constant(null), onSuccess) - case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) + case _ => ternary(Expression.isNull(generator.load(varName)), constant(null), onSuccess) } override def materializeRelationship(relIdVar: String, codeGenType: CodeGenType) = @@ -326,7 +326,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux override def expectParameter(key: String, variableName: String, codeGenType: CodeGenType) = { using( - generator.ifNotStatement(invoke(params, mapContains, constant(key)))) { block => + generator.ifStatement(not(invoke(params, mapContains, constant(key))))) { block => block.throwException(parameterNotFoundException(key)) } val invokeLoadParameter = invoke(loadParameter, invoke(params, mapGet, constantExpression(key))) @@ -622,15 +622,15 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux (block: MethodStructure[Expression] => Unit) = { if (structure.size == 1 && structure.head._2._1.repr == IntType) { val (_, (_, value)) = structure.head - using(generator.ifNotStatement(invoke(generator.load(name), - method[PrimitiveLongSet, Boolean]("contains", typeRef[Long]), value))) { body => + using(generator.ifStatement(not(invoke(generator.load(name), + method[PrimitiveLongSet, Boolean]("contains", typeRef[Long]), value)))) { body => body.expression(pop(invoke(generator.load(name), method[PrimitiveLongSet, Boolean]("add", typeRef[Long]), value))) block(copy(generator = body)) } } else { val tmpName = context.namer.newVarName() newUniqueAggregationKey(tmpName, structure) - using(generator.ifNotStatement(invoke(generator.load(name), Methods.setContains, generator.load(tmpName)))) { body => + using(generator.ifStatement(not(invoke(generator.load(name), Methods.setContains, generator.load(tmpName))))) { body => body.expression(pop(invoke(loadVariable(name), Methods.setAdd, generator.load(tmpName)))) block(copy(generator = body)) } @@ -812,16 +812,16 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux keyExpression))) - using(generator.ifNullStatement(generator.load(tmp))) { inner => + using(generator.ifStatement(Expression.isNull(generator.load(tmp)))) { inner => inner.assign(localVariable, invoke(method[Primitive, PrimitiveLongSet]("longSet"))) inner.expression(pop(invoke(generator.load(name), method[PrimitiveLongObjectMap[Object], Object]("put", typeRef[Long], typeRef[Object]), keyExpression, inner.load(tmp)))) } - using(generator.ifNotStatement(invoke(generator.load(tmp), - method[PrimitiveLongSet, Boolean]("contains", typeRef[Long]), - value))) { inner => + using(generator.ifStatement(not(invoke(generator.load(tmp), + method[PrimitiveLongSet, Boolean]("contains", typeRef[Long]), + value)))) { inner => block(copy(generator = inner)) } generator.expression(pop(invoke(generator.load(tmp), @@ -833,16 +833,16 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux invoke(generator.load(name), method[PrimitiveLongObjectMap[Object], Object]("get", typeRef[Long]), keyExpression))) - using(generator.ifNullStatement(generator.load(tmp))) { inner => + using(generator.ifStatement(Expression.isNull(generator.load(tmp)))) { inner => inner.assign(localVariable, createNewInstance(typeRef[util.HashSet[Object]])) inner.expression(pop(invoke(generator.load(name), method[PrimitiveLongObjectMap[Object], Object]("put", typeRef[Long], typeRef[Object]), keyExpression, inner.load(tmp)))) } - using(generator.ifNotStatement(invoke(generator.load(tmp), - method[util.HashSet[Object], Boolean]("contains", typeRef[Object]), - value))) { inner => + using(generator.ifStatement(not(invoke(generator.load(tmp), + method[util.HashSet[Object], Boolean]("contains", typeRef[Object]), + value)))) { inner => block(copy(generator = inner)) } generator.expression(pop(invoke(generator.load(tmp), @@ -859,7 +859,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux invoke(generator.load(name), method[util.HashMap[Object, PrimitiveLongSet], Object]("get", typeRef[Object]), generator.load(keyVar)))) - using(generator.ifNullStatement(generator.load(setVar))) { inner => + using(generator.ifStatement(Expression.isNull(generator.load(setVar)))) { inner => inner.assign(localVariable, invoke(method[Primitive, PrimitiveLongSet]("longSet"))) inner.expression(pop(invoke(generator.load(name), @@ -868,9 +868,9 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux generator.load(keyVar), inner.load(setVar)))) } - using(generator.ifNotStatement(invoke(generator.load(setVar), - method[PrimitiveLongSet, Boolean]("contains", typeRef[Long]), - value))) { inner => + using(generator.ifStatement(not(invoke(generator.load(setVar), + method[PrimitiveLongSet, Boolean]("contains", typeRef[Long]), + value)))) { inner => block(copy(generator = inner)) inner.expression(pop(invoke(generator.load(setVar), method[PrimitiveLongSet, Boolean]("add", typeRef[Long]), @@ -885,7 +885,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux invoke(generator.load(name), method[util.HashMap[Object, util.HashSet[Object]], Object]("get", typeRef[Object]), generator.load(keyVar)))) - using(generator.ifNullStatement(generator.load(setVar))) { inner => + using(generator.ifStatement(Expression.isNull(generator.load(setVar)))) { inner => inner.assign(localVariable, createNewInstance(typeRef[util.HashSet[Object]])) inner.expression(pop(invoke(generator.load(name), @@ -896,9 +896,9 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux val valueVar = context.namer.newVarName() newUniqueAggregationKey(valueVar, Map(context.namer.newVarName() -> (valueType -> value))) - using(generator.ifNotStatement(invoke(generator.load(setVar), - method[util.HashSet[Object], Boolean]("contains", typeRef[Object]), - generator.load(valueVar)))) { inner => + using(generator.ifStatement(not(invoke(generator.load(setVar), + method[util.HashSet[Object], Boolean]("contains", typeRef[Object]), + generator.load(valueVar))))) { inner => block(copy(generator = inner)) inner.expression(pop(invoke(generator.load(setVar), method[util.HashSet[Object], Boolean]("add", typeRef[Object]), @@ -1064,7 +1064,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux pop( invoke(generator.load(tableVar), countingTableCompositeKeyPut, generator.load(keyName), - ternaryOnNull(generator.load(countName), + ternary(Expression.isNull(generator.load(countName)), box(constant(1)), box(add(invoke(generator.load(countName), unboxInteger), constant(1))))))) } @@ -1090,8 +1090,8 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux newArray(typeRef[Long], keyVars.map(generator.load): _*))))) generator.assign(times, - ternaryOnNull( - intermediate, + ternary( + Expression.isNull(intermediate), constant(-1), invoke(intermediate, unboxInteger))) @@ -1109,7 +1109,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux val list = generator.declare(hashTable.listType, context.namer.newVarName()) val elementName = context.namer.newVarName() generator.assign(list, invoke(generator.load(tableVar), hashTable.get, generator.load(keyVar))) - using(generator.ifNonNullStatement(list)) { onTrue => + using(generator.ifStatement(Expression.notNull(list))) { onTrue => using(onTrue.forEach(Parameter.param(hashTable.valueType, elementName), list)) { forEach => localVars.foreach { case (l, f) => @@ -1133,7 +1133,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux newArray(typeRef[Long], keyVars.map(generator.load): _*)) ))) - using(generator.ifNonNullStatement(list)) { onTrue => + using(generator.ifStatement(Expression.notNull(list))) { onTrue => using(onTrue.forEach(Parameter.param(hashTable.valueType, elementName), list)) { forEach => localVars.foreach { case (l, f) => @@ -1170,7 +1170,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux invoke( generator.load(tableVar), hashTable.get, generator.load(keyVar)))) - using(generator.ifNullStatement(list)) { onTrue => // if (null == list) + using(generator.ifStatement(Expression.isNull(list))) { onTrue => // if (null == list) // list = new ListType(); onTrue.assign(list, createNewInstance(hashTable.listType)) onTrue.expression( @@ -1198,7 +1198,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux generator.assign(list, cast(hashTable.listType, invoke(generator.load(tableVar), hashTable.get, generator.load(keyName)))) - using(generator.ifNullStatement(generator.load(listName))) { onTrue => // if (null == list) + using(generator.ifStatement(Expression.isNull(generator.load(listName)))) { onTrue => // if (null == list) // list = new ListType(); onTrue.assign(list, createNewInstance(hashTable.listType)) // tableVar.put(keyVar, list); @@ -1302,10 +1302,10 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux case _ => throw new IllegalArgumentException(s"CodeGenType $codeGenType can not be converted to long") } - using(generator.ifStatement( + using(generator.ifStatement(and( gt(generator.load(nodeIdVar), constant(-1L)), invoke(readOperations, nodeExists, generator.load(nodeIdVar)) - )) { ifBody => + ))) { ifBody => block(copy(generator = ifBody)) } } diff --git a/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala b/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala index 0b7edff8fe10..da6c09ed7de0 100644 --- a/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala +++ b/enterprise/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala @@ -152,7 +152,7 @@ object Templates { val methodBuilder: Builder = MethodDeclaration.method(typeRef[ReadOperations], "getOrLoadReadOperations") using(clazz.generate(methodBuilder)) { generate => val ro = Expression.get(generate.self(), fields.ro) - using(generate.ifNullStatement(ro)) { block => + using(generate.ifStatement(Expression.isNull(ro))) { block => val transactionalContext: MethodReference = method[QueryContext, QueryTransactionalContext]("transactionalContext") val readOperations: MethodReference = method[QueryTransactionalContext, Object]("readOperations") val queryContext = Expression.get(block.self(), fields.queryContext)