Skip to content

Commit

Permalink
Make function be callable from Cypher
Browse files Browse the repository at this point in the history
- Introduce the necessary SPIs for calling functions
- Introduce command for executing functions in Cypher
  • Loading branch information
pontusmelke committed Sep 9, 2016
1 parent f403800 commit f0ae0ca
Show file tree
Hide file tree
Showing 38 changed files with 413 additions and 180 deletions.
Expand Up @@ -21,26 +21,45 @@ package org.neo4j.internal.cypher.acceptance

import java.util

import org.neo4j.cypher._

class FunctionCallSupportAcceptanceTest extends ProcedureCallAcceptanceTest {

ignore("should fail if calling procedure via rule planner") {
an [InternalException] shouldBe thrownBy(execute(
"CYPHER planner=rule CALL db.labels() YIELD label RETURN *"
))
}

ignore("should return correctly typed map result (even if converting to and from scala representation internally)") {
test("should return correctly typed map result (even if converting to and from scala representation internally)") {
val value = new util.HashMap[String, Any]()
value.put("name", "Cypher")
value.put("level", 9001)

registerFunctionReturningSingleValue(value)
registerUserFunction(value)

// Using graph execute to get a Java value
graph.execute("RETURN my.first.value()").stream().toArray.toList should equal(List(
java.util.Collections.singletonMap("my.first.value()", value)
))
}

test("should return correctly typed list result (even if converting to and from scala representation internally)") {
val value = new util.ArrayList[Any]()
value.add("Norris")
value.add("Strange")

registerUserFunction(value)

// Using graph execute to get a Java value
graph.execute("RETURN my.first.value() AS out").stream().toArray.toList should equal(List(
java.util.Collections.singletonMap("out", value)
))
}

test("should return correctly typed stream result (even if converting to and from scala representation internally)") {
val value = new util.ArrayList[Any]()
value.add("Norris")
value.add("Strange")
val stream = value.stream()

registerUserFunction(stream)

// Using graph execute to get a Java value
graph.execute("RETURN my.first.value() AS out").stream().toArray.toList should equal(List(
java.util.Collections.singletonMap("out", stream)
))
}
}
Expand Up @@ -24,10 +24,9 @@ import org.neo4j.cypher._
import org.neo4j.kernel.api.exceptions.ProcedureException
import org.neo4j.kernel.api.proc.CallableFunction.BasicFunction
import org.neo4j.kernel.api.proc.CallableProcedure.BasicProcedure
import org.neo4j.kernel.api.proc.FunctionSignature._
import org.neo4j.kernel.api.proc.ProcedureSignature._
import org.neo4j.kernel.api.proc._
import FunctionSignature._
import org.neo4j.kernel.api.proc.{Context, Neo4jTypes, ProcedureSignature, FunctionSignature}
import org.neo4j.kernel.api.proc.{Context, Neo4jTypes, ProcedureSignature}

abstract class ProcedureCallAcceptanceTest extends ExecutionEngineFunSuite {

Expand Down Expand Up @@ -57,7 +56,7 @@ abstract class ProcedureCallAcceptanceTest extends ExecutionEngineFunSuite {
}
}

protected def registerFunctionReturningSingleValue(value: AnyRef) =
protected def registerUserFunction(value: AnyRef) =
registerFunction("my.first.value") { builder =>
val builder = functionSignature(Array("my", "first"), "value")
builder.out("out", Neo4jTypes.NTAny)
Expand Down
Expand Up @@ -19,18 +19,18 @@
*/
package org.neo4j.cypher.internal.compiler.v3_1.ast

import org.neo4j.cypher.internal.compiler.v3_1.spi.{ProcedureReadOnlyAccess, ProcedureSignature, QualifiedProcedureName}
import org.neo4j.cypher.internal.compiler.v3_1.spi.{ProcedureReadOnlyAccess, ProcedureSignature, QualifiedName}
import org.neo4j.cypher.internal.frontend.v3_1.SemanticCheckResult._
import org.neo4j.cypher.internal.frontend.v3_1._
import org.neo4j.cypher.internal.frontend.v3_1.ast.Expression.SemanticContext
import org.neo4j.cypher.internal.frontend.v3_1.ast._
import org.neo4j.cypher.internal.frontend.v3_1.symbols.{CypherType, _}

object ResolvedCall {
def apply(signatureLookup: QualifiedProcedureName => ProcedureSignature)(unresolved: UnresolvedCall): ResolvedCall = {
def apply(signatureLookup: QualifiedName => ProcedureSignature)(unresolved: UnresolvedCall): ResolvedCall = {
val UnresolvedCall(_, _, declaredArguments, declaredResults) = unresolved
val position = unresolved.position
val signature = signatureLookup(QualifiedProcedureName(unresolved))
val signature = signatureLookup(QualifiedName(unresolved))
val nonDefaults = signature.inputSignature.flatMap(s => if (s.default.isDefined) None else Some(Parameter(s.name, CTAny)(position)))
val callArguments = declaredArguments.getOrElse(nonDefaults)
val callResults = declaredResults.getOrElse(signatureResults(signature, position))
Expand All @@ -51,7 +51,7 @@ case class ResolvedCall(signature: ProcedureSignature,
(val position: InputPosition)
extends CallClause {

def qualifiedName: QualifiedProcedureName = signature.name
def qualifiedName: QualifiedName = signature.name

def fullyDeclared: Boolean = declaredArguments && declaredResults

Expand Down
Expand Up @@ -19,39 +19,39 @@
*/
package org.neo4j.cypher.internal.compiler.v3_1.ast

import org.neo4j.cypher.internal.compiler.v3_1.spi.{UserDefinedFunctionSignature, ProcedureReadOnlyAccess, ProcedureSignature, QualifiedProcedureName}
import org.neo4j.cypher.internal.compiler.v3_1.spi._
import org.neo4j.cypher.internal.frontend.v3_1.SemanticCheckResult._
import org.neo4j.cypher.internal.frontend.v3_1._
import org.neo4j.cypher.internal.frontend.v3_1.ast.Expression.SemanticContext
import org.neo4j.cypher.internal.frontend.v3_1.ast._
import org.neo4j.cypher.internal.frontend.v3_1.ast.functions.UnresolvedFunction
import org.neo4j.cypher.internal.frontend.v3_1.symbols._

object ResolvedUserDefinedFunctionInvocation {
def apply(signatureLookup: QualifiedProcedureName => Option[UserDefinedFunctionSignature])(unresolved: FunctionInvocation): ResolvedUserDefinedFunctionInvocation = {
object ResolvedFunctionInvocation {

def apply(signatureLookup: QualifiedName => Option[UserDefinedFunctionSignature])(unresolved: FunctionInvocation): ResolvedFunctionInvocation = {
val position = unresolved.position
val name = QualifiedProcedureName(unresolved)
val name = QualifiedName(unresolved)
val signature = signatureLookup(name)
ResolvedUserDefinedFunctionInvocation(name, signature, unresolved.args)(position)
ResolvedFunctionInvocation(name, signature, unresolved.args)(position)
}
}

/**
* A ResolvedUserDefinedInvocation is a user-defined function where the signature
* has been resolve, i.e. verified that it exists in the database
*
* @param qualifiedName The qualified name of the function.
* @param fcnSignature Either `Some(signature)` if the signature was resolved, or
* `None` if the function didn't exist
* @param callArguments The argument list to the function
* @param position The position in the original query string.
*/
case class ResolvedUserDefinedFunctionInvocation(qualifiedName: QualifiedProcedureName,
fcnSignature: Option[UserDefinedFunctionSignature],
callArguments: IndexedSeq[Expression])
(val position: InputPosition)
case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
fcnSignature: Option[UserDefinedFunctionSignature],
callArguments: IndexedSeq[Expression])
(val position: InputPosition)
extends Expression with UserDefined {

def coerceArguments: ResolvedUserDefinedFunctionInvocation = fcnSignature match {
def coerceArguments: ResolvedFunctionInvocation = fcnSignature match {
case Some(signature) =>
val optInputFields = signature.inputSignature.map(Some(_)).toStream ++ Stream.continually(None)
val coercedArguments =
Expand Down
Expand Up @@ -20,15 +20,15 @@
package org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands

import org.neo4j.cypher.internal.compiler.v3_1._
import org.neo4j.cypher.internal.compiler.v3_1.ast._
import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.PatternConverters._
import org.neo4j.cypher.internal.compiler.v3_1.ast.rewriters.DesugaredMapProjection
import org.neo4j.cypher.internal.compiler.v3_1.ast.{InequalitySeekRangeWrapper, NestedPipeExpression, PrefixSeekRangeWrapper}
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.ProjectedPath._
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.{InequalitySeekRangeExpression, ProjectedPath, Expression => CommandExpression}
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.{Expression => CommandExpression, InequalitySeekRangeExpression, ProjectedPath}
import org.neo4j.cypher.internal.compiler.v3_1.commands.predicates.Predicate
import org.neo4j.cypher.internal.compiler.v3_1.commands.values.TokenType._
import org.neo4j.cypher.internal.compiler.v3_1.commands.values.UnresolvedRelType
import org.neo4j.cypher.internal.compiler.v3_1.commands.{PathExtractorExpression, predicates, expressions => commandexpressions, values => commandvalues}
import org.neo4j.cypher.internal.compiler.v3_1.commands.{PathExtractorExpression, expressions => commandexpressions, predicates, values => commandvalues}
import org.neo4j.cypher.internal.frontend.v3_1.ast._
import org.neo4j.cypher.internal.frontend.v3_1.ast.functions._
import org.neo4j.cypher.internal.frontend.v3_1.helpers.NonEmptyList
Expand Down Expand Up @@ -289,6 +289,7 @@ object ExpressionConverters {
case e: InequalitySeekRangeWrapper => InequalitySeekRangeExpression(e.range.mapBounds(toCommandExpression))
case e: ast.AndedPropertyInequalities => predicates.AndedPropertyComparablePredicates(variable(e.variable), toCommandProperty(e.property), e.inequalities.map(inequalityExpression))
case e: DesugaredMapProjection => commandexpressions.DesugaredMapProjection(e.name.name, e.includeAllProps, mapProjectionItems(e.items))
case e: ResolvedFunctionInvocation => commandexpressions.FunctionInvocation(e.fcnSignature.get, e.callArguments.map(toCommandExpression))
case e: ast.MapProjection => throw new InternalException("should have been rewritten away")
case _ =>
throw new InternalException(s"Unknown expression type during transformation (${expression.getClass})")
Expand Down
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.compiler.v3_1.commands.expressions

import org.neo4j.cypher.internal.compiler.v3_1._
import org.neo4j.cypher.internal.compiler.v3_1.executionplan.ProcedureCallMode
import org.neo4j.cypher.internal.compiler.v3_1.helpers.RuntimeScalaValueConverter
import org.neo4j.cypher.internal.compiler.v3_1.mutation.GraphElementPropertyFunctions
import org.neo4j.cypher.internal.compiler.v3_1.pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_1.spi.UserDefinedFunctionSignature
import org.neo4j.cypher.internal.compiler.v3_1.symbols.SymbolTable

case class FunctionInvocation(signature: UserDefinedFunctionSignature, arguments: IndexedSeq[Expression])
extends Expression with GraphElementPropertyFunctions {

private val callMode = ProcedureCallMode.fromAccessMode(signature.accessMode)

override def apply(ctx: ExecutionContext)(implicit state: QueryState): Any = {
val query = state.query

val result = callMode.callFunction(query, signature.name, arguments)

val isGraphKernelResultValue = query.isGraphKernelResultValue _
val scalaValues = new RuntimeScalaValueConverter(isGraphKernelResultValue, state.typeConverter.asPrivateType)
scalaValues.asDeepScalaValue(result)
}

override def rewrite(f: (Expression) => Expression) =
f(FunctionInvocation(signature, arguments.map(a => a.rewrite(f))))

override def calculateType(symbols: SymbolTable) = signature.outputField.typ

override def symbolTableDependencies = arguments.flatMap(_.symbolTableDependencies).toSet

override def toString = s"${signature.name}(${arguments.mkString(",")})"
}
Expand Up @@ -34,53 +34,67 @@ object ProcedureCallMode {
sealed trait ProcedureCallMode {
val queryType: InternalQueryType

def call(ctx: QueryContext, name: QualifiedProcedureName, args: Seq[Any]): Iterator[Array[AnyRef]]
def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]]

def callFunction(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): AnyRef

val allowed: Array[String]
}

case class LazyReadOnlyCallMode(allowed: Array[String]) extends ProcedureCallMode {
override val queryType: InternalQueryType = READ_ONLY

override def call(ctx: QueryContext, name: QualifiedProcedureName, args: Seq[Any]): Iterator[Array[AnyRef]] =
override def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]] =
ctx.callReadOnlyProcedure(name, args, allowed)

override def callFunction(ctx: QueryContext, name: QualifiedName, args: Seq[Any]) =
ctx.callReadOnlyFunction(name, args, allowed)
}

case class EagerReadWriteCallMode(allowed: Array[String]) extends ProcedureCallMode {
override val queryType: InternalQueryType = READ_WRITE

override def call(ctx: QueryContext, name: QualifiedProcedureName, args: Seq[Any]): Iterator[Array[AnyRef]] = {
override def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]] = {
val builder = ArrayBuffer.newBuilder[Array[AnyRef]]
val iterator = ctx.callReadWriteProcedure(name, args, allowed)
while (iterator.hasNext) {
builder += iterator.next()
}
builder.result().iterator
}

override def callFunction(ctx: QueryContext, name: QualifiedName, args: Seq[Any]) =
ctx.callReadWriteFunction(name, args, allowed)
}

case class SchemaWriteCallMode(allowed: Array[String]) extends ProcedureCallMode {
override val queryType: InternalQueryType = SCHEMA_WRITE

override def call(ctx: QueryContext, name: QualifiedProcedureName, args: Seq[Any]): Iterator[Array[AnyRef]] = {
override def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]] = {
val builder = ArrayBuffer.newBuilder[Array[AnyRef]]
val iterator = ctx.callSchemaWriteProcedure(name, args, allowed)
while (iterator.hasNext) {
builder += iterator.next()
}
builder.result().iterator
}

override def callFunction(ctx: QueryContext, name: QualifiedName, args: Seq[Any]) =
ctx.callSchemaWriteFunction(name, args, allowed)
}

case class DbmsCallMode(allowed: Array[String]) extends ProcedureCallMode {
override val queryType: InternalQueryType = DBMS

override def call(ctx: QueryContext, name: QualifiedProcedureName, args: Seq[Any]): Iterator[Array[AnyRef]] = {
override def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]] = {
val builder = ArrayBuffer.newBuilder[Array[AnyRef]]
val iterator = ctx.callDbmsProcedure(name, args, allowed)
while (iterator.hasNext) {
builder += iterator.next()
}
builder.result().iterator
}

override def callFunction(ctx: QueryContext, name: QualifiedName, args: Seq[Any]) =
ctx.callDbmsFunction(name, args, allowed)
}
Expand Up @@ -24,7 +24,7 @@ import java.util
import org.neo4j.cypher.internal.compiler.v3_1.codegen.ResultRowImpl
import org.neo4j.cypher.internal.compiler.v3_1.executionplan.{InternalQueryType, ProcedureCallMode, StandardInternalExecutionResult}
import org.neo4j.cypher.internal.compiler.v3_1.planDescription.InternalPlanDescription
import org.neo4j.cypher.internal.compiler.v3_1.spi.{QualifiedProcedureName, InternalResultVisitor, QueryContext}
import org.neo4j.cypher.internal.compiler.v3_1.spi.{InternalResultVisitor, QualifiedName, QueryContext}
import org.neo4j.cypher.internal.compiler.v3_1.{ExecutionMode, InternalQueryStatistics, ProfileMode, TaskCloser}
import org.neo4j.cypher.internal.frontend.v3_1.ProfilerStatisticsNotReadyException

Expand All @@ -42,7 +42,7 @@ import org.neo4j.cypher.internal.frontend.v3_1.ProfilerStatisticsNotReadyExcepti
*/
class ProcedureExecutionResult[E <: Exception](context: QueryContext,
taskCloser: TaskCloser,
name: QualifiedProcedureName,
name: QualifiedName,
callMode: ProcedureCallMode,
args: Seq[Any],
indexResultNameMappings: Seq[(Int, String)],
Expand All @@ -55,7 +55,7 @@ class ProcedureExecutionResult[E <: Exception](context: QueryContext,
private final val executionResults = executeCall

// The signature mode is taking care of eagerization
protected def executeCall = callMode.call(context, name, args)
protected def executeCall = callMode.callProcedure(context, name, args)

override protected def createInner = new util.Iterator[util.Map[String, Any]]() {
override def next(): util.Map[String, Any] =
Expand Down

0 comments on commit f0ae0ca

Please sign in to comment.