From e4a39389ebf882c815368337cf72f4aa27aa17b2 Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Wed, 23 Nov 2016 09:58:21 +0100 Subject: [PATCH] Key generation with proper equivalence semantics Hash keys should be generated according to proper Cypher equivalence semantics --- .../v3_2/codegen/LogicalPlanConverter.scala | 2 +- .../ir/expressions/AggregateExpression.scala | 109 ++-- .../ir/expressions/ExpressionConverter.scala | 10 +- .../ir/expressions/NodeExpression.scala | 5 +- .../expressions/RelationshipExpression.scala | 4 +- .../v3_2/codegen/spi/MethodStructure.scala | 9 +- .../v3_2/commands/predicates/Equivalent.scala | 2 + .../commands/predicates/EquivalentTest.scala | 5 +- .../codegen/CompiledConversionUtils.java | 98 +-- .../codegen/CompiledEquivalenceUtils.java | 559 +++++++++++++++++- .../internal/spi/v3_1/codegen/Methods.scala | 2 - .../spi/v3_2/codegen/AuxGenerator.scala | 47 +- .../codegen/GeneratedMethodStructure.scala | 82 ++- .../internal/spi/v3_2/codegen/Methods.scala | 2 + .../CompiledEquivalenceUtilsTest.scala | 184 +++++- 15 files changed, 938 insertions(+), 182 deletions(-) diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/LogicalPlanConverter.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/LogicalPlanConverter.scala index 95283342a3d2c..56e68049ac65d 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/LogicalPlanConverter.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/LogicalPlanConverter.scala @@ -103,7 +103,7 @@ object LogicalPlanConverter { val produceResultOpName = context.registerOperator(produceResults) val projections = (produceResults.lhs.get match { // if lhs is projection than we can simply load things that it projected - case _: plans.Projection => produceResults.columns.map(c => c -> LoadVariable(context.getVariable(c))) + case _: plans.Projection => produceResults.columns.map(c => c -> LoadVariable(context.getProjection(c))) // else we have to evaluate all expressions ourselves case _ => produceResults.columns.map(c => c -> ExpressionConverter.createExpressionForVariable(c)(context)) }).toMap diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/AggregateExpression.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/AggregateExpression.scala index 7bb42672c2098..bf3706c43017c 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/AggregateExpression.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/AggregateExpression.scala @@ -22,51 +22,91 @@ package org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.Instruction import org.neo4j.cypher.internal.compiler.v3_2.codegen.{CodeGenContext, MethodStructure, Variable} -trait AggregateExpression { +abstract class AggregateExpression(expression: CodeGenExpression, distinct: Boolean) { def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit def update[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit def continuation(instruction: Instruction): Instruction = instruction -} + def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])(block: MethodStructure[E] => Unit) + (implicit context: CodeGenContext) + + protected def ifNotNull[E](structure: MethodStructure[E])(block: MethodStructure[E] => Unit) + (implicit context: CodeGenContext) = { + expression match { + case NodeExpression(v) => primitiveIfNot(v, structure)(block(_)) + case NodeProjection(v) => primitiveIfNot(v, structure)(block(_)) + case RelationshipExpression(v) => primitiveIfNot(v, structure)(block(_)) + case RelationshipProjection(v) => primitiveIfNot(v, structure)(block(_)) + case _ => + val tmpName = context.namer.newVarName() + structure.assign(tmpName, expression.codeGenType, expression.generateExpression(structure)) + structure.ifNonNullStatement(structure.loadVariable(tmpName)) { body => + if (distinct) { + distinctCondition(structure.loadVariable(tmpName),expression.codeGenType, body) { inner => + block(inner) + } + } + else block(body) + } + } + } + private def primitiveIfNot[E](v: Variable, structure: MethodStructure[E])(block: MethodStructure[E] => Unit) + (implicit context: CodeGenContext) = { + structure.ifNotStatement(structure.equalityExpression(structure.loadVariable(v.name), + structure.constantExpression(Long.box(-1)), + CodeGenType.primitiveInt)) { body => + if (distinct) { + distinctCondition(structure.loadVariable(v.name), CodeGenType.primitiveInt, body) { inner => + block(inner) + } + } + else block(body) + } + } +} -case class SimpleCount(variable: Variable, expression: CodeGenExpression, distinct: Boolean) extends AggregateExpression { +case class SimpleCount(variable: Variable, expression: CodeGenExpression, distinct: Boolean) + extends AggregateExpression(expression, distinct) { def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = { expression.init(generator) generator.assign(variable.name, CodeGenType.primitiveInt, generator.constantExpression(Long.box(0L))) - if (distinct) generator.newSet(setName(variable)) + if (distinct) { + generator.newSet(setName(variable)) + } } def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = { - val tmpName = context.namer.newVarName() - structure.assign(tmpName, CodeGenType.Any, expression.generateExpression(structure)) - structure.ifNonNullStatement(structure.loadVariable(tmpName)) { body => - condition(tmpName, body) { inner => - inner.incrementInteger(variable.name) - } + ifNotNull(structure) { inner => + inner.incrementInteger(variable.name) } } - private def condition[E](name: String, structure: MethodStructure[E])(block: MethodStructure[E] => Unit) = { - if (distinct) { - structure.ifNotStatement(structure.setContains(setName(variable), structure.loadVariable(name))) { inner => - inner.addToSet(setName(variable), inner.loadVariable(name)) - block(inner) - } - } else block(structure) + def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E]) + (block: MethodStructure[E] => Unit) + (implicit context: CodeGenContext) = { + val tmpName = context.namer.newVarName() + structure.newUniqueAggregationKey(tmpName, Map(typeName(variable) -> (valueType -> value))) + structure.ifNotStatement(structure.setContains(setName(variable), structure.loadVariable(tmpName))) { inner => + inner.addToSet(setName(variable), inner.loadVariable(tmpName)) + block(inner) + } } private def setName(variable: Variable) = variable.name + "Set" + + private def typeName(variable: Variable) = variable.name + "Type" } class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpression, - groupingKey: Iterable[Variable], distinct: Boolean) extends AggregateExpression { + groupingKey: Iterable[Variable], distinct: Boolean) extends AggregateExpression(expression, distinct) { private var mapName: String = null + private var keyVar: String = null override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = { expression.init(generator) @@ -75,31 +115,24 @@ class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpres } override def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = { - val localVar = context.namer.newVarName() - structure.aggregationMapGet(mapName, localVar, createKey(structure)) - condition(structure, expression.generateExpression(structure)) { inner => - inner.incrementInteger(localVar) + keyVar = context.namer.newVarName() + val valueVar = context.namer.newVarName() + structure.aggregationMapGet(mapName, valueVar, createKey(structure), keyVar) + ifNotNull(structure) { inner => + inner.incrementInteger(valueVar) } - structure.aggregationMapPut(mapName, createKey(structure), - structure.loadVariable(localVar)) + structure.aggregationMapPut(mapName, createKey(structure), keyVar, structure.loadVariable(valueVar)) } - private def condition[E](structure: MethodStructure[E], value: E)(block: MethodStructure[E] => Unit)(implicit context: CodeGenContext) = { - if (distinct) { - structure.ifNonNullStatement(value) { i1 => - i1.checkDistinct(mapName, createKey(structure), value) { i2 => - block(i2) - } - } - } else { - structure.ifNonNullStatement(value) { inner => - block(inner) - } + def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])(block: MethodStructure[E] => Unit) + (implicit context: CodeGenContext) = { + structure.checkDistinct(mapName, createKey(structure), keyVar, value, expression.codeGenType) { inner => + block(inner) } } - private def createKey[E](body: MethodStructure[E])(implicit context: CodeGenContext): IndexedSeq[(CodeGenType, E)] = { - groupingKey.map(e => (e.codeGenType, body.loadVariable(e.name))).toIndexedSeq + private def createKey[E](body: MethodStructure[E])(implicit context: CodeGenContext) = { + groupingKey.map(e => e.name -> (e.codeGenType -> body.loadVariable(e.name))).toMap } override def continuation(instruction: Instruction): Instruction = new Instruction { @@ -110,7 +143,7 @@ class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpres override def body[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit = { generator.trace(opName) { body => - val keyArg = groupingKey.map(k => k.name -> k.codeGenType).toIndexedSeq + val keyArg = groupingKey.map(k => k.name -> k.codeGenType).toMap body.aggregationMapIterate(mapName, keyArg, variable.name) { inner => instruction.body(inner) } diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/ExpressionConverter.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/ExpressionConverter.scala index 5c7e2ca08f3cc..c2a0674963e36 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/ExpressionConverter.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/ExpressionConverter.scala @@ -86,10 +86,14 @@ object ExpressionConverter { expression match { case node@ast.Variable(name) if context.semanticTable.isNode(node) => - NodeProjection(context.getVariable(name)) + val variable = context.getProjection(name) + if (variable.codeGenType.isPrimitive) NodeProjection(variable) + else LoadVariable(variable) case rel@ast.Variable(name) if context.semanticTable.isRelationship(rel) => - RelationshipProjection(context.getVariable(name)) + val variable = context.getProjection(name) + if (variable.codeGenType.isPrimitive) RelationshipProjection(variable) + else LoadVariable(variable) case e => expressionConverter(e, createProjection) } @@ -167,6 +171,8 @@ object ExpressionConverter { case f: ast.FunctionInvocation => functionConverter(f, callback) + case ast.Variable(name) => LoadVariable(context.getProjection(name)) + case other => throw new CantCompileQueryException(s"Expression of $other not yet supported") } } diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/NodeExpression.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/NodeExpression.scala index 58d59e2fd4b4c..e9d325ba83ead 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/NodeExpression.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/NodeExpression.scala @@ -25,6 +25,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols import org.neo4j.cypher.internal.frontend.v3_2.symbols._ case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression { + assert(nodeIdVar.codeGenType.ct == symbols.CTNode) override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {} @@ -38,5 +39,5 @@ case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression { override def nullable(implicit context: CodeGenContext) = nodeIdVar.nullable - override def codeGenType(implicit context: CodeGenContext) = if (nullable) CodeGenType(CTNode, ReferenceType) else CodeGenType.primitiveNode -} + override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTNode, ReferenceType) +} \ No newline at end of file diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/RelationshipExpression.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/RelationshipExpression.scala index 36b7a27f5dfa2..06717416e1e66 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/RelationshipExpression.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/ir/expressions/RelationshipExpression.scala @@ -39,7 +39,5 @@ case class RelationshipExpression(relId: Variable) extends CodeGenExpression { override def nullable(implicit context: CodeGenContext) = relId.nullable - override def codeGenType(implicit context: CodeGenContext) = - if (nullable) CodeGenType(CTRelationship, ReferenceType) - else CodeGenType.primitiveRel + override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTRelationship, ReferenceType) } diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/spi/MethodStructure.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/spi/MethodStructure.scala index 2a17ba06ab891..1eb565e64337d 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/spi/MethodStructure.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/codegen/spi/MethodStructure.scala @@ -67,11 +67,12 @@ trait MethodStructure[E] { def newSet(name: String) def setContains(name: String, value: E): E def addToSet(name: String, value: E): Unit + def newUniqueAggregationKey(varName: String, structure: Map[String, (CodeGenType,E)]): Unit def newAggregationMap(name: String, keyTypes: IndexedSeq[CodeGenType], distinct: Boolean) - def aggregationMapGet(name: String, varName: String, key: IndexedSeq[(CodeGenType,E)]) - def aggregationMapPut(name: String, key: IndexedSeq[(CodeGenType,E)], value: E): Unit - def aggregationMapIterate(name: String, key: IndexedSeq[(String,CodeGenType)], valueVar: String)(block: MethodStructure[E] => Unit): Unit - def checkDistinct(name: String, key: IndexedSeq[(CodeGenType, E)], value: E)(block: MethodStructure[E] => Unit) + def aggregationMapGet(name: String, varName: String, key: Map[String,(CodeGenType,E)], keyVar: String) + def aggregationMapPut(name: String, key: Map[String,(CodeGenType,E)], keyVar: String, value: E): Unit + def aggregationMapIterate(name: String, key: Map[String,CodeGenType], valueVar: String)(block: MethodStructure[E] => Unit): Unit + def checkDistinct(name: String, key: Map[String,(CodeGenType, E)], keyVar: String, value: E, valueType: CodeGenType)(block: MethodStructure[E] => Unit) def castToCollection(value: E): E diff --git a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/Equivalent.scala b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/Equivalent.scala index 3a929f44f5045..dd00bc6a3dbda 100644 --- a/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/Equivalent.scala +++ b/community/cypher/cypher-compiler-3.2/src/main/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/Equivalent.scala @@ -81,6 +81,8 @@ class Equivalent(protected val eagerizedValue: Any, val originalValue: Any) exte length * (31 * hashCode(n.head) + hashCode(n(length / 2)) * 31 + hashCode(n.last)) else EMPTY_LIST + case m: Map[_,_] => + m.hashCode() case x => x.hashCode() } } diff --git a/community/cypher/cypher-compiler-3.2/src/test/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/EquivalentTest.scala b/community/cypher/cypher-compiler-3.2/src/test/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/EquivalentTest.scala index 44873a907b01b..5a4127097c4d1 100644 --- a/community/cypher/cypher-compiler-3.2/src/test/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/EquivalentTest.scala +++ b/community/cypher/cypher-compiler-3.2/src/test/scala/org/neo4j/cypher/internal/compiler/v3_2/commands/predicates/EquivalentTest.scala @@ -25,9 +25,11 @@ import java.util.Collections.singletonMap import org.neo4j.cypher.internal.compiler.v3_2.{CRS, GeographicPoint} import org.neo4j.cypher.internal.frontend.v3_2.test_helpers.CypherFunSuite -import org.neo4j.graphdb.spatial.{Coordinate, Point, CRS => JavaCRS} +import org.neo4j.graphdb.spatial.{CRS => JavaCRS, Coordinate, Point} class EquivalentTest extends CypherFunSuite { + shouldNotMatch(23.toByte, 23.5) + shouldMatch(1.0, 1L) shouldMatch(1.0, 1) shouldMatch(1.0, 1.0) @@ -58,7 +60,6 @@ class EquivalentTest extends CypherFunSuite { shouldMatch(43.toByte, 43.toLong) shouldMatch(23.toByte, 23.0d) shouldMatch(23.toByte, 23.0f) - shouldNotMatch(23.toByte, 23.5) shouldNotMatch(23.toByte, 23.5f) shouldMatch(11.toShort, 11.toByte) shouldMatch(42.toShort, 42.toShort) diff --git a/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledConversionUtils.java b/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledConversionUtils.java index 57f573eed4bbc..5bfec60274a32 100644 --- a/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledConversionUtils.java +++ b/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledConversionUtils.java @@ -23,14 +23,11 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.Iterator; -import java.util.List; import org.neo4j.cypher.internal.frontend.v3_2.CypherTypeException; import org.neo4j.cypher.internal.frontend.v3_2.IncomparableValuesException; import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Relationship; -import org.neo4j.helpers.MathUtil; // Class with static methods used by compiled execution plans public abstract class CompiledConversionUtils @@ -117,100 +114,7 @@ public static Boolean equals( Object lhs, Object rhs ) throw new IncomparableValuesException( lhs.getClass().getSimpleName(), rhs.getClass().getSimpleName() ); } - //if floats compare float values if integer types, - //compare long values - if ( lhs instanceof Number && rhs instanceof Number ) - { - if ( (lhs instanceof Double || lhs instanceof Float) - && (rhs instanceof Double || rhs instanceof Float) ) - { - double left = ((Number) lhs).doubleValue(); - double right = ((Number) rhs).doubleValue(); - return left == right; - } - else if ( (lhs instanceof Double || lhs instanceof Float) ) - { - double left = ((Number) lhs).doubleValue(); - long right = ((Number) rhs).longValue(); - return MathUtil.numbersEqual( left, right ); - } - else if ( (rhs instanceof Double || rhs instanceof Float) ) - { - long left = ((Number) lhs).longValue(); - double right = ((Number) rhs).doubleValue(); - return MathUtil.numbersEqual( right, left ); - } - - //evertyhing else is long from cyphers point-of-view - long left = ((Number) lhs).longValue(); - long right = ((Number) rhs).longValue(); - return left == right; - } - else if (lhs.getClass().isArray() && rhs.getClass().isArray() ) - { - int length = Array.getLength( lhs ); - if ( length != Array.getLength( rhs ) ) - { - return false; - } - for ( int i = 0; i < length; i++ ) - { - if (!equals( Array.get( lhs, i ), Array.get(rhs, i) )) - { - return false; - } - } - return true; - } - else if (lhs.getClass().isArray() && rhs instanceof List ) - { - return compareArrayAndList( lhs, (List) rhs ); - } - else if (lhs instanceof List && rhs.getClass().isArray()) - { - return compareArrayAndList( rhs, (List) lhs ); - } - else if (lhs instanceof List && rhs instanceof List) - { - List lhsList = (List) lhs; - List rhsList = (List) rhs; - if (lhsList.size() != rhsList.size()) - { - return false; - } - Iterator lhsIterator = lhsList.iterator(); - Iterator rhsIterator = rhsList.iterator(); - while (lhsIterator.hasNext()) - { - if (!equals( lhsIterator.next(), rhsIterator.next() )) - { - return false; - } - } - return true; - } - - //for everything else call equals - return lhs.equals( rhs ); - } - - private static Boolean compareArrayAndList(Object array, List list) - { - int length = Array.getLength( array ); - if ( length != list.size() ) - { - return false; - } - - int i = 0; - for ( Object o : list ) - { - if (!equals( o, Array.get(array, i++) )) - { - return false; - } - } - return true; + return CompiledEquivalenceUtils.equals( lhs, rhs ); } public static Boolean or( Object lhs, Object rhs ) diff --git a/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtils.java b/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtils.java index 3d6d627e1dd9f..4dfda072c11f3 100644 --- a/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtils.java +++ b/community/cypher/cypher/src/main/java/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtils.java @@ -1,8 +1,563 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ package org.neo4j.cypher.internal.codegen; +import java.lang.reflect.Array; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.neo4j.helpers.MathUtil; + /** - * Created by pontusmelke on 2016-11-18. + * Helper class for dealing with equivalence an hash code in compiled code. + * + * Note this class contains a lot of duplicated code in order to minimize boxing. */ -public class CompiledEquivalenceUtils +public final class CompiledEquivalenceUtils { + /** + * Do not instantiate this class + */ + private CompiledEquivalenceUtils() + { + throw new UnsupportedOperationException( ); + } + + /** + * Checks if two objects are equal according to Cypher semantics + * @param lhs the left-hand side to check + * @param rhs the right-hand sid to check + * @return true if the two objects are equal otherwise false + */ + @SuppressWarnings( "unchecked" ) + public static boolean equals( Object lhs, Object rhs ) + { + if ( lhs == rhs ) + { + return true; + } + else if ( lhs == null || rhs == null ) + { + return false; + } + //if floats compare float values if integer types, + //compare long values + else if ( lhs instanceof Number && rhs instanceof Number ) + { + if ( lhs instanceof Double && rhs instanceof Float ) + { + return mixedFloatEquality( (Float) rhs, (Double) lhs ); + } + else if ( lhs instanceof Float && rhs instanceof Double ) + { + return mixedFloatEquality( (Float) lhs, (Double) rhs ); + } + else if ( (lhs instanceof Double || lhs instanceof Float) + && (rhs instanceof Double || rhs instanceof Float) ) + { + double left = ((Number) lhs).doubleValue(); + double right = ((Number) rhs).doubleValue(); + return left == right; + } + else if ( (lhs instanceof Double || lhs instanceof Float) ) + { + double left = ((Number) lhs).doubleValue(); + long right = ((Number) rhs).longValue(); + return MathUtil.numbersEqual( left, right ); + } + else if ( (rhs instanceof Double || rhs instanceof Float) ) + { + long left = ((Number) lhs).longValue(); + double right = ((Number) rhs).doubleValue(); + return MathUtil.numbersEqual( right, left ); + } + + //everything else is a long from cyphers point-of-view + long left = ((Number) lhs).longValue(); + long right = ((Number) rhs).longValue(); + return left == right; + } + else if ( lhs instanceof Character && rhs instanceof String ) + { + return lhs.toString().equals( rhs ); + } + else if ( lhs instanceof String && rhs instanceof Character ) + { + return lhs.equals( rhs.toString() ); + } + else if ( lhs.getClass().isArray() && rhs.getClass().isArray() ) + { + int length = Array.getLength( lhs ); + if ( length != Array.getLength( rhs ) ) + { + return false; + } + for ( int i = 0; i < length; i++ ) + { + if ( !equals( Array.get( lhs, i ), Array.get( rhs, i ) ) ) + { + return false; + } + } + return true; + } + else if ( lhs.getClass().isArray() && rhs instanceof List ) + { + return compareArrayAndList( lhs, (List) rhs ); + } + else if ( lhs instanceof List && rhs.getClass().isArray() ) + { + return compareArrayAndList( rhs, (List) lhs ); + } + else if ( lhs instanceof List && rhs instanceof List ) + { + List lhsList = (List) lhs; + List rhsList = (List) rhs; + if ( lhsList.size() != rhsList.size() ) + { + return false; + } + Iterator lhsIterator = lhsList.iterator(); + Iterator rhsIterator = rhsList.iterator(); + while ( lhsIterator.hasNext() ) + { + if ( !equals( lhsIterator.next(), rhsIterator.next() ) ) + { + return false; + } + } + return true; + } + else if (lhs instanceof Map && rhs instanceof Map) + { + Map rMap = (Map) rhs; + Map lMap = (Map) lhs; + if (rMap.size() != lMap.size()) + { + return false; + } + for ( Map.Entry e : rMap.entrySet() ) + { + String key = e.getKey(); + Object value = e.getValue(); + if ( value == null ) + { + if ( !(lMap.get( key ) == null && lMap.containsKey( key )) ) + { + return false; + } + } + else + { + if ( !equals( value, lMap.get( key ) ) ) + { + return false; + } + } + } + return true; + } + + //for everything else call equals + return lhs.equals( rhs ); + } + + /** + * Calculates hash code of a given object + * @param element the element to calculate hash code for + * @return the hash code of the given object + */ + @SuppressWarnings( "unchecked" ) + public static int hashCode( Object element ) + { + if ( element == null ) + { + return 0; + } + else if ( element instanceof Number ) + { + return hashCode( ((Number) element).longValue() ); + } + else if ( element instanceof Character ) + { + return hashCode( (char) element ); + } + else if ( element instanceof Boolean ) + { + return hashCode( (boolean) element ); + } + else if ( element instanceof List ) + { + return hashCode( (List) element ); + } + else if ( element instanceof Map ) + { + return hashCode( ((Map) element) ); + } + else if ( element instanceof Object[] ) + { + return hashCode( (Object[]) element ); + } + else if ( element instanceof byte[] ) + { + return hashCode( (byte[]) element ); + } + else if ( element instanceof short[] ) + { + return hashCode( (short[]) element ); + } + else if ( element instanceof int[] ) + { + return hashCode( (int[]) element ); + } + else if ( element instanceof long[] ) + { + return hashCode( (long[]) element ); + } + else if ( element instanceof char[] ) + { + return hashCode( (char[]) element ); + } + else if ( element instanceof float[] ) + { + return hashCode( (float[]) element ); + } + else if ( element instanceof double[] ) + { + return hashCode( (double[]) element ); + } + else if ( element instanceof boolean[] ) + { + return hashCode( (boolean[]) element ); + } + else + { + return element.hashCode(); + } + } + + /** + * Calculate hash code of a map + * @param map the element to calculate hash code for + * @return the hash code of the given map + */ + public static int hashCode( Map map ) + { + int h = 0; + for ( Map.Entry next : map.entrySet() ) + { + String k = next.getKey(); + Object v = next.getValue(); + h += (k == null ? 0 : k.hashCode()) ^ (v == null ? 0 : hashCode( v )); + } + return h; + } + + /** + * Calculate hash code of a long value + * @param value the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( long value ) + { + return Long.hashCode( value ); + } + + /** + * Calculate hash code of a boolean value + * @param value the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( boolean value ) + { + return Boolean.hashCode( value ); + } + + /** + * Calculate hash code of a char value + * @param value the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( char value ) + { + return Character.hashCode( value ); + } + + /** + * Calculate hash code of a char[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( char[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + /** + * Calculate hash code of a list value + * @param list the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( List list ) + { + int len = list.size(); + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( list.get( 0 ) ); + case 2: + return 31 * hashCode( list.get( 0 ) ) + hashCode( list.get( 1 ) ); + case 3: + return (31 * hashCode( list.get( 0 ) ) + hashCode( list.get( 1 ) )) * 31 + hashCode( list.get( 2 ) ); + default: + return len * (31 * hashCode( list.get( 0 ) ) + hashCode( list.get( len / 2 ) ) * 31 + + hashCode( list.get( len - 1 ) )); + } + } + + /** + * Calculate hash code of a Object[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( Object[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + /** + * Calculate hash code of a byte[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( byte[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + /** + * Calculate hash code of a short[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( short[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + /** + * Calculate hash code of a int[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( int[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + /** + * Calculate hash code of a long[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( long[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + /** + * Calculate hash code of a float[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( float[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( (long) array[0] ); + case 2: + return 31 * hashCode( (long) array[0] ) + hashCode( (long) array[1] ); + case 3: + return (31 * hashCode( (long) array[0] ) + hashCode( (long) array[1] )) * 31 + hashCode( (long) array[2] ); + default: + return len * (31 * hashCode( (long) array[0] ) + hashCode( (long) array[len / 2] ) * 31 + + hashCode( (long) array[len - 1] )); + } + } + + /** + * Calculate hash code of a double[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( double[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( (long) array[0] ); + case 2: + return 31 * hashCode( (long) array[0] ) + hashCode( (long) array[1] ); + case 3: + return (31 * hashCode( (long) array[0] ) + hashCode( (long) array[1] )) * 31 + hashCode( (long) array[2] ); + default: + return len * (31 * hashCode( (long) array[0] ) + hashCode( (long) array[len / 2] ) * 31 + + hashCode( (long) array[len - 1] )); + } + } + + /** + * Calculate hash code of a boolean[] value + * @param array the value to compute hash code for + * @return the hash code of the given value + */ + public static int hashCode( boolean[] array ) + { + int len = array.length; + switch ( len ) + { + case 0: + return 42; + case 1: + return hashCode( array[0] ); + case 2: + return 31 * hashCode( array[0] ) + hashCode( array[1] ); + case 3: + return (31 * hashCode( array[0] ) + hashCode( array[1] )) * 31 + hashCode( array[2] ); + default: + return len * (31 * hashCode( array[0] ) + hashCode( array[len / 2] ) * 31 + hashCode( array[len - 1] )); + } + } + + private static Boolean compareArrayAndList( Object array, List list ) + { + int length = Array.getLength( array ); + if ( length != list.size() ) + { + return false; + } + + int i = 0; + for ( Object o : list ) + { + if ( !equals( o, Array.get( array, i++ ) ) ) + { + return false; + } + } + return true; + } + + private static boolean mixedFloatEquality( Float a, Double b ) + { + return a.doubleValue() == b || ( + (long) Math.rint( a.doubleValue() ) == b.longValue() && + (long) Math.rint( b ) == a.longValue()); + } } + diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/Methods.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/Methods.scala index e48c33e1d9869..9c4419e57c0c4 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/Methods.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/codegen/Methods.scala @@ -66,8 +66,6 @@ object Methods { val mathCastToLong = method[CompiledMathHelper, Long]("transformToLong", typeRef[Object]) val mapGet = method[util.Map[String, Object], Object]("get", typeRef[Object]) val mapContains = method[util.Map[String, Object], Boolean]("containsKey", typeRef[Object]) - val setContains = method[util.Set[Object], Boolean]("contains", typeRef[Object]) - val setAdd = method[util.Set[Object], Boolean]("add", typeRef[Object]) val labelGetForName = method[ReadOperations, Int]("labelGetForName", typeRef[String]) val propertyKeyGetForName = method[ReadOperations, Int]("propertyKeyGetForName", typeRef[String]) val coerceToPredicate = method[CompiledConversionUtils, Boolean]("coerceToPredicate", typeRef[Object]) diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala index da0a8c05e6c53..1c205b2261ab1 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/AuxGenerator.scala @@ -19,7 +19,11 @@ */ package org.neo4j.cypher.internal.spi.v3_2.codegen -import org.neo4j.codegen.{CodeGenerator, TypeReference} + +import org.neo4j.codegen.FieldReference.field +import org.neo4j.codegen.Parameter.param +import org.neo4j.codegen._ +import org.neo4j.cypher.internal.codegen.CompiledEquivalenceUtils import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions.CodeGenType import org.neo4j.cypher.internal.compiler.v3_2.helpers._ @@ -27,14 +31,15 @@ import scala.collection.mutable class AuxGenerator(val packageName: String, val generator: CodeGenerator) { - import GeneratedQueryStructure.lowerType + import GeneratedQueryStructure.{lowerType, method, typeRef} + private val types: scala.collection.mutable.Map[Map[String, CodeGenType], TypeReference] = mutable.Map.empty private var nameId = 0 def typeReference(structure: Map[String, CodeGenType]): TypeReference = { - types.getOrElseUpdate(structure, using(generator.generateClass(packageName, newName())) { clazz => + types.getOrElseUpdate(structure, using(generator.generateClass(packageName, newValueTypeName())) { clazz => structure.foreach { case (fieldName, fieldType: CodeGenType) => clazz.field(lowerType(fieldType), fieldName) } @@ -42,9 +47,43 @@ class AuxGenerator(val packageName: String, val generator: CodeGenerator) { }) } - private def newName() = { + def hashKey(structure: Map[String, CodeGenType]): TypeReference = { + types.getOrElseUpdate(structure, using(generator.generateClass(packageName, newKeyTypeName())) { clazz => + structure.foreach { + case (fieldName, fieldType) => clazz.field(lowerType(fieldType), fieldName) + } + clazz.field(classOf[Int], "hashCode") + clazz.generate(MethodTemplate.method(classOf[Int], "hashCode") + .returns(ExpressionTemplate.get(ExpressionTemplate.self(clazz.handle()), classOf[Int], "hashCode")).build()) + + using(clazz.generateMethod(typeRef[Boolean], "equals", param(typeRef[Object], "other"))) {body => + val otherName = s"other$nameId" + body.assign(body.declare(clazz.handle(), otherName), Expression.cast(clazz.handle(), body.load("other"))) + + body.returns(structure.map { + case (fieldName, fieldType) => + val fieldReference = field(clazz.handle(), lowerType(fieldType), fieldName) + Expression.invoke(method[CompiledEquivalenceUtils, Boolean]("equals", typeRef[Object], typeRef[Object]), + + Expression.box( + Expression.get(body.self(), fieldReference)), + Expression.box( + Expression.get(body.load(otherName), fieldReference))) + }.reduceLeft(Expression.and)) + } + clazz.handle() + }) + } + + private def newValueTypeName() = { val name = "ValueType" + nameId nameId += 1 name } + + private def newKeyTypeName() = { + val name = "KeyType" + nameId + nameId += 1 + name + } } diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala index a0a4fb0095e09..d5a6813b8aa6a 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/GeneratedMethodStructure.scala @@ -25,7 +25,7 @@ import org.neo4j.codegen.Expression.{not, or, _} import org.neo4j.codegen.MethodReference.methodReference import org.neo4j.codegen._ import org.neo4j.collection.primitive.hopscotch.LongKeyIntValueTable -import org.neo4j.collection.primitive.{PrimitiveLongIntMap, PrimitiveLongIterator, PrimitiveLongObjectMap} +import org.neo4j.collection.primitive._ import org.neo4j.cypher.internal.codegen.CompiledConversionUtils.CompositeKey import org.neo4j.cypher.internal.codegen._ import org.neo4j.cypher.internal.compiler.v3_2.ast.convert.commands.DirectionConverter.toGraphDb @@ -459,6 +459,28 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux generator.expression(pop(invoke(loadVariable(name), Methods.setAdd, value))) } + override def newUniqueAggregationKey(varName: String, structure: Map[String, (CodeGenType, Expression)]) = { + val typ = aux.hashKey(structure.map { + case (n, (t,_)) => n -> t + }) + val local = generator.declare(typ, varName) + locals += varName -> local + generator.assign(local, createNewInstance(typ)) + structure.foreach { + case (n, (t, e)) => + val field = FieldReference.field(typ, lowerType(t), n) + generator.put(generator.load(varName), field, e) + } + if (structure.size == 1) { + generator.put(generator.load(varName), FieldReference.field(typ, typeRef[Int], "hashCode"), + invoke(method[CompiledEquivalenceUtils, Int]("hashCode", typeRef[Object]), + box(structure.values.head._2, structure.values.head._1))) + } else { + generator.put(generator.load(varName), FieldReference.field(typ, typeRef[Int], "hashCode"), + invoke(method[CompiledEquivalenceUtils, Int]("hashCode", typeRef[Array[Object]]), + newArray(typeRef[Object], structure.values.map(_._2).toSeq: _*))) + } + } override def newAggregationMap(name: String, keyTypes: IndexedSeq[CodeGenType], distinct: Boolean) = { if (keyTypes.size == 1) { @@ -471,6 +493,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux block.expression( invoke(block.load(name), method[PrimitiveLongLongMap, Unit]("close")))) if (distinct) { + generator.assign(generator.declare(typeRef[PrimitiveLongObjectMap[util.HashSet[Object]]], name + "Seen"), invoke(method[Primitive, PrimitiveLongObjectMap[util.HashSet[Object]]]("longObjectMap") )) } @@ -487,32 +510,33 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux } } - override def aggregationMapGet(mapName: String, varName: String, key: IndexedSeq[(CodeGenType, Expression)]) = { - val local = generator.declare(typeRef[Long], varName) - locals += varName -> local + override def aggregationMapGet(mapName: String, valueVarName: String, key: Map[String,(CodeGenType, Expression)], keyVar: String) = { + val local = generator.declare(typeRef[Long], valueVarName) + locals += valueVarName -> local if (key.size == 1) { - val (keyType, keyExpression) = key.head + val (_, (keyType, keyExpression)) = key.head keyType match { case CodeGenType(_, IntType) => generator.assign(local, invoke(generator.load(mapName), method[PrimitiveLongLongMap, Long]("get", typeRef[ Long]), keyExpression)) - using(generator.ifStatement(equal(generator.load(varName), constant(Long.box(-1L))))) { body => + using(generator.ifStatement(equal(generator.load(valueVarName), constant(Long.box(-1L))))) { body => body.assign(local, constant(Long.box(0L))) } case _ => + newUniqueAggregationKey(keyVar, key) generator.assign(local,unbox( cast(typeRef[java.lang.Long], invoke(generator.load(mapName), method[util.HashMap[Object, java.lang.Long], Object]("getOrDefault", typeRef[Object], typeRef[Object]), - keyExpression, box(constant(Long.box(0L)), CodeGenType.primitiveInt))), CodeGenType(symbols.CTInteger, ReferenceType))) + generator.load(keyVar), box(constant(Long.box(0L)), CodeGenType.primitiveInt))), CodeGenType(symbols.CTInteger, ReferenceType))) } } else { ??? } } - override def checkDistinct(name: String, key: IndexedSeq[(CodeGenType, Expression)], value: Expression)(block: MethodStructure[Expression] => Unit) = { + override def checkDistinct(name: String, key: Map[String, (CodeGenType, Expression)], keyVar: String, value: Expression, valueType: CodeGenType)(block: MethodStructure[Expression] => Unit) = { if (key.size == 1) { - val (keyType, keyExpression) = key.head + val (_, (keyType, keyExpression)) = key.head keyType match { case CodeGenType(_, IntType) => val tmp = context.namer.newVarName() @@ -536,51 +560,57 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux generator.expression(pop(invoke(generator.load(tmp), method[util.HashSet[Object], Boolean]("add", typeRef[Object]), value))) case _ => - val tmp = context.namer.newVarName() - val localVariable = generator.declare(typeRef[util.HashSet[Object]], tmp) + val setVar = context.namer.newVarName() + val localVariable = generator.declare(typeRef[util.HashSet[Object]], setVar) + if (!locals.contains(keyVar)) newUniqueAggregationKey(keyVar, key) + generator.assign(localVariable, cast(typeRef[util.HashSet[Object]], invoke(generator.load(name + "Seen"), - method[util.HashMap[Object, util.HashSet[Object]], Object]("get", typeRef[Object]), keyExpression))) - using(generator.ifNullStatement(generator.load(tmp))) { inner => + method[util.HashMap[Object, util.HashSet[Object]], Object]("get", typeRef[Object]), generator.load(keyVar)))) + using(generator.ifNullStatement(generator.load(setVar))) { inner => + inner.assign(localVariable, createNewInstance(typeRef[util.HashSet[Object]])) inner.expression(pop(invoke(generator.load(name + "Seen"), method[util.HashMap[Object, util.HashSet[Object]], Object]("put", typeRef[Object], typeRef[Object]), - keyExpression, inner.load(tmp)))) + generator.load(keyVar), inner.load(setVar)))) } + val valueVar = context.namer.newVarName() + newUniqueAggregationKey(valueVar, Map(context.namer.newVarName() -> (valueType -> value) )) - using(generator.ifNotStatement(invoke(generator.load(tmp), - method[util.HashSet[Object], Boolean]("contains", typeRef[Object]), value))) + using(generator.ifNotStatement(invoke(generator.load(setVar), + method[util.HashSet[Object], Boolean]("contains", typeRef[Object]), generator.load(valueVar)))) { inner => block(copy(generator = inner)) + inner.expression(pop(invoke(generator.load(setVar), + method[util.HashSet[Object], Boolean]("add", typeRef[Object]), generator.load(valueVar)))) } - generator.expression(pop(invoke(generator.load(tmp), - method[util.HashSet[Object], Boolean]("add", typeRef[Object]), value))) } } else ??? } - override def aggregationMapPut(name: String, key: IndexedSeq[(CodeGenType, Expression)], value: Expression) = { + override def aggregationMapPut(name: String, key: Map[String, (CodeGenType, Expression)], keyVar: String, value: Expression) = { if (key.size == 1) { - val (keyType, keyExpression) = key.head + val (_,(keyType, keyExpression)) = key.head keyType match { case CodeGenType(_, IntType) => generator.expression(pop(invoke(generator.load(name), method[PrimitiveLongLongMap, Long]("put", typeRef[Long], typeRef[Long]), keyExpression, value))) case _ => + if (!locals.contains(keyVar)) newUniqueAggregationKey(keyVar, key) generator.expression(pop(invoke(generator.load(name), method[util.HashMap[Object, java.lang.Long], Object]("put", typeRef[Object], typeRef[Object]), - keyExpression, box(value, CodeGenType.primitiveInt)))) + generator.load(keyVar), box(value, CodeGenType.primitiveInt)))) } } else { ??? } } - override def aggregationMapIterate(name: String, key: IndexedSeq[(String, CodeGenType)], valueVar: String) + override def aggregationMapIterate(name: String, key: Map[String, CodeGenType], valueVar: String) (block: (MethodStructure[Expression]) => Unit) = { if (key.size == 1) { val (keyName, keyType) = key.head @@ -603,6 +633,7 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux val localName = context.namer.newVarName() val next = context.namer.newVarName() val variable = generator.declare(typeRef[java.util.Iterator[java.util.Map.Entry[Object,java.lang.Long]]], localName) + val keyStruct = aux.hashKey(key) generator.assign(variable, invoke(invoke(generator.load(name), method[util.HashMap[Object, java.lang.Long], @@ -614,7 +645,10 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux cast(typeRef[util.Map.Entry[Object, java.lang.Long]], invoke(body.load(localName), method[java.util.Iterator[java.util.Map.Entry[Object,java.lang.Long]], Object]("next")) )) body.assign(body.declare(lowerType(keyType), keyName), - invoke(body.load(next), method[java.util.Map.Entry[Object,java.lang.Long], Object]("getKey"))) + Expression.get( + cast(keyStruct, + invoke(body.load(next), method[java.util.Map.Entry[Object,java.lang.Long], Object]("getKey"))), + FieldReference.field(keyStruct, lowerType(keyType), keyName))) body.assign(body.declare(typeRef[Long], valueVar), unbox(cast(typeRef[java.lang.Long], @@ -994,4 +1028,4 @@ class GeneratedMethodStructure(val fields: Fields, val generator: CodeBlock, aux private def field(structure: Map[String, CodeGenType], fieldType: CodeGenType, fieldName: String) = FieldReference.field(aux.typeReference(structure), lowerType(fieldType), fieldName) -} \ No newline at end of file +} diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Methods.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Methods.scala index cc8f28fcdce0b..76458220b06fa 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Methods.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_2/codegen/Methods.scala @@ -66,6 +66,8 @@ object Methods { val mathCastToLong = method[CompiledMathHelper, Long]("transformToLong", typeRef[Object]) val mapGet = method[util.Map[String, Object], Object]("get", typeRef[Object]) val mapContains = method[util.Map[String, Object], Boolean]("containsKey", typeRef[Object]) + val setContains = method[util.Set[Object], Boolean]("contains", typeRef[Object]) + val setAdd = method[util.Set[Object], Boolean]("add", typeRef[Object]) val labelGetForName = method[ReadOperations, Int]("labelGetForName", typeRef[String]) val propertyKeyGetForName = method[ReadOperations, Int]("propertyKeyGetForName", typeRef[String]) val coerceToPredicate = method[CompiledConversionUtils, Boolean]("coerceToPredicate", typeRef[Object]) diff --git a/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtilsTest.scala b/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtilsTest.scala index 64441736b5040..4763098db1a6b 100644 --- a/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtilsTest.scala +++ b/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/codegen/CompiledEquivalenceUtilsTest.scala @@ -1,8 +1,190 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ package org.neo4j.cypher.internal.codegen +import java.util +import java.util.Arrays._ +import java.util.Collections._ + +import org.neo4j.cypher.internal.frontend.v3_2.test_helpers.CypherFunSuite +import org.neo4j.graphdb.spatial.{CRS, Coordinate, Point} + +import scala.collection.JavaConverters._ + /** * Created by pontusmelke on 2016-11-19. */ -class CompiledEquivalenceUtilsTest { +class CompiledEquivalenceUtilsTest extends CypherFunSuite { + shouldMatch(1.0, 1L) + shouldMatch(1.0, 1) + shouldMatch(1.0, 1.0) + shouldMatch(0.9, 0.9) + shouldMatch(Math.PI, Math.PI) + shouldMatch(1.1, 1.1) + shouldMatch(0, 0) + shouldMatch(Integer.MAX_VALUE.toDouble, Integer.MAX_VALUE) + shouldMatch(Long.MaxValue.toDouble, Long.MaxValue) + shouldMatch(Int.MaxValue.toDouble + 1, Int.MaxValue.toLong + 1) + shouldMatch(Double.PositiveInfinity, Double.PositiveInfinity) + shouldMatch(Double.NegativeInfinity, Double.NegativeInfinity) + shouldMatch(true, true) + shouldMatch(false, false) + shouldNotMatch(true, false) + shouldNotMatch(false, true) + shouldNotMatch(true, 0) + shouldNotMatch(false, 0) + shouldNotMatch(true, 1) + shouldNotMatch(false, 1) + shouldNotMatch(false, "false") + shouldNotMatch(true, "true") + shouldMatch(42.toByte, 42.toByte) + shouldMatch(42.toByte, 42.toShort) + shouldNotMatch(42.toByte, 42 + 256) + shouldMatch(43.toByte, 43) + shouldMatch(43.toByte, 43.toLong) + shouldMatch(23.toByte, 23.0d) + shouldMatch(23.toByte, 23.0f) + shouldNotMatch(23.toByte, 23.5) + shouldNotMatch(23.toByte, 23.5f) + shouldMatch(11.toShort, 11.toByte) + shouldMatch(42.toShort, 42.toShort) + shouldNotMatch(42.toShort, 42 + 65536) + shouldMatch(43.toShort, 43) + shouldMatch(43.toShort, 43.toLong) + shouldMatch(23.toShort, 23.0f) + shouldMatch(23.toShort, 23.0d) + shouldNotMatch(23.toShort, 23.5) + shouldNotMatch(23.toShort, 23.5f) + shouldMatch(11, 11.toByte) + shouldMatch(42, 42.toShort) + shouldNotMatch(42, 42 + 4294967296L) + shouldMatch(43, 43) + shouldMatch(Integer.MAX_VALUE, Integer.MAX_VALUE) + shouldMatch(43, 43.toLong) + shouldMatch(23, 23.0) + shouldNotMatch(23, 23.5) + shouldNotMatch(23, 23.5f) + shouldMatch(11L, 11.toByte) + shouldMatch(42L, 42.toShort) + shouldMatch(43L, 43.toInt) + shouldMatch(43L, 43.toLong) + shouldMatch(87L, 87.toLong) + shouldMatch(Long.MaxValue, Long.MaxValue) + shouldMatch(Int.MaxValue, Int.MaxValue.toLong) + shouldMatch(23L, 23.0) + shouldNotMatch(23L, 23.5) + shouldNotMatch(23L, 23.5f) + shouldMatch(9007199254740992L, 9007199254740992D) + shouldNotMatch(4611686018427387905L, 4611686018427387900L) + shouldMatch(11f, 11.toByte) + shouldMatch(42f, 42.toShort) + shouldMatch(43f, 43) + shouldMatch(43f, 43.toLong) + shouldMatch(23f, 23.0) + shouldNotMatch(23f, 23.5) + shouldNotMatch(23f, 23.5f) + shouldMatch(3.14f, 3.14f) + shouldMatch(3.14f, 3.14d) + shouldMatch(11d, 11.toByte) + shouldMatch(42d, 42.toShort) + shouldMatch(43d, 43) + shouldMatch(43d, 43.toLong) + shouldMatch(23d, 23.0) + shouldNotMatch(23d, 23.5) + shouldNotMatch(23d, 23.5f) + shouldMatch(3.14d, 3.14f) + shouldMatch(3.14d, 3.14d) + shouldMatch("A", "A") + shouldMatch('A', 'A') + shouldMatch('A', "A") + shouldMatch("A", 'A') + shouldNotMatch("AA", 'A') + shouldNotMatch("a", "A") + shouldNotMatch("A", "a") + shouldNotMatch("0", 0) + shouldNotMatch('0', 0) + + // Lists and arrays + shouldMatch(Array[Int](1, 2, 3), Array[Int](1, 2, 3)) + shouldMatch(Array[Array[Int]](Array(1), Array(2, 2), Array(3, 3, 3)), Array[Array[Double]](Array(1.0), Array(2.0, 2.0), Array(3.0, 3.0, 3.0))) + shouldMatch(Array[Int](1, 2, 3), Array[Long](1, 2, 3)) + shouldMatch(Array[Int](1, 2, 3), Array[Double](1.0, 2.0, 3.0)) + + shouldMatch(Array[String]("A", "B", "C"), Array[String]("A", "B", "C")) + shouldMatch(Array[String]("A", "B", "C"), Array[Char]('A', 'B', 'C')) + shouldMatch(Array[Char]('A', 'B', 'C'), Array[String]("A", "B", "C")) + shouldMatch(Array[Int](1, 2, 3), asList(1, 2, 3)) + + shouldMatch(asList(1, 2, 3), asList(1L, 2L, 3L)) + shouldMatch(asList(1, 2, 3, null), asList(1L, 2L, 3L, null)) + shouldMatch(Array[Int](1, 2, 3), asList(1L, 2L, 3L)) + shouldMatch(Array[Int](1, 2, 3), asList(1.0D, 2.0D, 3.0D)) + shouldMatch(Array[Any](1, Array[Int](2, 2), 3), asList(1.0D, asList(2.0D, 2.0D), 3.0D)) + shouldMatch(Array[String]("A", "B", "C"), asList("A", "B", "C")) + shouldMatch(Array[String]("A", "B", "C"), asList('A', 'B', 'C')) + shouldMatch(Array[Char]('A', 'B', 'C'), asList("A", "B", "C")) + shouldMatch(new util.ArrayList[AnyRef](), Array.empty) + + // Maps + shouldMatch(Map("a" -> 42).asJava, Map("a" -> 42).asJava) + shouldMatch(Map("a" -> 42).asJava, Map("a" -> 42.0).asJava) + shouldMatch(Map("a" -> 42).asJava, singletonMap("a", 42.0)) + shouldMatch(singletonMap("a", asList(41.0, 42.0)), Map("a" -> List(41,42).asJava).asJava) + shouldMatch(Map("a" -> singletonMap("x", asList(41.0, 'c'.asInstanceOf[Character]))).asJava, singletonMap("a", Map("x" -> List(41, "c").asJava).asJava)) + + // Geographic Values + val crs = ImplementsJavaCRS("cartesian", "http://spatialreference.org/ref/sr-org/7203/", 7203) + shouldMatch(ImplementsJavaPoint(32, 43, crs), ImplementsJavaPoint(32.0, 43.0, crs)) + + private def shouldMatch(v1: Any, v2: Any) { + test(testName(v1, v2, "=")) { + CompiledEquivalenceUtils.equals(v1, v2) shouldBe true + CompiledEquivalenceUtils.equals(v2, v1) shouldBe true + CompiledEquivalenceUtils.hashCode(v1) should equal(CompiledEquivalenceUtils.hashCode(v2)) + } + } + private def shouldNotMatch(v1: Any, v2: Any) { + test(testName(v1, v2, "<>")) { + CompiledEquivalenceUtils.equals(v1, v2) shouldBe false + CompiledEquivalenceUtils.equals(v2, v1) shouldBe false + } + } + + private def testName(v1: Any, v2: Any, operator: String): String = { + s"$v1 (${v1.getClass.getSimpleName}) $operator $v2 (${v2.getClass.getSimpleName})\n" + } +} + +case class ImplementsJavaPoint(longitude: Double, latitude: Double, crs: CRS) extends Point { + override def getCRS = crs + + override def getCoordinates: util.List[Coordinate] = asList(new Coordinate(longitude, latitude)) + + override def getGeometryType: String = crs.getType } + +case class ImplementsJavaCRS(typ: String, href: String, code: Int) extends CRS { + override def getType: String = typ + + override def getHref: String = href + + override def getCode: Int = code +} \ No newline at end of file