Skip to content

Commit

Permalink
Key generation with proper equivalence semantics
Browse files Browse the repository at this point in the history
Hash keys should be generated according to proper Cypher
equivalence semantics
  • Loading branch information
pontusmelke committed Dec 7, 2016
1 parent be0778d commit e4a3938
Show file tree
Hide file tree
Showing 15 changed files with 938 additions and 182 deletions.
Expand Up @@ -103,7 +103,7 @@ object LogicalPlanConverter {
val produceResultOpName = context.registerOperator(produceResults) val produceResultOpName = context.registerOperator(produceResults)
val projections = (produceResults.lhs.get match { val projections = (produceResults.lhs.get match {
// if lhs is projection than we can simply load things that it projected // if lhs is projection than we can simply load things that it projected
case _: plans.Projection => produceResults.columns.map(c => c -> LoadVariable(context.getVariable(c))) case _: plans.Projection => produceResults.columns.map(c => c -> LoadVariable(context.getProjection(c)))
// else we have to evaluate all expressions ourselves // else we have to evaluate all expressions ourselves
case _ => produceResults.columns.map(c => c -> ExpressionConverter.createExpressionForVariable(c)(context)) case _ => produceResults.columns.map(c => c -> ExpressionConverter.createExpressionForVariable(c)(context))
}).toMap }).toMap
Expand Down
Expand Up @@ -22,51 +22,91 @@ package org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions
import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.Instruction import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.Instruction
import org.neo4j.cypher.internal.compiler.v3_2.codegen.{CodeGenContext, MethodStructure, Variable} import org.neo4j.cypher.internal.compiler.v3_2.codegen.{CodeGenContext, MethodStructure, Variable}


trait AggregateExpression { abstract class AggregateExpression(expression: CodeGenExpression, distinct: Boolean) {


def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit


def update[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit def update[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit


def continuation(instruction: Instruction): Instruction = instruction def continuation(instruction: Instruction): Instruction = instruction
}


def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext)

protected def ifNotNull[E](structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext) = {
expression match {
case NodeExpression(v) => primitiveIfNot(v, structure)(block(_))
case NodeProjection(v) => primitiveIfNot(v, structure)(block(_))
case RelationshipExpression(v) => primitiveIfNot(v, structure)(block(_))
case RelationshipProjection(v) => primitiveIfNot(v, structure)(block(_))
case _ =>
val tmpName = context.namer.newVarName()
structure.assign(tmpName, expression.codeGenType, expression.generateExpression(structure))
structure.ifNonNullStatement(structure.loadVariable(tmpName)) { body =>
if (distinct) {
distinctCondition(structure.loadVariable(tmpName),expression.codeGenType, body) { inner =>
block(inner)
}
}
else block(body)
}
}
}


private def primitiveIfNot[E](v: Variable, structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext) = {
structure.ifNotStatement(structure.equalityExpression(structure.loadVariable(v.name),
structure.constantExpression(Long.box(-1)),
CodeGenType.primitiveInt)) { body =>
if (distinct) {
distinctCondition(structure.loadVariable(v.name), CodeGenType.primitiveInt, body) { inner =>
block(inner)
}
}
else block(body)
}
}
}


case class SimpleCount(variable: Variable, expression: CodeGenExpression, distinct: Boolean) extends AggregateExpression { case class SimpleCount(variable: Variable, expression: CodeGenExpression, distinct: Boolean)
extends AggregateExpression(expression, distinct) {


def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = { def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {
expression.init(generator) expression.init(generator)
generator.assign(variable.name, CodeGenType.primitiveInt, generator.constantExpression(Long.box(0L))) generator.assign(variable.name, CodeGenType.primitiveInt, generator.constantExpression(Long.box(0L)))
if (distinct) generator.newSet(setName(variable)) if (distinct) {
generator.newSet(setName(variable))
}
} }


def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = { def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = {
val tmpName = context.namer.newVarName() ifNotNull(structure) { inner =>
structure.assign(tmpName, CodeGenType.Any, expression.generateExpression(structure)) inner.incrementInteger(variable.name)
structure.ifNonNullStatement(structure.loadVariable(tmpName)) { body =>
condition(tmpName, body) { inner =>
inner.incrementInteger(variable.name)
}
} }
} }


private def condition[E](name: String, structure: MethodStructure[E])(block: MethodStructure[E] => Unit) = { def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])
if (distinct) { (block: MethodStructure[E] => Unit)
structure.ifNotStatement(structure.setContains(setName(variable), structure.loadVariable(name))) { inner => (implicit context: CodeGenContext) = {
inner.addToSet(setName(variable), inner.loadVariable(name)) val tmpName = context.namer.newVarName()
block(inner) structure.newUniqueAggregationKey(tmpName, Map(typeName(variable) -> (valueType -> value)))
} structure.ifNotStatement(structure.setContains(setName(variable), structure.loadVariable(tmpName))) { inner =>
} else block(structure) inner.addToSet(setName(variable), inner.loadVariable(tmpName))
block(inner)
}
} }


private def setName(variable: Variable) = variable.name + "Set" private def setName(variable: Variable) = variable.name + "Set"

private def typeName(variable: Variable) = variable.name + "Type"
} }


class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpression, class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpression,
groupingKey: Iterable[Variable], distinct: Boolean) extends AggregateExpression { groupingKey: Iterable[Variable], distinct: Boolean) extends AggregateExpression(expression, distinct) {


private var mapName: String = null private var mapName: String = null
private var keyVar: String = null


override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = { override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {
expression.init(generator) expression.init(generator)
Expand All @@ -75,31 +115,24 @@ class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpres
} }


override def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = { override def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = {
val localVar = context.namer.newVarName() keyVar = context.namer.newVarName()
structure.aggregationMapGet(mapName, localVar, createKey(structure)) val valueVar = context.namer.newVarName()
condition(structure, expression.generateExpression(structure)) { inner => structure.aggregationMapGet(mapName, valueVar, createKey(structure), keyVar)
inner.incrementInteger(localVar) ifNotNull(structure) { inner =>
inner.incrementInteger(valueVar)
} }
structure.aggregationMapPut(mapName, createKey(structure), structure.aggregationMapPut(mapName, createKey(structure), keyVar, structure.loadVariable(valueVar))
structure.loadVariable(localVar))
} }


private def condition[E](structure: MethodStructure[E], value: E)(block: MethodStructure[E] => Unit)(implicit context: CodeGenContext) = { def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
if (distinct) { (implicit context: CodeGenContext) = {
structure.ifNonNullStatement(value) { i1 => structure.checkDistinct(mapName, createKey(structure), keyVar, value, expression.codeGenType) { inner =>
i1.checkDistinct(mapName, createKey(structure), value) { i2 => block(inner)
block(i2)
}
}
} else {
structure.ifNonNullStatement(value) { inner =>
block(inner)
}
} }
} }


private def createKey[E](body: MethodStructure[E])(implicit context: CodeGenContext): IndexedSeq[(CodeGenType, E)] = { private def createKey[E](body: MethodStructure[E])(implicit context: CodeGenContext) = {
groupingKey.map(e => (e.codeGenType, body.loadVariable(e.name))).toIndexedSeq groupingKey.map(e => e.name -> (e.codeGenType -> body.loadVariable(e.name))).toMap
} }


override def continuation(instruction: Instruction): Instruction = new Instruction { override def continuation(instruction: Instruction): Instruction = new Instruction {
Expand All @@ -110,7 +143,7 @@ class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpres


override def body[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit = { override def body[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit = {
generator.trace(opName) { body => generator.trace(opName) { body =>
val keyArg = groupingKey.map(k => k.name -> k.codeGenType).toIndexedSeq val keyArg = groupingKey.map(k => k.name -> k.codeGenType).toMap
body.aggregationMapIterate(mapName, keyArg, variable.name) { inner => body.aggregationMapIterate(mapName, keyArg, variable.name) { inner =>
instruction.body(inner) instruction.body(inner)
} }
Expand Down
Expand Up @@ -86,10 +86,14 @@ object ExpressionConverter {


expression match { expression match {
case node@ast.Variable(name) if context.semanticTable.isNode(node) => case node@ast.Variable(name) if context.semanticTable.isNode(node) =>
NodeProjection(context.getVariable(name)) val variable = context.getProjection(name)
if (variable.codeGenType.isPrimitive) NodeProjection(variable)
else LoadVariable(variable)


case rel@ast.Variable(name) if context.semanticTable.isRelationship(rel) => case rel@ast.Variable(name) if context.semanticTable.isRelationship(rel) =>
RelationshipProjection(context.getVariable(name)) val variable = context.getProjection(name)
if (variable.codeGenType.isPrimitive) RelationshipProjection(variable)
else LoadVariable(variable)


case e => expressionConverter(e, createProjection) case e => expressionConverter(e, createProjection)
} }
Expand Down Expand Up @@ -167,6 +171,8 @@ object ExpressionConverter {


case f: ast.FunctionInvocation => functionConverter(f, callback) case f: ast.FunctionInvocation => functionConverter(f, callback)


case ast.Variable(name) => LoadVariable(context.getProjection(name))

case other => throw new CantCompileQueryException(s"Expression of $other not yet supported") case other => throw new CantCompileQueryException(s"Expression of $other not yet supported")
} }
} }
Expand Down
Expand Up @@ -25,6 +25,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols
import org.neo4j.cypher.internal.frontend.v3_2.symbols._ import org.neo4j.cypher.internal.frontend.v3_2.symbols._


case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression { case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression {

assert(nodeIdVar.codeGenType.ct == symbols.CTNode) assert(nodeIdVar.codeGenType.ct == symbols.CTNode)


override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {} override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {}
Expand All @@ -38,5 +39,5 @@ case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression {


override def nullable(implicit context: CodeGenContext) = nodeIdVar.nullable override def nullable(implicit context: CodeGenContext) = nodeIdVar.nullable


override def codeGenType(implicit context: CodeGenContext) = if (nullable) CodeGenType(CTNode, ReferenceType) else CodeGenType.primitiveNode override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTNode, ReferenceType)
} }
Expand Up @@ -39,7 +39,5 @@ case class RelationshipExpression(relId: Variable) extends CodeGenExpression {


override def nullable(implicit context: CodeGenContext) = relId.nullable override def nullable(implicit context: CodeGenContext) = relId.nullable


override def codeGenType(implicit context: CodeGenContext) = override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTRelationship, ReferenceType)
if (nullable) CodeGenType(CTRelationship, ReferenceType)
else CodeGenType.primitiveRel
} }
Expand Up @@ -67,11 +67,12 @@ trait MethodStructure[E] {
def newSet(name: String) def newSet(name: String)
def setContains(name: String, value: E): E def setContains(name: String, value: E): E
def addToSet(name: String, value: E): Unit def addToSet(name: String, value: E): Unit
def newUniqueAggregationKey(varName: String, structure: Map[String, (CodeGenType,E)]): Unit
def newAggregationMap(name: String, keyTypes: IndexedSeq[CodeGenType], distinct: Boolean) def newAggregationMap(name: String, keyTypes: IndexedSeq[CodeGenType], distinct: Boolean)
def aggregationMapGet(name: String, varName: String, key: IndexedSeq[(CodeGenType,E)]) def aggregationMapGet(name: String, varName: String, key: Map[String,(CodeGenType,E)], keyVar: String)
def aggregationMapPut(name: String, key: IndexedSeq[(CodeGenType,E)], value: E): Unit def aggregationMapPut(name: String, key: Map[String,(CodeGenType,E)], keyVar: String, value: E): Unit
def aggregationMapIterate(name: String, key: IndexedSeq[(String,CodeGenType)], valueVar: String)(block: MethodStructure[E] => Unit): Unit def aggregationMapIterate(name: String, key: Map[String,CodeGenType], valueVar: String)(block: MethodStructure[E] => Unit): Unit
def checkDistinct(name: String, key: IndexedSeq[(CodeGenType, E)], value: E)(block: MethodStructure[E] => Unit) def checkDistinct(name: String, key: Map[String,(CodeGenType, E)], keyVar: String, value: E, valueType: CodeGenType)(block: MethodStructure[E] => Unit)


def castToCollection(value: E): E def castToCollection(value: E): E


Expand Down
Expand Up @@ -81,6 +81,8 @@ class Equivalent(protected val eagerizedValue: Any, val originalValue: Any) exte
length * (31 * hashCode(n.head) + hashCode(n(length / 2)) * 31 + hashCode(n.last)) length * (31 * hashCode(n.head) + hashCode(n(length / 2)) * 31 + hashCode(n.last))
else else
EMPTY_LIST EMPTY_LIST
case m: Map[_,_] =>
m.hashCode()
case x => x.hashCode() case x => x.hashCode()
} }
} }
Expand Down
Expand Up @@ -25,9 +25,11 @@ import java.util.Collections.singletonMap


import org.neo4j.cypher.internal.compiler.v3_2.{CRS, GeographicPoint} import org.neo4j.cypher.internal.compiler.v3_2.{CRS, GeographicPoint}
import org.neo4j.cypher.internal.frontend.v3_2.test_helpers.CypherFunSuite import org.neo4j.cypher.internal.frontend.v3_2.test_helpers.CypherFunSuite
import org.neo4j.graphdb.spatial.{Coordinate, Point, CRS => JavaCRS} import org.neo4j.graphdb.spatial.{CRS => JavaCRS, Coordinate, Point}


class EquivalentTest extends CypherFunSuite { class EquivalentTest extends CypherFunSuite {
shouldNotMatch(23.toByte, 23.5)

shouldMatch(1.0, 1L) shouldMatch(1.0, 1L)
shouldMatch(1.0, 1) shouldMatch(1.0, 1)
shouldMatch(1.0, 1.0) shouldMatch(1.0, 1.0)
Expand Down Expand Up @@ -58,7 +60,6 @@ class EquivalentTest extends CypherFunSuite {
shouldMatch(43.toByte, 43.toLong) shouldMatch(43.toByte, 43.toLong)
shouldMatch(23.toByte, 23.0d) shouldMatch(23.toByte, 23.0d)
shouldMatch(23.toByte, 23.0f) shouldMatch(23.toByte, 23.0f)
shouldNotMatch(23.toByte, 23.5)
shouldNotMatch(23.toByte, 23.5f) shouldNotMatch(23.toByte, 23.5f)
shouldMatch(11.toShort, 11.toByte) shouldMatch(11.toShort, 11.toByte)
shouldMatch(42.toShort, 42.toShort) shouldMatch(42.toShort, 42.toShort)
Expand Down
Expand Up @@ -23,14 +23,11 @@
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator;
import java.util.List;


import org.neo4j.cypher.internal.frontend.v3_2.CypherTypeException; import org.neo4j.cypher.internal.frontend.v3_2.CypherTypeException;
import org.neo4j.cypher.internal.frontend.v3_2.IncomparableValuesException; import org.neo4j.cypher.internal.frontend.v3_2.IncomparableValuesException;
import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship; import org.neo4j.graphdb.Relationship;
import org.neo4j.helpers.MathUtil;


// Class with static methods used by compiled execution plans // Class with static methods used by compiled execution plans
public abstract class CompiledConversionUtils public abstract class CompiledConversionUtils
Expand Down Expand Up @@ -117,100 +114,7 @@ public static Boolean equals( Object lhs, Object rhs )
throw new IncomparableValuesException( lhs.getClass().getSimpleName(), rhs.getClass().getSimpleName() ); throw new IncomparableValuesException( lhs.getClass().getSimpleName(), rhs.getClass().getSimpleName() );
} }


//if floats compare float values if integer types, return CompiledEquivalenceUtils.equals( lhs, rhs );
//compare long values
if ( lhs instanceof Number && rhs instanceof Number )
{
if ( (lhs instanceof Double || lhs instanceof Float)
&& (rhs instanceof Double || rhs instanceof Float) )
{
double left = ((Number) lhs).doubleValue();
double right = ((Number) rhs).doubleValue();
return left == right;
}
else if ( (lhs instanceof Double || lhs instanceof Float) )
{
double left = ((Number) lhs).doubleValue();
long right = ((Number) rhs).longValue();
return MathUtil.numbersEqual( left, right );
}
else if ( (rhs instanceof Double || rhs instanceof Float) )
{
long left = ((Number) lhs).longValue();
double right = ((Number) rhs).doubleValue();
return MathUtil.numbersEqual( right, left );
}

//evertyhing else is long from cyphers point-of-view
long left = ((Number) lhs).longValue();
long right = ((Number) rhs).longValue();
return left == right;
}
else if (lhs.getClass().isArray() && rhs.getClass().isArray() )
{
int length = Array.getLength( lhs );
if ( length != Array.getLength( rhs ) )
{
return false;
}
for ( int i = 0; i < length; i++ )
{
if (!equals( Array.get( lhs, i ), Array.get(rhs, i) ))
{
return false;
}
}
return true;
}
else if (lhs.getClass().isArray() && rhs instanceof List<?> )
{
return compareArrayAndList( lhs, (List<?>) rhs );
}
else if (lhs instanceof List<?> && rhs.getClass().isArray())
{
return compareArrayAndList( rhs, (List<?>) lhs );
}
else if (lhs instanceof List<?> && rhs instanceof List<?>)
{
List<?> lhsList = (List<?>) lhs;
List<?> rhsList = (List<?>) rhs;
if (lhsList.size() != rhsList.size())
{
return false;
}
Iterator<?> lhsIterator = lhsList.iterator();
Iterator<?> rhsIterator = rhsList.iterator();
while (lhsIterator.hasNext())
{
if (!equals( lhsIterator.next(), rhsIterator.next() ))
{
return false;
}
}
return true;
}

//for everything else call equals
return lhs.equals( rhs );
}

private static Boolean compareArrayAndList(Object array, List<?> list)
{
int length = Array.getLength( array );
if ( length != list.size() )
{
return false;
}

int i = 0;
for ( Object o : list )
{
if (!equals( o, Array.get(array, i++) ))
{
return false;
}
}
return true;
} }


public static Boolean or( Object lhs, Object rhs ) public static Boolean or( Object lhs, Object rhs )
Expand Down

0 comments on commit e4a3938

Please sign in to comment.