Skip to content

Commit

Permalink
Key generation with proper equivalence semantics
Browse files Browse the repository at this point in the history
Hash keys should be generated according to proper Cypher
equivalence semantics
  • Loading branch information
pontusmelke committed Dec 7, 2016
1 parent be0778d commit e4a3938
Show file tree
Hide file tree
Showing 15 changed files with 938 additions and 182 deletions.
Expand Up @@ -103,7 +103,7 @@ object LogicalPlanConverter {
val produceResultOpName = context.registerOperator(produceResults)
val projections = (produceResults.lhs.get match {
// if lhs is projection than we can simply load things that it projected
case _: plans.Projection => produceResults.columns.map(c => c -> LoadVariable(context.getVariable(c)))
case _: plans.Projection => produceResults.columns.map(c => c -> LoadVariable(context.getProjection(c)))
// else we have to evaluate all expressions ourselves
case _ => produceResults.columns.map(c => c -> ExpressionConverter.createExpressionForVariable(c)(context))
}).toMap
Expand Down
Expand Up @@ -22,51 +22,91 @@ package org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.expressions
import org.neo4j.cypher.internal.compiler.v3_2.codegen.ir.Instruction
import org.neo4j.cypher.internal.compiler.v3_2.codegen.{CodeGenContext, MethodStructure, Variable}

trait AggregateExpression {
abstract class AggregateExpression(expression: CodeGenExpression, distinct: Boolean) {

def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit

def update[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit

def continuation(instruction: Instruction): Instruction = instruction
}

def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext)

protected def ifNotNull[E](structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext) = {
expression match {
case NodeExpression(v) => primitiveIfNot(v, structure)(block(_))
case NodeProjection(v) => primitiveIfNot(v, structure)(block(_))
case RelationshipExpression(v) => primitiveIfNot(v, structure)(block(_))
case RelationshipProjection(v) => primitiveIfNot(v, structure)(block(_))
case _ =>
val tmpName = context.namer.newVarName()
structure.assign(tmpName, expression.codeGenType, expression.generateExpression(structure))
structure.ifNonNullStatement(structure.loadVariable(tmpName)) { body =>
if (distinct) {
distinctCondition(structure.loadVariable(tmpName),expression.codeGenType, body) { inner =>
block(inner)
}
}
else block(body)
}
}
}

private def primitiveIfNot[E](v: Variable, structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext) = {
structure.ifNotStatement(structure.equalityExpression(structure.loadVariable(v.name),
structure.constantExpression(Long.box(-1)),
CodeGenType.primitiveInt)) { body =>
if (distinct) {
distinctCondition(structure.loadVariable(v.name), CodeGenType.primitiveInt, body) { inner =>
block(inner)
}
}
else block(body)
}
}
}

case class SimpleCount(variable: Variable, expression: CodeGenExpression, distinct: Boolean) extends AggregateExpression {
case class SimpleCount(variable: Variable, expression: CodeGenExpression, distinct: Boolean)
extends AggregateExpression(expression, distinct) {

def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {
expression.init(generator)
generator.assign(variable.name, CodeGenType.primitiveInt, generator.constantExpression(Long.box(0L)))
if (distinct) generator.newSet(setName(variable))
if (distinct) {
generator.newSet(setName(variable))
}
}

def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = {
val tmpName = context.namer.newVarName()
structure.assign(tmpName, CodeGenType.Any, expression.generateExpression(structure))
structure.ifNonNullStatement(structure.loadVariable(tmpName)) { body =>
condition(tmpName, body) { inner =>
inner.incrementInteger(variable.name)
}
ifNotNull(structure) { inner =>
inner.incrementInteger(variable.name)
}
}

private def condition[E](name: String, structure: MethodStructure[E])(block: MethodStructure[E] => Unit) = {
if (distinct) {
structure.ifNotStatement(structure.setContains(setName(variable), structure.loadVariable(name))) { inner =>
inner.addToSet(setName(variable), inner.loadVariable(name))
block(inner)
}
} else block(structure)
def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])
(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext) = {
val tmpName = context.namer.newVarName()
structure.newUniqueAggregationKey(tmpName, Map(typeName(variable) -> (valueType -> value)))
structure.ifNotStatement(structure.setContains(setName(variable), structure.loadVariable(tmpName))) { inner =>
inner.addToSet(setName(variable), inner.loadVariable(tmpName))
block(inner)
}
}

private def setName(variable: Variable) = variable.name + "Set"

private def typeName(variable: Variable) = variable.name + "Type"
}

class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpression,
groupingKey: Iterable[Variable], distinct: Boolean) extends AggregateExpression {
groupingKey: Iterable[Variable], distinct: Boolean) extends AggregateExpression(expression, distinct) {

private var mapName: String = null
private var keyVar: String = null

override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {
expression.init(generator)
Expand All @@ -75,31 +115,24 @@ class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpres
}

override def update[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = {
val localVar = context.namer.newVarName()
structure.aggregationMapGet(mapName, localVar, createKey(structure))
condition(structure, expression.generateExpression(structure)) { inner =>
inner.incrementInteger(localVar)
keyVar = context.namer.newVarName()
val valueVar = context.namer.newVarName()
structure.aggregationMapGet(mapName, valueVar, createKey(structure), keyVar)
ifNotNull(structure) { inner =>
inner.incrementInteger(valueVar)
}
structure.aggregationMapPut(mapName, createKey(structure),
structure.loadVariable(localVar))
structure.aggregationMapPut(mapName, createKey(structure), keyVar, structure.loadVariable(valueVar))
}

private def condition[E](structure: MethodStructure[E], value: E)(block: MethodStructure[E] => Unit)(implicit context: CodeGenContext) = {
if (distinct) {
structure.ifNonNullStatement(value) { i1 =>
i1.checkDistinct(mapName, createKey(structure), value) { i2 =>
block(i2)
}
}
} else {
structure.ifNonNullStatement(value) { inner =>
block(inner)
}
def distinctCondition[E](value: E, valueType: CodeGenType, structure: MethodStructure[E])(block: MethodStructure[E] => Unit)
(implicit context: CodeGenContext) = {
structure.checkDistinct(mapName, createKey(structure), keyVar, value, expression.codeGenType) { inner =>
block(inner)
}
}

private def createKey[E](body: MethodStructure[E])(implicit context: CodeGenContext): IndexedSeq[(CodeGenType, E)] = {
groupingKey.map(e => (e.codeGenType, body.loadVariable(e.name))).toIndexedSeq
private def createKey[E](body: MethodStructure[E])(implicit context: CodeGenContext) = {
groupingKey.map(e => e.name -> (e.codeGenType -> body.loadVariable(e.name))).toMap
}

override def continuation(instruction: Instruction): Instruction = new Instruction {
Expand All @@ -110,7 +143,7 @@ class DynamicCount(opName: String, variable: Variable, expression: CodeGenExpres

override def body[E](generator: MethodStructure[E])(implicit context: CodeGenContext): Unit = {
generator.trace(opName) { body =>
val keyArg = groupingKey.map(k => k.name -> k.codeGenType).toIndexedSeq
val keyArg = groupingKey.map(k => k.name -> k.codeGenType).toMap
body.aggregationMapIterate(mapName, keyArg, variable.name) { inner =>
instruction.body(inner)
}
Expand Down
Expand Up @@ -86,10 +86,14 @@ object ExpressionConverter {

expression match {
case node@ast.Variable(name) if context.semanticTable.isNode(node) =>
NodeProjection(context.getVariable(name))
val variable = context.getProjection(name)
if (variable.codeGenType.isPrimitive) NodeProjection(variable)
else LoadVariable(variable)

case rel@ast.Variable(name) if context.semanticTable.isRelationship(rel) =>
RelationshipProjection(context.getVariable(name))
val variable = context.getProjection(name)
if (variable.codeGenType.isPrimitive) RelationshipProjection(variable)
else LoadVariable(variable)

case e => expressionConverter(e, createProjection)
}
Expand Down Expand Up @@ -167,6 +171,8 @@ object ExpressionConverter {

case f: ast.FunctionInvocation => functionConverter(f, callback)

case ast.Variable(name) => LoadVariable(context.getProjection(name))

case other => throw new CantCompileQueryException(s"Expression of $other not yet supported")
}
}
Expand Down
Expand Up @@ -25,6 +25,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols
import org.neo4j.cypher.internal.frontend.v3_2.symbols._

case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression {

assert(nodeIdVar.codeGenType.ct == symbols.CTNode)

override def init[E](generator: MethodStructure[E])(implicit context: CodeGenContext) = {}
Expand All @@ -38,5 +39,5 @@ case class NodeExpression(nodeIdVar: Variable) extends CodeGenExpression {

override def nullable(implicit context: CodeGenContext) = nodeIdVar.nullable

override def codeGenType(implicit context: CodeGenContext) = if (nullable) CodeGenType(CTNode, ReferenceType) else CodeGenType.primitiveNode
}
override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTNode, ReferenceType)
}
Expand Up @@ -39,7 +39,5 @@ case class RelationshipExpression(relId: Variable) extends CodeGenExpression {

override def nullable(implicit context: CodeGenContext) = relId.nullable

override def codeGenType(implicit context: CodeGenContext) =
if (nullable) CodeGenType(CTRelationship, ReferenceType)
else CodeGenType.primitiveRel
override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTRelationship, ReferenceType)
}
Expand Up @@ -67,11 +67,12 @@ trait MethodStructure[E] {
def newSet(name: String)
def setContains(name: String, value: E): E
def addToSet(name: String, value: E): Unit
def newUniqueAggregationKey(varName: String, structure: Map[String, (CodeGenType,E)]): Unit
def newAggregationMap(name: String, keyTypes: IndexedSeq[CodeGenType], distinct: Boolean)
def aggregationMapGet(name: String, varName: String, key: IndexedSeq[(CodeGenType,E)])
def aggregationMapPut(name: String, key: IndexedSeq[(CodeGenType,E)], value: E): Unit
def aggregationMapIterate(name: String, key: IndexedSeq[(String,CodeGenType)], valueVar: String)(block: MethodStructure[E] => Unit): Unit
def checkDistinct(name: String, key: IndexedSeq[(CodeGenType, E)], value: E)(block: MethodStructure[E] => Unit)
def aggregationMapGet(name: String, varName: String, key: Map[String,(CodeGenType,E)], keyVar: String)
def aggregationMapPut(name: String, key: Map[String,(CodeGenType,E)], keyVar: String, value: E): Unit
def aggregationMapIterate(name: String, key: Map[String,CodeGenType], valueVar: String)(block: MethodStructure[E] => Unit): Unit
def checkDistinct(name: String, key: Map[String,(CodeGenType, E)], keyVar: String, value: E, valueType: CodeGenType)(block: MethodStructure[E] => Unit)

def castToCollection(value: E): E

Expand Down
Expand Up @@ -81,6 +81,8 @@ class Equivalent(protected val eagerizedValue: Any, val originalValue: Any) exte
length * (31 * hashCode(n.head) + hashCode(n(length / 2)) * 31 + hashCode(n.last))
else
EMPTY_LIST
case m: Map[_,_] =>
m.hashCode()
case x => x.hashCode()
}
}
Expand Down
Expand Up @@ -25,9 +25,11 @@ import java.util.Collections.singletonMap

import org.neo4j.cypher.internal.compiler.v3_2.{CRS, GeographicPoint}
import org.neo4j.cypher.internal.frontend.v3_2.test_helpers.CypherFunSuite
import org.neo4j.graphdb.spatial.{Coordinate, Point, CRS => JavaCRS}
import org.neo4j.graphdb.spatial.{CRS => JavaCRS, Coordinate, Point}

class EquivalentTest extends CypherFunSuite {
shouldNotMatch(23.toByte, 23.5)

shouldMatch(1.0, 1L)
shouldMatch(1.0, 1)
shouldMatch(1.0, 1.0)
Expand Down Expand Up @@ -58,7 +60,6 @@ class EquivalentTest extends CypherFunSuite {
shouldMatch(43.toByte, 43.toLong)
shouldMatch(23.toByte, 23.0d)
shouldMatch(23.toByte, 23.0f)
shouldNotMatch(23.toByte, 23.5)
shouldNotMatch(23.toByte, 23.5f)
shouldMatch(11.toShort, 11.toByte)
shouldMatch(42.toShort, 42.toShort)
Expand Down
Expand Up @@ -23,14 +23,11 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

import org.neo4j.cypher.internal.frontend.v3_2.CypherTypeException;
import org.neo4j.cypher.internal.frontend.v3_2.IncomparableValuesException;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;
import org.neo4j.helpers.MathUtil;

// Class with static methods used by compiled execution plans
public abstract class CompiledConversionUtils
Expand Down Expand Up @@ -117,100 +114,7 @@ public static Boolean equals( Object lhs, Object rhs )
throw new IncomparableValuesException( lhs.getClass().getSimpleName(), rhs.getClass().getSimpleName() );
}

//if floats compare float values if integer types,
//compare long values
if ( lhs instanceof Number && rhs instanceof Number )
{
if ( (lhs instanceof Double || lhs instanceof Float)
&& (rhs instanceof Double || rhs instanceof Float) )
{
double left = ((Number) lhs).doubleValue();
double right = ((Number) rhs).doubleValue();
return left == right;
}
else if ( (lhs instanceof Double || lhs instanceof Float) )
{
double left = ((Number) lhs).doubleValue();
long right = ((Number) rhs).longValue();
return MathUtil.numbersEqual( left, right );
}
else if ( (rhs instanceof Double || rhs instanceof Float) )
{
long left = ((Number) lhs).longValue();
double right = ((Number) rhs).doubleValue();
return MathUtil.numbersEqual( right, left );
}

//evertyhing else is long from cyphers point-of-view
long left = ((Number) lhs).longValue();
long right = ((Number) rhs).longValue();
return left == right;
}
else if (lhs.getClass().isArray() && rhs.getClass().isArray() )
{
int length = Array.getLength( lhs );
if ( length != Array.getLength( rhs ) )
{
return false;
}
for ( int i = 0; i < length; i++ )
{
if (!equals( Array.get( lhs, i ), Array.get(rhs, i) ))
{
return false;
}
}
return true;
}
else if (lhs.getClass().isArray() && rhs instanceof List<?> )
{
return compareArrayAndList( lhs, (List<?>) rhs );
}
else if (lhs instanceof List<?> && rhs.getClass().isArray())
{
return compareArrayAndList( rhs, (List<?>) lhs );
}
else if (lhs instanceof List<?> && rhs instanceof List<?>)
{
List<?> lhsList = (List<?>) lhs;
List<?> rhsList = (List<?>) rhs;
if (lhsList.size() != rhsList.size())
{
return false;
}
Iterator<?> lhsIterator = lhsList.iterator();
Iterator<?> rhsIterator = rhsList.iterator();
while (lhsIterator.hasNext())
{
if (!equals( lhsIterator.next(), rhsIterator.next() ))
{
return false;
}
}
return true;
}

//for everything else call equals
return lhs.equals( rhs );
}

private static Boolean compareArrayAndList(Object array, List<?> list)
{
int length = Array.getLength( array );
if ( length != list.size() )
{
return false;
}

int i = 0;
for ( Object o : list )
{
if (!equals( o, Array.get(array, i++) ))
{
return false;
}
}
return true;
return CompiledEquivalenceUtils.equals( lhs, rhs );
}

public static Boolean or( Object lhs, Object rhs )
Expand Down

0 comments on commit e4a3938

Please sign in to comment.