From 4e8e93a2b62ed4f3171203412df9b39b5399335d Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Wed, 13 Apr 2016 15:51:58 +0200 Subject: [PATCH] Changed greater than to consider type --- .../neo4j/codegen/BaseExpressionVisitor.java | 2 +- .../java/org/neo4j/codegen/Expression.java | 4 +- .../org/neo4j/codegen/ExpressionToString.java | 2 +- .../org/neo4j/codegen/ExpressionVisitor.java | 2 +- .../bytecode/ByteCodeExpressionVisitor.java | 13 +--- .../codegen/source/MethodSourceWriter.java | 2 +- .../org/neo4j/codegen/CodeGenerationTest.java | 69 ++++++++++++------- .../spi/v2_3/GeneratedQueryStructure.scala | 4 +- .../spi/v3_0/GeneratedQueryStructure.scala | 4 +- .../spi/v3_1/GeneratedQueryStructure.scala | 4 +- 10 files changed, 60 insertions(+), 46 deletions(-) diff --git a/community/codegen/src/main/java/org/neo4j/codegen/BaseExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/BaseExpressionVisitor.java index ca510c6fec60e..6081040f93890 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/BaseExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/BaseExpressionVisitor.java @@ -115,7 +115,7 @@ public void addDoubles( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs ) + public void gt( Expression lhs, Expression rhs, TypeReference type ) { } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java index fa79d05a652f0..559b6fe202d11 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/Expression.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/Expression.java @@ -34,14 +34,14 @@ public void accept( ExpressionVisitor visitor ) public abstract void accept( ExpressionVisitor visitor ); - public static Expression gt( final Expression lhs, final Expression rhs ) + public static Expression gt( final Expression lhs, final Expression rhs, TypeReference type ) { return new Expression() { @Override public void accept( ExpressionVisitor visitor ) { - visitor.gt( lhs, rhs ); + visitor.gt( lhs, rhs, type ); } }; } diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java index 19341b972e1c7..f34cee684e663 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionToString.java @@ -179,7 +179,7 @@ private void add( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs ) + public void gt( Expression lhs, Expression rhs, TypeReference ignored ) { result.append( "gt(" ); lhs.accept( this ); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java index 655641e064837..528577df5f758 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/ExpressionVisitor.java @@ -51,7 +51,7 @@ public interface ExpressionVisitor void addDoubles( Expression lhs, Expression rhs ); - void gt( Expression lhs, Expression rhs ); + void gt( Expression lhs, Expression rhs, TypeReference type ); void subtractInts( Expression lhs, Expression rhs ); diff --git a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java index 76ef8a723c263..b528f2549c955 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/bytecode/ByteCodeExpressionVisitor.java @@ -226,7 +226,6 @@ public void ternary( Expression test, Expression onTrue, Expression onFalse ) @Override public void eq( Expression lhs, Expression rhs, TypeReference type ) { - switch ( type.simpleName() ) { case "int": @@ -310,18 +309,10 @@ public void addDoubles( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs ) + public void gt( Expression lhs, Expression rhs, TypeReference type ) { - TypeReference lhsType = findType( lhs ); - TypeReference rhsType = findType( rhs ); - - if ( !lhsType.equals( rhsType ) ) - { - throw new IllegalStateException( "Cannot compare values of different types" ); - } - - switch ( lhsType.simpleName() ) + switch ( type.simpleName() ) { case "int": case "byte": diff --git a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java index 9e0be7fa58df2..289c567bd686c 100644 --- a/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java +++ b/community/codegen/src/main/java/org/neo4j/codegen/source/MethodSourceWriter.java @@ -436,7 +436,7 @@ private void add( Expression lhs, Expression rhs ) } @Override - public void gt( Expression lhs, Expression rhs ) + public void gt( Expression lhs, Expression rhs, TypeReference ignored ) { lhs.accept( this ); append( " > " ); diff --git a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java index 882742bf68bcd..7004f31db2ee9 100644 --- a/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java +++ b/community/codegen/src/test/java/org/neo4j/codegen/CodeGenerationTest.java @@ -1057,43 +1057,66 @@ public void shouldHandleEquality() throws Throwable @Test public void shouldHandleGreaterThan() throws Throwable { - assertTrue( compareForType( float.class, 43F, 42F, Expression::gt ) ); - assertTrue( compareForType( long.class, 43L, 42L, Expression::gt ) ); + assertTrue( compareForType( float.class, 43F, 42F, + ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + assertTrue( compareForType( long.class, 43L, 42L, + ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); // byte - assertTrue( compareForType( byte.class, (byte) 43, (byte) 42, Expression::gt ) ); - assertFalse( compareForType( byte.class, (byte) 42, (byte) 42, Expression::gt ) ); - assertFalse( compareForType( byte.class, (byte) 42, (byte) 43, Expression::gt ) ); + assertTrue( compareForType( byte.class, (byte) 43, (byte) 42, + ( a, b ) -> Expression.gt( a, b, typeReference( byte.class ) ) ) ); + assertFalse( compareForType( byte.class, (byte) 42, (byte) 42, + ( a, b ) -> Expression.gt( a, b, typeReference( byte.class ) ) ) ); + assertFalse( compareForType( byte.class, (byte) 42, (byte) 43, + ( a, b ) -> Expression.gt( a, b, typeReference( byte.class ) ) ) ); // short - assertTrue( compareForType( short.class, (short) 43, (short) 42, Expression::gt ) ); - assertFalse( compareForType( short.class, (short) 42, (short) 42, Expression::gt ) ); - assertFalse( compareForType( short.class, (short) 42, (short) 43, Expression::gt ) ); + assertTrue( compareForType( short.class, (short) 43, (short) 42, + ( a, b ) -> Expression.gt( a, b, typeReference( short.class ) ) ) ); + assertFalse( compareForType( short.class, (short) 42, (short) 42, + ( a, b ) -> Expression.gt( a, b, typeReference( short.class ) ) ) ); + assertFalse( compareForType( short.class, (short) 42, (short) 43, + ( a, b ) -> Expression.gt( a, b, typeReference( short.class ) ) ) ); // char - assertTrue( compareForType( char.class, (char) 43, (char) 42, Expression::gt ) ); - assertFalse( compareForType( char.class, (char) 42, (char) 42, Expression::gt ) ); - assertFalse( compareForType( char.class, (char) 42, (char) 43, Expression::gt ) ); + assertTrue( compareForType( char.class, (char) 43, (char) 42, + ( a, b ) -> Expression.gt( a, b, typeReference( char.class ) ) ) ); + assertFalse( compareForType( char.class, (char) 42, (char) 42, + ( a, b ) -> Expression.gt( a, b, typeReference( char.class ) ) ) ); + assertFalse( compareForType( char.class, (char) 42, (char) 43, + ( a, b ) -> Expression.gt( a, b, typeReference( char.class ) ) ) ); //int - assertTrue( compareForType( int.class, 43, 42, Expression::gt ) ); - assertFalse( compareForType( int.class, 42, 42, Expression::gt ) ); - assertFalse( compareForType( int.class, 42, 43, Expression::gt ) ); + assertTrue( + compareForType( int.class, 43, 42, ( a, b ) -> Expression.gt( a, b, typeReference( int.class ) ) ) ); + assertFalse( + compareForType( int.class, 42, 42, ( a, b ) -> Expression.gt( a, b, typeReference( int.class ) ) ) ); + assertFalse( + compareForType( int.class, 42, 43, ( a, b ) -> Expression.gt( a, b, typeReference( int.class ) ) ) ); //long - assertTrue( compareForType( long.class, 43L, 42L, Expression::gt ) ); - assertFalse( compareForType( long.class, 42L, 42L, Expression::gt ) ); - assertFalse( compareForType( long.class, 42L, 43L, Expression::gt ) ); + assertTrue( compareForType( long.class, 43L, 42L, + ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); + assertFalse( compareForType( long.class, 42L, 42L, + ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); + assertFalse( compareForType( long.class, 42L, 43L, + ( a, b ) -> Expression.gt( a, b, typeReference( long.class ) ) ) ); //float - assertTrue( compareForType( float.class, 43F, 42F, Expression::gt ) ); - assertFalse( compareForType( float.class, 42F, 42F, Expression::gt ) ); - assertFalse( compareForType( float.class, 42F, 43F, Expression::gt ) ); + assertTrue( compareForType( float.class, 43F, 42F, + ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + assertFalse( compareForType( float.class, 42F, 42F, + ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); + assertFalse( compareForType( float.class, 42F, 43F, + ( a, b ) -> Expression.gt( a, b, typeReference( float.class ) ) ) ); //double - assertTrue( compareForType( double.class, 43D, 42D, Expression::gt ) ); - assertFalse( compareForType( double.class, 42D, 42D, Expression::gt ) ); - assertFalse( compareForType( double.class, 42D, 43D, Expression::gt ) ); + assertTrue( compareForType( double.class, 43D, 42D, + ( a, b ) -> Expression.gt( a, b, typeReference( double.class ) ) ) ); + assertFalse( compareForType( double.class, 42D, 42D, + ( a, b ) -> Expression.gt( a, b, typeReference( double.class ) ) ) ); + assertFalse( compareForType( double.class, 42D, 43D, + ( a, b ) -> Expression.gt( a, b, typeReference( double.class ) ) ) ); } @Test diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v2_3/GeneratedQueryStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v2_3/GeneratedQueryStructure.scala index b16fd671e2643..576fac4bd4858 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v2_3/GeneratedQueryStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v2_3/GeneratedQueryStructure.scala @@ -540,7 +540,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator val keyVar = keyVars.head val times = generator.declare(typeRef[Int], context.namer.newVarName()) generator.assign(times, Expression.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar))) - using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => + using(generator.whileLoop(Expression.gt(times, Expression.constant(0), typeRef[Int]))) { body => block(copy(generator=body)) body.assign(times, Expression.subtractInts(times, Expression.constant(1))) } @@ -555,7 +555,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator Expression.eq(generator.load(intermediate.name()), Expression.constant(null)), Expression.constant(-1), generator.load(intermediate.name()))) - using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => + using(generator.whileLoop(Expression.gt(times, Expression.constant(0), typeRef[Int]))) { body => block(copy(generator=body)) body.assign(times, Expression.subtractInts(times, Expression.constant(1))) } diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_0/GeneratedQueryStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_0/GeneratedQueryStructure.scala index b44faf740ead7..136bfb6d49ac7 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_0/GeneratedQueryStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_0/GeneratedQueryStructure.scala @@ -554,7 +554,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator val keyVar = keyVars.head val times = generator.declare(typeRef[Int], context.namer.newVarName()) generator.assign(times, Expression.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar))) - using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => + using(generator.whileLoop(Expression.gt(times, Expression.constant(0), typeRef[Int]))) { body => block(copy(generator=body)) body.assign(times, Expression.subtractInts(times, Expression.constant(1))) } @@ -569,7 +569,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator Expression.eq(generator.load(intermediate.name()), Expression.constant(null)), Expression.constant(-1), generator.load(intermediate.name()))) - using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => + using(generator.whileLoop(Expression.gt(times, Expression.constant(0), typeRef[Int]))) { body => block(copy(generator=body)) body.assign(times, Expression.subtractInts(times, Expression.constant(1))) } diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/GeneratedQueryStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/GeneratedQueryStructure.scala index 641ba80a917ff..edae8f576fbe3 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/GeneratedQueryStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/GeneratedQueryStructure.scala @@ -581,7 +581,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator val times = generator.declare(typeRef[Int], context.namer.newVarName()) generator.assign(times, Expression .invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar))) - using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => + using(generator.whileLoop(Expression.gt(times, Expression.constant(0), typeRef[Int]))) { body => block(copy(generator=body)) body.assign(times, Expression.subtractInts(times, Expression.constant(1))) } @@ -596,7 +596,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator Expression.eq(generator.load(intermediate.name()), Expression.constant(null)), Expression.constant(-1), generator.load(intermediate.name()))) - using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => + using(generator.whileLoop(Expression.gt(times, Expression.constant(0), typeRef[Int]))) { body => block(copy(generator=body)) body.assign(times, Expression.subtractInts(times, Expression.constant(1))) }