Permalink
Browse files

Fixing Reduce's calculateType

  • Loading branch information...
1 parent bd2f520 commit 3f58cf88e023af5c657adb728ea1c8dccb8586ea @freeeve committed Oct 5, 2012
@@ -34,7 +34,7 @@ case class ReduceFunction(collection: Expression, id: String, expression: Expres
computedMap(acc)
}
- def rewrite(f: (Expression) => Expression) = f(ReduceFunction(collection.rewrite(f), id, expression.rewrite(f), acc, init))
+ def rewrite(f: (Expression) => Expression) = f(ReduceFunction(collection.rewrite(f), id, expression.rewrite(f), acc, init.rewrite(f)))
def filter(f: (Expression) => Boolean) = if (f(this))
Seq(this) ++ collection.filter(f)
@@ -43,7 +43,13 @@ case class ReduceFunction(collection: Expression, id: String, expression: Expres
def identifierDependencies(expectedType: CypherType) = AnyType
- def calculateType(symbols: SymbolTable) = collection.evaluateType(AnyCollectionType(), symbols).iteratedType
+ def calculateType(symbols: SymbolTable) = {
+ val iteratorType = collection.evaluateType(AnyCollectionType(), symbols).iteratedType
+ var innerSymbols = symbols.add(acc, init.evaluateType(AnyType(), symbols))
+ innerSymbols = innerSymbols.add(id, iteratorType)
+ // return expressions's type as the end result for reduce
+ expression.evaluateType(AnyType(), innerSymbols)
+ }
- def symbolTableDependencies = collection.symbolTableDependencies
+ def symbolTableDependencies = (collection.symbolTableDependencies ++ expression.symbolTableDependencies ++ init.symbolTableDependencies) - id - acc
}
@@ -20,6 +20,7 @@
package org.neo4j.cypher.internal.commands
import expressions.{ReduceFunction, Identifier, LengthFunction, Add, Literal}
+import org.neo4j.cypher.internal.symbols.{SymbolTable, StringType, NumberType, AnyCollectionType}
import org.scalatest.Assertions
import org.junit.Test
@@ -44,4 +45,34 @@ class ReduceTest extends Assertions {
assert(reduce(m) === null)
}
+
+ @Test def reduce_has_the_expected_type_string() {
+ val expression = Add(Identifier("acc"), Identifier("n"))
+ val collection = Literal(Seq(1,2,3))
+
+ val reduce = ReduceFunction(collection, "n", expression, "acc", Literal(""))
+ val typ = reduce.calculateType(new SymbolTable())
+
+ assert(typ === StringType())
+ }
+
+ @Test def reduce_has_the_expected_type_number() {
+ val expression = Add(Identifier("acc"), Identifier("n"))
+ val collection = Literal(Seq(1,2,3))
+
+ val reduce = ReduceFunction(collection, "n", expression, "acc", Literal(0))
+ val typ = reduce.calculateType(new SymbolTable())
+
+ assert(typ === NumberType())
+ }
+
+ @Test def reduce_has_the_expected_type_array() {
+ val expression = Add(Identifier("acc"), Identifier("n"))
+ val collection = Literal(Seq(1,2,3))
+
+ val reduce = ReduceFunction(collection, "n", expression, "acc", Literal(Seq(1,2)))
+ val typ = reduce.calculateType(new SymbolTable())
+
+ assert(typ === AnyCollectionType())
+ }
}

0 comments on commit 3f58cf8

Please sign in to comment.