Skip to content

Commit

Permalink
Integrated aggregation functions in Cypher
Browse files Browse the repository at this point in the history
Added the infrastructure to be able to call and run user-defined aggregation
function in Cypher.
  • Loading branch information
pontusmelke committed Dec 12, 2016
1 parent 207ff6a commit e0dae39
Show file tree
Hide file tree
Showing 41 changed files with 1,147 additions and 74 deletions.
@@ -0,0 +1,110 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.internal.cypher.acceptance

import java.util

import org.neo4j.kernel.api.proc.Neo4jTypes

class AggregationFunctionCallSupportAcceptanceTest extends ProcedureCallAcceptanceTest {

test("should return correctly typed map result (even if converting to and from scala representation internally)") {
val value = new util.HashMap[String, Any]()
value.put("name", "Cypher")
value.put("level", 9001)

registerUserAggregationFunction(value)

// Using graph execute to get a Java value
graph.execute("RETURN my.first.value()").stream().toArray.toList should equal(List(
java.util.Collections.singletonMap("my.first.value()", value)
))
}

test("should return correctly typed list result (even if converting to and from scala representation internally)") {
val value = new util.ArrayList[Any]()
value.add("Norris")
value.add("Strange")

registerUserAggregationFunction(value)

// Using graph execute to get a Java value
graph.execute("RETURN my.first.value() AS out").stream().toArray.toList should equal(List(
java.util.Collections.singletonMap("out", value)
))
}

test("should return correctly typed stream result (even if converting to and from scala representation internally)") {
val value = new util.ArrayList[Any]()
value.add("Norris")
value.add("Strange")
val stream = value.stream()

registerUserAggregationFunction(stream)

// Using graph execute to get a Java value
graph.execute("RETURN my.first.value() AS out").stream().toArray.toList should equal(List(
java.util.Collections.singletonMap("out", stream)
))
}

test("should not copy lists unnecessarily") {
val value = new util.ArrayList[Any]()
value.add("Norris")
value.add("Strange")

registerUserAggregationFunction(value)

// Using graph execute to get a Java value
val returned = graph.execute("RETURN my.first.value() AS out").next().get("out")

returned shouldBe an [util.ArrayList[_]]
returned shouldBe value
}

test("should not copy unnecessarily with nested types") {
val value = new util.ArrayList[Any]()
val inner = new util.ArrayList[Any]()
value.add("Norris")
value.add(inner)

registerUserAggregationFunction(value)

// Using graph execute to get a Java value
val returned = graph.execute("RETURN my.first.value() AS out").next().get("out")

returned shouldBe an [util.ArrayList[_]]
returned shouldBe value
}

test("should handle interacting with list") {
val value = new util.ArrayList[Integer]()
value.add(1)
value.add(3)

registerUserAggregationFunction(value, Neo4jTypes.NTList(Neo4jTypes.NTInteger))

// Using graph execute to get a Java value
val returned = graph.execute("WITH my.first.value() AS list RETURN list[0] + list[1] AS out")
.next().get("out")

returned should equal(4)
}
}
Expand Up @@ -23,6 +23,7 @@ import org.neo4j.collection.RawIterator
import org.neo4j.cypher._
import org.neo4j.kernel.api.exceptions.ProcedureException
import org.neo4j.kernel.api.proc.CallableProcedure.BasicProcedure
import org.neo4j.kernel.api.proc.CallableUserAggregationFunction.{Aggregator, BasicUserAggregationFunction}
import org.neo4j.kernel.api.proc.CallableUserFunction.BasicUserFunction
import org.neo4j.kernel.api.proc.ProcedureSignature.procedureSignature
import org.neo4j.kernel.api.proc.UserFunctionSignature.functionSignature
Expand Down Expand Up @@ -66,6 +67,23 @@ abstract class ProcedureCallAcceptanceTest extends ExecutionEngineFunSuite {
}
}

protected def registerUserAggregationFunction(value: AnyRef, typ: Neo4jTypes.AnyType = Neo4jTypes.NTAny) =
registerUserDefinedAggregationFunction("my.first.value") { builder =>
val builder = functionSignature(Array("my", "first"), "value")
builder.out(typ)

new BasicUserAggregationFunction(builder.build) {

override def create(ctx: Context): Aggregator = new Aggregator {

override def result() = value

override def update(input: Array[AnyRef]) = {}
}
}
}


protected def registerVoidProcedure() =
registerProcedure("dbms.do_nothing") { builder =>
builder.out(ProcedureSignature.VOID)
Expand Down
Expand Up @@ -24,6 +24,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.SemanticCheckResult._
import org.neo4j.cypher.internal.frontend.v3_2._
import org.neo4j.cypher.internal.frontend.v3_2.ast.Expression.SemanticContext
import org.neo4j.cypher.internal.frontend.v3_2.ast._
import org.neo4j.cypher.internal.frontend.v3_2.ast.functions.UserDefinedFunctionInvocation

object ResolvedFunctionInvocation {

Expand All @@ -49,7 +50,7 @@ case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
fcnSignature: Option[UserFunctionSignature],
callArguments: IndexedSeq[Expression])
(val position: InputPosition)
extends Expression {
extends Expression with UserDefinedFunctionInvocation {

def coerceArguments: ResolvedFunctionInvocation = fcnSignature match {
case Some(signature) =>
Expand Down Expand Up @@ -90,4 +91,6 @@ case class ResolvedFunctionInvocation(qualifiedName: QualifiedName,
|meaning that it expects $expectedNumArgs $msg""".stripMargin, position))
}
}

override def isAggregate: Boolean = fcnSignature.exists(_.isAggregate)
}
Expand Up @@ -295,7 +295,9 @@ object ExpressionConverters {
val callArgumentCommands = e.callArguments.map(Some(_)).zipAll(e.fcnSignature.get.inputSignature.map(_.default.map(_.value)), None, None).map {
case (given, default) => given.map(toCommandExpression).getOrElse(commandexpressions.Literal(default.get))
}
commandexpressions.FunctionInvocation(e.fcnSignature.get, callArgumentCommands)
val signature = e.fcnSignature.get
if (signature.isAggregate) commandexpressions.AggregationFunctionInvocation(signature, callArgumentCommands)
else commandexpressions.FunctionInvocation(signature, callArgumentCommands)
case e: ast.MapProjection => throw new InternalException("should have been rewritten away")
case e: NestedPlanExpression => commandexpressions.NestedPlanExpression(e.plan)
case _ =>
Expand Down
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.compiler.v3_2.commands.expressions

import org.neo4j.cypher.internal.compiler.v3_2.ExecutionContext
import org.neo4j.cypher.internal.compiler.v3_2.helpers.{RuntimeJavaValueConverter, RuntimeScalaValueConverter}
import org.neo4j.cypher.internal.compiler.v3_2.pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_2.pipes.aggregation.AggregationFunction
import org.neo4j.cypher.internal.compiler.v3_2.spi.UserFunctionSignature

case class AggregationFunctionInvocation(signature: UserFunctionSignature, arguments: IndexedSeq[Expression])
extends AggregationExpression {

override def createAggregationFunction: AggregationFunction = new AggregationFunction {
private var inner: UserDefinedAggregator = null

override def result(implicit state:QueryState) = {
val isGraphKernelResultValue = state.query.isGraphKernelResultValue _
val scalaValues = new RuntimeScalaValueConverter(isGraphKernelResultValue, state.typeConverter.asPrivateType)
scalaValues.asDeepScalaValue(aggregator.result)
}

override def apply(data: ExecutionContext)
(implicit state: QueryState) = {
val converter = new RuntimeJavaValueConverter(state.query.isGraphKernelResultValue, state.typeConverter.asPublicType)
val argValues = arguments.map(arg => converter.asDeepJavaValue(arg(data)(state)))
aggregator.update(argValues)
}

private def aggregator(implicit state: QueryState) = {
if (inner == null) {
inner = state.query.aggregateFunction(signature.name, signature.allowed)
}
inner
}


}

override def rewrite(f: (Expression) => Expression): Expression = f(
AggregationFunctionInvocation(signature, arguments.map(a => a.rewrite(f))))

override def symbolTableDependencies: Set[String] = arguments.flatMap(_.symbolTableDependencies).toSet
}
Expand Up @@ -47,4 +47,4 @@ case class FunctionInvocation(signature: UserFunctionSignature, arguments: Index
override def symbolTableDependencies = arguments.flatMap(_.symbolTableDependencies).toSet

override def toString = s"${signature.name}(${arguments.mkString(",")})"
}
}
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.compiler.v3_2.commands.expressions

trait UserDefinedAggregator {
def update(args: IndexedSeq[Any]): Unit
def result: Any
}
Expand Up @@ -37,6 +37,7 @@ case class EagerAggregationPipe(source: Pipe, keyExpressions: Set[String], aggre
protected def internalCreateResults(input: Iterator[ExecutionContext], state: QueryState) = {
//register as parent so that stats are associated with this pipe
state.decorator.registerParentPipe(this)
implicit val s = state

// This is the temporary storage used while the aggregation is going on
val result = MutableMap[Equals, Seq[AggregationFunction]]()
Expand Down
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.compiler.v3_2.pipes.aggregation

import org.neo4j.cypher.internal.compiler.v3_2._
import pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_2.pipes.QueryState

/**
* Base class for aggregation functions. The function is stateful
Expand All @@ -36,6 +36,6 @@ abstract class AggregationFunction {
/**
* The aggregated result.
*/
def result: Any
def result(implicit state:QueryState): Any
}

Expand Up @@ -20,9 +20,9 @@
package org.neo4j.cypher.internal.compiler.v3_2.pipes.aggregation

import org.neo4j.cypher.internal.compiler.v3_2._
import commands.expressions.Expression
import org.neo4j.cypher.internal.compiler.v3_2.commands.expressions.Expression
import org.neo4j.cypher.internal.compiler.v3_2.helpers.TypeSafeMathSupport
import pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_2.pipes.QueryState

/**
* AVG computation is calculated using cumulative moving average approach:
Expand All @@ -38,7 +38,7 @@ class AvgFunction(val value: Expression)
private var count: Long = 0L
private var sum: OverflowAwareSum[_] = OverflowAwareSum(0)

def result =
def result(implicit state: QueryState) =
if (count > 0) {
sum.value
} else {
Expand Down
Expand Up @@ -20,9 +20,10 @@
package org.neo4j.cypher.internal.compiler.v3_2.pipes.aggregation

import org.neo4j.cypher.internal.compiler.v3_2._
import commands.expressions.Expression
import pipes.QueryState
import collection.mutable.ListBuffer
import org.neo4j.cypher.internal.compiler.v3_2.commands.expressions.Expression
import org.neo4j.cypher.internal.compiler.v3_2.pipes.QueryState

import scala.collection.mutable.ListBuffer

class CollectFunction(value:Expression) extends AggregationFunction {
val collection = new ListBuffer[Any]()
Expand All @@ -34,5 +35,5 @@ class CollectFunction(value:Expression) extends AggregationFunction {
}
}

def result: Any = collection.toIndexedSeq
def result(implicit state: QueryState): Any = collection.toIndexedSeq
}
Expand Up @@ -20,8 +20,8 @@
package org.neo4j.cypher.internal.compiler.v3_2.pipes.aggregation

import org.neo4j.cypher.internal.compiler.v3_2._
import commands.expressions.Expression
import pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_2.commands.expressions.Expression
import org.neo4j.cypher.internal.compiler.v3_2.pipes.QueryState

class CountFunction(value: Expression) extends AggregationFunction {
var count: Long = 0
Expand All @@ -33,5 +33,5 @@ class CountFunction(value: Expression) extends AggregationFunction {
}
}

def result: Long = count
def result(implicit state: QueryState): Long = count
}
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.compiler.v3_2.pipes.aggregation

import org.neo4j.cypher.internal.compiler.v3_2._
import pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_2.pipes.QueryState

class CountStarFunction extends AggregationFunction {
var count:Long = 0
Expand All @@ -29,6 +29,6 @@ class CountStarFunction extends AggregationFunction {
count += 1
}

def result: Long = count
def result(implicit state: QueryState): Long = count
}

Expand Up @@ -45,5 +45,5 @@ class DistinctFunction(value: Expression, inner: AggregationFunction) extends Ag
}
}

override def result = inner.result
override def result(implicit state: QueryState) = inner.result
}
Expand Up @@ -30,7 +30,7 @@ trait MinMax extends AggregationFunction with Comparer {

private var biggestSeen: Any = null

def result: Any = biggestSeen
def result(implicit state: QueryState): Any = biggestSeen

def apply(data: ExecutionContext)(implicit state: QueryState) {
value(data) match {
Expand Down

0 comments on commit e0dae39

Please sign in to comment.