Skip to content

Commit

Permalink
Merge pull request #7992 from boggle/3.0-type-checking
Browse files Browse the repository at this point in the history
Fix type checking of literal list elements + some renaming
  • Loading branch information
Stefan Plantikow committed Sep 23, 2016
2 parents a54e796 + b8cc6bc commit f3746e7
Show file tree
Hide file tree
Showing 122 changed files with 432 additions and 558 deletions.
Expand Up @@ -19,7 +19,7 @@
*/
package org.neo4j.internal.cypher.acceptance

import org.neo4j.cypher.internal.compiler.v3_0.helpers.CollectionSupport
import org.neo4j.cypher.internal.compiler.v3_0.helpers.ListSupport
import org.neo4j.cypher.{CypherException, ExecutionEngineFunSuite, QueryStatisticsTestSupport}
import org.neo4j.graphdb.Node
import org.neo4j.kernel.impl.storageengine.impl.recordstorage.RecordStorageEngine
Expand All @@ -28,7 +28,7 @@ import org.scalatest.Assertions
import org.scalautils.LegacyTripleEquals

class LabelsAcceptanceTest extends ExecutionEngineFunSuite
with QueryStatisticsTestSupport with Assertions with CollectionSupport with LegacyTripleEquals {
with QueryStatisticsTestSupport with Assertions with ListSupport with LegacyTripleEquals {

test("Adding_single_literal_label") {
assertThat("create (n {}) set n:FOO", List("FOO"))
Expand Down
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.compiler.v3_0

import org.neo4j.cypher.internal.compiler.v3_0.commands.values.KeyToken
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{IsCollection, IsMap}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{IsList, IsMap}
import org.neo4j.cypher.internal.compiler.v3_0.pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_0.spi.QueryContext
import org.neo4j.graphdb.{Node, PropertyContainer, Relationship}
Expand All @@ -46,7 +46,7 @@ trait CypherSerializer {
case x: Node => x.toString + serializeProperties(x, qtx)
case x: Relationship => ":" + x.getType.name() + "[" + x.getId + "]" + serializeProperties(x, qtx)
case IsMap(m) => makeString(m, qtx)
case IsCollection(coll) => coll.map(elem => serialize(elem, qtx)).mkString("[", ",", "]")
case IsList(coll) => coll.map(elem => serialize(elem, qtx)).mkString("[", ",", "]")
case x: String => "\"" + x + "\""
case v: KeyToken => v.name
case Some(x) => x.toString
Expand Down
Expand Up @@ -23,7 +23,7 @@ import java.io.PrintWriter
import java.util

import org.neo4j.cypher.internal.compiler.v3_0.executionplan.{InternalExecutionResult, InternalQueryType}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{CollectionSupport, RuntimeJavaValueConverter}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{ListSupport, RuntimeJavaValueConverter}
import org.neo4j.cypher.internal.compiler.v3_0.pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_0.planDescription.InternalPlanDescription
import org.neo4j.cypher.internal.compiler.v3_0.spi.{InternalResultVisitor, QueryContext}
Expand All @@ -41,7 +41,7 @@ class PipeExecutionResult(val result: ResultIterator,
val executionMode: ExecutionMode,
val executionType: InternalQueryType)
extends InternalExecutionResult
with CollectionSupport {
with ListSupport {

self =>

Expand Down
Expand Up @@ -210,7 +210,7 @@ object ExpressionConverters {
else
command
case Tail =>
commandexpressions.CollectionSliceExpression(
commandexpressions.ListSlice(
toCommandExpression(invocation.arguments.head),
Some(commandexpressions.Literal(1)),
None
Expand Down Expand Up @@ -267,17 +267,17 @@ object ExpressionConverters {
case e: ast.PatternExpression => commands.PathExpression(e.pattern.asLegacyPatterns)
case e: ast.ShortestPathExpression => commandexpressions.ShortestPathExpression(e.pattern.asLegacyPatterns(None).head)
case e: ast.HasLabels => hasLabels(e)
case e: ast.Collection => commandexpressions.Collection(toCommandExpression(e.expressions): _*)
case e: ast.ListLiteral => commandexpressions.ListLiteral(toCommandExpression(e.expressions): _*)
case e: ast.MapExpression => mapExpression(e)
case e: ast.CollectionSlice => commandexpressions.CollectionSliceExpression(toCommandExpression(e.list), toCommandExpression(e.from), toCommandExpression(e.to))
case e: ast.ListSlice => commandexpressions.ListSlice(toCommandExpression(e.list), toCommandExpression(e.from), toCommandExpression(e.to))
case e: ast.ContainerIndex => commandexpressions.ContainerIndex(toCommandExpression(e.expr), toCommandExpression(e.idx))
case e: ast.FilterExpression => commandexpressions.FilterFunction(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.ExtractExpression => commandexpressions.ExtractFunction(toCommandExpression(e.expression), e.variable.name, toCommandExpression(e.scope.extractExpression.get))
case e: ast.ListComprehension => listComprehension(e)
case e: ast.AllIterablePredicate => commands.AllInCollection(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.AnyIterablePredicate => commands.AnyInCollection(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.NoneIterablePredicate => commands.NoneInCollection(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.SingleIterablePredicate => commands.SingleInCollection(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.AllIterablePredicate => commands.AllInList(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.AnyIterablePredicate => commands.AnyInList(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.NoneIterablePredicate => commands.NoneInList(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.SingleIterablePredicate => commands.SingleInList(toCommandExpression(e.expression), e.variable.name, e.innerPredicate.map(toCommandPredicate).getOrElse(predicates.True()))
case e: ast.ReduceExpression => commandexpressions.ReduceFunction(toCommandExpression(e.list), e.variable.name, toCommandExpression(e.expression), e.accumulator.name, toCommandExpression(e.init))
case e: ast.PathExpression => toCommandProjectedPath(e)
case e: NestedPipeExpression => commandexpressions.NestedPipeExpression(e.pipe, toCommandProjectedPath(e.path))
Expand Down Expand Up @@ -336,10 +336,10 @@ object ExpressionConverters {
case value: Parameter =>
predicates.ConstantCachedIn(toCommandExpression(e.lhs), toCommandExpression(value))

case value@Collection(expressions) if expressions.isEmpty =>
case value@ListLiteral(expressions) if expressions.isEmpty =>
predicates.Not(predicates.True())

case value@Collection(expressions) if expressions.forall(_.isInstanceOf[Literal]) =>
case value@ListLiteral(expressions) if expressions.forall(_.isInstanceOf[Literal]) =>
predicates.ConstantCachedIn(toCommandExpression(e.lhs), toCommandExpression(value))

case _ =>
Expand Down
Expand Up @@ -242,7 +242,7 @@ object ClauseConverters {
}

private def toPropertySelection(identifier: Variable, map:Map[PropertyKeyName, Expression]): Seq[Expression] = map.map {
case (k, e) => In(Property(identifier, k)(k.position), Collection(Seq(e))(e.position))(identifier.position)
case (k, e) => In(Property(identifier, k)(k.position), ListLiteral(Seq(e))(e.position))(identifier.position)
}.toSeq

private def toSetPattern(semanticTable: SemanticTable)(setItem: SetItem): SetMutatingPattern = setItem match {
Expand Down
Expand Up @@ -19,13 +19,13 @@
*/
package org.neo4j.cypher.internal.compiler.v3_0.ast.convert.plannerQuery

import org.neo4j.cypher.internal.compiler.v3_0.helpers.CollectionSupport
import org.neo4j.cypher.internal.compiler.v3_0.helpers.ListSupport
import org.neo4j.cypher.internal.compiler.v3_0.planner._
import org.neo4j.cypher.internal.compiler.v3_0.planner.logical.plans.IdName
import org.neo4j.cypher.internal.frontend.v3_0.SemanticTable

case class PlannerQueryBuilder(private val q: PlannerQuery, semanticTable: SemanticTable, returns: Seq[IdName] = Seq.empty)
extends CollectionSupport {
extends ListSupport {

def withReturns(returns: Seq[IdName]): PlannerQueryBuilder = copy(returns = returns)

Expand Down
Expand Up @@ -28,7 +28,7 @@ import scala.collection.immutable.Iterable
This class merges multiple IN predicates into larger ones.
These can later be turned into index lookups or node-by-id ops
*/
case object collapseInCollections extends Rewriter {
case object collapseMultipleInPredicates extends Rewriter {

override def apply(that: AnyRef) = instance(that)

Expand All @@ -38,13 +38,13 @@ case object collapseInCollections extends Rewriter {
case predicate@Ors(exprs) =>
// Find all the expressions we want to rewrite
val (const: List[Expression], nonRewritable: List[Expression]) = exprs.toList.partition {
case in@In(_, rhs: Collection) => true
case in@In(_, rhs: ListLiteral) => true
case _ => false
}

// For each expression on the RHS of any IN, produce a InValue place holder
val ins: List[InValue] = const.flatMap {
case In(lhs, rhs: Collection) =>
case In(lhs, rhs: ListLiteral) =>
rhs.expressions.map(expr => InValue(lhs, expr))
}

Expand All @@ -53,7 +53,7 @@ case object collapseInCollections extends Rewriter {
val flattenConst: Iterable[In] = groupedINPredicates.map {
case (lhs, values) =>
val pos = lhs.position
In(lhs, Collection(values.map(_.expr).toSeq)(pos))(pos)
In(lhs, ListLiteral(values.map(_.expr).toSeq)(pos))(pos)
}

// Return the original non-rewritten predicates with our new ones
Expand Down
Expand Up @@ -72,7 +72,7 @@ object literalReplacement {
acc =>
val parameter = ast.Parameter(s" AUTOBOOL${acc.size}", CTBoolean)(l.position)
(acc + (l -> LiteralReplacement(parameter, l.value)), None)
case l: ast.Collection if l.expressions.forall(_.isInstanceOf[Literal])=>
case l: ast.ListLiteral if l.expressions.forall(_.isInstanceOf[Literal])=>
acc =>
val parameter = ast.Parameter(s" AUTOLIST${acc.size}", CTList(CTAny))(l.position)
val values: Seq[AnyRef] = l.expressions.map(_.asInstanceOf[Literal].value)
Expand Down
Expand Up @@ -26,22 +26,22 @@ import org.neo4j.cypher.internal.frontend.v3_0.{Rewriter, bottomUp}
This class rewrites equality predicates into IN comparisons which can then be turned into
either index lookup or node-by-id operations
*/
case object rewriteEqualityToInCollection extends Rewriter {
case object rewriteEqualityToInPredicate extends Rewriter {

override def apply(that: AnyRef) = instance(that)

private val instance: Rewriter = bottomUp(Rewriter.lift {
// id(a) = value => id(a) IN [value]
case predicate@Equals(func@FunctionInvocation(_, _, IndexedSeq(idExpr)), idValueExpr)
if func.function.contains(functions.Id) =>
In(func, Collection(Seq(idValueExpr))(idValueExpr.position))(predicate.position)
In(func, ListLiteral(Seq(idValueExpr))(idValueExpr.position))(predicate.position)

// Equality between two property lookups should not be rewritten
case predicate@Equals(_:Property, _:Property) =>
predicate

// a.prop = value => a.prop IN [value]
case predicate@Equals(prop@Property(id: Variable, propKeyName), idValueExpr) =>
In(prop, Collection(Seq(idValueExpr))(idValueExpr.position))(predicate.position)
In(prop, ListLiteral(Seq(idValueExpr))(idValueExpr.position))(predicate.position)
})
}
Expand Up @@ -22,13 +22,13 @@ package org.neo4j.cypher.internal.compiler.v3_0.commands
import org.neo4j.cypher.internal.compiler.v3_0._
import expressions.{Closure, Expression}
import org.neo4j.cypher.internal.compiler.v3_0.commands.predicates.Predicate
import org.neo4j.cypher.internal.compiler.v3_0.helpers.CollectionSupport
import org.neo4j.cypher.internal.compiler.v3_0.helpers.ListSupport
import pipes.QueryState
import collection.Seq

abstract class InCollection(collectionExpression: Expression, id: String, predicate: Predicate)
abstract class InList(collectionExpression: Expression, id: String, predicate: Predicate)
extends Predicate
with CollectionSupport
with ListSupport
with Closure {

type CollectionPredicate[U] = ((U) => Option[Boolean]) => Option[Boolean]
Expand Down Expand Up @@ -59,8 +59,8 @@ abstract class InCollection(collectionExpression: Expression, id: String, predic
def symbolTableDependencies = symbolTableDependencies(collectionExpression, predicate, id)
}

case class AllInCollection(collection: Expression, symbolName: String, inner: Predicate)
extends InCollection(collection, symbolName, inner) {
case class AllInList(collection: Expression, symbolName: String, inner: Predicate)
extends InList(collection, symbolName, inner) {

private def forAll[U](collectionValue: Seq[U])(predicate: (U => Option[Boolean])): Option[Boolean] = {
var result: Option[Boolean] = Some(true)
Expand All @@ -80,14 +80,14 @@ case class AllInCollection(collection: Expression, symbolName: String, inner: Pr
def name = "all"

def rewrite(f: (Expression) => Expression) =
f(AllInCollection(
f(AllInList(
collection = collection.rewrite(f),
symbolName = symbolName,
inner = inner.rewriteAsPredicate(f)))
}

case class AnyInCollection(collection: Expression, symbolName: String, inner: Predicate)
extends InCollection(collection, symbolName, inner) {
case class AnyInList(collection: Expression, symbolName: String, inner: Predicate)
extends InList(collection, symbolName, inner) {

private def exists[U](collectionValue: Seq[U])(predicate: (U => Option[Boolean])): Option[Boolean] = {
var result: Option[Boolean] = Some(false)
Expand All @@ -108,14 +108,14 @@ case class AnyInCollection(collection: Expression, symbolName: String, inner: Pr
def name = "any"

def rewrite(f: (Expression) => Expression) =
f(AnyInCollection(
f(AnyInList(
collection = collection.rewrite(f),
symbolName = symbolName,
inner = inner.rewriteAsPredicate(f)))
}

case class NoneInCollection(collection: Expression, symbolName: String, inner: Predicate)
extends InCollection(collection, symbolName, inner) {
case class NoneInList(collection: Expression, symbolName: String, inner: Predicate)
extends InList(collection, symbolName, inner) {

private def none[U](collectionValue: Seq[U])(predicate: (U => Option[Boolean])): Option[Boolean] = {
var result: Option[Boolean] = Some(true)
Expand All @@ -136,14 +136,14 @@ case class NoneInCollection(collection: Expression, symbolName: String, inner: P
def name = "none"

def rewrite(f: (Expression) => Expression) =
f(NoneInCollection(
f(NoneInList(
collection = collection.rewrite(f),
symbolName = symbolName,
inner = inner.rewriteAsPredicate(f)))
}

case class SingleInCollection(collection: Expression, symbolName: String, inner: Predicate)
extends InCollection(collection, symbolName, inner) {
case class SingleInList(collection: Expression, symbolName: String, inner: Predicate)
extends InList(collection, symbolName, inner) {

private def single[U](collectionValue: Seq[U])(predicate: (U => Option[Boolean])): Option[Boolean] = {
var matched = false
Expand All @@ -165,7 +165,7 @@ case class SingleInCollection(collection: Expression, symbolName: String, inner:
def name = "single"

def rewrite(f: (Expression) => Expression) =
f(SingleInCollection(
f(SingleInList(
collection = collection.rewrite(f),
symbolName = symbolName,
inner = inner.rewriteAsPredicate(f)))
Expand Down
Expand Up @@ -23,7 +23,7 @@ import org.neo4j.cypher.internal.compiler.v3_0._
import org.neo4j.cypher.internal.compiler.v3_0.commands.expressions.Expression
import org.neo4j.cypher.internal.compiler.v3_0.commands.values.KeyToken
import org.neo4j.cypher.internal.compiler.v3_0.executionplan.{SetLabel, Effect, Effects}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{CastSupport, CollectionSupport}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{CastSupport, ListSupport}
import org.neo4j.cypher.internal.compiler.v3_0.mutation.SetAction
import org.neo4j.cypher.internal.compiler.v3_0.pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_0.symbols.SymbolTable
Expand All @@ -36,7 +36,7 @@ case object LabelRemoveOp extends LabelOp

//TODO: Should take single label
case class LabelAction(entity: Expression, labelOp: LabelOp, labels: Seq[KeyToken])
extends SetAction with CollectionSupport {
extends SetAction with ListSupport {

def localEffects(ignored: SymbolTable) = Effects(labels.map(l => SetLabel(l.name)).toSet[Effect])

Expand Down
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.compiler.v3_0.commands

import org.neo4j.cypher.internal.compiler.v3_0.{Geometry, Point}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{IsCollection, IsMap}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{IsList, IsMap}
import org.neo4j.cypher.internal.compiler.v3_0.spi.QueryContext
import org.neo4j.cypher.internal.frontend.v3_0.CypherTypeException
import org.neo4j.cypher.internal.frontend.v3_0.symbols._
Expand All @@ -45,8 +45,8 @@ object coerce {
case t: ListType => value match {
case p: Path if t.innerType == CTNode => throw cantCoerce(value, typ)
case p: Path if t.innerType == CTRelationship => throw cantCoerce(value, typ)
case IsCollection(coll) if t.innerType == CTAny => coll
case IsCollection(coll) => coll.map(coerce(_, t.innerType))
case IsList(coll) if t.innerType == CTAny => coll
case IsList(coll) => coll.map(coerce(_, t.innerType))
case _ => throw cantCoerce(value, typ)
}
case CTBoolean => value.asInstanceOf[Boolean]
Expand Down
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.cypher.internal.compiler.v3_0.commands.expressions

import org.neo4j.cypher.internal.compiler.v3_0._
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{IsCollection, TypeSafeMathSupport}
import org.neo4j.cypher.internal.compiler.v3_0.helpers.{IsList, TypeSafeMathSupport}
import org.neo4j.cypher.internal.compiler.v3_0.pipes.QueryState
import org.neo4j.cypher.internal.compiler.v3_0.symbols.SymbolTable
import org.neo4j.cypher.internal.frontend.v3_0.CypherTypeException
Expand All @@ -36,9 +36,9 @@ case class Add(a: Expression, b: Expression) extends Expression with TypeSafeMat
case (_, null) => null
case (x: Number, y: Number) => plus(x,y)
case (x: String, y: String) => x + y
case (IsCollection(x), IsCollection(y)) => x ++ y
case (IsCollection(x), y) => x ++ Seq(y)
case (x, IsCollection(y)) => Seq(x) ++ y
case (IsList(x), IsList(y)) => x ++ y
case (IsList(x), y) => x ++ Seq(y)
case (x, IsList(y)) => Seq(x) ++ y
case (x: String, y: Number) => x + y.toString
case (x: Number, y: String) => x.toString + y
case _ => throw new CypherTypeException("Don't know how to add `" + aVal.toString + "` and `" + bVal.toString + "`")
Expand Down

0 comments on commit f3746e7

Please sign in to comment.