diff --git a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFilters.scala b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFilters.scala index 88f23838..99387b7e 100644 --- a/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFilters.scala +++ b/translation/src/main/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFilters.scala @@ -99,6 +99,9 @@ object GroupStepFilters extends GremlinRewriter { case ChooseT3(Seq(Constant(value)), _, _) :: Is(_) :: As(_) :: SelectK(stepLabel) :: ChooseP2(_, Seq(Id)) :: Is(_) :: WhereP( _: Eq) :: Nil => (stepLabel, HasP(T.id.getAccessor, Eq(value))) :: Nil + case ChooseT3(Seq(Constant(value)), _, _) :: Is(_) :: As(_) :: SelectK(stepLabel) :: ChooseP2(_, Seq(Id)) :: Is(_) :: WhereP( + _: Within) :: Nil => + (stepLabel, HasP(T.id.getAccessor, Within(value))) :: Nil case SelectK(stepLabel) :: rest if rest.forall(_.isInstanceOf[HasLabel]) => rest.map((stepLabel, _)) case _ => @@ -116,6 +119,9 @@ object GroupStepFilters extends GremlinRewriter { case ChooseT3(Seq(Constant(_)), _, _) :: Is(_) :: As(_) :: SelectK(alias) :: ChooseP2(_, Seq(Id)) :: Is(_) :: WhereP( _: Eq) :: Nil if aliases.contains(alias) => None + case ChooseT3(Seq(Constant(_)), _, _) :: Is(_) :: As(_) :: SelectK(alias) :: ChooseP2(_, Seq(Id)) :: Is(_) :: WhereP( + _: Within) :: Nil if aliases.contains(alias) => + None case SelectK(alias) :: rest if aliases.contains(alias) && rest.forall(_.isInstanceOf[HasLabel]) => None case other => diff --git a/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFiltersTest.scala b/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFiltersTest.scala index ce86c877..0a893a06 100644 --- a/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFiltersTest.scala +++ b/translation/src/test/scala/org/opencypher/gremlin/translation/ir/rewrite/GroupStepFiltersTest.scala @@ -21,6 +21,7 @@ import org.junit.Test import org.opencypher.gremlin.translation.CypherAst.parse import org.opencypher.gremlin.translation.Tokens import org.opencypher.gremlin.translation.Tokens.{GENERATED, NULL, UNNAMED} +import org.opencypher.gremlin.translation.ir.builder.IRGremlinBindings import org.opencypher.gremlin.translation.ir.helpers.CypherAstAssert.{P, __} import org.opencypher.gremlin.translation.ir.helpers.CypherAstAssertions.assertThat import org.opencypher.gremlin.translation.ir.model.GremlinBinding @@ -283,4 +284,17 @@ class GroupStepFiltersTest { .rewritingWith(GroupStepFilters) .keeps(__.addV().as("a").property(single, "x", __.constant(1))) } + + @Test + def collectionOfParameters(): Unit = { + val ids = new IRGremlinBindings().bind("ids", 1) + + assertThat(parse("MATCH (p:Person) WHERE id(p) in {ids} RETURN p.name")) + .withFlavor(flavor) + .rewritingWith(GroupStepFilters) + .removes(__.choose(__.constant(ids), __.constant(ids), __.constant(NULL))) + .removes(__.where(P.within(GENERATED + "1"))) + .adds(__.has("~id", P.within(ids))) + .debug() + } }