From 9a2636202d4b30dc68266e41d47a7f36a644c52e Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Thu, 27 Oct 2016 08:49:46 +0200 Subject: [PATCH] Removed explicit type argument from comparisons Now that expression contains a type there is no longer any reason to send in explicit types to `equal`, `lt`, `lte`, `gt`, and `gte` --- .../java/org/neo4j/codegen/Expression.java | 20 ++-- .../org/neo4j/codegen/ExpressionTemplate.java | 5 + .../org/neo4j/codegen/ExpressionToString.java | 36 +++---- .../org/neo4j/codegen/ExpressionVisitor.java | 12 +-- .../bytecode/ByteCodeExpressionVisitor.java | 44 +++++--- .../codegen/source/MethodSourceWriter.java | 18 ++-- .../org/neo4j/codegen/CodeGenerationTest.java | 102 +++++++++--------- .../codegen/GeneratedMethodStructure.scala | 22 ++-- .../codegen/GeneratedMethodStructure.scala | 29 +++-- .../codegen/GeneratedQueryStructure.scala | 12 +-- .../internal/spi/v3_2/codegen/Templates.scala | 47 ++++---- 11 files changed, 185 insertions(+), 162 deletions(-) diff --git a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java index 80a58c4e56211..df8bd2c91ff0f 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java @@ -46,50 +46,50 @@ public void accept( ExpressionVisitor visitor ) } }; - public static Expression gt( final Expression lhs, final Expression rhs, TypeReference argType ) + public static Expression gt( final Expression lhs, final Expression rhs ) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.gt( lhs, rhs, argType ); + visitor.gt( lhs, rhs ); } }; } - public static Expression gte( final Expression lhs, final Expression rhs, TypeReference argType ) + public static Expression gte( final Expression lhs, final Expression rhs ) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.gte( lhs, rhs, argType ); + visitor.gte( lhs, rhs ); } }; } - public static Expression lt( final Expression lhs, final Expression rhs, TypeReference argType ) + public static Expression lt( final Expression lhs, final Expression rhs) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.lt( lhs, rhs, argType ); + visitor.lt( lhs, rhs); } }; } - public static Expression lte( final Expression lhs, final Expression rhs, TypeReference argType ) + public static Expression lte( final Expression lhs, final Expression rhs ) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.lte( lhs, rhs, argType ); + visitor.lte( lhs, rhs ); } }; } @@ -118,14 +118,14 @@ public void accept( ExpressionVisitor visitor ) }; } - public static Expression equal( final Expression lhs, final Expression rhs, TypeReference argType ) + public static Expression equal( final Expression lhs, final Expression rhs ) { return new Expression( BOOLEAN ) { @Override public void accept( ExpressionVisitor visitor ) { - visitor.equal( lhs, rhs, argType ); + visitor.equal( lhs, rhs ); } }; } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionTemplate.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionTemplate.java index 22fcaf9cccb01..e651122b42283 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionTemplate.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionTemplate.java @@ -241,4 +241,9 @@ void templateAccept( CodeBlock method, ExpressionVisitor visitor ) } }; } + + public TypeReference type() + { + return type; + } } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java index bb9711d937e0b..32145dc0d4622 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java @@ -63,7 +63,7 @@ public void invoke( MethodReference method, Expression[] arguments ) } @Override - public void load( LocalVariable variable) + public void load( LocalVariable variable ) { result.append( "load{type=" ); if ( variable.type() == null ) @@ -125,9 +125,9 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) { result.append( "ternary{test=" ); test.accept( this ); - result.append(", onTrue="); + result.append( ", onTrue=" ); onTrue.accept( this ); - result.append(", onFalse="); + result.append( ", onFalse=" ); onFalse.accept( this ); result.append( "}" ); } @@ -135,7 +135,7 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) @Override public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) { - ternary( Expression.equal( test, Expression.constant( null ), TypeReference.OBJECT ), + ternary( Expression.equal( test, Expression.constant( null ) ), onTrue, onFalse ); } @@ -143,12 +143,12 @@ public void ternaryOnNull( Expression test, Expression onTrue, Expression onFals public void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ) { ternary( Expression.not( - Expression.equal( test, Expression.constant( null ), TypeReference.OBJECT )), + Expression.equal( test, Expression.constant( null ) ) ), onTrue, onFalse ); } @Override - public void equal( Expression lhs, Expression rhs, TypeReference ignored ) + public void equal( Expression lhs, Expression rhs ) { result.append( "equal(" ); lhs.accept( this ); @@ -180,19 +180,19 @@ public void and( Expression lhs, Expression rhs ) @Override public void addInts( Expression lhs, Expression rhs ) { - add(lhs, rhs); + add( lhs, rhs ); } @Override public void addLongs( Expression lhs, Expression rhs ) { - add(lhs, rhs); + add( lhs, rhs ); } @Override public void addDoubles( Expression lhs, Expression rhs ) { - add(lhs, rhs); + add( lhs, rhs ); } private void add( Expression lhs, Expression rhs ) @@ -205,7 +205,7 @@ private void add( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs, TypeReference ignored ) + public void gt( Expression lhs, Expression rhs ) { result.append( "gt(" ); lhs.accept( this ); @@ -215,7 +215,7 @@ public void gt( Expression lhs, Expression rhs, TypeReference ignored ) } @Override - public void gte( Expression lhs, Expression rhs, TypeReference ignored ) + public void gte( Expression lhs, Expression rhs ) { result.append( "gt(" ); lhs.accept( this ); @@ -225,7 +225,7 @@ public void gte( Expression lhs, Expression rhs, TypeReference ignored ) } @Override - public void lt( Expression lhs, Expression rhs, TypeReference ignored ) + public void lt( Expression lhs, Expression rhs ) { result.append( "lt(" ); lhs.accept( this ); @@ -235,7 +235,7 @@ public void lt( Expression lhs, Expression rhs, TypeReference ignored ) } @Override - public void lte( Expression lhs, Expression rhs, TypeReference ignored ) + public void lte( Expression lhs, Expression rhs ) { result.append( "gt(" ); lhs.accept( this ); @@ -247,19 +247,19 @@ public void lte( Expression lhs, Expression rhs, TypeReference ignored ) @Override public void subtractInts( Expression lhs, Expression rhs ) { - sub( lhs, rhs); + sub( lhs, rhs ); } @Override public void subtractLongs( Expression lhs, Expression rhs ) { - sub( lhs, rhs); + sub( lhs, rhs ); } @Override public void subtractDoubles( Expression lhs, Expression rhs ) { - sub( lhs, rhs); + sub( lhs, rhs ); } private void sub( Expression lhs, Expression rhs ) @@ -274,13 +274,13 @@ private void sub( Expression lhs, Expression rhs ) @Override public void multiplyLongs( Expression lhs, Expression rhs ) { - mul( lhs, rhs); + mul( lhs, rhs ); } @Override public void multiplyDoubles( Expression lhs, Expression rhs ) { - mul( lhs, rhs); + mul( lhs, rhs ); } private void mul( Expression lhs, Expression rhs ) diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java index faf31a20140e8..c853f9f2c5a23 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java @@ -25,7 +25,7 @@ public interface ExpressionVisitor void invoke( MethodReference method, Expression[] arguments ); - void load( LocalVariable variable); + void load( LocalVariable variable ); void getField( Expression target, FieldReference field ); @@ -45,7 +45,7 @@ public interface ExpressionVisitor void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ); - void equal( Expression lhs, Expression rhs, TypeReference type ); + void equal( Expression lhs, Expression rhs); void or( Expression lhs, Expression rhs ); @@ -57,13 +57,13 @@ public interface ExpressionVisitor void addDoubles( Expression lhs, Expression rhs ); - void gt( Expression lhs, Expression rhs, TypeReference type ); + void gt( Expression lhs, Expression rhs ); - void gte( Expression lhs, Expression rhs, TypeReference type ); + void gte( Expression lhs, Expression rhs ); - void lt( Expression lhs, Expression rhs, TypeReference type ); + void lt( Expression lhs, Expression rhs ); - void lte( Expression lhs, Expression rhs, TypeReference type ); + void lte( Expression lhs, Expression rhs ); void subtractInts( Expression lhs, Expression rhs ); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java index d122b01a94408..7f79d8a510fcd 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java @@ -184,7 +184,7 @@ public void getField( Expression target, FieldReference field ) target.accept( this ); methodVisitor .visitFieldInsn( GETFIELD, byteCodeName( field.owner() ), field.name(), typeName( field.type() ) ); -} + } @Override public void constant( Object value ) @@ -271,7 +271,7 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) @Override public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) { - ternaryExpression( IFNONNULL, test, onTrue, onFalse ); + ternaryExpression( IFNONNULL, test, onTrue, onFalse ); } @Override @@ -281,9 +281,10 @@ public void ternaryOnNonNull( Expression test, Expression onTrue, Expression onF } @Override - public void equal( Expression lhs, Expression rhs, TypeReference type ) + public void equal( Expression lhs, Expression rhs ) { - switch ( type.simpleName() ) + assertSameType( lhs, rhs ); + switch ( lhs.type().simpleName() ) { case "int": case "byte": @@ -396,20 +397,22 @@ public void addDoubles( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs, TypeReference type ) + public void gt( Expression lhs, Expression rhs ) { - numberOperation( type, + assertSameType( lhs, rhs ); + numberOperation( lhs.type(), () -> compareIntOrReferenceType( lhs, rhs, IF_ICMPLE ), () -> compareLongOrFloatType( lhs, rhs, LCMP, IFLE ), () -> compareLongOrFloatType( lhs, rhs, FCMPL, IFLE ), () -> compareLongOrFloatType( lhs, rhs, DCMPL, IFLE ) - ); + ); } @Override - public void gte( Expression lhs, Expression rhs, TypeReference type ) + public void gte( Expression lhs, Expression rhs ) { - numberOperation( type, + assertSameType( lhs, rhs ); + numberOperation( lhs.type(), () -> compareIntOrReferenceType( lhs, rhs, IF_ICMPLT ), () -> compareLongOrFloatType( lhs, rhs, LCMP, IFLT ), () -> compareLongOrFloatType( lhs, rhs, FCMPL, IFLT ), @@ -418,9 +421,10 @@ public void gte( Expression lhs, Expression rhs, TypeReference type ) } @Override - public void lt( Expression lhs, Expression rhs, TypeReference type ) + public void lt( Expression lhs, Expression rhs ) { - numberOperation( type, + assertSameType( lhs, rhs ); + numberOperation( lhs.type(), () -> compareIntOrReferenceType( lhs, rhs, IF_ICMPGE ), () -> compareLongOrFloatType( lhs, rhs, LCMP, IFGE ), () -> compareLongOrFloatType( lhs, rhs, FCMPG, IFGE ), @@ -429,9 +433,10 @@ public void lt( Expression lhs, Expression rhs, TypeReference type ) } @Override - public void lte( Expression lhs, Expression rhs, TypeReference type ) + public void lte( Expression lhs, Expression rhs) { - numberOperation( type, + assertSameType( lhs, rhs ); + numberOperation( lhs.type(), () -> compareIntOrReferenceType( lhs, rhs, IF_ICMPGT ), () -> compareLongOrFloatType( lhs, rhs, LCMP, IFGT ), () -> compareLongOrFloatType( lhs, rhs, FCMPG, IFGT ), @@ -640,7 +645,7 @@ private void arrayStore( TypeReference reference ) } } - private void ternaryExpression(int op, Expression test, Expression onTrue, Expression onFalse) + private void ternaryExpression( int op, Expression test, Expression onTrue, Expression onFalse ) { test.accept( this ); Label l0 = new Label(); @@ -653,7 +658,8 @@ private void ternaryExpression(int op, Expression test, Expression onTrue, Expre methodVisitor.visitLabel( l1 ); } - private void numberOperation( TypeReference type, Runnable onInt, Runnable onLong, Runnable onFloat, Runnable onDouble ) + private void numberOperation( TypeReference type, Runnable onInt, Runnable onLong, Runnable onFloat, + Runnable onDouble ) { switch ( type.simpleName() ) @@ -679,4 +685,12 @@ private void numberOperation( TypeReference type, Runnable onInt, Runnable onLon } } + private void assertSameType( Expression lhs, Expression rhs ) + { + if ( !lhs.type().equals( rhs.type() ) ) + { + throw new IllegalArgumentException( "Can only compare values of the same type" ); + } + } + } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java index f7034dfecea09..d0d04e9bd4f5c 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java @@ -181,7 +181,7 @@ public void beginIfNull( Expression...tests ) Expression[] nulls = new Expression[tests.length]; for ( int i = 0; i < tests.length; i++ ) { - nulls[i] = Expression.equal(tests[i], Expression.constant( null ), TypeReference.OBJECT); + nulls[i] = Expression.equal(tests[i], Expression.constant( null ) ); } beginIf(nulls); } @@ -192,7 +192,7 @@ public void beginIfNonNull( Expression...tests ) Expression[] notNulls = new Expression[tests.length]; for ( int i = 0; i < tests.length; i++ ) { - notNulls[i] = Expression.not(Expression.equal(tests[i], Expression.constant( null ), TypeReference.OBJECT)); + notNulls[i] = Expression.not(Expression.equal(tests[i], Expression.constant( null ) )); } beginIf(notNulls); } @@ -356,7 +356,7 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) @Override public void ternaryOnNull( Expression test, Expression onTrue, Expression onFalse ) { - ternary( Expression.equal( test, Expression.constant( null ), TypeReference.OBJECT ), + ternary( Expression.equal( test, Expression.constant( null ) ), onTrue, onFalse ); } @@ -364,12 +364,12 @@ public void ternaryOnNull( Expression test, Expression onTrue, Expression onFals public void ternaryOnNonNull( Expression test, Expression onTrue, Expression onFalse ) { ternary( Expression.not( - Expression.equal( test, Expression.constant( null ), TypeReference.OBJECT )), + Expression.equal( test, Expression.constant( null ) )), onTrue, onFalse ); } @Override - public void equal( Expression lhs, Expression rhs, TypeReference ignored ) + public void equal( Expression lhs, Expression rhs ) { binaryOperation( lhs, rhs, " == " ); } @@ -411,25 +411,25 @@ private void add( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs, TypeReference ignored ) + public void gt( Expression lhs, Expression rhs ) { binaryOperation( lhs, rhs, " > " ); } @Override - public void gte( Expression lhs, Expression rhs, TypeReference ignored ) + public void gte( Expression lhs, Expression rhs ) { binaryOperation( lhs, rhs, " >= " ); } @Override - public void lt( Expression lhs, Expression rhs, TypeReference ignored ) + public void lt( Expression lhs, Expression rhs ) { binaryOperation( lhs, rhs, " < " ); } @Override - public void lte( Expression lhs, Expression rhs, TypeReference ignored ) + public void lte( Expression lhs, Expression rhs ) { binaryOperation( lhs, rhs, " <= " ); } diff --git a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java index e2e9e33204379..712f02f1a39ca 100644 --- a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java +++ b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java @@ -1193,144 +1193,144 @@ public void shouldHandleEquality() throws Throwable { // boolean assertTrue( compareForType( boolean.class, true, true, - ( a, b ) -> Expression.equal( a, b, typeReference( boolean.class ) ) ) ); + Expression::equal ) ); assertTrue( compareForType( boolean.class, false, false, - ( a, b ) -> Expression.equal( a, b, typeReference( boolean.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( boolean.class, true, false, - ( a, b ) -> Expression.equal( a, b, typeReference( boolean.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( boolean.class, false, true, - ( a, b ) -> Expression.equal( a, b, typeReference( boolean.class ) ) ) ); + Expression::equal ) ); // byte assertTrue( compareForType( byte.class, (byte) 42, (byte) 42, - ( a, b ) -> Expression.equal( a, b, typeReference( byte.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( byte.class, (byte) 43, (byte) 42, - ( a, b ) -> Expression.equal( a, b, typeReference( byte.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( byte.class, (byte) 42, (byte) 43, - ( a, b ) -> Expression.equal( a, b, typeReference( byte.class ) ) ) ); + Expression::equal ) ); // short assertTrue( compareForType( short.class, (short) 42, (short) 42, - ( a, b ) -> Expression.equal( a, b, typeReference( short.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( short.class, (short) 43, (short) 42, - ( a, b ) -> Expression.equal( a, b, typeReference( short.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( short.class, (short) 42, (short) 43, - ( a, b ) -> Expression.equal( a, b, typeReference( short.class ) ) ) ); + Expression::equal ) ); // char assertTrue( compareForType( char.class, (char) 42, (char) 42, - ( a, b ) -> Expression.equal( a, b, typeReference( char.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( char.class, (char) 43, (char) 42, - ( a, b ) -> Expression.equal( a, b, typeReference( char.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( char.class, (char) 42, (char) 43, - ( a, b ) -> Expression.equal( a, b, typeReference( char.class ) ) ) ); + Expression::equal ) ); //int assertTrue( compareForType( int.class, 42, 42, - ( a, b ) -> Expression.equal( a, b, typeReference( int.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( int.class, 43, 42, - ( a, b ) -> Expression.equal( a, b, typeReference( int.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( int.class, 42, 43, - ( a, b ) -> Expression.equal( a, b, typeReference( int.class ) ) ) ); + Expression::equal ) ); //long assertTrue( compareForType( long.class, 42L, 42L, - ( a, b ) -> Expression.equal( a, b, typeReference( long.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( long.class, 43L, 42L, - ( a, b ) -> Expression.equal( a, b, typeReference( long.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( long.class, 42L, 43L, - ( a, b ) -> Expression.equal( a, b, typeReference( long.class ) ) ) ); + Expression::equal ) ); //float assertTrue( compareForType( float.class, 42F, 42F, - ( a, b ) -> Expression.equal( a, b, typeReference( float.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( float.class, 43F, 42F, - ( a, b ) -> Expression.equal( a, b, typeReference( float.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( float.class, 42F, 43F, - ( a, b ) -> Expression.equal( a, b, typeReference( float.class ) ) ) ); + Expression::equal ) ); //double assertTrue( compareForType( double.class, 42D, 42D, - ( a, b ) -> Expression.equal( a, b, typeReference( double.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( double.class, 43D, 42D, - ( a, b ) -> Expression.equal( a, b, typeReference( double.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( double.class, 42D, 43D, - ( a, b ) -> Expression.equal( a, b, typeReference( double.class ) ) ) ); + Expression::equal ) ); //reference Object obj1 = new Object(); Object obj2 = new Object(); assertTrue( compareForType( Object.class, obj1, obj1, - ( a, b ) -> Expression.equal( a, b, typeReference( Object.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( Object.class, obj1, obj2, - ( a, b ) -> Expression.equal( a, b, typeReference( Object.class ) ) ) ); + Expression::equal ) ); assertFalse( compareForType( Object.class, obj2, obj1, - ( a, b ) -> Expression.equal( a, b, typeReference( Object.class ) ) ) ); + Expression::equal ) ); } @Test public void shouldHandleGreaterThan() throws Throwable { assertTrue( compareForType( float.class, 43F, 42F, - ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + Expression::gt ) ); assertTrue( compareForType( long.class, 43L, 42L, - ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); + Expression::gt ) ); // byte assertTrue( compareForType( byte.class, (byte) 43, (byte) 42, - ( a, b ) -> Expression.gt( a, b, typeReference( byte.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( byte.class, (byte) 42, (byte) 42, - ( a, b ) -> Expression.gt( a, b, typeReference( byte.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( byte.class, (byte) 42, (byte) 43, - ( a, b ) -> Expression.gt( a, b, typeReference( byte.class ) ) ) ); + Expression::gt ) ); // short assertTrue( compareForType( short.class, (short) 43, (short) 42, - ( a, b ) -> Expression.gt( a, b, typeReference( short.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( short.class, (short) 42, (short) 42, - ( a, b ) -> Expression.gt( a, b, typeReference( short.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( short.class, (short) 42, (short) 43, - ( a, b ) -> Expression.gt( a, b, typeReference( short.class ) ) ) ); + Expression::gt ) ); // char assertTrue( compareForType( char.class, (char) 43, (char) 42, - ( a, b ) -> Expression.gt( a, b, typeReference( char.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( char.class, (char) 42, (char) 42, - ( a, b ) -> Expression.gt( a, b, typeReference( char.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( char.class, (char) 42, (char) 43, - ( a, b ) -> Expression.gt( a, b, typeReference( char.class ) ) ) ); + Expression::gt ) ); //int assertTrue( - compareForType( int.class, 43, 42, ( a, b ) -> Expression.gt( a, b, typeReference( int.class ) ) ) ); + compareForType( int.class, 43, 42, Expression::gt ) ); assertFalse( - compareForType( int.class, 42, 42, ( a, b ) -> Expression.gt( a, b, typeReference( int.class ) ) ) ); + compareForType( int.class, 42, 42, Expression::gt ) ); assertFalse( - compareForType( int.class, 42, 43, ( a, b ) -> Expression.gt( a, b, typeReference( int.class ) ) ) ); + compareForType( int.class, 42, 43, Expression::gt ) ); //long assertTrue( compareForType( long.class, 43L, 42L, - ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( long.class, 42L, 42L, - ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( long.class, 42L, 43L, - ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); + Expression::gt ) ); //float assertTrue( compareForType( float.class, 43F, 42F, - ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( float.class, 42F, 42F, - ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( float.class, 42F, 43F, - ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + Expression::gt ) ); //double assertTrue( compareForType( double.class, 43D, 42D, - ( a, b ) -> Expression.gt( a, b, typeReference( double.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( double.class, 42D, 42D, - ( a, b ) -> Expression.gt( a, b, typeReference( double.class ) ) ) ); + Expression::gt ) ); assertFalse( compareForType( double.class, 42D, 43D, - ( a, b ) -> Expression.gt( a, b, typeReference( double.class ) ) ) ); + Expression::gt ) ); } @Test diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala index 9bf071832f07f..5b9169f8d0cd3 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/GeneratedMethodStructure.scala @@ -106,7 +106,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A generator.assign(typeRef[Long], toNodeVar, toGraphDb(direction) match { case Direction.INCOMING => start case Direction.OUTGOING => end - case Direction.BOTH => ternary(equal(start, generator.load(fromNodeVar), typeRef[Long]), end, start) + case Direction.BOTH => ternary(equal(start, generator.load(fromNodeVar)), end, start) }) generator.assign(typeRef[Long], relVar, invoke(generator.load(extractor), getRelationship)) } @@ -191,12 +191,12 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A val local = locals(name) generator.assign(local, subtractInts(local, constant(1))) - equal(constant(0), local, typeRef[Int]) + equal(constant(0), local) } override def counterEqualsZero(name: String): Expression = { val local = locals(name) - equal(constant(0), local, typeRef[Int]) + equal(constant(0), local) } override def setInRow(column: String, value: Expression) = @@ -231,7 +231,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nullablePrimitive(varName: String, codeGenType: CodeGenType, onSuccess: Expression) = codeGenType match { case CodeGenType(CTNode, IntType) | CodeGenType(CTRelationship, IntType) => ternary( - equal(nullValue(codeGenType), generator.load(varName), lowerType(codeGenType)), + equal(nullValue(codeGenType), generator.load(varName)), nullValue(codeGenType), onSuccess) case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) @@ -240,7 +240,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nullableReference(varName: String, codeGenType: CodeGenType, onSuccess: Expression) = codeGenType match { case CodeGenType(CTNode, IntType) | CodeGenType(CTRelationship, IntType) => ternary( - equal(nullValue(codeGenType), generator.load(varName), lowerType(codeGenType)), + equal(nullValue(codeGenType), generator.load(varName)), constant(null), onSuccess) case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) @@ -296,7 +296,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def threeValuedEqualsExpression(lhs: Expression, rhs: Expression) = invoke(Methods.ternaryEquals, lhs, rhs) override def equalityExpression(lhs: Expression, rhs: Expression, codeGenType: CodeGenType) = - if (codeGenType.isPrimitive) equal(lhs, rhs, lowerType(codeGenType)) + if (codeGenType.isPrimitive) equal(lhs, rhs) else invoke(lhs, Methods.equals, rhs) override def orExpression(lhs: Expression, rhs: Expression) = or(lhs, rhs) @@ -308,7 +308,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def isNull(varName: String, codeGenType: CodeGenType) = - equal(nullValue(codeGenType), generator.load(varName), lowerType(codeGenType)) + equal(nullValue(codeGenType), generator.load(varName)) override def notNull(varName: String, codeGenType: CodeGenType) = not(isNull(varName, codeGenType)) @@ -467,7 +467,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A pop( invoke(generator.load(tableVar), countingTablePut, generator.load(keyVar), ternary( - equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL")), typeRef[Int]), + equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL"))), constant(1), addInts(generator.load(countName), constant(1)))))) @@ -502,7 +502,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A val keyVar = keyVars.head val times = generator.declare(typeRef[Int], context.namer.newVarName()) generator.assign(times, invoke(generator.load(tableVar), countingTableGet, generator.load(keyVar))) - using(generator.whileLoop(gt(times, constant(0), typeRef[Int]))) { body => + using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) body.assign(times, subtractInts(times, constant(1))) } @@ -522,7 +522,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A constant(-1), invoke(intermediate, unboxInteger))) - using(generator.whileLoop(gt(times, constant(0), typeRef[Int]))) { body => + using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) body.assign(times, subtractInts(times, constant(1))) } @@ -708,7 +708,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nodeIdSeek(nodeIdVar: String, expression: Expression)(block: MethodStructure[Expression] => Unit) = { generator.assign(typeRef[Long], nodeIdVar, invoke(Methods.mathCastToLong, expression)) using(generator.ifStatement( - gt(generator.load(nodeIdVar), constant(-1L), typeRef[Long]), + gt(generator.load(nodeIdVar), constant(-1L)), invoke(readOperations, nodeExists, generator.load(nodeIdVar)) )) { ifBody => block(copy(generator = ifBody)) diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala index 5d0f30f61351b..2feebc1b6cc22 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala @@ -106,7 +106,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A generator.assign(typeRef[Long], toNodeVar, toGraphDb(direction) match { case Direction.INCOMING => start case Direction.OUTGOING => end - case Direction.BOTH => ternary(equal(start, generator.load(fromNodeVar), typeRef[Long]), end, start) + case Direction.BOTH => ternary(equal(start, generator.load(fromNodeVar)), end, start) }) generator.assign(typeRef[Long], relVar, invoke(generator.load(extractor), getRelationship)) } @@ -187,7 +187,6 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A generator.assign(variable, invoke(mathCastToInt, initialValue)) } - override def decrementCounter(name: String) = { val local = locals(name) generator.assign(local, subtractInts(local, constant(1))) @@ -196,11 +195,11 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def checkCounter(name: String, comparator: Comparator, value: Int): Expression = { val local = locals(name) comparator match { - case Equal => equal(local, constant(value), typeRef[Int]) - case LessThan => lt(local, constant(value), typeRef[Int]) - case LessThanEqual => lte(local, constant(value), typeRef[Int]) - case GreaterThan => gt(local, constant(value), typeRef[Int]) - case GreaterThanEqual => gte(local, constant(value), typeRef[Int]) + case Equal => equal(local, constant(value)) + case LessThan => lt(local, constant(value)) + case LessThanEqual => lte(local, constant(value)) + case GreaterThan => gt(local, constant(value)) + case GreaterThanEqual => gte(local, constant(value)) } } @@ -236,7 +235,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nullablePrimitive(varName: String, codeGenType: CodeGenType, onSuccess: Expression) = codeGenType match { case CodeGenType(CTNode, IntType) | CodeGenType(CTRelationship, IntType) => ternary( - equal(nullValue(codeGenType), generator.load(varName), lowerType(codeGenType)), + equal(nullValue(codeGenType), generator.load(varName)), nullValue(codeGenType), onSuccess) case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) @@ -245,7 +244,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nullableReference(varName: String, codeGenType: CodeGenType, onSuccess: Expression) = codeGenType match { case CodeGenType(CTNode, IntType) | CodeGenType(CTRelationship, IntType) => ternary( - equal(nullValue(codeGenType), generator.load(varName), lowerType(codeGenType)), + equal(nullValue(codeGenType), generator.load(varName)), constant(null), onSuccess) case _ => ternaryOnNull(generator.load(varName), constant(null), onSuccess) @@ -301,7 +300,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def threeValuedEqualsExpression(lhs: Expression, rhs: Expression) = invoke(Methods.ternaryEquals, lhs, rhs) override def equalityExpression(lhs: Expression, rhs: Expression, codeGenType: CodeGenType) = - if (codeGenType.isPrimitive) equal(lhs, rhs, lowerType(codeGenType)) + if (codeGenType.isPrimitive) equal(lhs, rhs) else invoke(lhs, Methods.equals, rhs) override def orExpression(lhs: Expression, rhs: Expression) = or(lhs, rhs) @@ -313,7 +312,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def isNull(varName: String, codeGenType: CodeGenType) = - equal(nullValue(codeGenType), generator.load(varName), lowerType(codeGenType)) + equal(nullValue(codeGenType), generator.load(varName)) override def notNull(varName: String, codeGenType: CodeGenType) = not(isNull(varName, codeGenType)) @@ -472,7 +471,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A pop( invoke(generator.load(tableVar), countingTablePut, generator.load(keyVar), ternary( - equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL")), typeRef[Int]), + equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL"))), constant(1), addInts(generator.load(countName), constant(1)))))) @@ -507,7 +506,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A val keyVar = keyVars.head val times = generator.declare(typeRef[Int], context.namer.newVarName()) generator.assign(times, invoke(generator.load(tableVar), countingTableGet, generator.load(keyVar))) - using(generator.whileLoop(gt(times, constant(0), typeRef[Int]))) { body => + using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) body.assign(times, subtractInts(times, constant(1))) } @@ -527,7 +526,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A constant(-1), invoke(intermediate, unboxInteger))) - using(generator.whileLoop(gt(times, constant(0), typeRef[Int]))) { body => + using(generator.whileLoop(gt(times, constant(0)))) { body => block(copy(generator = body)) body.assign(times, subtractInts(times, constant(1))) } @@ -713,7 +712,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A override def nodeIdSeek(nodeIdVar: String, expression: Expression)(block: MethodStructure[Expression] => Unit) = { generator.assign(typeRef[Long], nodeIdVar, invoke(Methods.mathCastToLong, expression)) using(generator.ifStatement( - gt(generator.load(nodeIdVar), constant(-1L), typeRef[Long]), + gt(generator.load(nodeIdVar), constant(-1L)), invoke(readOperations, nodeExists, generator.load(nodeIdVar)) )) { ifBody => block(copy(generator = ifBody)) diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedQueryStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedQueryStructure.scala index de98d89493c68..d6e9cb07c094c 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedQueryStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedQueryStructure.scala @@ -94,8 +94,8 @@ object GeneratedQueryStructure extends CodeStructure[GeneratedQuery] { tracer = clazz.field(typeRef[QueryExecutionTracer], "tracer"), params = clazz.field(typeRef[util.Map[String, Object]], "params"), closeable = clazz.field(typeRef[SuccessfulCloseable], "closeable"), - success = clazz.generate(Templates.SUCCESS), - close = clazz.generate(Templates.CLOSE)) + success = clazz.generate(Templates.success(clazz.handle())), + close = clazz.generate(Templates.close(clazz.handle()))) // the "COLUMNS" static field clazz.staticField(typeRef[util.List[String]], "COLUMNS", Templates.asList[String]( columns.map(key => constant(key)))) @@ -106,10 +106,10 @@ object GeneratedQueryStructure extends CodeStructure[GeneratedQuery] { } // simple methods - clazz.generate(Templates.CONSTRUCTOR) - clazz.generate(Templates.SET_SUCCESSFUL_CLOSEABLE) - clazz.generate(Templates.EXECUTION_MODE) - clazz.generate(Templates.EXECUTION_PLAN_DESCRIPTION) + clazz.generate(Templates.constructor(clazz.handle())) + clazz.generate(Templates.setSuccessfulCloseable(clazz.handle())) + clazz.generate(Templates.executionMode(clazz.handle())) + clazz.generate(Templates.executionPlanDescription(clazz.handle())) clazz.generate(Templates.JAVA_COLUMNS) using(clazz.generate(MethodDeclaration.method(typeRef[Unit], "accept", diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala index eebc93f5f4474..4d08d6c4f0dc4 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Templates.scala @@ -104,7 +104,7 @@ object Templates { .invoke(Expression.newInstance(typeRef[RelationshipDataExtractor]), MethodReference.constructorReference(typeRef[RelationshipDataExtractor])) - val CONSTRUCTOR = MethodTemplate.constructor( + def constructor(classHandle: ClassHandle) = MethodTemplate.constructor( param[TaskCloser]("closer"), param[QueryContext]("queryContext"), param[ExecutionMode]("executionMode"), @@ -113,40 +113,45 @@ object Templates { param[util.Map[String, Object]]("params")). invokeSuper(). - put(self(), typeRef[TaskCloser], "closer", load("closer")). - put(self(), typeRef[ReadOperations], "ro", + put(self(classHandle), typeRef[TaskCloser], "closer", load("closer", typeRef[TaskCloser])). + put(self(classHandle), typeRef[ReadOperations], "ro", cast(classOf[ReadOperations], invoke( - invoke(load("queryContext"), method[QueryContext, QueryTransactionalContext]("transactionalContext")), + invoke(load("queryContext", typeRef[QueryContext]), method[QueryContext, QueryTransactionalContext]("transactionalContext")), method[QueryTransactionalContext, Object]("readOperations")))). - put(self(), typeRef[ExecutionMode], "executionMode", load("executionMode")). - put(self(), typeRef[Provider[InternalPlanDescription]], "description", load("description")). - put(self(), typeRef[QueryExecutionTracer], "tracer", load("tracer")). - put(self(), typeRef[util.Map[String, Object]], "params", load("params")). - put(self(), typeRef[NodeManager], "nodeManager", + put(self(classHandle), typeRef[ExecutionMode], "executionMode", load("executionMode", typeRef[ExecutionMode])). + put(self(classHandle), typeRef[Provider[InternalPlanDescription]], "description", load("description", typeRef[InternalPlanDescription])). + put(self(classHandle), typeRef[QueryExecutionTracer], "tracer", load("tracer", typeRef[QueryExecutionTracer])). + put(self(classHandle), typeRef[util.Map[String, Object]], "params", load("params", typeRef[util.Map[String, Object]])). + put(self(classHandle), typeRef[NodeManager], "nodeManager", cast(typeRef[NodeManager], - invoke(load("queryContext"), method[QueryContext, Object]("entityAccessor")))). + invoke(load("queryContext", typeRef[QueryContext]), method[QueryContext, Object]("entityAccessor")))). build() - val SET_SUCCESSFUL_CLOSEABLE = MethodTemplate.method(typeRef[Unit], "setSuccessfulCloseable", - param[SuccessfulCloseable]("closeable")). - put(self(), typeRef[SuccessfulCloseable], "closeable", load("closeable")). + def setSuccessfulCloseable(classHandle: ClassHandle) = MethodTemplate.method(typeRef[Unit], "setSuccessfulCloseable", + param[SuccessfulCloseable]("closeable")). + put(self(classHandle), typeRef[SuccessfulCloseable], "closeable", load("closeable", typeRef[SuccessfulCloseable])). build() - val SUCCESS = MethodTemplate.method(typeRef[Unit], "success"). + + def success(classHandle: ClassHandle) = MethodTemplate.method(typeRef[Unit], "success"). expression( - invoke(get(self(), typeRef[SuccessfulCloseable], "closeable"), method[SuccessfulCloseable, Unit]("success"))). + invoke(get(self(classHandle), typeRef[SuccessfulCloseable], "closeable"), method[SuccessfulCloseable, Unit]("success"))). build() - val CLOSE = MethodTemplate.method(typeRef[Unit], "close"). + + def close(classHandle: ClassHandle) = MethodTemplate.method(typeRef[Unit], "close"). expression( - invoke(get(self(), typeRef[SuccessfulCloseable], "closeable"), method[SuccessfulCloseable, Unit]("close"))). + invoke(get(self(classHandle), typeRef[SuccessfulCloseable], "closeable"), method[SuccessfulCloseable, Unit]("close"))). build() - val EXECUTION_MODE = MethodTemplate.method(typeRef[ExecutionMode], "executionMode"). - returns(get(self(), typeRef[ExecutionMode], "executionMode")). + + def executionMode(classHandle: ClassHandle) = MethodTemplate.method(typeRef[ExecutionMode], "executionMode"). + returns(get(self(classHandle), typeRef[ExecutionMode], "executionMode")). build() - val EXECUTION_PLAN_DESCRIPTION = MethodTemplate.method(typeRef[InternalPlanDescription], "executionPlanDescription"). + + def executionPlanDescription(classHandle: ClassHandle) = MethodTemplate.method(typeRef[InternalPlanDescription], "executionPlanDescription"). returns(cast( typeRef[InternalPlanDescription], - invoke(get(self(), typeRef[Provider[InternalPlanDescription]], "description"), + invoke(get(self(classHandle), typeRef[Provider[InternalPlanDescription]], "description"), method[Provider[InternalPlanDescription], Object]("get")))). build() + val JAVA_COLUMNS = MethodTemplate.method(typeRef[util.List[String]], "javaColumns"). returns(get(typeRef[util.List[String]], "COLUMNS")). build()