Skip to content

Commit

Permalink
Added @name to functions and reverted name change
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Sep 9, 2016
1 parent 308a55b commit c71dfc4
Show file tree
Hide file tree
Showing 73 changed files with 289 additions and 219 deletions.
Expand Up @@ -34,7 +34,7 @@ import org.neo4j.cypher.internal.compiler.v3_1.planner.logical.plans.rewriter.Lo
import org.neo4j.cypher.internal.compiler.v3_1.planner.logical.{CachedMetricsFactory, DefaultQueryPlanner, SimpleMetricsFactory} import org.neo4j.cypher.internal.compiler.v3_1.planner.logical.{CachedMetricsFactory, DefaultQueryPlanner, SimpleMetricsFactory}
import org.neo4j.cypher.internal.compiler.v3_1.spi.{PlanContext, ProcedureSignature} import org.neo4j.cypher.internal.compiler.v3_1.spi.{PlanContext, ProcedureSignature}
import org.neo4j.cypher.internal.compiler.v3_1.tracing.rewriters.RewriterStepSequencer import org.neo4j.cypher.internal.compiler.v3_1.tracing.rewriters.RewriterStepSequencer
import org.neo4j.cypher.internal.frontend.v3_1.ast.{FunctionName, Statement, UserFunctionInvocation} import org.neo4j.cypher.internal.frontend.v3_1.ast.{FunctionInvocation, FunctionName, Statement}
import org.neo4j.cypher.internal.frontend.v3_1.notification.{DeprecatedFunctionNotification, DeprecatedProcedureNotification, InternalNotification} import org.neo4j.cypher.internal.frontend.v3_1.notification.{DeprecatedFunctionNotification, DeprecatedProcedureNotification, InternalNotification}
import org.neo4j.cypher.internal.frontend.v3_1.parser.CypherParser import org.neo4j.cypher.internal.frontend.v3_1.parser.CypherParser
import org.neo4j.cypher.internal.frontend.v3_1.{InputPosition, SemanticTable, inSequence} import org.neo4j.cypher.internal.frontend.v3_1.{InputPosition, SemanticTable, inSequence}
Expand Down Expand Up @@ -230,7 +230,7 @@ case class CypherCompiler(parser: CypherParser,


private def syntaxDeprecationNotifications(statement: Statement): Set[InternalNotification] = private def syntaxDeprecationNotifications(statement: Statement): Set[InternalNotification] =
statement.treeFold(Set.empty[InternalNotification]) { statement.treeFold(Set.empty[InternalNotification]) {
case f@UserFunctionInvocation(_, FunctionName(name), _, _) if aliases.get(name).nonEmpty => case f@FunctionInvocation(_, FunctionName(name), _, _) if aliases.get(name).nonEmpty =>
(seq) => (seq + DeprecatedFunctionNotification(f.position, name, aliases(name)), None) (seq) => (seq + DeprecatedFunctionNotification(f.position, name, aliases(name)), None)
} }


Expand Down
Expand Up @@ -300,9 +300,9 @@ object QueryTagger extends QueryTagger[String] {


// functions // functions
lift[ASTNode] { lift[ASTNode] {
case f: UserFunctionInvocation if mathFunctions contains f.function => Set(MathFunctionTag) case f: FunctionInvocation if mathFunctions contains f.function => Set(MathFunctionTag)
case f: UserFunctionInvocation if stringFunctions contains f.function => Set(StringFunctionTag) case f: FunctionInvocation if stringFunctions contains f.function => Set(StringFunctionTag)
case f: UserFunctionInvocation if isAggregation(f.function) => Set(AggregationTag) case f: FunctionInvocation if isAggregation(f.function) => Set(AggregationTag)
} }
)) ))


Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.ast._


object ResolvedFunctionInvocation { object ResolvedFunctionInvocation {


def apply(signatureLookup: QualifiedName => Option[UserFunctionSignature])(unresolved: UserFunctionInvocation): ResolvedFunctionInvocation = { def apply(signatureLookup: QualifiedName => Option[UserFunctionSignature])(unresolved: FunctionInvocation): ResolvedFunctionInvocation = {
val position = unresolved.position val position = unresolved.position
val name = QualifiedName(unresolved) val name = QualifiedName(unresolved)
val signature = signatureLookup(name) val signature = signatureLookup(name)
Expand Down Expand Up @@ -59,7 +59,7 @@ case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
.zip(optInputFields) .zip(optInputFields)
.map { .map {
case (arg, optField) => case (arg, optField) =>
optField.map { field => CoerceTo(arg, field) }.getOrElse(arg) optField.map { field => CoerceTo(arg, field.typ) }.getOrElse(arg)
} }
copy(callArguments = coercedArguments)(position) copy(callArguments = coercedArguments)(position)
case None => this case None => this
Expand All @@ -69,17 +69,18 @@ case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
case None => SemanticError(s"Unknown function '$qualifiedName'", position) case None => SemanticError(s"Unknown function '$qualifiedName'", position)
case Some(signature) => case Some(signature) =>
val expectedNumArgs = signature.inputSignature.length val expectedNumArgs = signature.inputSignature.length
val actualNumArgs = callArguments.length val usedDefaultArgs = signature.inputSignature.drop(callArguments.length).flatMap(_.default)
val actualNumArgs = callArguments.length + usedDefaultArgs.length


if (expectedNumArgs == actualNumArgs) { if (expectedNumArgs == actualNumArgs) {
signature.inputSignature.zip(callArguments).map { //this zip is fine since it will only verify provided args in callArguments
case (field, arg) => //default values are checked at load time
arg.semanticCheck(SemanticContext.Results) chain arg.expectType(field.covariant) signature.inputSignature.zip(callArguments).map {
}.foldLeft(success)(_ chain _) case (field, arg) =>
} else { arg.semanticCheck(SemanticContext.Results) chain arg.expectType(field.typ.covariant)
error(_: SemanticState, }.foldLeft(success)(_ chain _)
SemanticError(s"Function call does not provide the required number of arguments ($expectedNumArgs)", } else {
position)) error(_: SemanticState, SemanticError(s"Function call does not provide the required number of arguments ($expectedNumArgs)", position))
} }
} }
} }
Expand Up @@ -20,23 +20,23 @@
package org.neo4j.cypher.internal.compiler.v3_1.ast.conditions package org.neo4j.cypher.internal.compiler.v3_1.ast.conditions


import org.neo4j.cypher.internal.compiler.v3_1.tracing.rewriters.Condition import org.neo4j.cypher.internal.compiler.v3_1.tracing.rewriters.Condition
import org.neo4j.cypher.internal.frontend.v3_1.ast.{Equals, Expression, Property, UserFunctionInvocation, functions} import org.neo4j.cypher.internal.frontend.v3_1.ast.{Equals, Expression, FunctionInvocation, Property, functions}


case object normalizedEqualsArguments extends Condition { case object normalizedEqualsArguments extends Condition {
def apply(that: Any): Seq[String] = { def apply(that: Any): Seq[String] = {
val equals = collectNodesOfType[Equals].apply(that) val equals = collectNodesOfType[Equals].apply(that)
equals.collect { equals.collect {
case eq@Equals(expr, Property(_,_)) if !expr.isInstanceOf[Property] && notIdFunction(expr) => case eq@Equals(expr, Property(_,_)) if !expr.isInstanceOf[Property] && notIdFunction(expr) =>
s"Equals at ${eq.position} is not normalized: $eq" s"Equals at ${eq.position} is not normalized: $eq"
case eq@Equals(expr, func@UserFunctionInvocation(_, _, _, _)) if isIdFunction(func) && notIdFunction(expr) => case eq@Equals(expr, func@FunctionInvocation(_, _, _, _)) if isIdFunction(func) && notIdFunction(expr) =>
s"Equals at ${eq.position} is not normalized: $eq" s"Equals at ${eq.position} is not normalized: $eq"
} }
} }


private def isIdFunction(func: UserFunctionInvocation) = func.function == functions.Id private def isIdFunction(func: FunctionInvocation) = func.function == functions.Id


private def notIdFunction(expr: Expression) = private def notIdFunction(expr: Expression) =
!expr.isInstanceOf[UserFunctionInvocation] || !isIdFunction(expr.asInstanceOf[UserFunctionInvocation]) !expr.isInstanceOf[FunctionInvocation] || !isIdFunction(expr.asInstanceOf[FunctionInvocation])


override def name: String = productPrefix override def name: String = productPrefix
} }
Expand Up @@ -36,7 +36,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.{InternalException, SemanticDirec
import org.neo4j.graphdb.Direction import org.neo4j.graphdb.Direction


object ExpressionConverters { object ExpressionConverters {
def toCommandExpression(expression: ast.Function, invocation: ast.UserFunctionInvocation): CommandExpression = def toCommandExpression(expression: ast.Function, invocation: ast.FunctionInvocation): CommandExpression =
expression match { expression match {
case Abs => commandexpressions.AbsFunction(toCommandExpression(invocation.arguments.head)) case Abs => commandexpressions.AbsFunction(toCommandExpression(invocation.arguments.head))
case Acos => commandexpressions.AcosFunction(toCommandExpression(invocation.arguments.head)) case Acos => commandexpressions.AcosFunction(toCommandExpression(invocation.arguments.head))
Expand Down Expand Up @@ -259,7 +259,7 @@ object ExpressionConverters {
case e: ast.Divide => commandexpressions.Divide(toCommandExpression(e.lhs), toCommandExpression(e.rhs)) case e: ast.Divide => commandexpressions.Divide(toCommandExpression(e.lhs), toCommandExpression(e.rhs))
case e: ast.Modulo => commandexpressions.Modulo(toCommandExpression(e.lhs), toCommandExpression(e.rhs)) case e: ast.Modulo => commandexpressions.Modulo(toCommandExpression(e.lhs), toCommandExpression(e.rhs))
case e: ast.Pow => commandexpressions.Pow(toCommandExpression(e.lhs), toCommandExpression(e.rhs)) case e: ast.Pow => commandexpressions.Pow(toCommandExpression(e.lhs), toCommandExpression(e.rhs))
case e: ast.UserFunctionInvocation => toCommandExpression(e.function, e) case e: ast.FunctionInvocation => toCommandExpression(e.function, e)
case e: ast.CountStar => commandexpressions.CountStar() case e: ast.CountStar => commandexpressions.CountStar()
case e: ast.Property => toCommandProperty(e) case e: ast.Property => toCommandProperty(e)
case e: ast.Parameter => toCommandParameter(e) case e: ast.Parameter => toCommandParameter(e)
Expand Down Expand Up @@ -289,7 +289,11 @@ object ExpressionConverters {
case e: InequalitySeekRangeWrapper => InequalitySeekRangeExpression(e.range.mapBounds(toCommandExpression)) case e: InequalitySeekRangeWrapper => InequalitySeekRangeExpression(e.range.mapBounds(toCommandExpression))
case e: ast.AndedPropertyInequalities => predicates.AndedPropertyComparablePredicates(variable(e.variable), toCommandProperty(e.property), e.inequalities.map(inequalityExpression)) case e: ast.AndedPropertyInequalities => predicates.AndedPropertyComparablePredicates(variable(e.variable), toCommandProperty(e.property), e.inequalities.map(inequalityExpression))
case e: DesugaredMapProjection => commandexpressions.DesugaredMapProjection(e.name.name, e.includeAllProps, mapProjectionItems(e.items)) case e: DesugaredMapProjection => commandexpressions.DesugaredMapProjection(e.name.name, e.includeAllProps, mapProjectionItems(e.items))
case e: ResolvedFunctionInvocation => commandexpressions.FunctionInvocation(e.fcnSignature.get, e.callArguments.map(toCommandExpression)) case e: ResolvedFunctionInvocation =>
val callArgumentCommands = e.callArguments.map(Some(_)).zipAll(e.fcnSignature.get.inputSignature.map(_.default.map(_.value)), None, None).map {
case (given, default) => given.map(toCommandExpression).getOrElse(commandexpressions.Literal(default.get))
}
commandexpressions.FunctionInvocation(e.fcnSignature.get, callArgumentCommands)
case e: ast.MapProjection => throw new InternalException("should have been rewritten away") case e: ast.MapProjection => throw new InternalException("should have been rewritten away")
case _ => case _ =>
throw new InternalException(s"Unknown expression type during transformation (${expression.getClass})") throw new InternalException(s"Unknown expression type during transformation (${expression.getClass})")
Expand Down
Expand Up @@ -33,10 +33,10 @@ case object normalizeArgumentOrder extends Rewriter {
private val instance: Rewriter = topDown(Rewriter.lift { private val instance: Rewriter = topDown(Rewriter.lift {


// move id(n) on equals to the left // move id(n) on equals to the left
case predicate @ Equals(func@UserFunctionInvocation(_, _, _, _), _) if func.function == functions.Id => case predicate @ Equals(func@FunctionInvocation(_, _, _, _), _) if func.function == functions.Id =>
predicate predicate


case predicate @ Equals(lhs, rhs @ UserFunctionInvocation(_, _, _, _)) if rhs.function == functions.Id => case predicate @ Equals(lhs, rhs @ FunctionInvocation(_, _, _, _)) if rhs.function == functions.Id =>
predicate.copy(lhs = rhs, rhs = lhs)(predicate.position) predicate.copy(lhs = rhs, rhs = lhs)(predicate.position)


// move n.prop on equals to the left // move n.prop on equals to the left
Expand Down
Expand Up @@ -37,7 +37,7 @@ case object replaceAliasedFunctionInvocations extends Rewriter {
"rels" -> "relationships")(CaseInsensitiveOrdered) "rels" -> "relationships")(CaseInsensitiveOrdered)


val instance: Rewriter = bottomUp(Rewriter.lift { val instance: Rewriter = bottomUp(Rewriter.lift {
case func@UserFunctionInvocation(_, f@FunctionName(name), _, _) if aliases.get(name).nonEmpty => case func@FunctionInvocation(_, f@FunctionName(name), _, _) if aliases.get(name).nonEmpty =>
func.copy(functionName = FunctionName(aliases(name))(f.position))(func.position) func.copy(functionName = FunctionName(aliases(name))(f.position))(func.position)
}) })


Expand Down
Expand Up @@ -32,7 +32,7 @@ case object rewriteEqualityToInCollection extends Rewriter {


private val instance: Rewriter = bottomUp(Rewriter.lift { private val instance: Rewriter = bottomUp(Rewriter.lift {
// id(a) = value => id(a) IN [value] // id(a) = value => id(a) IN [value]
case predicate@Equals(func@UserFunctionInvocation(_, _, _, IndexedSeq(idExpr)), idValueExpr) case predicate@Equals(func@FunctionInvocation(_, _, _, IndexedSeq(idExpr)), idValueExpr)
if func.function == functions.Id => if func.function == functions.Id =>
In(func, Collection(Seq(idValueExpr))(idValueExpr.position))(predicate.position) In(func, Collection(Seq(idValueExpr))(idValueExpr.position))(predicate.position)


Expand Down
Expand Up @@ -158,7 +158,7 @@ object ExpressionConverter {


case ast.Not(inner) => Not(callback(inner)) case ast.Not(inner) => Not(callback(inner))


case f: ast.UserFunctionInvocation => functionConverter(f, callback) case f: ast.FunctionInvocation => functionConverter(f, callback)


case other => throw new CantCompileQueryException(s"Expression of $other not yet supported") case other => throw new CantCompileQueryException(s"Expression of $other not yet supported")
} }
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.ast


object functionConverter { object functionConverter {


def apply(fcn: ast.UserFunctionInvocation, callback: ast.Expression => CodeGenExpression) def apply(fcn: ast.FunctionInvocation, callback: ast.Expression => CodeGenExpression)
(implicit context: CodeGenContext): CodeGenExpression = fcn.function match { (implicit context: CodeGenContext): CodeGenExpression = fcn.function match {


// id(n) // id(n)
Expand Down
Expand Up @@ -23,7 +23,7 @@ import org.neo4j.cypher.internal.compiler.v3_1.ExecutionContext
import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.ExpressionConverters import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.ExpressionConverters
import org.neo4j.cypher.internal.compiler.v3_1.pipes.{NullPipeDecorator, QueryState} import org.neo4j.cypher.internal.compiler.v3_1.pipes.{NullPipeDecorator, QueryState}
import org.neo4j.cypher.internal.frontend.v3_1.ast.functions.{Rand, Timestamp} import org.neo4j.cypher.internal.frontend.v3_1.ast.functions.{Rand, Timestamp}
import org.neo4j.cypher.internal.frontend.v3_1.ast.{Expression, Parameter, UserFunctionInvocation} import org.neo4j.cypher.internal.frontend.v3_1.ast.{Expression, FunctionInvocation, Parameter}
import org.neo4j.cypher.internal.frontend.v3_1.{CypherException => InternalCypherException} import org.neo4j.cypher.internal.frontend.v3_1.{CypherException => InternalCypherException}


import scala.collection.mutable import scala.collection.mutable
Expand All @@ -38,8 +38,8 @@ object simpleExpressionEvaluator {


def isNonDeterministic(expr: Expression): Boolean = def isNonDeterministic(expr: Expression): Boolean =
expr.inputs.exists { expr.inputs.exists {
case (func@UserFunctionInvocation(_, _, _, _), _) if func.function == Rand => true case (func@FunctionInvocation(_, _, _, _), _) if func.function == Rand => true
case (func@UserFunctionInvocation(_, _, _, _), _) if func.function == Timestamp => true case (func@FunctionInvocation(_, _, _, _), _) if func.function == Timestamp => true
case _ => false case _ => false
} }


Expand Down
Expand Up @@ -83,16 +83,16 @@ object expandSolverStep {
(variable, innerPredicate) -> all (variable, innerPredicate) -> all
//MATCH p = ... WHERE all(n in nodes(p)... or all(r in relationships(p) //MATCH p = ... WHERE all(n in nodes(p)... or all(r in relationships(p)
case all@AllIterablePredicate(FilterScope(variable, Some(innerPredicate)), case all@AllIterablePredicate(FilterScope(variable, Some(innerPredicate)),
UserFunctionInvocation(_, FunctionName(fname), false, FunctionInvocation(_, FunctionName(fname), false,
Seq(PathExpression( Seq(PathExpression(
NodePathStep(startNode, MultiRelationshipPathStep(rel, _, NilPathStep) ))) )) NodePathStep(startNode, MultiRelationshipPathStep(rel, _, NilPathStep) ))) ))
if (fname == "nodes" || fname == "relationships") && startNode.name == nodeId.name && rel.name == patternRel.name.name => if (fname == "nodes" || fname == "relationships") && startNode.name == nodeId.name && rel.name == patternRel.name.name =>
(variable, innerPredicate) -> all (variable, innerPredicate) -> all


//MATCH p = ... WHERE all(n in nodes(p)... or all(r in relationships(p) //MATCH p = ... WHERE all(n in nodes(p)... or all(r in relationships(p)
case none@NoneIterablePredicate(FilterScope(variable, Some(innerPredicate)), case none@NoneIterablePredicate(FilterScope(variable, Some(innerPredicate)),
UserFunctionInvocation(_, FunctionName(fname), false, FunctionInvocation(_, FunctionName(fname), false,
Seq(PathExpression( Seq(PathExpression(
NodePathStep(startNode, MultiRelationshipPathStep(rel, _, NilPathStep) ))) )) NodePathStep(startNode, MultiRelationshipPathStep(rel, _, NilPathStep) ))) ))
if (fname == "nodes" || fname == "relationships") && startNode.name == nodeId.name && rel.name == patternRel.name.name => if (fname == "nodes" || fname == "relationships") && startNode.name == nodeId.name && rel.name == patternRel.name.name =>
(variable, Not(innerPredicate)(innerPredicate.position)) -> none (variable, Not(innerPredicate)(innerPredicate.position)) -> none
Expand Down
Expand Up @@ -35,7 +35,7 @@ object AsDynamicPropertyNonSeekable {
object AsDynamicPropertyNonScannable { object AsDynamicPropertyNonScannable {
def unapply(v: Any) = v match { def unapply(v: Any) = v match {


case func@UserFunctionInvocation(_, _, _, IndexedSeq(ContainerIndex(variable: Variable, _))) case func@FunctionInvocation(_, _, _, IndexedSeq(ContainerIndex(variable: Variable, _)))
if func.function == functions.Exists => if func.function == functions.Exists =>
Some(variable) Some(variable)


Expand Down
Expand Up @@ -38,7 +38,7 @@ object WithSeekableArgs {


object AsIdSeekable { object AsIdSeekable {
def unapply(v: Any) = v match { def unapply(v: Any) = v match {
case WithSeekableArgs(func@UserFunctionInvocation(_, _, _, IndexedSeq(ident: Variable)), rhs) case WithSeekableArgs(func@FunctionInvocation(_, _, _, IndexedSeq(ident: Variable)), rhs)
if func.function == functions.Id && !rhs.dependencies(ident) => if func.function == functions.Id && !rhs.dependencies(ident) =>
Some(IdSeekable(func, ident, rhs)) Some(IdSeekable(func, ident, rhs))
case _ => case _ =>
Expand All @@ -59,7 +59,7 @@ object AsPropertySeekable {
object AsPropertyScannable { object AsPropertyScannable {
def unapply(v: Any): Option[Scannable[Expression]] = v match { def unapply(v: Any): Option[Scannable[Expression]] = v match {


case func@UserFunctionInvocation(_, _, _, IndexedSeq(property@Property(ident: Variable, _))) case func@FunctionInvocation(_, _, _, IndexedSeq(property@Property(ident: Variable, _)))
if func.function == functions.Exists => if func.function == functions.Exists =>
Some(ExplicitlyPropertyScannable(func, ident, property)) Some(ExplicitlyPropertyScannable(func, ident, property))


Expand All @@ -85,7 +85,7 @@ object AsPropertyScannable {
private def partialPropertyPredicate[P <: Expression](predicate: P, lhs: Expression) = lhs match { private def partialPropertyPredicate[P <: Expression](predicate: P, lhs: Expression) = lhs match {
case property@Property(ident: Variable, _) => case property@Property(ident: Variable, _) =>
PartialPredicate.ifNotEqual( PartialPredicate.ifNotEqual(
UserFunctionInvocation(FunctionName(functions.Exists.name)(predicate.position), property)(predicate.position), FunctionInvocation(FunctionName(functions.Exists.name)(predicate.position), property)(predicate.position),
predicate predicate
).map(ImplicitlyPropertyScannable(_, ident, property)) ).map(ImplicitlyPropertyScannable(_, ident, property))


Expand Down Expand Up @@ -130,8 +130,8 @@ sealed trait EqualitySeekable[T <: Expression] extends Seekable[T] {
def args: SeekableArgs def args: SeekableArgs
} }


case class IdSeekable(expr: UserFunctionInvocation, ident: Variable, args: SeekableArgs) case class IdSeekable(expr: FunctionInvocation, ident: Variable, args: SeekableArgs)
extends EqualitySeekable[UserFunctionInvocation] { extends EqualitySeekable[FunctionInvocation] {


def dependencies = args.dependencies def dependencies = args.dependencies
} }
Expand Down Expand Up @@ -182,8 +182,8 @@ sealed trait Scannable[+T <: Expression] extends Sargable[T] {
def propertyKey = property.propertyKey def propertyKey = property.propertyKey
} }


case class ExplicitlyPropertyScannable(expr: UserFunctionInvocation, ident: Variable, property: Property) case class ExplicitlyPropertyScannable(expr: FunctionInvocation, ident: Variable, property: Property)
extends Scannable[UserFunctionInvocation] extends Scannable[FunctionInvocation]


case class ImplicitlyPropertyScannable[+T <: Expression](expr: PartialPredicate[T], ident: Variable, property: Property) case class ImplicitlyPropertyScannable[+T <: Expression](expr: PartialPredicate[T], ident: Variable, property: Property)
extends Scannable[PartialPredicate[T]] extends Scannable[PartialPredicate[T]]
Expand Down
Expand Up @@ -249,7 +249,7 @@ case class CollectionSubQueryExpressionSolver[T <: Expression](namer: T => (T, M
topDown(inner, stopper = { topDown(inner, stopper = {
case _: PatternComprehension => false case _: PatternComprehension => false
case _: ScopeExpression | _: CaseExpression => true case _: ScopeExpression | _: CaseExpression => true
case f: UserFunctionInvocation => f.function == Exists case f: FunctionInvocation => f.function == Exists
case _ => false case _ => false
}) })
} }
Expand Down
Expand Up @@ -54,15 +54,15 @@ case object countStorePlanner {
argumentIds: Set[IdName], selections: Selections)(implicit context: LogicalPlanningContext): Option[LogicalPlan] = argumentIds: Set[IdName], selections: Selections)(implicit context: LogicalPlanningContext): Option[LogicalPlan] =
exp match { exp match {
case // COUNT(<id>) case // COUNT(<id>)
func@UserFunctionInvocation(_, _, false, Vector(Variable(variableName))) if func.function == functions.Count => func@FunctionInvocation(_, _, false, Vector(Variable(variableName))) if func.function == functions.Count =>
trySolveNodeAggregation(query, columnName, Some(variableName), patternRelationships, patternNodes, argumentIds, selections) trySolveNodeAggregation(query, columnName, Some(variableName), patternRelationships, patternNodes, argumentIds, selections)


case // COUNT(*) case // COUNT(*)
CountStar() => CountStar() =>
trySolveNodeAggregation(query, columnName, None, patternRelationships, patternNodes, argumentIds, selections) trySolveNodeAggregation(query, columnName, None, patternRelationships, patternNodes, argumentIds, selections)


case // COUNT(n.prop) case // COUNT(n.prop)
func@UserFunctionInvocation(_, _, false, Vector(Property(Variable(variableName), PropertyKeyName(propKeyName)))) func@FunctionInvocation(_, _, false, Vector(Property(Variable(variableName), PropertyKeyName(propKeyName))))
if func.function == functions.Count => if func.function == functions.Count =>
val labelCheck: Option[LabelName] => (Option[LogicalPlan] => Option[LogicalPlan]) = { val labelCheck: Option[LabelName] => (Option[LogicalPlan] => Option[LogicalPlan]) = {
case None => _ => None case None => _ => None
Expand Down

0 comments on commit c71dfc4

Please sign in to comment.