diff --git a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/ResolvedCall.scala b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/ResolvedCall.scala index 05ebe04162a3..4e9ec7bd993c 100644 --- a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/ResolvedCall.scala +++ b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/ResolvedCall.scala @@ -31,14 +31,12 @@ object ResolvedCall { val UnresolvedCall(_, _, declaredArguments, declaredResults) = unresolved val position = unresolved.position val signature = signatureLookup(QualifiedProcedureName(unresolved)) - val callArguments = declaredArguments.getOrElse(signatureArguments(signature, position)) + val nonDefaults = signature.inputSignature.flatMap(s => if (s.default.isDefined) None else Some(Parameter(s.name, CTAny)(position))) + val callArguments = declaredArguments.getOrElse(nonDefaults) val callResults = declaredResults.getOrElse(signatureResults(signature, position)) ResolvedCall(signature, callArguments, callResults, declaredArguments.nonEmpty, declaredResults.nonEmpty)(position) } - private def signatureArguments(signature: ProcedureSignature, position: InputPosition): Seq[Parameter] = - signature.inputSignature.map { field => Parameter(field.name, CTAny)(position) } - private def signatureResults(signature: ProcedureSignature, position: InputPosition): Seq[ProcedureResultItem] = signature.outputSignature.getOrElse(Seq.empty).map { field => ProcedureResultItem(Variable(field.name)(position))(position) } } @@ -90,10 +88,13 @@ case class ResolvedCall(signature: ProcedureSignature, private def argumentCheck: SemanticCheck = { val expectedNumArgs = signature.inputSignature.length - val actualNumArgs = callArguments.length + val defaultArgs = signature.inputSignature.flatMap(_.default).drop(callArguments.length) + val actualNumArgs = callArguments.length + defaultArgs.length if (declaredArguments) { if (expectedNumArgs == actualNumArgs) { + //this zip is fine since it will only verify provided args in callArguments + //default values are checked at load time signature.inputSignature.zip(callArguments).map { case (field, arg) => arg.semanticCheck(SemanticContext.Results) chain arg.expectType(field.typ.covariant) diff --git a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlan.scala b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlan.scala index dcd57c62bc4f..8fb34a0d3833 100644 --- a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlan.scala +++ b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlan.scala @@ -20,6 +20,8 @@ package org.neo4j.cypher.internal.compiler.v3_1.executionplan.procs import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.ExpressionConverters._ +import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions +import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.Literal import org.neo4j.cypher.internal.compiler.v3_1.executionplan.{ExecutionPlan, InternalExecutionResult, ProcedureCallMode, READ_ONLY} import org.neo4j.cypher.internal.compiler.v3_1.helpers.{Counter, RuntimeJavaValueConverter} import org.neo4j.cypher.internal.compiler.v3_1.pipes.{ExternalCSVResource, QueryState} @@ -47,7 +49,8 @@ case class ProcedureCallExecutionPlan(signature: ProcedureSignature, publicTypeConverter: Any => Any) extends ExecutionPlan { - private val argExprCommands = argExprs.map(toCommandExpression) + private val argExprCommands: Seq[expressions.Expression] = argExprs.map(toCommandExpression) ++ + signature.inputSignature.drop(argExprs.size).flatMap(_.default).map(Literal(_)) override def run(ctx: QueryContext, planType: ExecutionMode, params: Map[String, Any]): InternalExecutionResult = { val input = evaluateArguments(ctx, params) diff --git a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/planner/execution/PipeExecutionPlanBuilder.scala b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/planner/execution/PipeExecutionPlanBuilder.scala index 8ad785dd0fe4..84ada8172507 100644 --- a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/planner/execution/PipeExecutionPlanBuilder.scala +++ b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/planner/execution/PipeExecutionPlanBuilder.scala @@ -27,7 +27,7 @@ import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.PatternConve import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.StatementConverters import org.neo4j.cypher.internal.compiler.v3_1.ast.rewriters.projectNamedPaths import org.neo4j.cypher.internal.compiler.v3_1.commands.EntityProducerFactory -import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.{AggregationExpression, Expression => CommandExpression} +import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.{AggregationExpression, Expression => CommandExpression, Literal} import org.neo4j.cypher.internal.compiler.v3_1.commands.predicates.{True, _} import org.neo4j.cypher.internal.compiler.v3_1.executionplan._ import org.neo4j.cypher.internal.compiler.v3_1.executionplan.builders.prepare.KeyTokenResolver @@ -314,7 +314,9 @@ case class ActualPipeBuilder(monitors: Monitors, recurse: LogicalPlan => Pipe, r case ProcedureCall(_, call@ResolvedCall(signature, callArguments, callResults, _, _)) => val callMode = ProcedureCallMode.fromAccessMode(signature.accessMode) - val callArgumentCommands = callArguments.map(toCommandExpression) + val callArgumentCommands = callArguments.map(Some(_)).zipAll(signature.inputSignature.map(_.default), None, None).map { + case (given, default) => given.map(toCommandExpression).getOrElse(Literal(default.get)) + } val rowProcessing = ProcedureCallRowProcessing(signature) ProcedureCallPipe(source, signature.name, callMode, callArgumentCommands, rowProcessing, call.callResultTypes, call.callResultIndices)() diff --git a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/spi/ProcedureSignature.scala b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/spi/ProcedureSignature.scala index 62d0e51be36a..b759249434f4 100644 --- a/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/spi/ProcedureSignature.scala +++ b/community/cypher/cypher-compiler-3.1/src/main/scala/org/neo4j/cypher/internal/compiler/v3_1/spi/ProcedureSignature.scala @@ -23,8 +23,8 @@ import org.neo4j.cypher.internal.frontend.v3_1.ast.UnresolvedCall import org.neo4j.cypher.internal.frontend.v3_1.symbols.CypherType case class ProcedureSignature(name: QualifiedProcedureName, - inputSignature: Seq[FieldSignature], - outputSignature: Option[Seq[FieldSignature]], + inputSignature: IndexedSeq[FieldSignature], + outputSignature: Option[IndexedSeq[FieldSignature]], accessMode: ProcedureAccessMode = ProcedureReadOnlyAccess) { def outputFields = outputSignature.getOrElse(Seq.empty) @@ -41,7 +41,7 @@ case class QualifiedProcedureName(namespace: Seq[String], name: String) { override def toString = s"""${namespace.mkString(".")}.$name""" } -case class FieldSignature(name: String, typ: CypherType) +case class FieldSignature(name: String, typ: CypherType, default: Option[AnyRef] = None) sealed trait ProcedureAccessMode diff --git a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/RewriteProcedureCallsTest.scala b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/RewriteProcedureCallsTest.scala index 9928e061286b..ddb904a91039 100644 --- a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/RewriteProcedureCallsTest.scala +++ b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/RewriteProcedureCallsTest.scala @@ -30,8 +30,8 @@ class RewriteProcedureCallsTest extends CypherFunSuite with AstConstructionTestS val ns = ProcedureNamespace(List("my", "proc"))(pos) val name = ProcedureName("foo")(pos) val qualifiedName = QualifiedProcedureName(ns.parts, name.name) - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val lookup: (QualifiedProcedureName) => ProcedureSignature = _ => signature diff --git a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/CallClauseTest.scala b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/CallClauseTest.scala index 570901944fe4..f429d3922578 100644 --- a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/CallClauseTest.scala +++ b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/CallClauseTest.scala @@ -33,8 +33,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should resolve CALL my.proc.foo") { val unresolved = UnresolvedCall(ns, name, None, None)(pos) - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) @@ -58,7 +58,7 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { test("should resolve void CALL my.proc.foo") { val unresolved = UnresolvedCall(ns, name, None, None)(pos) - val signatureInputs = Seq(FieldSignature("a", CTInteger)) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = None val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) @@ -82,8 +82,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should resolve CALL my.proc.foo YIELD x, y") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) @@ -107,8 +107,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should resolve CALL my.proc.foo(a)") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos)) @@ -132,7 +132,7 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should resolve void CALL my.proc.foo(a)") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) val signatureOutputs = None val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) @@ -157,8 +157,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should resolve CALL my.proc.foo(a) YIELD x, y AS z") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq( @@ -183,7 +183,7 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("pretends to be based on user-declared arguments and results upon request") { - val signature = ProcedureSignature(qualifiedName, Seq.empty, Some(Seq.empty), ProcedureReadOnlyAccess) + val signature = ProcedureSignature(qualifiedName, IndexedSeq.empty, Some(IndexedSeq.empty), ProcedureReadOnlyAccess) val call = ResolvedCall(signature, null, null, declaredArguments = false, declaredResults = false)(pos) call.withFakedFullDeclarations.declaredArguments should be(true) @@ -191,8 +191,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("adds coercion of arguments to signature types upon request") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq( @@ -218,8 +218,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should verify number of arguments during semantic checking of resolved calls") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq.empty val callResults = Seq( @@ -235,8 +235,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should verify that result variables are unique during semantic checking of resolved calls") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq( @@ -252,8 +252,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should verify that output field names are correct during semantic checking of resolved calls") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(Parameter("a", CTAny)(pos)) val callResults = Seq( @@ -269,8 +269,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport { } test("should verify result types during semantic checking of resolved calls") { - val signatureInputs = Seq(FieldSignature("a", CTInteger)) - val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) + val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger)) + val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode)))) val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess) val callArguments = Seq(StringLiteral("nope")(pos)) val callResults = Seq( diff --git a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/convert/plannerQuery/StatementConvertersTest.scala b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/convert/plannerQuery/StatementConvertersTest.scala index e7491f7bbc95..a6270f4ce37b 100644 --- a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/convert/plannerQuery/StatementConvertersTest.scala +++ b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/ast/convert/plannerQuery/StatementConvertersTest.scala @@ -879,8 +879,8 @@ class StatementConvertersTest extends CypherFunSuite with LogicalPlanningTestSup test("CALL foo() YIELD all RETURN all") { val signature = ProcedureSignature( QualifiedProcedureName(Seq.empty, "foo"), - inputSignature = Seq.empty, - outputSignature = Some(Seq(FieldSignature("all", CTInteger))) + inputSignature = IndexedSeq.empty, + outputSignature = Some(IndexedSeq(FieldSignature("all", CTInteger))) ) val query = buildPlannerQuery("CALL foo() YIELD all RETURN all", Some(_ => signature)) diff --git a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlanTest.scala b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlanTest.scala index 1da72c1b8c23..5eb0c52a6e2b 100644 --- a/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlanTest.scala +++ b/community/cypher/cypher-compiler-3.1/src/test/scala/org/neo4j/cypher/internal/compiler/v3_1/executionplan/procs/ProcedureCallExecutionPlanTest.scala @@ -82,16 +82,16 @@ class ProcedureCallExecutionPlanTest extends CypherFunSuite { def string(s: String): Expression = StringLiteral(s)(pos) private val readSignature = ProcedureSignature( - QualifiedProcedureName(Seq.empty, "foo"), - Seq(FieldSignature("a", symbols.CTInteger)), - Some(Seq(FieldSignature("b", symbols.CTInteger))), + QualifiedProcedureName(IndexedSeq.empty, "foo"), + IndexedSeq(FieldSignature("a", symbols.CTInteger)), + Some(IndexedSeq(FieldSignature("b", symbols.CTInteger))), ProcedureReadOnlyAccess ) private val writeSignature = ProcedureSignature( QualifiedProcedureName(Seq.empty, "foo"), - Seq(FieldSignature("a", symbols.CTInteger)), - Some(Seq(FieldSignature("b", symbols.CTInteger))), + IndexedSeq(FieldSignature("a", symbols.CTInteger)), + Some(IndexedSeq(FieldSignature("b", symbols.CTInteger))), ProcedureReadWriteAccess ) diff --git a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/TransactionBoundPlanContext.scala b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/TransactionBoundPlanContext.scala index 03d5750f3c4d..fea3615616cb 100644 --- a/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/TransactionBoundPlanContext.scala +++ b/community/cypher/cypher/src/main/scala/org/neo4j/cypher/internal/spi/v3_1/TransactionBoundPlanContext.scala @@ -19,6 +19,8 @@ */ package org.neo4j.cypher.internal.spi.v3_1 +import java.util.Optional + import org.neo4j.cypher.MissingIndexException import org.neo4j.cypher.internal.LastCommittedTxIdProvider import org.neo4j.cypher.internal.compiler.v3_1.pipes.EntityProducer @@ -130,13 +132,15 @@ class TransactionBoundPlanContext(tc: TransactionalContextWrapperv3_1) override def procedureSignature(name: QualifiedProcedureName) = { val kn = new KernelProcedureSignature.ProcedureName(name.namespace.asJava, name.name) val ks = tc.statement.readOperations().procedureGet(kn) - val input = ks.inputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType()))) - val output = if (ks.isVoid) None else Some(ks.outputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType())))) + val input = ks.inputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType()), asOption(s.defaultValue()))).toIndexedSeq + val output = if (ks.isVoid) None else Some(ks.outputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType()))).toIndexedSeq) val mode = asCypherProcMode(ks.mode()) ProcedureSignature(name, input, output, mode) } + private def asOption[T](optional: Optional[T]): Option[T] = if (optional.isPresent) Some(optional.get()) else None + private def asCypherProcMode(mode: KernelProcedureSignature.Mode): ProcedureAccessMode = mode match { case KernelProcedureSignature.Mode.READ_ONLY => ProcedureReadOnlyAccess case KernelProcedureSignature.Mode.READ_WRITE => ProcedureReadWriteAccess diff --git a/community/kernel/src/main/java/org/neo4j/kernel/api/proc/ProcedureSignature.java b/community/kernel/src/main/java/org/neo4j/kernel/api/proc/ProcedureSignature.java index 0d4f9bf21afd..c13bcb17feed 100644 --- a/community/kernel/src/main/java/org/neo4j/kernel/api/proc/ProcedureSignature.java +++ b/community/kernel/src/main/java/org/neo4j/kernel/api/proc/ProcedureSignature.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; +import java.util.Optional; import org.neo4j.helpers.collection.Iterables; import org.neo4j.kernel.api.proc.Neo4jTypes.AnyType; @@ -96,11 +97,18 @@ public static class FieldSignature { private final String name; private final AnyType type; + private final Optional defaultValue; - public FieldSignature( String name, AnyType type ) + public FieldSignature( String name, AnyType type) + { + this(name, type, Optional.empty()); + } + + public FieldSignature( String name, AnyType type, Optional defaultValue ) { this.name = name; this.type = type; + this.defaultValue = defaultValue; } public String name() @@ -113,6 +121,11 @@ public AnyType neo4jType() return type; } + public Optional defaultValue() + { + return defaultValue; + } + @Override public String toString() { @@ -253,7 +266,7 @@ public Builder mode( Mode mode ) /** Define an input field */ public Builder in( String name, AnyType type ) { - inputSignature.add( new FieldSignature( name, type ) ); + inputSignature.add( new FieldSignature( name, type) ); return this; } diff --git a/community/kernel/src/main/java/org/neo4j/kernel/impl/factory/DataSourceModule.java b/community/kernel/src/main/java/org/neo4j/kernel/impl/factory/DataSourceModule.java index 07c754c4d32d..6ee450e794c8 100644 --- a/community/kernel/src/main/java/org/neo4j/kernel/impl/factory/DataSourceModule.java +++ b/community/kernel/src/main/java/org/neo4j/kernel/impl/factory/DataSourceModule.java @@ -354,9 +354,9 @@ private Procedures setupProcedures( PlatformModule platform, EditionModule editi platform.life.add( procedures ); platform.dependencies.satisfyDependency( procedures ); - procedures.registerType( Node.class, new SimpleConverter( NTNode, Node.class ) ); - procedures.registerType( Relationship.class, new SimpleConverter( NTRelationship, Relationship.class ) ); - procedures.registerType( Path.class, new SimpleConverter( NTPath, Path.class ) ); + procedures.registerType( Node.class, new SimpleConverter( NTNode, Node.class) ); + procedures.registerType( Relationship.class, new SimpleConverter( NTRelationship, Relationship.class) ); + procedures.registerType( Path.class, new SimpleConverter( NTPath, Path.class) ); // Register injected public API components Log proceduresLog = platform.logging.getUserLog( Procedures.class ); diff --git a/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/MethodSignatureCompiler.java b/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/MethodSignatureCompiler.java index 883f7eb80580..d9f07884fe60 100644 --- a/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/MethodSignatureCompiler.java +++ b/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/MethodSignatureCompiler.java @@ -24,10 +24,12 @@ import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import org.neo4j.kernel.api.exceptions.ProcedureException; import org.neo4j.kernel.api.exceptions.Status; import org.neo4j.kernel.api.proc.ProcedureSignature.FieldSignature; +import org.neo4j.kernel.impl.proc.TypeMappers.NeoValueConverter; import org.neo4j.procedure.Name; /** @@ -48,6 +50,7 @@ public List signatureFor( Method method ) throws ProcedureExcept Parameter[] params = method.getParameters(); Type[] types = method.getGenericParameterTypes(); List signature = new ArrayList<>(params.length); + boolean seenDefault = false; for ( int i = 0; i < params.length; i++ ) { Parameter param = params[i]; @@ -60,7 +63,8 @@ public List signatureFor( Method method ) throws ProcedureExcept "Please add the annotation, recompile the class and try again.", i, method.getName(), Name.class.getSimpleName() ); } - String name = param.getAnnotation( Name.class ).value(); + Name parameter = param.getAnnotation( Name.class ); + String name = parameter.value(); if( name.trim().length() == 0 ) { @@ -72,7 +76,20 @@ public List signatureFor( Method method ) throws ProcedureExcept try { - signature.add(new FieldSignature( name, typeMappers.neoTypeFor( type ) )); + NeoValueConverter valueConverter = typeMappers.converterFor( type ); + Optional defaultValue = valueConverter.defaultValue( parameter ); + //it is not allowed to have holes in default values + if (seenDefault && !defaultValue.isPresent()) + { + throw new ProcedureException( Status.Procedure.ProcedureRegistrationFailed, + "Non-default argument at position %d with name %s in method %s follows default argument. " + + "Add a default value or rearrange arguments so that the non-default values comes first.", + i, parameter.value(), method.getName() ); + } + + seenDefault = defaultValue.isPresent(); + signature.add( new FieldSignature( name, valueConverter.type(), + defaultValue ) ); } catch ( ProcedureException e ) { diff --git a/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/TypeMappers.java b/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/TypeMappers.java index aa9b48ff107d..680cb7574899 100644 --- a/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/TypeMappers.java +++ b/community/kernel/src/main/java/org/neo4j/kernel/impl/proc/TypeMappers.java @@ -24,11 +24,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.function.Function; import org.neo4j.helpers.collection.Iterables; import org.neo4j.kernel.api.exceptions.ProcedureException; import org.neo4j.kernel.api.exceptions.Status; import org.neo4j.kernel.api.proc.Neo4jTypes.AnyType; +import org.neo4j.procedure.Name; import static org.neo4j.kernel.api.proc.Neo4jTypes.NTAny; import static org.neo4j.kernel.api.proc.Neo4jTypes.NTBoolean; @@ -50,6 +53,7 @@ interface NeoValueConverter { AnyType type(); Object toNeoValue( Object javaValue ) throws ProcedureException; + Optional defaultValue(Name parameter); } private final Map javaToNeo = new HashMap<>(); @@ -125,17 +129,19 @@ public void registerType( Class javaClass, NeoValueConverter toNeo ) } private final NeoValueConverter TO_ANY = new SimpleConverter( NTAny, Object.class ); - private final NeoValueConverter TO_STRING = new SimpleConverter( NTString, String.class ); - private final NeoValueConverter TO_INTEGER = new SimpleConverter( NTInteger, Long.class ); - private final NeoValueConverter TO_FLOAT = new SimpleConverter( NTFloat, Double.class ); - private final NeoValueConverter TO_NUMBER = new SimpleConverter( NTNumber, Number.class ); - private final NeoValueConverter TO_BOOLEAN = new SimpleConverter( NTBoolean, Boolean.class ); - private final NeoValueConverter TO_MAP = new SimpleConverter( NTMap, Map.class ); + private final NeoValueConverter TO_STRING = new SimpleConverter( NTString, String.class, s -> s); + private final NeoValueConverter TO_INTEGER = new SimpleConverter( NTInteger, Long.class, Long::parseLong ); + private final NeoValueConverter TO_FLOAT = new SimpleConverter( NTFloat, Double.class, Double::parseDouble); + private final NeoValueConverter TO_NUMBER = new SimpleConverter( NTNumber, Number.class, Boolean::parseBoolean); + private final NeoValueConverter TO_BOOLEAN = new SimpleConverter( NTBoolean, Boolean.class, Boolean::parseBoolean); + private final NeoValueConverter TO_MAP = new SimpleConverter( NTMap, Map.class); private final NeoValueConverter TO_LIST = toList( TO_ANY ); private NeoValueConverter toList( NeoValueConverter inner ) { - return new SimpleConverter( NTList( inner.type() ), List.class ); + return new SimpleConverter( NTList( inner.type() ), List.class,s -> { + throw new UnsupportedOperationException("Default values for type List is not supported" ); + } ); } private ProcedureException javaToNeoMappingError( Type cls ) @@ -153,11 +159,33 @@ public static class SimpleConverter implements NeoValueConverter { private final AnyType type; private final Class javaClass; + private final Function defaultConverter; - public SimpleConverter( AnyType type, Class javaClass ) + public SimpleConverter( AnyType type, Class javaClass) + { + this( type, javaClass, s -> { + throw new UnsupportedOperationException( String.format("Default values for type %s is not supported", javaClass.getSimpleName() )); + } ); + } + + public SimpleConverter( AnyType type, Class javaClass, Function defaultConverter ) { this.type = type; this.javaClass = javaClass; + this.defaultConverter = defaultConverter; + } + + public Optional defaultValue(Name parameter) + { + String defaultValue = parameter.defaultValue(); + if ( defaultValue.equals( Name.DEFAULT_VALUE ) ) + { + return Optional.empty(); + } + else + { + return Optional.of( defaultConverter.apply( defaultValue ) ); + } } @Override diff --git a/community/kernel/src/main/java/org/neo4j/procedure/Name.java b/community/kernel/src/main/java/org/neo4j/procedure/Name.java index 200e5e241421..7b45471d1bd8 100644 --- a/community/kernel/src/main/java/org/neo4j/procedure/Name.java +++ b/community/kernel/src/main/java/org/neo4j/procedure/Name.java @@ -37,4 +37,13 @@ * @return the name of this input argument. */ String value(); + + String defaultValue() default DEFAULT_VALUE; + + /* + * Defaults in annotation requires compile time constants, the only way + * to check if a returned defaultValue() is a default is to use a constant + * that is highly unlikely to be used in real code. + */ + String DEFAULT_VALUE = " <[6795b15e-8693-4a21-b57a-4a7b87f09a5a]> "; } diff --git a/community/kernel/src/test/java/org/neo4j/kernel/impl/proc/ReflectiveProcedureWithArgumentsTest.java b/community/kernel/src/test/java/org/neo4j/kernel/impl/proc/ReflectiveProcedureWithArgumentsTest.java index e483cf31cd1e..8d5db33ed4b5 100644 --- a/community/kernel/src/test/java/org/neo4j/kernel/impl/proc/ReflectiveProcedureWithArgumentsTest.java +++ b/community/kernel/src/test/java/org/neo4j/kernel/impl/proc/ReflectiveProcedureWithArgumentsTest.java @@ -110,6 +110,19 @@ public void shouldFailIfMissingAnnotations() throws Throwable compile( ClassWithProcedureWithoutAnnotatedArgs.class ); } + @Test + public void shouldFailIfMisplacedDefaultValue() throws Throwable + { + // Expect + exception.expect( ProcedureException.class ); + exception.expectMessage( + "Non-default argument at position 2 with name c in method defaultValues follows default argument. " + + "Add a default value or rearrange arguments so that the non-default values comes first." ); + + // When + compile( ClassWithProcedureWithMisplacedDefault.class ); + } + public static class MyOutputRecord { public String name; @@ -156,6 +169,25 @@ public Stream listCoolPeople( String name, int age ) } } + public static class ClassWithProcedureWithDefaults + { + @Procedure + public Stream defaultValues( @Name( value = "a", defaultValue = "a") String a , @Name( value = "b", defaultValue = "42") long b, @Name( value = "c", + defaultValue = "3.14") double c) + { + return Stream.empty(); + } + } + + public static class ClassWithProcedureWithMisplacedDefault + { + @Procedure + public Stream defaultValues( @Name( "a" ) String a , @Name( value = "b", defaultValue = "42") long b, @Name( "c" ) Object c) + { + return Stream.empty(); + } + } + private List compile( Class clazz ) throws KernelException { return new ReflectiveProcedureCompiler( new TypeMappers(), new ComponentRegistry() ).compile( clazz ); diff --git a/integrationtests/src/test/java/org/neo4j/procedure/ProcedureIT.java b/integrationtests/src/test/java/org/neo4j/procedure/ProcedureIT.java index 21a816515e92..eab039729e3e 100644 --- a/integrationtests/src/test/java/org/neo4j/procedure/ProcedureIT.java +++ b/integrationtests/src/test/java/org/neo4j/procedure/ProcedureIT.java @@ -104,6 +104,84 @@ public void shouldCallProcedureWithParameterMap() throws Throwable } } + @Test + public void shouldCallProcedureWithDefaultArgument() throws Throwable + { + //Given/When + Result res = db.execute( "CALL org.neo4j.procedure.simpleArgumentWithDefault" ); + + // Then + assertThat( res.next(), equalTo( map( "someVal", 42L ) ) ); + assertFalse( res.hasNext() ); + } + + @Test + public void shouldCallYieldProcedureWithDefaultArgument() throws Throwable + { + // Given/When + Result res = db.execute( + "CALL org.neo4j.procedure.simpleArgumentWithDefault() YIELD someVal as n RETURN n + 1295 as val" ); + + // Then + assertThat( res.next(), equalTo( map( "val", 1337L ) ) ); + assertFalse( res.hasNext() ); + } + + @Test + public void shouldCallProcedureWithAllDefaultArgument() throws Throwable + { + //Given/When + Result res = db.execute( "CALL org.neo4j.procedure.defaultValues" ); + + // Then + assertThat( res.next(), equalTo( map( "string", "a string", "integer", 42L, "aFloat", 3.14, "aBoolean", true ) ) ); + assertFalse( res.hasNext() ); + } + + @Test + public void shouldCallProcedureWithOneProvidedRestDefaultArgument() throws Throwable + { + //Given/When + Result res = db.execute( "CALL org.neo4j.procedure.defaultValues('another string')"); + + // Then + assertThat( res.next(), equalTo( map( "string", "another string", "integer", 42L, "aFloat", 3.14, "aBoolean", true ) ) ); + assertFalse( res.hasNext() ); + } + + @Test + public void shouldCallProcedureWithTwoProvidedRestDefaultArgument() throws Throwable + { + //Given/When + Result res = db.execute( "CALL org.neo4j.procedure.defaultValues('another string', 1337)"); + + // Then + assertThat( res.next(), equalTo( map( "string", "another string", "integer", 1337L, "aFloat", 3.14, "aBoolean", true ) ) ); + assertFalse( res.hasNext() ); + } + + @Test + public void shouldCallProcedureWithThreeProvidedRestDefaultArgument() throws Throwable + { + //Given/When + Result res = db.execute( "CALL org.neo4j.procedure.defaultValues('another string', 1337, 2.718281828)"); + + // Then + assertThat( res.next(), equalTo( map( "string", "another string", "integer", 1337L, "aFloat", 2.718281828, "aBoolean", true ) ) ); + assertFalse( res.hasNext() ); + } + + @Test + public void shouldCallProcedureWithFourProvidedRestDefaultArgument() throws Throwable + { + //Given/When + Result res = db.execute( "CALL org.neo4j.procedure.defaultValues('another string', 1337, 2.718281828, false)"); + + // Then + assertThat( res.next(), equalTo( map( "string", "another string", "integer", 1337L, "aFloat", 2.718281828, "aBoolean", false ) ) ); + assertFalse( res.hasNext() ); + } + @Test public void shouldGiveNiceErrorMessageOnWrongStaticType() throws Throwable { @@ -897,6 +975,22 @@ public Output( long someVal ) } } + public static class PrimitiveOutput + { + public String string; + public long integer; + public double aFloat; + public boolean aBoolean; + + public PrimitiveOutput( String string, long integer, double aFloat, boolean aBoolean ) + { + this.string = string; + this.integer = integer; + this.aFloat = aFloat; + this.aBoolean = aBoolean; + } + } + public static class DoubleOutput { public double result = 0.0d; @@ -976,6 +1070,23 @@ public Stream simpleArgument( @Name( "name" ) long someValue ) return Stream.of( new Output( someValue ) ); } + @Procedure + public Stream simpleArgumentWithDefault( @Name( value = "name", defaultValue = "42") long someValue ) + { + return Stream.of( new Output( someValue ) ); + } + + @Procedure + public Stream defaultValues( + @Name( value = "string", defaultValue = "a string") String string, + @Name( value = "integer", defaultValue = "42") long integer, + @Name( value = "float", defaultValue = "3.14") double aFloat, + @Name( value = "boolean", defaultValue = "true") boolean aBoolean + ) + { + return Stream.of( new PrimitiveOutput( string, integer, aFloat, aBoolean ) ); + } + @Procedure public Stream nodeListArgument( @Name( "nodes" ) List nodes ) {