Skip to content

Commit

Permalink
Support avg in morsel runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Feb 25, 2018
1 parent 92388e7 commit b8d88d7
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 6 deletions.
Expand Up @@ -66,7 +66,7 @@ public final class Values
public static final Value MIN_NUMBER = Values.doubleValue( Double.NEGATIVE_INFINITY );
public static final Value MAX_NUMBER = Values.doubleValue( Double.NaN );
public static final Value ZERO_FLOAT = Values.doubleValue( 0.0 );
public static final Value ZERO_INT = Values.longValue( 0 );
public static final IntegralValue ZERO_INT = Values.longValue( 0 );
public static final Value MIN_STRING = StringValue.EMTPY;
public static final Value MAX_STRING = Values.booleanValue( false );
public static final BooleanValue TRUE = Values.booleanValue( true );
Expand Down
Expand Up @@ -134,6 +134,36 @@ class MorselRuntimeAcceptanceTest extends ExecutionEngineFunSuite {
result.getExecutionPlanDescription.getArguments.get("runtime") should equal("MORSEL")
}

test("should support average") {
//Given
10 to 100 by 10 foreach(i => createNode("prop" -> i))

//When
val result = graph.execute("CYPHER runtime=morsel MATCH (n) RETURN avg(n.prop)")

//Then
result.next().get("avg(n.prop)") should equal(55.0)
result.getExecutionPlanDescription.getArguments.get("runtime") should equal("MORSEL")
}

test("should support average with grouping") {
//Given
10 to 100 by 10 foreach(i => createNode("prop" -> i, "group" -> (if (i > 50) "FOO" else "BAR")))

//When
val result = graph.execute("CYPHER runtime=morsel MATCH (n) RETURN n.group, avg(n.prop)")

//Then
val first = result.next()
first.get("n.group") should equal("FOO")
first.get("avg(n.prop)") should equal(80.0)
val second = result.next()
second.get("n.group") should equal("BAR")
second.get("avg(n.prop)") should equal(30.0)

result.getExecutionPlanDescription.getArguments.get("runtime") should equal("MORSEL")
}

//we use a ridiculously small morsel size in order to trigger as many morsel overflows as possible
override def databaseConfig(): Map[Setting[_], String] = Map(GraphDatabaseSettings.cypher_morsel_size -> "4")
}
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2002-2018 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.runtime.vectorized.expressions

import org.neo4j.cypher.internal.runtime.interpreted.commands.expressions.Expression
import org.neo4j.cypher.internal.runtime.interpreted.pipes.{QueryState => OldQueryState}
import org.neo4j.cypher.internal.runtime.vectorized._
import org.neo4j.cypher.internal.util.v3_4.symbols.CTAny
import org.neo4j.values.AnyValue
import org.neo4j.values.storable.Values.longValue
import org.neo4j.values.storable.{LongValue, NumberValue, Values}
import org.neo4j.values.virtual.{ListValue, VirtualValues}

/*
Vectorized version of the average aggregation function
*/
case class AvgOperatorExpression(anInner: Expression)
extends AggregationExpressionOperatorWithInnerExpression(anInner) {

override def expectedInnerType = CTAny

override def rewrite(f: (Expression) => Expression): Expression = f(CountOperatorExpression(anInner.rewrite(f)))

override def createAggregationMapper: AggregationMapper = new AvgMapper(anInner)

override def createAggregationReducer: AggregationReducer = new AvgReducer
}

class AvgMapper(value: Expression) extends AggregationMapper {

private var count: Long = 0L
private var sum: NumberValue = Values.ZERO_INT

override def result: AnyValue = VirtualValues.list(longValue(count), sum)

override def map(data: MorselExecutionContext,
state: OldQueryState): Unit = value(data, state) match {
case Values.NO_VALUE =>
case number: NumberValue =>
count += 1
sum = sum.plus(number)
}
}

class AvgReducer extends AggregationReducer {

private var count: Long = 0L
private var sum: NumberValue = longValue(0L)

override def result: AnyValue = sum.times(1.0 / count.toDouble)

override def reduce(value: AnyValue): Unit = value match {
case l: ListValue =>
count += l.value(0).asInstanceOf[LongValue].longValue()
sum = sum.plus(l.value(1).asInstanceOf[NumberValue])
case _ =>
}
}



Expand Up @@ -26,6 +26,9 @@ import org.neo4j.cypher.internal.util.v3_4.symbols.CTAny
import org.neo4j.values.AnyValue
import org.neo4j.values.storable.{NumberValue, Values}

/*
Vectorized version of the count aggregation function
*/
case class CountOperatorExpression(anInner: Expression) extends AggregationExpressionOperatorWithInnerExpression(anInner) {

override def expectedInnerType = CTAny
Expand All @@ -37,7 +40,7 @@ case class CountOperatorExpression(anInner: Expression) extends AggregationExpre
override def createAggregationReducer: AggregationReducer = new CountReducer
}

class CountMapper(value: Expression) extends AggregationMapper {
private class CountMapper(value: Expression) extends AggregationMapper {
private var count: Long = 0L

override def result: AnyValue = Values.longValue(count)
Expand All @@ -48,7 +51,7 @@ class CountMapper(value: Expression) extends AggregationMapper {
}
}

class CountReducer extends AggregationReducer {
private class CountReducer extends AggregationReducer {
private var count: Long = 0L

override def result: AnyValue = Values.longValue(count)
Expand Down
Expand Up @@ -34,8 +34,9 @@ object MorselExpressionConverters extends ExpressionConverter {
self: ExpressionConverters): Option[Expression] = expression match {

case c: FunctionInvocation if c.function == functions.Count =>
val inner = self.toCommandExpression(c.arguments.head)
Some(CountOperatorExpression(inner))
Some(CountOperatorExpression(self.toCommandExpression(c.arguments.head)))
case c: FunctionInvocation if c.function == functions.Avg =>
Some(AvgOperatorExpression(self.toCommandExpression(c.arguments.head)))
case f: FunctionInvocation if f.function.isInstanceOf[AggregatingFunction] => throw new CantCompileQueryException()
case _ => None
}
Expand Down
Expand Up @@ -21,7 +21,6 @@ package org.neo4j.cypher.internal.runtime.vectorized.operators

import org.neo4j.cypher.internal.compatibility.v3_4.runtime.SlotConfiguration
import org.neo4j.cypher.internal.runtime.QueryContext
import org.neo4j.cypher.internal.runtime.interpreted.pipes.{QueryState => OldQueryState}
import org.neo4j.cypher.internal.runtime.vectorized._

/*
Expand Down

0 comments on commit b8d88d7

Please sign in to comment.