Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support alias reuse in WHERE patterns #101

Merged
merged 4 commits into from
Jun 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ public void pattern() {
.containsExactlyInAnyOrder("marko", "josh", "peter");
}

@Test
public void reversePattern() {
List<Map<String, Object>> results = submitAndGet(
"MATCH (n:person) " +
"WHERE (:software)<-[:created]-(n) " +
"RETURN n.name"
);

assertThat(results)
.extracting("n.name")
.containsExactlyInAnyOrder("marko", "josh", "peter");
}

/**
* Custom predicate deserialization is not implemented
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,27 @@ import org.opencypher.gremlin.translation.ir.model._
*/
object RemoveUselessSteps extends GremlinRewriter {
override def apply(steps: Seq[GremlinStep]): Seq[GremlinStep] = {
mapTraversals(replace({
// Remove `fold` and `unfold` pairs, since the former is an inverse of the latter.
case Fold :: Unfold :: rest =>
rest
case Unfold :: Fold :: rest =>
rest

// Remove unused projections
case Project(projectKey) :: By(Identity :: Nil, None) :: SelectK(selectKey) :: rest if projectKey == selectKey =>
rest
}))(steps)
mapTraversals(
firstPass
.andThen(secondPass)
)(steps)
}

private val firstPass: Seq[GremlinStep] => Seq[GremlinStep] = replace({
// Remove `fold` and `unfold` pairs, since the former is an inverse of the latter.
case Fold :: Unfold :: rest =>
rest
case Unfold :: Fold :: rest =>
rest

// Remove unused projections
case Project(projectKey) :: By(Identity :: Nil, None) :: SelectK(selectKey) :: rest if projectKey == selectKey =>
rest
})

private val secondPass: Seq[GremlinStep] => Seq[GremlinStep] = replace({
// Remove duplicate `as` steps
case As(stepLabel1) :: As(stepLabel2) :: rest if stepLabel1 == stepLabel2 =>
As(stepLabel1) :: rest
})
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2018 "Neo4j, Inc." [https://neo4j.com]
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opencypher.gremlin.translation.ir.rewrite

import org.opencypher.gremlin.translation.Tokens.NULL
import org.opencypher.gremlin.translation.ir.TraversalHelper._
import org.opencypher.gremlin.translation.ir.model._

/**
* Match patterns that start with renaming an existing alias
* can simply select the existing alias from the traversal.
*/
object SimplifyRenamedAliases extends GremlinRewriter {

override def apply(steps: Seq[GremlinStep]): Seq[GremlinStep] = {
splitAfter({
case MapT(Project(_*) :: _) => true
case Project(_*) => true
case _ => false
})(steps)
.flatMap(rewriteSegment)
}

private def rewriteSegment(steps: Seq[GremlinStep]): Seq[GremlinStep] = {
// Find all step labels of vertex start steps
val startLabels = extract({
case Vertex :: As(asLabel) :: WhereT(SelectK(selectLabel) :: WhereP(Eq(_)) :: Nil) :: _
if asLabel == selectLabel =>
None
case Vertex :: As(stepLabel) :: _ =>
Some(stepLabel)
})(steps).flatten.toSet

mapTraversals(replace({
case Vertex :: As(asLabel) :: WhereT(SelectK(selectLabel) :: WhereP(Eq(eqLabel: String)) :: Nil) :: rest
if asLabel == selectLabel =>
if (startLabels.contains(eqLabel)) {
// Vertex traverser should be non-null
SelectK(eqLabel) :: rest
} else {
SelectK(eqLabel) :: Is(Neq(NULL)) :: rest
}
}))(steps)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ object TranslatorFlavor {
rewriters = Seq(
InlineMapTraversal,
SimplifyPropertySetters,
SimplifyRenamedAliases,
GroupStepFilters,
SimplifySingleProjections,
RemoveImmediateReselect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,17 +409,15 @@ private class ExpressionWalker[T, P](context: StatementContext[T, P], g: Gremlin
projection: Expression): String = {
val select = __
val contextWhere = context.copy()
PatternWalker.walkExpression(contextWhere, select, relationshipChain)
PatternWalker.walk(contextWhere, select, relationshipChain)
maybePredicate.foreach(WhereWalker.walk(contextWhere, select, _))

if (projection.isInstanceOf[PathExpression]) {
select.path()
}

val name = contextWhere.generateName()
g.sideEffect(
select.aggregate(name)
)
g.sideEffect(select.aggregate(name)).barrier()

name
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package org.opencypher.gremlin.translation.walker

import org.opencypher.gremlin.translation.Tokens.NULL
import org.opencypher.gremlin.translation.Tokens._
import org.opencypher.gremlin.translation._
import org.opencypher.gremlin.translation.context.StatementContext
import org.opencypher.gremlin.translation.walker.NodeUtils._
Expand Down Expand Up @@ -72,21 +72,14 @@ private class MatchWalker[T, P](context: StatementContext[T, P], g: GremlinSteps
def walkPatternParts(patternParts: Seq[PatternPart], whereOption: Option[Where]): Unit = {
patternParts.foreach {
case EveryPath(patternElement) =>
foldPatternElement(None, patternElement)
PatternWalker.walk(context, g, patternElement)
case NamedPatternPart(Variable(pathName), EveryPath(patternElement)) =>
foldPatternElement(Some(pathName), patternElement)
PatternWalker.walk(context, g, patternElement, Some(pathName))
g.path().as(pathName)
case n =>
context.unsupported("match pattern", n)
}

whereOption.foreach(WhereWalker.walk(context, g, _))
}

private def foldPatternElement(maybeName: Option[String], patternElement: PatternElement): Unit = {
context.markFirstStatement()
g.V()

PatternWalker.walkMatch(context, g, patternElement, maybeName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,24 @@ import org.opencypher.v9_0.util.InputPosition.NONE
* of match pattern nodes of the Cypher AST.
*/
object PatternWalker {
def walkMatch[T, P](
def walk[T, P](
context: StatementContext[T, P],
g: GremlinSteps[T, P],
node: PatternElement,
pathName: Option[String]): Unit = {
pathName: Option[String] = None): Unit = {
new PatternWalker(context, g).walk(node, pathName)
}

def walkExpression[T, P](context: StatementContext[T, P], g: GremlinSteps[T, P], node: PatternElement): Unit = {
new PatternWalker(context, g).walk(node, pathName = None, selectFirst = true)
}
}

class PatternWalker[T, P](context: StatementContext[T, P], g: GremlinSteps[T, P]) {
def walk(node: PatternElement, pathName: Option[String], selectFirst: Boolean = false): Unit = {
def walk(node: PatternElement, pathName: Option[String]): Unit = {
context.markFirstStatement()
g.V()

val chain = flattenRelationshipChain(node)
var firstNode = true
chain.foreach {
case node: NodePattern =>
walkNode(node, selectFirst && firstNode)
firstNode = false
walkNode(node)
case relationship: RelationshipPattern =>
walkRelationship(pathName, relationship)
case n =>
Expand All @@ -66,15 +63,11 @@ class PatternWalker[T, P](context: StatementContext[T, P], g: GremlinSteps[T, P]
}
}

private def walkNode(node: NodePattern, select: Boolean): Unit = {
private def walkNode(node: NodePattern): Unit = {
val NodePattern(variableOption, labels, properties) = node
val variable @ Variable(name) = variableOption
.getOrElse(Variable(context.generateName())(NONE))
if (select) {
g.select(name)
} else {
asUniqueName(name, g, context)
}
asUniqueName(name, g, context)
g.map(hasLabels(labels))
properties.map(hasProperties(variable, _)).foreach(g.map)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,9 @@ private class ProjectionWalker[T, P](context: StatementContext[T, P], g: Gremlin
for (item <- items) {
val AliasedReturnItem(expression, Variable(alias)) = item

val (_, traversalUnfold) = pivot(alias, expression, finalize)
val (returnType, traversal) = pivot(alias, expression, finalize)

allCollector.put(alias, traversalUnfold)
allCollector.put(alias, traversal)

returnType match {
case Pivot => pivotCollector.put(alias, traversal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ private class WhereWalker[T, P](context: StatementContext[T, P], g: GremlinSteps

case PatternExpression(RelationshipsPattern(relationshipChain)) =>
val traversal = g.start()
PatternWalker.walkExpression(context, traversal, relationshipChain)
PatternWalker.walk(context, traversal, relationshipChain)
traversal

case l: Literal =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ public void adjacentMap() {
))
.withFlavor(TranslatorFlavor.empty())
.rewritingWith(InlineMapTraversal$.MODULE$)
.removes(__().select("n").map(__()).as(" cypher.path.start.GENERATED1").map(__().outE().inV()))
.adds(__().select("n").as(" cypher.path.start.GENERATED1").outE().inV());
.removes(__().map(__()))
.removes(__().as(" cypher.path.start.GENERATED2").map(__().outE().inV()))
.adds(__().as(" cypher.path.start.GENERATED2").outE().inV());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) 2018 "Neo4j, Inc." [https://neo4j.com]
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opencypher.gremlin.translation.ir.rewrite;

import static org.opencypher.gremlin.translation.CypherAstWrapper.parse;
import static org.opencypher.gremlin.translation.Tokens.NULL;
import static org.opencypher.gremlin.translation.helpers.CypherAstAssert.P;
import static org.opencypher.gremlin.translation.helpers.CypherAstAssert.__;
import static org.opencypher.gremlin.translation.helpers.CypherAstAssertions.assertThat;
import static org.opencypher.gremlin.translation.helpers.ScalaHelpers.seq;

import org.junit.Test;
import org.opencypher.gremlin.translation.translator.TranslatorFlavor;

public class SimplifyRenamedAliases {

private final TranslatorFlavor flavor = new TranslatorFlavor(
seq(
InlineMapTraversal$.MODULE$
),
seq()
);

@Test
public void whereMatchPattern() {
assertThat(parse(
"MATCH (n) " +
"WHERE (n)-->(:L) " +
"RETURN n"
))
.withFlavor(flavor)
.rewritingWith(SimplifyRenamedAliases$.MODULE$)
.removes(
__().V().as(" GENERATED1")
.where(__().select(" GENERATED1").where(P.isEq("n")))
.as(" cypher.path.start.GENERATED2"))
.adds(
__().select("n")
.as(" cypher.path.start.GENERATED2")
);
}

@Test
public void whereNonStartMatchPattern() {
assertThat(parse(
"MATCH (n)-->(m) " +
"WHERE (m)-->(:L) " +
"RETURN m"
))
.withFlavor(flavor)
.rewritingWith(SimplifyRenamedAliases$.MODULE$)
.removes(
__().V().as(" GENERATED2")
.where(__().select(" GENERATED2").where(P.isEq("m")))
.as(" cypher.path.start.GENERATED3"))
.adds(
__().select("m").is(P.neq(NULL))
.as(" cypher.path.start.GENERATED3"));
}
}