From c50ae5e2b805505182a89056f45250a7992aa19c Mon Sep 17 00:00:00 2001 From: Pontus Melke Date: Tue, 6 Mar 2018 11:30:00 +0100 Subject: [PATCH] Support calling by name as well as by id In order to support older versions of Cypher (3.3) we need to keep the infrastructure for calling procedures et al by name as well. --- .../logical/plans/ProcedureSignature.scala | 8 +- .../v3_4/phases/DeprecationWarnings.scala | 6 +- .../v3_4/RewriteProcedureCallsTest.scala | 2 +- .../compiler/v3_4/ast/CallClauseTest.scala | 36 ++-- .../StatementConvertersTest.scala | 1 - .../planner/LogicalPlanningTestSupport.scala | 1 - .../logical/PlanEventHorizonTest.scala | 2 +- .../v3_3/LogicalPlanConverter.scala | 32 ++- .../ExceptionTranslatingQueryContext.scala | 19 ++ .../procs/ProcedureExecutionResult.scala | 6 +- .../v3_3/LogicalPlanConverterTest.scala | 4 +- .../ProcedureCallExecutionPlanTest.scala | 2 - .../interpreted/DelegatingQueryContext.scala | 20 ++ .../TransactionBoundPlanContext.scala | 6 +- .../TransactionBoundQueryContext.scala | 94 +++++++-- .../CommunityExpressionConverter.scala | 15 +- .../AggregationFunctionInvocation.scala | 27 ++- .../expressions/FunctionInvocation.scala | 34 ++- .../interpreted/pipes/ProcedureCallPipe.scala | 13 +- .../interpreted/QueryContextAdaptation.scala | 12 ++ .../pipes/ProcedureCallPipeTest.scala | 3 +- .../internal/runtime/ProcedureCallMode.scala | 37 +++- .../internal/runtime/QueryContext.scala | 7 + .../neo4j/internal/kernel/api/Procedures.java | 92 ++++++++- .../kernel/impl/newapi/AllStoreHolder.java | 195 ++++++++++++++++-- .../neo4j/kernel/impl/newapi/MockStore.java | 74 ++++++- 26 files changed, 647 insertions(+), 101 deletions(-) diff --git a/community/cypher/cypher-logical-plans-3.4/src/main/scala/org/neo4j/cypher/internal/v3_4/logical/plans/ProcedureSignature.scala b/community/cypher/cypher-logical-plans-3.4/src/main/scala/org/neo4j/cypher/internal/v3_4/logical/plans/ProcedureSignature.scala index 716b921242896..0b8f6d46a0f16 100644 --- a/community/cypher/cypher-logical-plans-3.4/src/main/scala/org/neo4j/cypher/internal/v3_4/logical/plans/ProcedureSignature.scala +++ b/community/cypher/cypher-logical-plans-3.4/src/main/scala/org/neo4j/cypher/internal/v3_4/logical/plans/ProcedureSignature.scala @@ -24,13 +24,13 @@ import org.neo4j.cypher.internal.util.v3_4.symbols.CypherType import org.neo4j.cypher.internal.v3_4.expressions.FunctionInvocation case class ProcedureSignature(name: QualifiedName, - id: Int, inputSignature: IndexedSeq[FieldSignature], outputSignature: Option[IndexedSeq[FieldSignature]], deprecationInfo: Option[String], accessMode: ProcedureAccessMode, description: Option[String] = None, - warning: Option[String] = None) { + warning: Option[String] = None, + id: Option[Int] = None) { def outputFields = outputSignature.getOrElse(Seq.empty) @@ -43,13 +43,13 @@ case class ProcedureSignature(name: QualifiedName, } case class UserFunctionSignature(name: QualifiedName, - id: Int, inputSignature: IndexedSeq[FieldSignature], outputType: CypherType, deprecationInfo: Option[String], allowed: Array[String], description: Option[String], - isAggregate: Boolean) { + isAggregate: Boolean, + id: Option[Int] = None) { override def toString = s"$name(${inputSignature.mkString(", ")}) :: ${outputType.toNeoTypeString}" } diff --git a/community/cypher/cypher-planner-3.4/src/main/scala/org/neo4j/cypher/internal/compiler/v3_4/phases/DeprecationWarnings.scala b/community/cypher/cypher-planner-3.4/src/main/scala/org/neo4j/cypher/internal/compiler/v3_4/phases/DeprecationWarnings.scala index 1818db0bea02c..33a32f7c1b706 100644 --- a/community/cypher/cypher-planner-3.4/src/main/scala/org/neo4j/cypher/internal/compiler/v3_4/phases/DeprecationWarnings.scala +++ b/community/cypher/cypher-planner-3.4/src/main/scala/org/neo4j/cypher/internal/compiler/v3_4/phases/DeprecationWarnings.scala @@ -35,7 +35,7 @@ object ProcedureDeprecationWarnings extends VisitorPhase[BaseContext, BaseState] private def findDeprecations(statement: Statement): Set[InternalNotification] = statement.treeFold(Set.empty[InternalNotification]) { - case f@ResolvedCall(ProcedureSignature(name, _, _, _, Some(deprecatedBy), _, _, _), _, _, _, _) => + case f@ResolvedCall(ProcedureSignature(name, _, _, Some(deprecatedBy), _, _, _, _), _, _, _, _) => (seq) => (seq + DeprecatedProcedureNotification(f.position, name.toString, deprecatedBy), None) case _:UnresolvedCall => throw new InternalException("Expected procedures to have been resolved already") @@ -55,9 +55,9 @@ object ProcedureWarnings extends VisitorPhase[BaseContext, BaseState] { private def findWarnings(statement: Statement): Set[InternalNotification] = statement.treeFold(Set.empty[InternalNotification]) { - case f@ResolvedCall(ProcedureSignature(name, _,_, _, _, _, _, Some(warning)), _, _, _, _) => + case f@ResolvedCall(ProcedureSignature(name, _, _, _, _, _, Some(warning),_), _, _, _, _) => (seq) => (seq + ProcedureWarningNotification(f.position, name.toString, warning), None) - case ResolvedCall(ProcedureSignature(name, _, _, Some(output), None, _, _, _), _, results, _, _) + case ResolvedCall(ProcedureSignature(name, _, Some(output), None, _, _, _, _), _, results, _, _) if output.exists(_.deprecated) => (set) => (set ++ usedDeprecatedFields(name.toString, results, output), None) case _:UnresolvedCall => throw new InternalException("Expected procedures to have been resolved already") diff --git a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/RewriteProcedureCallsTest.scala b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/RewriteProcedureCallsTest.scala index b357869b447e5..b1476586da655 100644 --- a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/RewriteProcedureCallsTest.scala +++ b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/RewriteProcedureCallsTest.scala @@ -34,7 +34,7 @@ class RewriteProcedureCallsTest extends CypherFunSuite with AstConstructionTestS val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, 42, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty[String])) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty[String])) val procLookup: (QualifiedName) => ProcedureSignature = _ => signature val fcnLookup: (QualifiedName) => Option[UserFunctionSignature] = _ => None diff --git a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/CallClauseTest.scala b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/CallClauseTest.scala index a90d8ca4d5809..71a81e3b4d3c4 100644 --- a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/CallClauseTest.scala +++ b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/CallClauseTest.scala @@ -37,7 +37,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { val unresolved = UnresolvedCall(ns, name, None, None)(pos) val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = IndexedSeq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) @@ -62,7 +63,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { val unresolved = UnresolvedCall(ns, name, None, None)(pos) val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = None - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq.empty @@ -86,7 +88,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should resolve CALL my.proc.foo YIELD x, y") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) val unresolved = UnresolvedCall(ns, name, None, Some(ProcedureResult(callResults)(pos)))(pos) @@ -111,7 +114,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should resolve CALL my.proc.foo(a)") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) val unresolved = UnresolvedCall(ns, name, Some(callArguments), None)(pos) @@ -136,7 +140,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should resolve void CALL my.proc.foo(a)") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = None - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq.empty val unresolved = UnresolvedCall(ns, name, Some(callArguments), None)(pos) @@ -161,7 +166,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should resolve CALL my.proc.foo(a) YIELD x, y AS z") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq( ProcedureResultItem(varFor("x"))(pos), @@ -185,7 +191,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("pretends to be based on user-declared arguments and results upon request") { - val signature = ProcedureSignature(qualifiedName, ID, IndexedSeq.empty, Some(IndexedSeq.empty), None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, IndexedSeq.empty, Some(IndexedSeq.empty), None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val call = ResolvedCall(signature, null, null, declaredArguments = false, declaredResults = false)(pos) call.withFakedFullDeclarations.declaredArguments should be(true) @@ -195,7 +202,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("adds coercion of arguments to signature types upon request") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq( ProcedureResultItem(varFor("x"))(pos), @@ -222,7 +230,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should verify number of arguments during semantic checking of resolved calls") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq.empty val callResults = IndexedSeq( ProcedureResultItem(varFor("x"))(pos), @@ -243,7 +252,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should verify that result variables are unique during semantic checking of resolved calls") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq( ProcedureResultItem(varFor("x"))(pos), @@ -260,7 +270,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should verify that output field names are correct during semantic checking of resolved calls") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = IndexedSeq( ProcedureResultItem(varFor("x"))(pos), @@ -277,7 +288,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should verify result types during semantic checking of resolved calls") { val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, ID, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, + ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val callArguments = Seq(StringLiteral("nope")(pos)) val callResults = IndexedSeq( ProcedureResultItem(varFor("x"))(pos), diff --git a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/convert/plannerQuery/StatementConvertersTest.scala b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/convert/plannerQuery/StatementConvertersTest.scala index e754dbd8b15b0..423798aa3ddca 100644 --- a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/convert/plannerQuery/StatementConvertersTest.scala +++ b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/ast/convert/plannerQuery/StatementConvertersTest.scala @@ -860,7 +860,6 @@ class StatementConvertersTest extends CypherFunSuite with LogicalPlanningTestSup test("CALL foo() YIELD all RETURN all") { val signature = ProcedureSignature( QualifiedName(Seq.empty, "foo"), - 42, inputSignature = IndexedSeq.empty, deprecationInfo = None, outputSignature = Some(IndexedSeq(FieldSignature("all", CTInteger))), diff --git a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/LogicalPlanningTestSupport.scala b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/LogicalPlanningTestSupport.scala index a3ce88b8e5111..bb4510b52753c 100644 --- a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/LogicalPlanningTestSupport.scala +++ b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/LogicalPlanningTestSupport.scala @@ -264,7 +264,6 @@ trait LogicalPlanningTestSupport extends CypherTestSupport with AstConstructionT fcnLookup: Option[QualifiedName => Option[UserFunctionSignature]] = None) = { val signature = ProcedureSignature( QualifiedName(Seq.empty, "foo"), - 42, inputSignature = IndexedSeq.empty, deprecationInfo = None, outputSignature = Some(IndexedSeq(FieldSignature("all", CTInteger))), diff --git a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/logical/PlanEventHorizonTest.scala b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/logical/PlanEventHorizonTest.scala index bac71a4f59bc9..dadf24e00e249 100644 --- a/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/logical/PlanEventHorizonTest.scala +++ b/community/cypher/cypher-planner-3.4/src/test/scala/org/neo4j/cypher/internal/compiler/v3_4/planner/logical/PlanEventHorizonTest.scala @@ -52,7 +52,7 @@ class PlanEventHorizonTest extends CypherFunSuite with LogicalPlanningTestSuppor val qualifiedName = QualifiedName(ns.parts, name.name) val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) - val signature = ProcedureSignature(qualifiedName, 42, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, None, ProcedureReadOnlyAccess(Array.empty)) val callResults = IndexedSeq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) val call = ResolvedCall(signature, Seq.empty, callResults)(pos) diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverter.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverter.scala index 60c4668d1866c..d12be2c4c0882 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverter.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverter.scala @@ -22,22 +22,23 @@ package org.neo4j.cypher.internal.compatibility.v3_3 import java.lang.reflect.Constructor import org.neo4j.cypher.internal.compatibility.v3_3.SemanticTableConverter.ExpressionMapping3To4 -import org.neo4j.cypher.internal.planner.v3_4.spi.PlanningAttributes.{Cardinalities, Solveds} +import org.neo4j.cypher.internal.compiler.{v3_3 => compilerV3_3} import org.neo4j.cypher.internal.frontend.v3_3.ast.{Expression => ExpressionV3_3} import org.neo4j.cypher.internal.frontend.v3_3.{InputPosition => InputPositionV3_3, SemanticDirection => SemanticDirectionV3_3, ast => astV3_3, symbols => symbolsV3_3} import org.neo4j.cypher.internal.frontend.{v3_3 => frontendV3_3} import org.neo4j.cypher.internal.ir.{v3_3 => irV3_3, v3_4 => irV3_4} +import org.neo4j.cypher.internal.planner.v3_4.spi.PlanningAttributes.{Cardinalities, Solveds} import org.neo4j.cypher.internal.util.v3_4.Rewritable.RewritableAny +import org.neo4j.cypher.internal.util.v3_4.attribution.{Id, IdGen, SameId, SequentialIdGen} +import org.neo4j.cypher.internal.util.v3_4.symbols.CypherType import org.neo4j.cypher.internal.util.v3_4.{symbols => symbolsV3_4, _} import org.neo4j.cypher.internal.util.{v3_4 => utilV3_4} import org.neo4j.cypher.internal.v3_3.logical.plans.{LogicalPlan => LogicalPlanV3_3} import org.neo4j.cypher.internal.v3_3.logical.{plans => plansV3_3} import org.neo4j.cypher.internal.v3_4.expressions.{Expression => ExpressionV3_4} -import org.neo4j.cypher.internal.v3_4.logical.plans.{LogicalPlan => LogicalPlanV3_4} +import org.neo4j.cypher.internal.v3_4.logical.plans.{FieldSignature, ProcedureAccessMode, QualifiedName, LogicalPlan => LogicalPlanV3_4} import org.neo4j.cypher.internal.v3_4.logical.{plans => plansV3_4} import org.neo4j.cypher.internal.v3_4.{expressions => expressionsV3_4} -import org.neo4j.cypher.internal.compiler.{v3_3 => compilerV3_3} -import org.neo4j.cypher.internal.util.v3_4.attribution.{Id, IdGen, SameId, SequentialIdGen} import scala.collection.mutable import scala.collection.mutable.{HashMap => MutableHashMap} @@ -142,12 +143,31 @@ object LogicalPlanConverter { convertVersion("frontend.v3_3", "util.v3_4", nameId, children) case (frontendV3_3.helpers.Fby(head, tail), children: Seq[AnyRef]) => utilV3_4.Fby(children(0), children(1).asInstanceOf[utilV3_4.NonEmptyList[_]]) case (frontendV3_3.helpers.Last(head), children: Seq[AnyRef]) => utilV3_4.Last(children(0)) + + case ( _:plansV3_3.ProcedureSignature, children: Seq[AnyRef]) => + plansV3_4.ProcedureSignature(children(0).asInstanceOf[QualifiedName], + children(1).asInstanceOf[IndexedSeq[FieldSignature]], + children(2).asInstanceOf[Option[IndexedSeq[FieldSignature]]], + children(3).asInstanceOf[Option[String]], + children(4).asInstanceOf[ProcedureAccessMode], + children(5).asInstanceOf[Option[String]], + children(5).asInstanceOf[Option[String]], + None) + + case ( _:plansV3_3.UserFunctionSignature, children: Seq[AnyRef]) => + plansV3_4.UserFunctionSignature(children(0).asInstanceOf[QualifiedName], + children(1).asInstanceOf[IndexedSeq[FieldSignature]], + children(2).asInstanceOf[CypherType], + children(3).asInstanceOf[Option[String]], + children(4).asInstanceOf[Array[String]], + children(5).asInstanceOf[Option[String]], + children(5).asInstanceOf[Boolean], + None) + case (item@(_: plansV3_3.CypherValue | _: plansV3_3.QualifiedName | _: plansV3_3.FieldSignature | _: plansV3_3.ProcedureAccessMode | - _: plansV3_3.ProcedureSignature | - _: plansV3_3.UserFunctionSignature | _: plansV3_3.QueryExpression[_] | _: plansV3_3.SeekableArgs | _: irV3_3.PatternRelationship | diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/ExceptionTranslatingQueryContext.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/ExceptionTranslatingQueryContext.scala index 4175a68f19fec..a642860262259 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/ExceptionTranslatingQueryContext.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/ExceptionTranslatingQueryContext.scala @@ -26,6 +26,7 @@ import org.neo4j.cypher.internal.planner.v3_4.spi.IndexDescriptor import org.neo4j.cypher.internal.runtime._ import org.neo4j.cypher.internal.runtime.interpreted.{DelegatingOperations, DelegatingQueryTransactionalContext} import org.neo4j.cypher.internal.v3_4.expressions.SemanticDirection +import org.neo4j.cypher.internal.v3_4.logical.plans.QualifiedName import org.neo4j.graphdb.{Node, Path, PropertyContainer} import org.neo4j.internal.kernel.api.IndexReference import org.neo4j.internal.kernel.api.helpers.RelationshipSelectionCursor @@ -153,14 +154,32 @@ class ExceptionTranslatingQueryContext(val inner: QueryContext) extends QueryCon override def callDbmsProcedure(id: Int, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = translateIterator(inner.callDbmsProcedure(id, args, allowed)) + override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = + translateIterator(inner.callReadOnlyProcedure(name, args, allowed)) + + override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = + translateIterator(inner.callReadWriteProcedure(name, args, allowed)) + + override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = + translateIterator(inner.callSchemaWriteProcedure(name, args, allowed)) + + override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = + translateIterator(inner.callDbmsProcedure(name, args, allowed)) + override def callFunction(id: Int, args: Seq[AnyValue], allowed: Array[String]) = translateException(inner.callFunction(id, args, allowed)) + override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]) = + translateException(inner.callFunction(name, args, allowed)) override def aggregateFunction(id: Int, allowed: Array[String]): UserDefinedAggregator = translateException(inner.aggregateFunction(id, allowed)) + override def aggregateFunction(name: QualifiedName, + allowed: Array[String]): UserDefinedAggregator = + translateException(inner.aggregateFunction(name, allowed)) + override def isGraphKernelResultValue(v: Any): Boolean = translateException(inner.isGraphKernelResultValue(v)) diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureExecutionResult.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureExecutionResult.scala index 67992edce604e..f9797cce53bb0 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureExecutionResult.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureExecutionResult.scala @@ -56,7 +56,7 @@ import org.neo4j.values.storable._ class ProcedureExecutionResult(context: QueryContext, taskCloser: TaskCloser, name: QualifiedName, - id: Int, + id: Option[Int], callMode: ProcedureCallMode, args: Seq[Any], indexResultNameMappings: IndexedSeq[(Int, String, CypherType)], @@ -69,7 +69,9 @@ class ProcedureExecutionResult(context: QueryContext, private final val executionResults = executeCall // The signature mode is taking care of eagerization - protected def executeCall: Iterator[Array[AnyRef]] = callMode.callProcedure(context, id, args) + protected def executeCall: Iterator[Array[AnyRef]] = + if (id.nonEmpty) callMode.callProcedure(context, id.get, args) + else callMode.callProcedure(context, name, args) override protected def createInner = new util.Iterator[util.Map[String, Any]]() { override def next(): util.Map[String, Any] = diff --git a/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverterTest.scala b/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverterTest.scala index 3dcff7ab48094..48dff071351ef 100644 --- a/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverterTest.scala +++ b/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_3/LogicalPlanConverterTest.scala @@ -339,7 +339,7 @@ class LogicalPlanConverterTest extends FunSuite with Matchers { val var3_4 = expressionsV3_4.Variable("n")(pos3_4) val a3_4 = plansV3_4.AllNodesScan("n", Set.empty) val inputv3_4 = plansV3_4.FieldSignature("d", symbolsV3_4.CTString, Some(plansV3_4.CypherValue("e", symbolsV3_4.CTString))) - val sigv3_4 = plansV3_4.ProcedureSignature(plansV3_4.QualifiedName(Seq("a", "b"), "c"), 42, IndexedSeq(inputv3_4), None, None, plansV3_4.ProcedureReadWriteAccess(Array("foo", "bar"))) + val sigv3_4 = plansV3_4.ProcedureSignature(plansV3_4.QualifiedName(Seq("a", "b"), "c"), IndexedSeq(inputv3_4), None, None, plansV3_4.ProcedureReadWriteAccess(Array("foo", "bar"))) val pres3_4 = astV3_4.ProcedureResultItem(Some(expressionsV3_4.ProcedureOutput("f")(pos3_4)), var3_4)(pos3_4) val rc3_4 = plansV3_4.ResolvedCall(sigv3_4, Seq(var3_4), IndexedSeq(pres3_4))(pos3_4) val pc3_4 = plansV3_4.ProcedureCall(a3_4, rc3_4) @@ -423,7 +423,7 @@ class LogicalPlanConverterTest extends FunSuite with Matchers { val name3_4 = plansV3_4.QualifiedName(Seq.empty, "foo") val call3_4 = plansV3_4.ResolvedFunctionInvocation(name3_4, - Some(plansV3_4.UserFunctionSignature(name3_4, 0,Vector(plansV3_4.FieldSignature("input", symbolsV3_4.CTAny, + Some(plansV3_4.UserFunctionSignature(name3_4, Vector(plansV3_4.FieldSignature("input", symbolsV3_4.CTAny, default = Some(plansV3_4.CypherValue(null, symbolsV3_4.CTAny)))), symbolsV3_4.CTAny, None, allowed, None, isAggregate = false)), Vector())(InputPosition(1, 2, 3)) diff --git a/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureCallExecutionPlanTest.scala b/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureCallExecutionPlanTest.scala index afef7c9292ca4..37fb5111bf539 100644 --- a/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureCallExecutionPlanTest.scala +++ b/community/cypher/cypher/src/test/scala/org/neo4j/cypher/internal/compatibility/v3_4/runtime/executionplan/procs/ProcedureCallExecutionPlanTest.scala @@ -85,7 +85,6 @@ class ProcedureCallExecutionPlanTest extends CypherFunSuite { private val readSignature = ProcedureSignature( QualifiedName(IndexedSeq.empty, "foo"), - 42, IndexedSeq(FieldSignature("a", CTInteger)), Some(IndexedSeq(FieldSignature("b", CTInteger))), None, @@ -94,7 +93,6 @@ class ProcedureCallExecutionPlanTest extends CypherFunSuite { private val writeSignature = ProcedureSignature( QualifiedName(Seq.empty, "foo"), - 42, IndexedSeq(FieldSignature("a", CTInteger)), Some(IndexedSeq(FieldSignature("b", CTInteger))), None, diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/DelegatingQueryContext.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/DelegatingQueryContext.scala index be09485e27f90..173f6bd591c06 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/DelegatingQueryContext.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/DelegatingQueryContext.scala @@ -25,6 +25,7 @@ import org.neo4j.collection.primitive.PrimitiveLongIterator import org.neo4j.cypher.internal.planner.v3_4.spi.{IndexDescriptor, KernelStatisticProvider} import org.neo4j.cypher.internal.runtime._ import org.neo4j.cypher.internal.v3_4.expressions.SemanticDirection +import org.neo4j.cypher.internal.v3_4.logical.plans.QualifiedName import org.neo4j.graphdb.{Node, Path, PropertyContainer} import org.neo4j.internal.kernel.api.helpers.RelationshipSelectionCursor import org.neo4j.internal.kernel.api.{CursorFactory, IndexReference, Read, Write, _} @@ -224,13 +225,32 @@ abstract class DelegatingQueryContext(val inner: QueryContext) extends QueryCont override def callDbmsProcedure(id: Int, args: Seq[Any], allowed: Array[String]) = inner.callDbmsProcedure(id, args, allowed) + override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = + singleDbHit(inner.callReadOnlyProcedure(name, args, allowed)) + + override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = + singleDbHit(inner.callReadWriteProcedure(name, args, allowed)) + + override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = + singleDbHit(inner.callSchemaWriteProcedure(name, args, allowed)) + + override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = + inner.callDbmsProcedure(name, args, allowed) + override def callFunction(id: Int, args: Seq[AnyValue], allowed: Array[String]) = singleDbHit(inner.callFunction(id, args, allowed)) + override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]) = + singleDbHit(inner.callFunction(name, args, allowed)) + override def aggregateFunction(id: Int, allowed: Array[String]): UserDefinedAggregator = singleDbHit(inner.aggregateFunction(id, allowed)) + override def aggregateFunction(name: QualifiedName, + allowed: Array[String]): UserDefinedAggregator = + singleDbHit(inner.aggregateFunction(name, allowed)) + override def isGraphKernelResultValue(v: Any): Boolean = inner.isGraphKernelResultValue(v) diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundPlanContext.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundPlanContext.scala index a2f2cb9dc0463..3775f3030d6bc 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundPlanContext.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundPlanContext.scala @@ -144,7 +144,7 @@ class TransactionBoundPlanContext(tc: TransactionalContextWrapper, logger: Inter val description = asOption(signature.description()) val warning = asOption(signature.warning()) - ProcedureSignature(name, handle.id(), input, output, deprecationInfo, mode, description, warning) + ProcedureSignature(name, input, output, deprecationInfo, mode, description, warning, Some(handle.id())) } override def functionSignature(name: QualifiedName): Option[UserFunctionSignature] = { @@ -165,8 +165,8 @@ class TransactionBoundPlanContext(tc: TransactionalContextWrapper, logger: Inter val description = asOption(signature.description()) - Some(UserFunctionSignature(name, fcn.id(), input, output, deprecationInfo, - signature.allowed(), description, isAggregate = aggregation)) + Some(UserFunctionSignature(name, input, output, deprecationInfo, + signature.allowed(), description, isAggregate = aggregation, id = Some(fcn.id()))) } } diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundQueryContext.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundQueryContext.scala index 8c96704a79eac..0f7c60a35c97c 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundQueryContext.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/TransactionBoundQueryContext.scala @@ -46,6 +46,7 @@ import org.neo4j.internal.kernel.api._ import org.neo4j.internal.kernel.api.exceptions.ProcedureException import org.neo4j.internal.kernel.api.helpers.RelationshipSelections.{allCursor, incomingCursor, outgoingCursor} import org.neo4j.internal.kernel.api.helpers._ +import org.neo4j.internal.kernel.api.procs.{UserAggregator, QualifiedName => KernelQualifiedName} import org.neo4j.kernel.GraphDatabaseQueryService import org.neo4j.kernel.api._ import org.neo4j.kernel.api.exceptions.schema.{AlreadyConstrainedException, AlreadyIndexedException} @@ -961,7 +962,7 @@ sealed class TransactionBoundQueryContext(val transactionalContext: Transactiona .asScala } - type KernelProcedureCall = (Int, Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException] + type KernelProcedureCall = (Array[AnyRef]) => RawIterator[Array[AnyRef], ProcedureException] private def shouldElevate(allowed: Array[String]): Boolean = { // We have to be careful with elevation, since we cannot elevate permissions in a nested procedure call @@ -973,42 +974,82 @@ sealed class TransactionBoundQueryContext(val transactionalContext: Transactiona override def callReadOnlyProcedure(id: Int, args: Seq[Any], allowed: Array[String]) = { val call: KernelProcedureCall = if (shouldElevate(allowed)) - transactionalContext.tc.kernelTransaction().procedures().procedureCallReadOverride(_, _) + transactionalContext.tc.kernelTransaction().procedures().procedureCallReadOverride(id, _) else - transactionalContext.tc.kernelTransaction().procedures().procedureCallRead(_, _) + transactionalContext.tc.kernelTransaction().procedures().procedureCallRead(id, _) - callProcedure(id, args, call) + callProcedure(args, call) } override def callReadWriteProcedure(id: Int, args: Seq[Any], allowed: Array[String]) = { val call: KernelProcedureCall = if (shouldElevate(allowed)) - transactionalContext.tc.kernelTransaction().procedures().procedureCallWriteOverride(_, _) + transactionalContext.tc.kernelTransaction().procedures().procedureCallWriteOverride(id, _) else - transactionalContext.tc.kernelTransaction().procedures().procedureCallWrite(_, _) - callProcedure(id, args, call) + transactionalContext.tc.kernelTransaction().procedures().procedureCallWrite(id, _) + callProcedure(args, call) } override def callSchemaWriteProcedure(id: Int, args: Seq[Any], allowed: Array[String]) = { val call: KernelProcedureCall = if (shouldElevate(allowed)) - transactionalContext.tc.kernelTransaction().procedures().procedureCallSchemaOverride(_, _) + transactionalContext.tc.kernelTransaction().procedures().procedureCallSchemaOverride(id, _) else - transactionalContext.tc.kernelTransaction().procedures().procedureCallSchema(_, _) - callProcedure(id, args, call) + transactionalContext.tc.kernelTransaction().procedures().procedureCallSchema(id, _) + callProcedure(args, call) } override def callDbmsProcedure(id: Int, args: Seq[Any], allowed: Array[String]) = { - callProcedure(id, args, - transactionalContext.dbmsOperations.procedureCallDbms(_, + callProcedure(args, + transactionalContext.dbmsOperations.procedureCallDbms(id, _, transactionalContext.securityContext, transactionalContext.resourceTracker)) } - private def callProcedure(id: Int, args: Seq[Any], call: KernelProcedureCall) = { + override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { + val kn = new KernelQualifiedName(name.namespace.asJava, name.name) + val call: KernelProcedureCall = + if (shouldElevate(allowed)) + transactionalContext.tc.kernelTransaction().procedures().procedureCallReadOverride(kn, _) + else + transactionalContext.tc.kernelTransaction().procedures().procedureCallRead(kn, _) + + callProcedure(args, call) + } + + override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { + val kn = new KernelQualifiedName(name.namespace.asJava, name.name) + val call: KernelProcedureCall = + if (shouldElevate(allowed)) + transactionalContext.tc.kernelTransaction().procedures().procedureCallWriteOverride(kn, _) + else + transactionalContext.tc.kernelTransaction().procedures().procedureCallWrite(kn, _) + callProcedure(args, call) + } + + override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { + val kn = new KernelQualifiedName(name.namespace.asJava, name.name) + val call: KernelProcedureCall = + if (shouldElevate(allowed)) + transactionalContext.tc.kernelTransaction().procedures().procedureCallSchemaOverride(kn, _) + else + transactionalContext.tc.kernelTransaction().procedures().procedureCallSchema(kn, _) + callProcedure(args, call) + } + + override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]) = { + val kn = new KernelQualifiedName(name.namespace.asJava, name.name) + callProcedure(args, + transactionalContext.dbmsOperations.procedureCallDbms(kn, + _, + transactionalContext.securityContext, + transactionalContext.resourceTracker)) + } + + private def callProcedure(args: Seq[Any], call: KernelProcedureCall) = { val toArray = args.map(_.asInstanceOf[AnyRef]).toArray - val read = call(id, toArray) + val read = call(toArray) new scala.Iterator[Array[AnyRef]] { override def hasNext: Boolean = read.hasNext @@ -1023,13 +1064,36 @@ sealed class TransactionBoundQueryContext(val transactionalContext: Transactiona transactionalContext.tc.kernelTransaction().procedures().functionCall(id, args.toArray) } + override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]) = { + val kn = new KernelQualifiedName(name.namespace.asJava, name.name) + if (shouldElevate(allowed)) + transactionalContext.tc.kernelTransaction().procedures().functionCallOverride(kn, args.toArray) + else + transactionalContext.tc.kernelTransaction().procedures().functionCall(kn, args.toArray) + } + override def aggregateFunction(id: Int, allowed: Array[String]) = { - val aggregator = + val aggregator: UserAggregator = if (shouldElevate(allowed)) transactionalContext.tc.kernelTransaction().procedures().aggregationFunctionOverride(id) else transactionalContext.tc.kernelTransaction().procedures().aggregationFunction(id) + userDefinedAggregator(aggregator) + } + + override def aggregateFunction(name: QualifiedName, allowed: Array[String]) = { + val kn = new KernelQualifiedName(name.namespace.asJava, name.name) + val aggregator: UserAggregator = + if (shouldElevate(allowed)) + transactionalContext.tc.kernelTransaction().procedures().aggregationFunctionOverride(kn) + else + transactionalContext.tc.kernelTransaction().procedures().aggregationFunction(kn) + + userDefinedAggregator(aggregator) + } + + private def userDefinedAggregator(aggregator: UserAggregator) = { new UserDefinedAggregator { override def result: AnyRef = aggregator.result() diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/convert/CommunityExpressionConverter.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/convert/CommunityExpressionConverter.scala index 1e5bdba37a80a..1d0cb40b4aa79 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/convert/CommunityExpressionConverter.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/convert/CommunityExpressionConverter.scala @@ -19,18 +19,17 @@ */ package org.neo4j.cypher.internal.runtime.interpreted.commands.convert -import org.neo4j.cypher.internal.util.v3_4.{InternalException, NonEmptyList} +import org.neo4j.cypher.internal.frontend.v3_4.ast.rewriters.DesugaredMapProjection import org.neo4j.cypher.internal.runtime.interpreted._ import org.neo4j.cypher.internal.runtime.interpreted.commands.expressions.{InequalitySeekRangeExpression, PointDistanceSeekRangeExpression, Expression => CommandExpression} import org.neo4j.cypher.internal.runtime.interpreted.commands.predicates.Predicate import org.neo4j.cypher.internal.runtime.interpreted.commands.values.TokenType.PropertyKey import org.neo4j.cypher.internal.runtime.interpreted.commands.values.UnresolvedRelType import org.neo4j.cypher.internal.runtime.interpreted.commands.{PathExtractorExpression, predicates, expressions => commandexpressions, values => commandvalues} +import org.neo4j.cypher.internal.util.v3_4.{InternalException, NonEmptyList} import org.neo4j.cypher.internal.v3_4.functions._ -import org.neo4j.cypher.internal.v3_4.functions -import org.neo4j.cypher.internal.v3_4.{expressions => ast} -import org.neo4j.cypher.internal.frontend.v3_4.ast.rewriters.DesugaredMapProjection import org.neo4j.cypher.internal.v3_4.logical.plans._ +import org.neo4j.cypher.internal.v3_4.{functions, expressions => ast} object CommunityExpressionConverter extends ExpressionConverter { @@ -104,8 +103,12 @@ object CommunityExpressionConverter extends ExpressionConverter { case (given, default) => given.map(self.toCommandExpression).getOrElse(commandexpressions.Literal(default.get)) } val signature = e.fcnSignature.get - if (signature.isAggregate) commandexpressions.AggregationFunctionInvocation(signature, callArgumentCommands) - else commandexpressions.FunctionInvocation(signature, callArgumentCommands) + (signature.isAggregate, signature.id) match { + case (true, Some(_)) => commandexpressions.AggregationFunctionInvocationById(signature, callArgumentCommands) + case (true, None) => commandexpressions.AggregationFunctionInvocationByName(signature, callArgumentCommands) + case (false, Some(_)) => commandexpressions.FunctionInvocationById(signature, callArgumentCommands) + case (false, None) => commandexpressions.FunctionInvocationByName(signature, callArgumentCommands) + } case e: ast.MapProjection => throw new InternalException("should have been rewritten away") case e: NestedPlanExpression => commandexpressions.NestedPlanExpression(e.plan) case _ => null diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/AggregationFunctionInvocation.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/AggregationFunctionInvocation.scala index 62d765b9a812d..ec54b65eb09c5 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/AggregationFunctionInvocation.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/AggregationFunctionInvocation.scala @@ -20,13 +20,13 @@ package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions import org.neo4j.cypher.internal.runtime.UserDefinedAggregator -import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, ValueConversion} import org.neo4j.cypher.internal.runtime.interpreted.pipes.QueryState import org.neo4j.cypher.internal.runtime.interpreted.pipes.aggregation.AggregationFunction +import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, ValueConversion} import org.neo4j.cypher.internal.v3_4.logical.plans.UserFunctionSignature import org.neo4j.values.AnyValue -case class AggregationFunctionInvocation(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) +abstract class AggregationFunctionInvocation(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) extends AggregationExpression { private val valueConverter = ValueConversion.getValueConverter(signature.outputType) @@ -46,14 +46,31 @@ case class AggregationFunctionInvocation(signature: UserFunctionSignature, argum private def aggregator(state: QueryState) = { if (inner == null) { - inner = state.query.aggregateFunction(signature.id, signature.allowed) + inner = call(state) } inner } } + override def symbolTableDependencies: Set[String] = arguments.flatMap(_.symbolTableDependencies).toSet + + protected def call(state: QueryState): UserDefinedAggregator +} + +case class AggregationFunctionInvocationById(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) + extends AggregationFunctionInvocation(signature, arguments) +{ + protected def call(state: QueryState) = {state.query.aggregateFunction(signature.id.get, signature.allowed)} + override def rewrite(f: (Expression) => Expression): Expression = f( - AggregationFunctionInvocation(signature, arguments.map(a => a.rewrite(f)))) + AggregationFunctionInvocationById(signature, arguments.map(a => a.rewrite(f)))) +} - override def symbolTableDependencies: Set[String] = arguments.flatMap(_.symbolTableDependencies).toSet +case class AggregationFunctionInvocationByName(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) + extends AggregationFunctionInvocation(signature, arguments) +{ + protected def call(state: QueryState) = {state.query.aggregateFunction(signature.name, signature.allowed)} + + override def rewrite(f: (Expression) => Expression): Expression = f( + AggregationFunctionInvocationByName(signature, arguments.map(a => a.rewrite(f)))) } diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/FunctionInvocation.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/FunctionInvocation.scala index c7ebace6a6418..b7a32b8b3cada 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/FunctionInvocation.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/commands/expressions/FunctionInvocation.scala @@ -19,12 +19,13 @@ */ package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions +import org.neo4j.cypher.internal.runtime.QueryContext import org.neo4j.cypher.internal.runtime.interpreted.pipes.QueryState import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, GraphElementPropertyFunctions} import org.neo4j.cypher.internal.v3_4.logical.plans.UserFunctionSignature import org.neo4j.values._ -case class FunctionInvocation(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) +abstract class FunctionInvocation(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) extends Expression with GraphElementPropertyFunctions { override def apply(ctx: ExecutionContext, state: QueryState): AnyValue = { @@ -32,13 +33,38 @@ case class FunctionInvocation(signature: UserFunctionSignature, arguments: Index val argValues = arguments.map(arg => { arg(ctx, state) }) - query.callFunction(signature.id, argValues, signature.allowed) + call(query, argValues) } - override def rewrite(f: (Expression) => Expression) = - f(FunctionInvocation(signature, arguments.map(a => a.rewrite(f)))) + protected def call(query: QueryContext, + argValues: IndexedSeq[AnyValue]): AnyValue + override def symbolTableDependencies = arguments.flatMap(_.symbolTableDependencies).toSet override def toString = s"${signature.name}(${arguments.mkString(",")})" } + +case class FunctionInvocationById(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) + extends FunctionInvocation(signature, arguments) { + + protected def call(query: QueryContext, + argValues: IndexedSeq[AnyValue]): AnyValue = { + query.callFunction(signature.id.get, argValues, signature.allowed) + } + + override def rewrite(f: (Expression) => Expression) = + f(FunctionInvocationById(signature, arguments.map(a => a.rewrite(f)))) +} + +case class FunctionInvocationByName(signature: UserFunctionSignature, arguments: IndexedSeq[Expression]) + extends FunctionInvocation(signature, arguments) { + + protected def call(query: QueryContext, + argValues: IndexedSeq[AnyValue]): AnyValue = { + query.callFunction(signature.name, argValues, signature.allowed) + } + + override def rewrite(f: (Expression) => Expression) = + f(FunctionInvocationById(signature, arguments.map(a => a.rewrite(f)))) +} diff --git a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipe.scala b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipe.scala index 15b45adc9060f..3270d3984f418 100644 --- a/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipe.scala +++ b/community/cypher/interpreted-runtime/src/main/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipe.scala @@ -19,9 +19,9 @@ */ package org.neo4j.cypher.internal.runtime.interpreted.pipes -import org.neo4j.cypher.internal.runtime.ProcedureCallMode -import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, ValueConversion} import org.neo4j.cypher.internal.runtime.interpreted.commands.expressions.Expression +import org.neo4j.cypher.internal.runtime.interpreted.{ExecutionContext, ValueConversion} +import org.neo4j.cypher.internal.runtime.{ProcedureCallMode, QueryContext} import org.neo4j.cypher.internal.util.v3_4.attribution.Id import org.neo4j.cypher.internal.util.v3_4.symbols.CypherType import org.neo4j.cypher.internal.v3_4.logical.plans.ProcedureSignature @@ -64,7 +64,7 @@ case class ProcedureCallPipe(source: Pipe, builder.sizeHint(resultIndices.length) input flatMap { input => val argValues = argExprs.map(arg => qtx.asObject(arg(input, state))) - val results = callMode.callProcedure(qtx, signature.id, argValues) + val results = call(qtx, argValues) results map { resultValues => resultIndices foreach { case (k, v) => val javaValue = maybeConverter.get(k)(resultValues(k)) @@ -78,11 +78,16 @@ case class ProcedureCallPipe(source: Pipe, } } + private def call(qtx: QueryContext, + argValues: Seq[Any]) = + if (signature.id.nonEmpty) callMode.callProcedure(qtx, signature.id.get, argValues) + else callMode.callProcedure(qtx, signature.name, argValues) + private def internalCreateResultsByPassingThrough(input: Iterator[ExecutionContext], state: QueryState): Iterator[ExecutionContext] = { val qtx = state.query input map { input => val argValues = argExprs.map(arg => qtx.asObject(arg(input, state))) - val results = callMode.callProcedure(qtx, signature.id, argValues) + val results = call(qtx, argValues) // the iterator here should be empty; we'll drain just in case while (results.hasNext) results.next() input diff --git a/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/QueryContextAdaptation.scala b/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/QueryContextAdaptation.scala index 03e916e0b1e93..96a4922886868 100644 --- a/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/QueryContextAdaptation.scala +++ b/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/QueryContextAdaptation.scala @@ -25,6 +25,7 @@ import org.neo4j.collection.primitive.PrimitiveLongIterator import org.neo4j.cypher.internal.planner.v3_4.spi.{IdempotentResult, IndexDescriptor} import org.neo4j.cypher.internal.runtime._ import org.neo4j.cypher.internal.v3_4.expressions.SemanticDirection +import org.neo4j.cypher.internal.v3_4.logical.plans.QualifiedName import org.neo4j.graphdb.{Node, Path, PropertyContainer} import org.neo4j.internal.kernel.api.IndexReference import org.neo4j.internal.kernel.api.helpers.RelationshipSelectionCursor @@ -155,12 +156,23 @@ trait QueryContextAdaptation { override def callDbmsProcedure(id: Int, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = ??? + override def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): scala.Iterator[Array[AnyRef]] = ??? + + override def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): scala.Iterator[Array[AnyRef]] = ??? + + override def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = ??? + + override def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] = ??? override def callFunction(id: Int, args: Seq[AnyValue], allowed: Array[String]): AnyValue = ??? override def aggregateFunction(id: Int, allowed: Array[String]): UserDefinedAggregator = ??? + override def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]): AnyValue = ??? + + override def aggregateFunction(name: QualifiedName, allowed: Array[String]): UserDefinedAggregator = ??? + override def getOrCreateFromSchemaState[K, V](key: K, creator: => V): V = ??? override def removeLabelsFromNode(node: Long, labelIds: scala.Iterator[Int]): Int = ??? diff --git a/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipeTest.scala b/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipeTest.scala index 33ba1c162bbca..703e680cfec5b 100644 --- a/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipeTest.scala +++ b/community/cypher/interpreted-runtime/src/test/scala/org/neo4j/cypher/internal/runtime/interpreted/pipes/ProcedureCallPipeTest.scala @@ -36,7 +36,8 @@ class ProcedureCallPipeTest val ID = 42 val procedureName = QualifiedName(List.empty, "foo") - val signature = ProcedureSignature(procedureName, ID, IndexedSeq.empty, Some(IndexedSeq(FieldSignature("foo", CTAny))), None, ProcedureReadOnlyAccess(Array.empty)) + val signature = ProcedureSignature(procedureName, IndexedSeq.empty, Some(IndexedSeq(FieldSignature("foo", CTAny))), + None, ProcedureReadOnlyAccess(Array.empty), id = Some(ID)) val emptyStringArray = Array.empty[String] test("should execute read-only procedure calls") { diff --git a/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/ProcedureCallMode.scala b/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/ProcedureCallMode.scala index ee7a982db3bb2..c2adb905f3eb1 100644 --- a/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/ProcedureCallMode.scala +++ b/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/ProcedureCallMode.scala @@ -35,6 +35,7 @@ sealed trait ProcedureCallMode { val queryType: InternalQueryType def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] + def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]] val allowed: Array[String] } @@ -44,40 +45,64 @@ case class LazyReadOnlyCallMode(allowed: Array[String]) extends ProcedureCallMod override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = ctx.callReadOnlyProcedure(id, args, allowed) + + override def callProcedure(ctx: QueryContext, name: QualifiedName, args: Seq[Any]): Iterator[Array[AnyRef]] = + ctx.callReadOnlyProcedure(name, args, allowed) } case class EagerReadWriteCallMode(allowed: Array[String]) extends ProcedureCallMode { override val queryType: InternalQueryType = READ_WRITE - override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = { + private def call(iterator: Iterator[Array[AnyRef]]) = { val builder = ArrayBuffer.newBuilder[Array[AnyRef]] - val iterator = ctx.callReadWriteProcedure(id, args, allowed) while (iterator.hasNext) { builder += iterator.next() } builder.result().iterator } + + override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = call(ctx.callReadWriteProcedure(id, args, allowed)) + + override def callProcedure(ctx: QueryContext, + name: QualifiedName, + args: Seq[Any]): Iterator[Array[AnyRef]] = call(ctx.callReadWriteProcedure(name, args, allowed)) } case class SchemaWriteCallMode(allowed: Array[String]) extends ProcedureCallMode { override val queryType: InternalQueryType = SCHEMA_WRITE - override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = { + private def call(iterator: Iterator[Array[AnyRef]]) = { val builder = ArrayBuffer.newBuilder[Array[AnyRef]] - val iterator = ctx.callSchemaWriteProcedure(id, args, allowed) while (iterator.hasNext) { builder += iterator.next() } builder.result().iterator } + + override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = call(ctx + .callSchemaWriteProcedure( + id, args, + allowed)) + + override def callProcedure(ctx: QueryContext, + name: QualifiedName, + args: Seq[Any]): Iterator[Array[AnyRef]] = call(ctx.callSchemaWriteProcedure(name, args, allowed)) } case class DbmsCallMode(allowed: Array[String]) extends ProcedureCallMode { override val queryType: InternalQueryType = DBMS - override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = { + override def callProcedure(ctx: QueryContext, id: Int, args: Seq[Any]): Iterator[Array[AnyRef]] = + call(ctx.callDbmsProcedure(id, args, allowed)) + + override def callProcedure(ctx: QueryContext, + name: QualifiedName, + args: Seq[Any]): Iterator[Array[AnyRef]] = + call(ctx.callDbmsProcedure(name, args, allowed)) + + + private def call(iterator: Iterator[Array[AnyRef]]) = { val builder = ArrayBuffer.newBuilder[Array[AnyRef]] - val iterator = ctx.callDbmsProcedure(id, args, allowed) while (iterator.hasNext) { builder += iterator.next() } diff --git a/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/QueryContext.scala b/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/QueryContext.scala index 223e4f60f7b65..851da362bf2b1 100644 --- a/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/QueryContext.scala +++ b/community/cypher/runtime-util/src/main/scala/org/neo4j/cypher/internal/runtime/QueryContext.scala @@ -24,6 +24,7 @@ import java.net.URL import org.neo4j.collection.primitive.PrimitiveLongIterator import org.neo4j.cypher.internal.planner.v3_4.spi.{IdempotentResult, IndexDescriptor, KernelStatisticProvider, TokenContext} import org.neo4j.cypher.internal.v3_4.expressions.SemanticDirection +import org.neo4j.cypher.internal.v3_4.logical.plans.QualifiedName import org.neo4j.graphdb.{Node, Path, PropertyContainer} import org.neo4j.internal.kernel.api.helpers.RelationshipSelectionCursor import org.neo4j.internal.kernel.api.{CursorFactory, IndexReference, Read, Write, _} @@ -185,16 +186,22 @@ trait QueryContext extends TokenContext { def lockRelationships(relIds: Long*) def callReadOnlyProcedure(id: Int, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] + def callReadOnlyProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] def callReadWriteProcedure(id: Int, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] + def callReadWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] def callSchemaWriteProcedure(id: Int, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] + def callSchemaWriteProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] def callDbmsProcedure(id: Int, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] + def callDbmsProcedure(name: QualifiedName, args: Seq[Any], allowed: Array[String]): Iterator[Array[AnyRef]] def callFunction(id: Int, args: Seq[AnyValue], allowed: Array[String]): AnyValue + def callFunction(name: QualifiedName, args: Seq[AnyValue], allowed: Array[String]): AnyValue def aggregateFunction(id: Int, allowed: Array[String]): UserDefinedAggregator + def aggregateFunction(name: QualifiedName, allowed: Array[String]): UserDefinedAggregator // Check if a runtime value is a node, relationship, path or some such value returned from // other query context values by calling down to the underlying database diff --git a/community/kernel-api/src/main/java/org/neo4j/internal/kernel/api/Procedures.java b/community/kernel-api/src/main/java/org/neo4j/internal/kernel/api/Procedures.java index 15c12c48a7420..c61d3d69e95c4 100644 --- a/community/kernel-api/src/main/java/org/neo4j/internal/kernel/api/Procedures.java +++ b/community/kernel-api/src/main/java/org/neo4j/internal/kernel/api/Procedures.java @@ -53,7 +53,7 @@ public interface Procedures ProcedureHandle procedureGet( QualifiedName name ) throws ProcedureException; /** - * Invoke a read-only procedure by name. + * Invoke a read-only procedure by id. * @param id the id of the procedure. * @param arguments the procedure arguments. * @return an iterator containing the procedure results. @@ -113,6 +113,67 @@ RawIterator procedureCallSchema( int id, Object[] RawIterator procedureCallSchemaOverride( int id, Object[] arguments ) throws ProcedureException; + /** + * Invoke a read-only procedure by name. + * @param name the name of the procedure. + * @param arguments the procedure arguments. + * @return an iterator containing the procedure results. + * @throws ProcedureException if there was an exception thrown during procedure execution. + */ + RawIterator procedureCallRead( QualifiedName name, Object[] arguments ) + throws ProcedureException; + + /** + * Invoke a read-only procedure by name, and set the transaction's access mode to + * {@link AccessMode.Static#READ READ} for the duration of the procedure execution. + * @param name the name of the procedure. + * @param arguments the procedure arguments. + * @return an iterator containing the procedure results. + * @throws ProcedureException if there was an exception thrown during procedure execution. + */ + RawIterator procedureCallReadOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException; + + /** + * Invoke a read/write procedure by name. + * @param name the name of the procedure. + * @param arguments the procedure arguments. + * @return an iterator containing the procedure results. + * @throws ProcedureException if there was an exception thrown during procedure execution. + */ + RawIterator procedureCallWrite( QualifiedName name, Object[] arguments ) + throws ProcedureException; + /** + * Invoke a read/write procedure by name, and set the transaction's access mode to + * {@link AccessMode.Static#WRITE WRITE} for the duration of the procedure execution. + * @param name the name of the procedure. + * @param arguments the procedure arguments. + * @return an iterator containing the procedure results. + * @throws ProcedureException if there was an exception thrown during procedure execution. + */ + RawIterator procedureCallWriteOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException; + + /** + * Invoke a schema write procedure by name. + * @param name the name of the procedure. + * @param arguments the procedure arguments. + * @return an iterator containing the procedure results. + * @throws ProcedureException if there was an exception thrown during procedure execution. + */ + RawIterator procedureCallSchema( QualifiedName name, Object[] arguments ) + throws ProcedureException; + /** + * Invoke a schema write procedure by name, and set the transaction's access mode to + * {@link AccessMode.Static#FULL FULL} for the duration of the procedure execution. + * @param name the name of the procedure. + * @param arguments the procedure arguments. + * @return an iterator containing the procedure results. + * @throws ProcedureException if there was an exception thrown during procedure execution. + */ + RawIterator procedureCallSchemaOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException; + /** Invoke a read-only function by id * @param id the id of the function. * @param arguments the function arguments. @@ -120,6 +181,13 @@ RawIterator procedureCallSchemaOverride( int id, O */ AnyValue functionCall( int id, AnyValue[] arguments ) throws ProcedureException; + /** Invoke a read-only function by name + * @param name the name of the function. + * @param arguments the function arguments. + * @throws ProcedureException if there was an exception thrown during function execution. + */ + AnyValue functionCall( QualifiedName name, AnyValue[] arguments ) throws ProcedureException; + /** Invoke a read-only function by id, and set the transaction's access mode to * {@link AccessMode.Static#READ READ} for the duration of the function execution. * @param id the id of the function. @@ -128,6 +196,14 @@ RawIterator procedureCallSchemaOverride( int id, O */ AnyValue functionCallOverride( int id, AnyValue[] arguments ) throws ProcedureException; + /** Invoke a read-only function by name, and set the transaction's access mode to + * {@link AccessMode.Static#READ READ} for the duration of the function execution. + * @param name the name of the function. + * @param arguments the function arguments. + * @throws ProcedureException if there was an exception thrown during function execution. + */ + AnyValue functionCallOverride( QualifiedName name, AnyValue[] arguments ) throws ProcedureException; + /** * Create a read-only aggregation function by id * @param id the id of the function @@ -136,6 +212,14 @@ RawIterator procedureCallSchemaOverride( int id, O */ UserAggregator aggregationFunction( int id ) throws ProcedureException; + /** + * Create a read-only aggregation function by name + * @param name the name of the function + * @return the aggregation function + * @throws ProcedureException + */ + UserAggregator aggregationFunction( QualifiedName name ) throws ProcedureException; + /** Invoke a read-only aggregation function by id, and set the transaction's access mode to * {@link AccessMode.Static#READ READ} for the duration of the function execution. * @param id the id of the function. @@ -143,4 +227,10 @@ RawIterator procedureCallSchemaOverride( int id, O */ UserAggregator aggregationFunctionOverride( int id ) throws ProcedureException; + /** Invoke a read-only aggregation function by name, and set the transaction's access mode to + * {@link AccessMode.Static#READ READ} for the duration of the function execution. + * @param name the name of the function. + * @throws ProcedureException if there was an exception thrown during function execution. + */ + UserAggregator aggregationFunctionOverride( QualifiedName name ) throws ProcedureException; } diff --git a/community/kernel/src/main/java/org/neo4j/kernel/impl/newapi/AllStoreHolder.java b/community/kernel/src/main/java/org/neo4j/kernel/impl/newapi/AllStoreHolder.java index 90436e183b213..fd164555764f9 100644 --- a/community/kernel/src/main/java/org/neo4j/kernel/impl/newapi/AllStoreHolder.java +++ b/community/kernel/src/main/java/org/neo4j/kernel/impl/newapi/AllStoreHolder.java @@ -619,6 +619,71 @@ public RawIterator procedureCallSchemaOverride( int new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.FULL ) ); } + @Override + public RawIterator procedureCallRead( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + AccessMode accessMode = ktx.securityContext().mode(); + if ( !accessMode.allowsReads() ) + { + throw accessMode.onViolation( format( "Read operations are not allowed for %s.", + ktx.securityContext().description() ) ); + } + return callProcedure( name, arguments, new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static + .READ ) ); + } + + @Override + public RawIterator procedureCallReadOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + return callProcedure( name, arguments, + new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); + } + + @Override + public RawIterator procedureCallWrite( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + AccessMode accessMode = ktx.securityContext().mode(); + if ( !accessMode.allowsWrites() ) + { + throw accessMode.onViolation( format( "Write operations are not allowed for %s.", + ktx.securityContext().description() ) ); + } + return callProcedure( name, arguments, new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static.TOKEN_WRITE ) ); + } + + @Override + public RawIterator procedureCallWriteOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + return callProcedure( name, arguments, new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.TOKEN_WRITE ) ); + + } + + @Override + public RawIterator procedureCallSchema( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + AccessMode accessMode = ktx.securityContext().mode(); + if ( !accessMode.allowsSchemaWrites() ) + { + throw accessMode.onViolation( format( "Schema operations are not allowed for %s.", + ktx.securityContext().description() ) ); + } + return callProcedure( name, arguments, + new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static.FULL ) ); + } + + @Override + public RawIterator procedureCallSchemaOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + return callProcedure( name, arguments, + new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.FULL ) ); + } + @Override public AnyValue functionCall( int id, AnyValue[] arguments ) throws ProcedureException { @@ -631,12 +696,30 @@ public AnyValue functionCall( int id, AnyValue[] arguments ) throws ProcedureExc new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); } + @Override + public AnyValue functionCall( QualifiedName name, AnyValue[] arguments ) throws ProcedureException + { + if ( !ktx.securityContext().mode().allowsReads() ) + { + throw ktx.securityContext().mode().onViolation( + format( "Read operations are not allowed for %s.", ktx.securityContext().description() ) ); + } + return callFunction( name, arguments, + new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); + } + @Override public AnyValue functionCallOverride( int id, AnyValue[] arguments ) throws ProcedureException { return callFunction( id, arguments, new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); + } + @Override + public AnyValue functionCallOverride( QualifiedName name, AnyValue[] arguments ) throws ProcedureException + { + return callFunction( name, arguments, + new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); } @Override @@ -650,12 +733,31 @@ public UserAggregator aggregationFunction( int id ) throws ProcedureException return aggregationFunction( id, new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); } + @Override + public UserAggregator aggregationFunction( QualifiedName name ) throws ProcedureException + { + if ( !ktx.securityContext().mode().allowsReads() ) + { + throw ktx.securityContext().mode().onViolation( + format( "Read operations are not allowed for %s.", ktx.securityContext().description() ) ); + } + return aggregationFunction( name, new RestrictedAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); + } + @Override public UserAggregator aggregationFunctionOverride( int id ) throws ProcedureException { return aggregationFunction( id, new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); } + + @Override + public UserAggregator aggregationFunctionOverride( QualifiedName name ) throws ProcedureException + { + return aggregationFunction( name, + new OverriddenAccessMode( ktx.securityContext().mode(), AccessMode.Static.READ ) ); + } + private RawIterator callProcedure( int id, Object[] input, final AccessMode override ) throws ProcedureException @@ -667,12 +769,30 @@ private RawIterator callProcedure( try ( KernelTransaction.Revertable ignore = ktx.overrideWith( procedureSecurityContext ); Statement statement = ktx.acquireStatement() ) { - BasicContext ctx = new BasicContext(); - ctx.put( Context.KERNEL_TRANSACTION, ktx ); - ctx.put( Context.THREAD, Thread.currentThread() ); - ctx.put( Context.SECURITY_CONTEXT, procedureSecurityContext ); - procedureCall = procedures.callProcedure( ctx, id, input, statement ); + procedureCall = procedures.callProcedure( populateProcedureContext( procedureSecurityContext ), id, input, statement ); + } + return createIterator( procedureSecurityContext, procedureCall ); + } + + private RawIterator callProcedure( + QualifiedName name, Object[] input, final AccessMode override ) + throws ProcedureException + { + ktx.assertOpen(); + + final SecurityContext procedureSecurityContext = ktx.securityContext().withMode( override ); + final RawIterator procedureCall; + try ( KernelTransaction.Revertable ignore = ktx.overrideWith( procedureSecurityContext ); + Statement statement = ktx.acquireStatement() ) + { + procedureCall = procedures.callProcedure( populateProcedureContext( procedureSecurityContext ), name, input, statement ); } + return createIterator( procedureSecurityContext, procedureCall ); + } + + private RawIterator createIterator( SecurityContext procedureSecurityContext, + RawIterator procedureCall ) + { return new RawIterator() { @Override @@ -694,20 +814,24 @@ public Object[] next() throws ProcedureException } }; } + private AnyValue callFunction( int id, AnyValue[] input, final AccessMode mode ) throws ProcedureException { ktx.assertOpen(); try ( KernelTransaction.Revertable ignore = ktx.overrideWith( ktx.securityContext().withMode( mode ) ) ) { - BasicContext ctx = new BasicContext(); - ctx.put( Context.KERNEL_TRANSACTION, ktx ); - ctx.put( Context.THREAD, Thread.currentThread() ); - ClockContext clocks = ktx.clocks(); - ctx.put( Context.SYSTEM_CLOCK, clocks.systemClock() ); - ctx.put( Context.STATEMENT_CLOCK, clocks.statementClock() ); - ctx.put( Context.TRANSACTION_CLOCK, clocks.transactionClock() ); - return procedures.callFunction( ctx, id, input ); + return procedures.callFunction( populateFunctionContext(), id, input ); + } + } + + private AnyValue callFunction( QualifiedName name, AnyValue[] input, final AccessMode mode ) throws ProcedureException + { + ktx.assertOpen(); + + try ( KernelTransaction.Revertable ignore = ktx.overrideWith( ktx.securityContext().withMode( mode ) ) ) + { + return procedures.callFunction( populateFunctionContext(), name, input ); } } @@ -718,10 +842,47 @@ private UserAggregator aggregationFunction( int id, final AccessMode mode ) try ( KernelTransaction.Revertable ignore = ktx.overrideWith( ktx.securityContext().withMode( mode ) ) ) { - BasicContext ctx = new BasicContext(); - ctx.put( Context.KERNEL_TRANSACTION, ktx ); - ctx.put( Context.THREAD, Thread.currentThread() ); - return procedures.createAggregationFunction( ctx, id ); + return procedures.createAggregationFunction( populateAggregationContext(), id ); } } + + private UserAggregator aggregationFunction( QualifiedName name, final AccessMode mode ) + throws ProcedureException + { + ktx.assertOpen(); + + try ( KernelTransaction.Revertable ignore = ktx.overrideWith( ktx.securityContext().withMode( mode ) ) ) + { + return procedures.createAggregationFunction( populateAggregationContext(), name ); + } + } + + private BasicContext populateFunctionContext() + { + BasicContext ctx = new BasicContext(); + ctx.put( Context.KERNEL_TRANSACTION, ktx ); + ctx.put( Context.THREAD, Thread.currentThread() ); + ClockContext clocks = ktx.clocks(); + ctx.put( Context.SYSTEM_CLOCK, clocks.systemClock() ); + ctx.put( Context.STATEMENT_CLOCK, clocks.statementClock() ); + ctx.put( Context.TRANSACTION_CLOCK, clocks.transactionClock() ); + return ctx; + } + + private BasicContext populateAggregationContext() + { + BasicContext ctx = new BasicContext(); + ctx.put( Context.KERNEL_TRANSACTION, ktx ); + ctx.put( Context.THREAD, Thread.currentThread() ); + return ctx; + } + + private BasicContext populateProcedureContext( SecurityContext procedureSecurityContext ) + { + BasicContext ctx = new BasicContext(); + ctx.put( Context.KERNEL_TRANSACTION, ktx ); + ctx.put( Context.THREAD, Thread.currentThread() ); + ctx.put( Context.SECURITY_CONTEXT, procedureSecurityContext ); + return ctx; + } } diff --git a/community/kernel/src/test/java/org/neo4j/kernel/impl/newapi/MockStore.java b/community/kernel/src/test/java/org/neo4j/kernel/impl/newapi/MockStore.java index 2a2eccf51ddfb..f8b6106ebd9b7 100644 --- a/community/kernel/src/test/java/org/neo4j/kernel/impl/newapi/MockStore.java +++ b/community/kernel/src/test/java/org/neo4j/kernel/impl/newapi/MockStore.java @@ -296,29 +296,95 @@ public RawIterator procedureCallSchemaOverride( int } @Override - public AnyValue functionCall( int id, AnyValue[] arguments ) throws ProcedureException + public RawIterator procedureCallRead( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + throw new UnsupportedOperationException( "not implemented" ); + } + + @Override + public RawIterator procedureCallReadOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + throw new UnsupportedOperationException( "not implemented" ); + } + + @Override + public RawIterator procedureCallWrite( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + throw new UnsupportedOperationException( "not implemented" ); + } + + @Override + public RawIterator procedureCallWriteOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + throw new UnsupportedOperationException( "not implemented" ); + } + + @Override + public RawIterator procedureCallSchema( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + throw new UnsupportedOperationException( "not implemented" ); + } + + @Override + public RawIterator procedureCallSchemaOverride( QualifiedName name, Object[] arguments ) + throws ProcedureException + { + throw new UnsupportedOperationException( "not implemented" ); + } + + @Override + public AnyValue functionCall( QualifiedName name, AnyValue[] arguments ) throws ProcedureException { throw new UnsupportedOperationException(); } @Override - public AnyValue functionCallOverride( int id, AnyValue[] arguments ) throws ProcedureException + public AnyValue functionCallOverride( QualifiedName name, AnyValue[] arguments ) throws ProcedureException { throw new UnsupportedOperationException(); } @Override - public UserAggregator aggregationFunction( int id ) throws ProcedureException + public UserAggregator aggregationFunction( QualifiedName name ) throws ProcedureException { throw new UnsupportedOperationException(); } @Override - public UserAggregator aggregationFunctionOverride( int id ) throws ProcedureException + public UserAggregator aggregationFunctionOverride( QualifiedName name ) throws ProcedureException { throw new UnsupportedOperationException(); } + @Override + public AnyValue functionCall( int id, AnyValue[] arguments ) throws ProcedureException + { + throw new UnsupportedOperationException(); + } + + @Override + public AnyValue functionCallOverride( int id, AnyValue[] arguments ) throws ProcedureException + { + throw new UnsupportedOperationException(); + } + + @Override + public UserAggregator aggregationFunction( int id ) throws ProcedureException + { + throw new UnsupportedOperationException(); + } + + @Override + public UserAggregator aggregationFunctionOverride( int id ) throws ProcedureException + { + throw new UnsupportedOperationException(); + } + private abstract static class Record { abstract void initialize( R record );