Skip to content

Commit

Permalink
Add default values to procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Jul 12, 2016
1 parent 24c9880 commit b469fd9
Show file tree
Hide file tree
Showing 16 changed files with 278 additions and 58 deletions.
Expand Up @@ -31,14 +31,12 @@ object ResolvedCall {
val UnresolvedCall(_, _, declaredArguments, declaredResults) = unresolved
val position = unresolved.position
val signature = signatureLookup(QualifiedProcedureName(unresolved))
val callArguments = declaredArguments.getOrElse(signatureArguments(signature, position))
val nonDefaults = signature.inputSignature.flatMap(s => if (s.default.isDefined) None else Some(Parameter(s.name, CTAny)(position)))
val callArguments = declaredArguments.getOrElse(nonDefaults)
val callResults = declaredResults.getOrElse(signatureResults(signature, position))
ResolvedCall(signature, callArguments, callResults, declaredArguments.nonEmpty, declaredResults.nonEmpty)(position)
}

private def signatureArguments(signature: ProcedureSignature, position: InputPosition): Seq[Parameter] =
signature.inputSignature.map { field => Parameter(field.name, CTAny)(position) }

private def signatureResults(signature: ProcedureSignature, position: InputPosition): Seq[ProcedureResultItem] =
signature.outputSignature.getOrElse(Seq.empty).map { field => ProcedureResultItem(Variable(field.name)(position))(position) }
}
Expand Down Expand Up @@ -90,10 +88,13 @@ case class ResolvedCall(signature: ProcedureSignature,

private def argumentCheck: SemanticCheck = {
val expectedNumArgs = signature.inputSignature.length
val actualNumArgs = callArguments.length
val defaultArgs = signature.inputSignature.flatMap(_.default).drop(callArguments.length)
val actualNumArgs = callArguments.length + defaultArgs.length

if (declaredArguments) {
if (expectedNumArgs == actualNumArgs) {
//this zip is fine since it will only verify provided args in callArguments
//default values are checked at load time
signature.inputSignature.zip(callArguments).map {
case (field, arg) =>
arg.semanticCheck(SemanticContext.Results) chain arg.expectType(field.typ.covariant)
Expand Down
Expand Up @@ -20,6 +20,8 @@
package org.neo4j.cypher.internal.compiler.v3_1.executionplan.procs

import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.ExpressionConverters._
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.Literal
import org.neo4j.cypher.internal.compiler.v3_1.executionplan.{ExecutionPlan, InternalExecutionResult, ProcedureCallMode, READ_ONLY}
import org.neo4j.cypher.internal.compiler.v3_1.helpers.{Counter, RuntimeJavaValueConverter}
import org.neo4j.cypher.internal.compiler.v3_1.pipes.{ExternalCSVResource, QueryState}
Expand Down Expand Up @@ -47,7 +49,8 @@ case class ProcedureCallExecutionPlan(signature: ProcedureSignature,
publicTypeConverter: Any => Any)
extends ExecutionPlan {

private val argExprCommands = argExprs.map(toCommandExpression)
private val argExprCommands: Seq[expressions.Expression] = argExprs.map(toCommandExpression) ++
signature.inputSignature.drop(argExprs.size).flatMap(_.default).map(Literal(_))

override def run(ctx: QueryContext, planType: ExecutionMode, params: Map[String, Any]): InternalExecutionResult = {
val input = evaluateArguments(ctx, params)
Expand Down
Expand Up @@ -27,7 +27,7 @@ import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.PatternConve
import org.neo4j.cypher.internal.compiler.v3_1.ast.convert.commands.StatementConverters
import org.neo4j.cypher.internal.compiler.v3_1.ast.rewriters.projectNamedPaths
import org.neo4j.cypher.internal.compiler.v3_1.commands.EntityProducerFactory
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.{AggregationExpression, Expression => CommandExpression}
import org.neo4j.cypher.internal.compiler.v3_1.commands.expressions.{AggregationExpression, Expression => CommandExpression, Literal}
import org.neo4j.cypher.internal.compiler.v3_1.commands.predicates.{True, _}
import org.neo4j.cypher.internal.compiler.v3_1.executionplan._
import org.neo4j.cypher.internal.compiler.v3_1.executionplan.builders.prepare.KeyTokenResolver
Expand Down Expand Up @@ -314,7 +314,9 @@ case class ActualPipeBuilder(monitors: Monitors, recurse: LogicalPlan => Pipe, r

case ProcedureCall(_, call@ResolvedCall(signature, callArguments, callResults, _, _)) =>
val callMode = ProcedureCallMode.fromAccessMode(signature.accessMode)
val callArgumentCommands = callArguments.map(toCommandExpression)
val callArgumentCommands = callArguments.map(Some(_)).zipAll(signature.inputSignature.map(_.default), None, None).map {
case (given, default) => given.map(toCommandExpression).getOrElse(Literal(default.get))
}
val rowProcessing = ProcedureCallRowProcessing(signature)
ProcedureCallPipe(source, signature.name, callMode, callArgumentCommands, rowProcessing, call.callResultTypes, call.callResultIndices)()

Expand Down
Expand Up @@ -23,8 +23,8 @@ import org.neo4j.cypher.internal.frontend.v3_1.ast.UnresolvedCall
import org.neo4j.cypher.internal.frontend.v3_1.symbols.CypherType

case class ProcedureSignature(name: QualifiedProcedureName,
inputSignature: Seq[FieldSignature],
outputSignature: Option[Seq[FieldSignature]],
inputSignature: IndexedSeq[FieldSignature],
outputSignature: Option[IndexedSeq[FieldSignature]],
accessMode: ProcedureAccessMode = ProcedureReadOnlyAccess) {

def outputFields = outputSignature.getOrElse(Seq.empty)
Expand All @@ -41,7 +41,7 @@ case class QualifiedProcedureName(namespace: Seq[String], name: String) {
override def toString = s"""${namespace.mkString(".")}.$name"""
}

case class FieldSignature(name: String, typ: CypherType)
case class FieldSignature(name: String, typ: CypherType, default: Option[AnyRef] = None)

sealed trait ProcedureAccessMode

Expand Down
Expand Up @@ -30,8 +30,8 @@ class RewriteProcedureCallsTest extends CypherFunSuite with AstConstructionTestS
val ns = ProcedureNamespace(List("my", "proc"))(pos)
val name = ProcedureName("foo")(pos)
val qualifiedName = QualifiedProcedureName(ns.parts, name.name)
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val lookup: (QualifiedProcedureName) => ProcedureSignature = _ => signature

Expand Down
Expand Up @@ -33,8 +33,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {

test("should resolve CALL my.proc.foo") {
val unresolved = UnresolvedCall(ns, name, None, None)(pos)
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos))
Expand All @@ -58,7 +58,7 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {

test("should resolve void CALL my.proc.foo") {
val unresolved = UnresolvedCall(ns, name, None, None)(pos)
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = None
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
Expand All @@ -82,8 +82,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should resolve CALL my.proc.foo YIELD x, y") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos))
Expand All @@ -107,8 +107,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should resolve CALL my.proc.foo(a)") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(ProcedureResultItem(varFor("x"))(pos), ProcedureResultItem(varFor("y"))(pos))
Expand All @@ -132,7 +132,7 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should resolve void CALL my.proc.foo(a)") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = None
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
Expand All @@ -157,8 +157,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should resolve CALL my.proc.foo(a) YIELD x, y AS z") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(
Expand All @@ -183,16 +183,16 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("pretends to be based on user-declared arguments and results upon request") {
val signature = ProcedureSignature(qualifiedName, Seq.empty, Some(Seq.empty), ProcedureReadOnlyAccess)
val signature = ProcedureSignature(qualifiedName, IndexedSeq.empty, Some(IndexedSeq.empty), ProcedureReadOnlyAccess)
val call = ResolvedCall(signature, null, null, declaredArguments = false, declaredResults = false)(pos)

call.withFakedFullDeclarations.declaredArguments should be(true)
call.withFakedFullDeclarations.declaredResults should be(true)
}

test("adds coercion of arguments to signature types upon request") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(
Expand All @@ -218,8 +218,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should verify number of arguments during semantic checking of resolved calls") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq.empty
val callResults = Seq(
Expand All @@ -235,8 +235,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should verify that result variables are unique during semantic checking of resolved calls") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(
Expand All @@ -252,8 +252,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should verify that output field names are correct during semantic checking of resolved calls") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(Parameter("a", CTAny)(pos))
val callResults = Seq(
Expand All @@ -269,8 +269,8 @@ class CallClauseTest extends CypherFunSuite with AstConstructionTestSupport {
}

test("should verify result types during semantic checking of resolved calls") {
val signatureInputs = Seq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(Seq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signatureInputs = IndexedSeq(FieldSignature("a", CTInteger))
val signatureOutputs = Some(IndexedSeq(FieldSignature("x", CTInteger), FieldSignature("y", CTList(CTNode))))
val signature = ProcedureSignature(qualifiedName, signatureInputs, signatureOutputs, ProcedureReadOnlyAccess)
val callArguments = Seq(StringLiteral("nope")(pos))
val callResults = Seq(
Expand Down
Expand Up @@ -879,8 +879,8 @@ class StatementConvertersTest extends CypherFunSuite with LogicalPlanningTestSup
test("CALL foo() YIELD all RETURN all") {
val signature = ProcedureSignature(
QualifiedProcedureName(Seq.empty, "foo"),
inputSignature = Seq.empty,
outputSignature = Some(Seq(FieldSignature("all", CTInteger)))
inputSignature = IndexedSeq.empty,
outputSignature = Some(IndexedSeq(FieldSignature("all", CTInteger)))
)
val query = buildPlannerQuery("CALL foo() YIELD all RETURN all", Some(_ => signature))

Expand Down
Expand Up @@ -82,16 +82,16 @@ class ProcedureCallExecutionPlanTest extends CypherFunSuite {
def string(s: String): Expression = StringLiteral(s)(pos)

private val readSignature = ProcedureSignature(
QualifiedProcedureName(Seq.empty, "foo"),
Seq(FieldSignature("a", symbols.CTInteger)),
Some(Seq(FieldSignature("b", symbols.CTInteger))),
QualifiedProcedureName(IndexedSeq.empty, "foo"),
IndexedSeq(FieldSignature("a", symbols.CTInteger)),
Some(IndexedSeq(FieldSignature("b", symbols.CTInteger))),
ProcedureReadOnlyAccess
)

private val writeSignature = ProcedureSignature(
QualifiedProcedureName(Seq.empty, "foo"),
Seq(FieldSignature("a", symbols.CTInteger)),
Some(Seq(FieldSignature("b", symbols.CTInteger))),
IndexedSeq(FieldSignature("a", symbols.CTInteger)),
Some(IndexedSeq(FieldSignature("b", symbols.CTInteger))),
ProcedureReadWriteAccess
)

Expand Down
Expand Up @@ -19,6 +19,8 @@
*/
package org.neo4j.cypher.internal.spi.v3_1

import java.util.Optional

import org.neo4j.cypher.MissingIndexException
import org.neo4j.cypher.internal.LastCommittedTxIdProvider
import org.neo4j.cypher.internal.compiler.v3_1.pipes.EntityProducer
Expand Down Expand Up @@ -130,13 +132,15 @@ class TransactionBoundPlanContext(tc: TransactionalContextWrapperv3_1)
override def procedureSignature(name: QualifiedProcedureName) = {
val kn = new KernelProcedureSignature.ProcedureName(name.namespace.asJava, name.name)
val ks = tc.statement.readOperations().procedureGet(kn)
val input = ks.inputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType())))
val output = if (ks.isVoid) None else Some(ks.outputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType()))))
val input = ks.inputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType()), asOption(s.defaultValue()))).toIndexedSeq
val output = if (ks.isVoid) None else Some(ks.outputSignature().asScala.map(s => FieldSignature(s.name(), asCypherType(s.neo4jType()))).toIndexedSeq)
val mode = asCypherProcMode(ks.mode())

ProcedureSignature(name, input, output, mode)
}

private def asOption[T](optional: Optional[T]): Option[T] = if (optional.isPresent) Some(optional.get()) else None

private def asCypherProcMode(mode: KernelProcedureSignature.Mode): ProcedureAccessMode = mode match {
case KernelProcedureSignature.Mode.READ_ONLY => ProcedureReadOnlyAccess
case KernelProcedureSignature.Mode.READ_WRITE => ProcedureReadWriteAccess
Expand Down
Expand Up @@ -24,6 +24,7 @@
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;

import org.neo4j.helpers.collection.Iterables;
import org.neo4j.kernel.api.proc.Neo4jTypes.AnyType;
Expand Down Expand Up @@ -96,11 +97,18 @@ public static class FieldSignature
{
private final String name;
private final AnyType type;
private final Optional<Object> defaultValue;

public FieldSignature( String name, AnyType type )
public FieldSignature( String name, AnyType type)
{
this(name, type, Optional.empty());
}

public FieldSignature( String name, AnyType type, Optional<Object> defaultValue )
{
this.name = name;
this.type = type;
this.defaultValue = defaultValue;
}

public String name()
Expand All @@ -113,6 +121,11 @@ public AnyType neo4jType()
return type;
}

public Optional<Object> defaultValue()
{
return defaultValue;
}

@Override
public String toString()
{
Expand Down Expand Up @@ -253,7 +266,7 @@ public Builder mode( Mode mode )
/** Define an input field */
public Builder in( String name, AnyType type )
{
inputSignature.add( new FieldSignature( name, type ) );
inputSignature.add( new FieldSignature( name, type) );
return this;
}

Expand Down

0 comments on commit b469fd9

Please sign in to comment.