Skip to content

Commit

Permalink
Added @name to functions and reverted name change
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Sep 9, 2016
1 parent 308a55b commit c71dfc4
Show file tree
Hide file tree
Showing 73 changed files with 289 additions and 219 deletions.
Expand Up @@ -34,7 +34,7 @@ import org.neo4j.cypher.internal.compiler.v3_1.planner.logical.plans.rewriter.Lo
import org.neo4j.cypher.internal.compiler.v3_1.planner.logical.{CachedMetricsFactory, DefaultQueryPlanner, SimpleMetricsFactory}
import org.neo4j.cypher.internal.compiler.v3_1.spi.{PlanContext, ProcedureSignature}
import org.neo4j.cypher.internal.compiler.v3_1.tracing.rewriters.RewriterStepSequencer
import org.neo4j.cypher.internal.frontend.v3_1.ast.{FunctionName, Statement, UserFunctionInvocation}
import org.neo4j.cypher.internal.frontend.v3_1.ast.{FunctionInvocation, FunctionName, Statement}
import org.neo4j.cypher.internal.frontend.v3_1.notification.{DeprecatedFunctionNotification, DeprecatedProcedureNotification, InternalNotification}
import org.neo4j.cypher.internal.frontend.v3_1.parser.CypherParser
import org.neo4j.cypher.internal.frontend.v3_1.{InputPosition, SemanticTable, inSequence}
Expand Down Expand Up @@ -230,7 +230,7 @@ case class CypherCompiler(parser: CypherParser,

private def syntaxDeprecationNotifications(statement: Statement): Set[InternalNotification] =
statement.treeFold(Set.empty[InternalNotification]) {
case f@UserFunctionInvocation(_, FunctionName(name), _, _) if aliases.get(name).nonEmpty =>
case f@FunctionInvocation(_, FunctionName(name), _, _) if aliases.get(name).nonEmpty =>
(seq) => (seq + DeprecatedFunctionNotification(f.position, name, aliases(name)), None)
}

Expand Down
Expand Up @@ -300,9 +300,9 @@ object QueryTagger extends QueryTagger[String] {

// functions
lift[ASTNode] {
case f: UserFunctionInvocation if mathFunctions contains f.function => Set(MathFunctionTag)
case f: UserFunctionInvocation if stringFunctions contains f.function => Set(StringFunctionTag)
case f: UserFunctionInvocation if isAggregation(f.function) => Set(AggregationTag)
case f: FunctionInvocation if mathFunctions contains f.function => Set(MathFunctionTag)
case f: FunctionInvocation if stringFunctions contains f.function => Set(StringFunctionTag)
case f: FunctionInvocation if isAggregation(f.function) => Set(AggregationTag)
}
))

Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.ast._

object ResolvedFunctionInvocation {

def apply(signatureLookup: QualifiedName => Option[UserFunctionSignature])(unresolved: UserFunctionInvocation): ResolvedFunctionInvocation = {
def apply(signatureLookup: QualifiedName => Option[UserFunctionSignature])(unresolved: FunctionInvocation): ResolvedFunctionInvocation = {
val position = unresolved.position
val name = QualifiedName(unresolved)
val signature = signatureLookup(name)
Expand Down Expand Up @@ -59,7 +59,7 @@ case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
.zip(optInputFields)
.map {
case (arg, optField) =>
optField.map { field => CoerceTo(arg, field) }.getOrElse(arg)
optField.map { field => CoerceTo(arg, field.typ) }.getOrElse(arg)
}
copy(callArguments = coercedArguments)(position)
case None => this
Expand All @@ -69,17 +69,18 @@ case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
case None => SemanticError(s"Unknown function '$qualifiedName'", position)
case Some(signature) =>
val expectedNumArgs = signature.inputSignature.length
val actualNumArgs = callArguments.length
val usedDefaultArgs = signature.inputSignature.drop(callArguments.length).flatMap(_.default)
val actualNumArgs = callArguments.length + usedDefaultArgs.length

if (expectedNumArgs == actualNumArgs) {
signature.inputSignature.zip(callArguments).map {
case (field, arg) =>
arg.semanticCheck(SemanticContext.Results) chain arg.expectType(field.covariant)
}.foldLeft(success)(_ chain _)
} else {
error(_: SemanticState,
SemanticError(s"Function call does not provide the required number of arguments ($expectedNumArgs)",
position))
}
if (expectedNumArgs == actualNumArgs) {
//this zip is fine since it will only verify provided args in callArguments
//default values are checked at load time
signature.inputSignature.zip(callArguments).map {
case (field, arg) =>
arg.semanticCheck(SemanticContext.Results) chain arg.expectType(field.typ.covariant)
}.foldLeft(success)(_ chain _)
} else {
error(_: SemanticState, SemanticError(s"Function call does not provide the required number of arguments ($expectedNumArgs)", position))
}
}
}
Expand Up @@ -20,23 +20,23 @@
package org.neo4j.cypher.internal.compiler.v3_1.ast.conditions

import org.neo4j.cypher.internal.compiler.v3_1.tracing.rewriters.Condition
import org.neo4j.cypher.internal.frontend.v3_1.ast.{Equals, Expression, Property, UserFunctionInvocation, functions}
import org.neo4j.cypher.internal.frontend.v3_1.ast.{Equals, Expression, FunctionInvocation, Property, functions}

case object normalizedEqualsArguments extends Condition {
def apply(that: Any): Seq[String] = {
val equals = collectNodesOfType[Equals].apply(that)
equals.collect {
case eq@Equals(expr, Property(_,_)) if !expr.isInstanceOf[Property] && notIdFunction(expr) =>
s"Equals at ${eq.position} is not normalized: $eq"
case eq@Equals(expr, func@UserFunctionInvocation(_, _, _, _)) if isIdFunction(func) && notIdFunction(expr) =>
case eq@Equals(expr, func@FunctionInvocation(_, _, _, _)) if isIdFunction(func) && notIdFunction(expr) =>
s"Equals at ${eq.position} is not normalized: $eq"
}
}

private def isIdFunction(func: UserFunctionInvocation) = func.function == functions.Id
private def isIdFunction(func: FunctionInvocation) = func.function == functions.Id

private def notIdFunction(expr: Expression) =
!expr.isInstanceOf[UserFunctionInvocation] || !isIdFunction(expr.asInstanceOf[UserFunctionInvocation])
!expr.isInstanceOf[FunctionInvocation] || !isIdFunction(expr.asInstanceOf[FunctionInvocation])

override def name: String = productPrefix
}
Expand Up @@ -36,7 +36,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.{InternalException, SemanticDirec
import org.neo4j.graphdb.Direction

object ExpressionConverters {
def toCommandExpression(expression: ast.Function, invocation: ast.UserFunctionInvocation): CommandExpression =
def toCommandExpression(expression: ast.Function, invocation: ast.FunctionInvocation): CommandExpression =
expression match {
case Abs => commandexpressions.AbsFunction(toCommandExpression(invocation.arguments.head))
case Acos => commandexpressions.AcosFunction(toCommandExpression(invocation.arguments.head))
Expand Down Expand Up @@ -259,7 +259,7 @@ object ExpressionConverters {
case e: ast.Divide => commandexpressions.Divide(toCommandExpression(e.lhs), toCommandExpression(e.rhs))
case e: ast.Modulo => commandexpressions.Modulo(toCommandExpression(e.lhs), toCommandExpression(e.rhs))
case e: ast.Pow => commandexpressions.Pow(toCommandExpression(e.lhs), toCommandExpression(e.rhs))
case e: ast.UserFunctionInvocation => toCommandExpression(e.function, e)
case e: ast.FunctionInvocation => toCommandExpression(e.function, e)
case e: ast.CountStar => commandexpressions.CountStar()
case e: ast.Property => toCommandProperty(e)
case e: ast.Parameter => toCommandParameter(e)
Expand Down Expand Up @@ -289,7 +289,11 @@ object ExpressionConverters {
case e: InequalitySeekRangeWrapper => InequalitySeekRangeExpression(e.range.mapBounds(toCommandExpression))
case e: ast.AndedPropertyInequalities => predicates.AndedPropertyComparablePredicates(variable(e.variable), toCommandProperty(e.property), e.inequalities.map(inequalityExpression))
case e: DesugaredMapProjection => commandexpressions.DesugaredMapProjection(e.name.name, e.includeAllProps, mapProjectionItems(e.items))
case e: ResolvedFunctionInvocation => commandexpressions.FunctionInvocation(e.fcnSignature.get, e.callArguments.map(toCommandExpression))
case e: ResolvedFunctionInvocation =>
val callArgumentCommands = e.callArguments.map(Some(_)).zipAll(e.fcnSignature.get.inputSignature.map(_.default.map(_.value)), None, None).map {
case (given, default) => given.map(toCommandExpression).getOrElse(commandexpressions.Literal(default.get))
}
commandexpressions.FunctionInvocation(e.fcnSignature.get, callArgumentCommands)
case e: ast.MapProjection => throw new InternalException("should have been rewritten away")
case _ =>
throw new InternalException(s"Unknown expression type during transformation (${expression.getClass})")
Expand Down
Expand Up @@ -33,10 +33,10 @@ case object normalizeArgumentOrder extends Rewriter {
private val instance: Rewriter = topDown(Rewriter.lift {

// move id(n) on equals to the left
case predicate @ Equals(func@UserFunctionInvocation(_, _, _, _), _) if func.function == functions.Id =>
case predicate @ Equals(func@FunctionInvocation(_, _, _, _), _) if func.function == functions.Id =>
predicate

case predicate @ Equals(lhs, rhs @ UserFunctionInvocation(_, _, _, _)) if rhs.function == functions.Id =>
case predicate @ Equals(lhs, rhs @ FunctionInvocation(_, _, _, _)) if rhs.function == functions.Id =>
predicate.copy(lhs = rhs, rhs = lhs)(predicate.position)

// move n.prop on equals to the left
Expand Down
Expand Up @@ -37,7 +37,7 @@ case object replaceAliasedFunctionInvocations extends Rewriter {
"rels" -> "relationships")(CaseInsensitiveOrdered)

val instance: Rewriter = bottomUp(Rewriter.lift {
case func@UserFunctionInvocation(_, f@FunctionName(name), _, _) if aliases.get(name).nonEmpty =>
case func@FunctionInvocation(_, f@FunctionName(name), _, _) if aliases.get(name).nonEmpty =>
func.copy(functionName = FunctionName(aliases(name))(f.position))(func.position)
})

Expand Down
Expand Up @@ -32,7 +32,7 @@ case object rewriteEqualityToInCollection extends Rewriter {

private val instance: Rewriter = bottomUp(Rewriter.lift {
// id(a) = value => id(a) IN [value]
case predicate@Equals(func@UserFunctionInvocation(_, _, _, IndexedSeq(idExpr)), idValueExpr)
case predicate@Equals(func@FunctionInvocation(_, _, _, IndexedSeq(idExpr)), idValueExpr)
if func.function == functions.Id =>
In(func, Collection(Seq(idValueExpr))(idValueExpr.position))(predicate.position)

Expand Down
Expand Up @@ -158,7 +158,7 @@ object ExpressionConverter {

case ast.Not(inner) => Not(callback(inner))

case f: ast.UserFunctionInvocation => functionConverter(f, callback)
case f: ast.FunctionInvocation => functionConverter(f, callback)

case other => throw new CantCompileQueryException(s"Expression of $other not yet supported")
}
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.ast

object functionConverter {

def apply(fcn: ast.UserFunctionInvocation, callback: ast.Expression => CodeGenExpression)
def apply(fcn: ast.FunctionInvocation, callback: ast.Expression => CodeGenExpression)
(implicit context: CodeGenContext): CodeGenExpression = fcn.function match {

// id(n)
Expand Down
Expand Up @@ -23,7 +23,7 @@ import org.neo4j.cypher.internal.compiler.v3_1.ExecutionContext
import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.ExpressionConverters
import org.neo4j.cypher.internal.compiler.v3_1.pipes.{NullPipeDecorator, QueryState}
import org.neo4j.cypher.internal.frontend.v3_1.ast.functions.{Rand, Timestamp}
import org.neo4j.cypher.internal.frontend.v3_1.ast.{Expression, Parameter, UserFunctionInvocation}
import org.neo4j.cypher.internal.frontend.v3_1.ast.{Expression, FunctionInvocation, Parameter}
import org.neo4j.cypher.internal.frontend.v3_1.{CypherException => InternalCypherException}

import scala.collection.mutable
Expand All @@ -38,8 +38,8 @@ object simpleExpressionEvaluator {

def isNonDeterministic(expr: Expression): Boolean =
expr.inputs.exists {
case (func@UserFunctionInvocation(_, _, _, _), _) if func.function == Rand => true
case (func@UserFunctionInvocation(_, _, _, _), _) if func.function == Timestamp => true
case (func@FunctionInvocation(_, _, _, _), _) if func.function == Rand => true
case (func@FunctionInvocation(_, _, _, _), _) if func.function == Timestamp => true
case _ => false
}

Expand Down
Expand Up @@ -83,16 +83,16 @@ object expandSolverStep {
(variable, innerPredicate) -> all
//MATCH p = ... WHERE all(n in nodes(p)... or all(r in relationships(p)
case all@AllIterablePredicate(FilterScope(variable, Some(innerPredicate)),
UserFunctionInvocation(_, FunctionName(fname), false,
Seq(PathExpression(
FunctionInvocation(_, FunctionName(fname), false,
Seq(PathExpression(
NodePathStep(startNode, MultiRelationshipPathStep(rel, _, NilPathStep) ))) ))
if (fname == "nodes" || fname == "relationships") && startNode.name == nodeId.name && rel.name == patternRel.name.name =>
(variable, innerPredicate) -> all

//MATCH p = ... WHERE all(n in nodes(p)... or all(r in relationships(p)
case none@NoneIterablePredicate(FilterScope(variable, Some(innerPredicate)),
UserFunctionInvocation(_, FunctionName(fname), false,
Seq(PathExpression(
FunctionInvocation(_, FunctionName(fname), false,
Seq(PathExpression(
NodePathStep(startNode, MultiRelationshipPathStep(rel, _, NilPathStep) ))) ))
if (fname == "nodes" || fname == "relationships") && startNode.name == nodeId.name && rel.name == patternRel.name.name =>
(variable, Not(innerPredicate)(innerPredicate.position)) -> none
Expand Down
Expand Up @@ -35,7 +35,7 @@ object AsDynamicPropertyNonSeekable {
object AsDynamicPropertyNonScannable {
def unapply(v: Any) = v match {

case func@UserFunctionInvocation(_, _, _, IndexedSeq(ContainerIndex(variable: Variable, _)))
case func@FunctionInvocation(_, _, _, IndexedSeq(ContainerIndex(variable: Variable, _)))
if func.function == functions.Exists =>
Some(variable)

Expand Down
Expand Up @@ -38,7 +38,7 @@ object WithSeekableArgs {

object AsIdSeekable {
def unapply(v: Any) = v match {
case WithSeekableArgs(func@UserFunctionInvocation(_, _, _, IndexedSeq(ident: Variable)), rhs)
case WithSeekableArgs(func@FunctionInvocation(_, _, _, IndexedSeq(ident: Variable)), rhs)
if func.function == functions.Id && !rhs.dependencies(ident) =>
Some(IdSeekable(func, ident, rhs))
case _ =>
Expand All @@ -59,7 +59,7 @@ object AsPropertySeekable {
object AsPropertyScannable {
def unapply(v: Any): Option[Scannable[Expression]] = v match {

case func@UserFunctionInvocation(_, _, _, IndexedSeq(property@Property(ident: Variable, _)))
case func@FunctionInvocation(_, _, _, IndexedSeq(property@Property(ident: Variable, _)))
if func.function == functions.Exists =>
Some(ExplicitlyPropertyScannable(func, ident, property))

Expand All @@ -85,7 +85,7 @@ object AsPropertyScannable {
private def partialPropertyPredicate[P <: Expression](predicate: P, lhs: Expression) = lhs match {
case property@Property(ident: Variable, _) =>
PartialPredicate.ifNotEqual(
UserFunctionInvocation(FunctionName(functions.Exists.name)(predicate.position), property)(predicate.position),
FunctionInvocation(FunctionName(functions.Exists.name)(predicate.position), property)(predicate.position),
predicate
).map(ImplicitlyPropertyScannable(_, ident, property))

Expand Down Expand Up @@ -130,8 +130,8 @@ sealed trait EqualitySeekable[T <: Expression] extends Seekable[T] {
def args: SeekableArgs
}

case class IdSeekable(expr: UserFunctionInvocation, ident: Variable, args: SeekableArgs)
extends EqualitySeekable[UserFunctionInvocation] {
case class IdSeekable(expr: FunctionInvocation, ident: Variable, args: SeekableArgs)
extends EqualitySeekable[FunctionInvocation] {

def dependencies = args.dependencies
}
Expand Down Expand Up @@ -182,8 +182,8 @@ sealed trait Scannable[+T <: Expression] extends Sargable[T] {
def propertyKey = property.propertyKey
}

case class ExplicitlyPropertyScannable(expr: UserFunctionInvocation, ident: Variable, property: Property)
extends Scannable[UserFunctionInvocation]
case class ExplicitlyPropertyScannable(expr: FunctionInvocation, ident: Variable, property: Property)
extends Scannable[FunctionInvocation]

case class ImplicitlyPropertyScannable[+T <: Expression](expr: PartialPredicate[T], ident: Variable, property: Property)
extends Scannable[PartialPredicate[T]]
Expand Down
Expand Up @@ -249,7 +249,7 @@ case class CollectionSubQueryExpressionSolver[T <: Expression](namer: T => (T, M
topDown(inner, stopper = {
case _: PatternComprehension => false
case _: ScopeExpression | _: CaseExpression => true
case f: UserFunctionInvocation => f.function == Exists
case f: FunctionInvocation => f.function == Exists
case _ => false
})
}
Expand Down
Expand Up @@ -54,15 +54,15 @@ case object countStorePlanner {
argumentIds: Set[IdName], selections: Selections)(implicit context: LogicalPlanningContext): Option[LogicalPlan] =
exp match {
case // COUNT(<id>)
func@UserFunctionInvocation(_, _, false, Vector(Variable(variableName))) if func.function == functions.Count =>
func@FunctionInvocation(_, _, false, Vector(Variable(variableName))) if func.function == functions.Count =>
trySolveNodeAggregation(query, columnName, Some(variableName), patternRelationships, patternNodes, argumentIds, selections)

case // COUNT(*)
CountStar() =>
trySolveNodeAggregation(query, columnName, None, patternRelationships, patternNodes, argumentIds, selections)

case // COUNT(n.prop)
func@UserFunctionInvocation(_, _, false, Vector(Property(Variable(variableName), PropertyKeyName(propKeyName))))
func@FunctionInvocation(_, _, false, Vector(Property(Variable(variableName), PropertyKeyName(propKeyName))))
if func.function == functions.Count =>
val labelCheck: Option[LabelName] => (Option[LogicalPlan] => Option[LogicalPlan]) = {
case None => _ => None
Expand Down

0 comments on commit c71dfc4

Please sign in to comment.