From 2a501828c95b795dc8793dcb2e38f6a0a67f5451 Mon Sep 17 00:00:00 2001 From: Petra Selmer Date: Wed, 2 Dec 2015 14:27:19 +0000 Subject: [PATCH] ALL and NONE predicates evaluated during shortest path traversal so that branches are pruned earlier --- .../ShortestPathAcceptanceTest.scala | 172 +++++++++++++++++- .../expressions/ShortestPathExpression.scala | 81 ++++++--- .../graphalgo/impl/path/ShortestPath.java | 49 ++--- .../graphalgo/impl/path/TestShortestPath.java | 24 +-- 4 files changed, 267 insertions(+), 59 deletions(-) diff --git a/community/cypher/acceptance/src/test/scala/org/neo4j/internal/cypher/acceptance/ShortestPathAcceptanceTest.scala b/community/cypher/acceptance/src/test/scala/org/neo4j/internal/cypher/acceptance/ShortestPathAcceptanceTest.scala index d89442cfa5201..6bebe23ffe96f 100644 --- a/community/cypher/acceptance/src/test/scala/org/neo4j/internal/cypher/acceptance/ShortestPathAcceptanceTest.scala +++ b/community/cypher/acceptance/src/test/scala/org/neo4j/internal/cypher/acceptance/ShortestPathAcceptanceTest.scala @@ -357,7 +357,42 @@ class ShortestPathAcceptanceTest extends ExecutionEngineFunSuite with NewPlanner result.toList should equal(List(Map("nodes" -> List(nodes("source"), nodes("node3"), nodes("node4"), nodes("target"))))) } - test("shortest path should work with multiple expressions and predicates") { + test("shortest path should work with predicates that can be applied to node expanders") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:D) + |MATCH p = shortestPath((a)-[rs:REL*]->(b)) + |WHERE ALL(n in nodes(p) WHERE NOT exists(n.blocked)) + |RETURN nodes(p) as nodes + """.stripMargin + + val result = executeWithAllPlanners(query) + + result.toList should equal(List(Map("nodes" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy"))))) + } + + test("shortest path should work with predicates that can be applied to both relationship and node expanders") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:D) + |MATCH p = shortestPath((a)-[rs:REL*]->(b)) + |WHERE ALL(n in nodes(p) WHERE exists(n.name) OR exists(n.age)) + |AND ALL(r in rels(p) WHERE r.likesLevel > 10) + |RETURN nodes(p) as nodes + """.stripMargin + + val result = executeWithAllPlanners(query) + + result.toList should equal(List( + Map("nodes" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy"))), + Map("nodes" -> List(nodes("Mickey"), nodes("Minnie"), nodes("Daisy"))), + Map("nodes" -> List(nodes("Minnie"), nodes("Daisy"))) + )) + } + + test("shortest path should work with multiple expressions and predicates - relationship expander") { val nodes = shortestPathModel() val query = """PROFILE CYPHER @@ -371,10 +406,45 @@ class ShortestPathAcceptanceTest extends ExecutionEngineFunSuite with NewPlanner val result = executeWithCostPlannerOnly(query) result.toList should equal(List(Map("nodes1" -> List(nodes("source"), nodes("node3"), nodes("node4"), nodes("target")), - "nodes2" -> List(nodes("source"), nodes("target"))))) + "nodes2" -> List(nodes("source"), nodes("target"))))) + } + + test("shortest path should work with multiple expressions and predicates - node expander") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:D) + |MATCH p1 = shortestPath((a)-[rs1:REL*]->(b)) + |MATCH p2 = shortestPath((a)-[rs2:REL*]->(b)) + |WHERE ALL(n in nodes(p2) WHERE exists(n.name) or n.age > 50) + |RETURN nodes(p1) AS nodes1, nodes(p2) as nodes2 + """.stripMargin + + val result = executeWithCostPlannerOnly(query) + + result.toList should equal(List(Map("nodes1" -> List(nodes("Donald"), nodes("Goofy"), nodes("Daisy")), + "nodes2" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy"))))) } - test("shortest path should work with predicates that depend on the path expression") { + test("shortest path should work with multiple expressions and predicates - relationship and node expander") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:D) + |MATCH p1 = shortestPath((a)-[rs1:REL*]->(b)) + |MATCH p2 = shortestPath((a)-[rs2:REL*]->(b)) + |WHERE ALL(n in nodes(p2) WHERE exists(n.name) or n.age > 50) + |AND NONE(r in rels(p1) WHERE exists(r.blocked) OR NOT exists(r.likesLevel)) + |RETURN nodes(p1) AS nodes1, nodes(p2) as nodes2 + """.stripMargin + + val result = executeWithCostPlannerOnly(query) + + result.toList should equal(List(Map("nodes1" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy")), + "nodes2" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy"))))) + } + + test("shortest path should work with predicates that depend on the path expression (relationships)") { val nodes = shortestPathModel() val query = """PROFILE CYPHER @@ -389,13 +459,48 @@ class ShortestPathAcceptanceTest extends ExecutionEngineFunSuite with NewPlanner result.toList should equal(List(Map("nodes" -> List(nodes("source"), nodes("node3"), nodes("node4"), nodes("target"))))) } + test("shortest path should work with predicates that depend on the path expression (nodes)") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:A) + |MATCH p = shortestPath((a)-[r:REL*]->(b)) + |WHERE ALL(n in nodes(p) WHERE labels(n) = labels(nodes(p)[0]) AND exists(n.age)) + |RETURN nodes(p) as nodes + """.stripMargin + + val result = executeWithAllPlanners(query) + + result.toList should equal(List(Map("nodes" -> List(nodes("Donald"), nodes("Mickey"))), + Map("nodes" -> List(nodes("Donald"), nodes("Mickey"), nodes("Minnie"))), + Map("nodes" -> List(nodes("Mickey"), nodes("Minnie"))))) + } + + test("shortest path should work with predicates that depend on the path expression (relationships and nodes)") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:A) + |MATCH p = shortestPath((a)-[r:REL*]->(b)) + |WHERE ALL(n in nodes(p) WHERE labels(n) = labels(nodes(p)[0]) AND exists(n.age)) + |AND ALL(r in rels(p) WHERE type(r) = type(rels(p)[0]) AND exists(r.likesLevel)) + |RETURN nodes(p) as nodes + """.stripMargin + + val result = executeWithAllPlanners(query) + + result.toList should equal(List(Map("nodes" -> List(nodes("Donald"), nodes("Mickey"))), + Map("nodes" -> List(nodes("Donald"), nodes("Mickey"), nodes("Minnie"))), + Map("nodes" -> List(nodes("Mickey"), nodes("Minnie"))))) + } + test("shortest path should work with predicates that can be applied to relationship expanders and include dependencies on execution context") { val nodes = shortestPathModel() val query = """PROFILE CYPHER |MATCH (a:X), (b:Y) |MATCH p = shortestPath((a)-[rs:REL*]->(b)) - |WHERE ALL(r in rels(p) WHERE NOT exists(r.blocked) AND a:X) AND NOT has(b.property) + |WHERE ALL(r in rels(p) WHERE NOT exists(r.blocked) AND a:X) AND NOT exists(b.property) |RETURN nodes(p) AS nodes """.stripMargin @@ -404,6 +509,38 @@ class ShortestPathAcceptanceTest extends ExecutionEngineFunSuite with NewPlanner result.toList should equal(List(Map("nodes" -> List(nodes("source"), nodes("node3"), nodes("node4"), nodes("target"))))) } + test("shortest path should work with predicates that can be applied to node expanders and include dependencies on execution context") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:D) + |MATCH p = shortestPath((a)-[rs:REL*]->(b)) + |WHERE ALL(n in nodes(p) WHERE exists(n.name) AND a.name = 'Donald Duck') AND b.name = 'Daisy Duck' + |RETURN nodes(p) AS nodes + """.stripMargin + + val result = executeWithAllPlanners(query) + + result.toList should equal(List(Map("nodes" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy"))))) + } + + test("shortest path should work with predicates that can be applied to relationship and node expanders and include dependencies on execution context") { + val nodes = largerShortestPathModel() + + val query = """PROFILE CYPHER + |MATCH (a:A), (b:D) + |MATCH p = shortestPath((a)-[rs:REL*]->(b)) + |WHERE ALL(n in nodes(p) WHERE (exists(n.name) OR exists(n.age)) AND a.name = 'Donald Duck') + |AND ALL(r in rels(p) WHERE r.likesLevel > 10) + |AND b.name = 'Daisy Duck' + |RETURN nodes(p) AS nodes + """.stripMargin + + val result = executeWithAllPlanners(query) + + result.toList should equal(List(Map("nodes" -> List(nodes("Donald"), nodes("Huey"), nodes("Dewey"), nodes("Louie"), nodes("Daisy"))))) + } + test("shortest path should work with predicates that reference shortestPath relationship identifier") { val nodes = shortestPathModel() @@ -467,4 +604,31 @@ class ShortestPathAcceptanceTest extends ExecutionEngineFunSuite with NewPlanner nodes } + + def largerShortestPathModel(): Map[String, Node] = { + val nodes = Map[String, Node]( + "Donald" -> createLabeledNode(Map("id" -> "Donald", "name" -> "Donald Duck", "age" -> 15), "A"), + "Daisy" -> createLabeledNode(Map("id" -> "Daisy", "name" -> "Daisy Duck"), "D"), + "Huey" -> createLabeledNode(Map("id" -> "Huey", "name" -> "Huey Duck"), "B"), + "Dewey" -> createLabeledNode(Map("id" -> "Dewey", "name" -> "Dewey Duck"), "B"), + "Louie" -> createLabeledNode(Map("id" -> "Louie", "name" -> "Louie Duck"), "B"), + "Goofy" -> createLabeledNode(Map("id" -> "Goofy", "blocked" -> true), "C"), + "Mickey" -> createLabeledNode(Map("id" -> "Mickey", "age" -> 10), "A"), + "Minnie" -> createLabeledNode(Map("id" -> "Minnie", "age" -> 20, "blocked" -> true), "A"), + "Pluto" -> createLabeledNode(Map("id" -> "Pluto", "age" -> 2), "E")) + + relate(nodes("Donald"), nodes("Goofy"), "REL", Map("blocked" -> true)) + relate(nodes("Donald"), nodes("Huey"), "REL", Map("likesLevel" -> 20)) + relate(nodes("Huey"), nodes("Dewey"), "REL", Map("likesLevel" -> 11)) + relate(nodes("Dewey"), nodes("Louie"), "REL", Map("likesLevel" -> 13)) + relate(nodes("Louie"), nodes("Daisy"), "REL", Map("likesLevel" -> 26)) + relate(nodes("Goofy"), nodes("Daisy"), "REL", Map("likesLevel" -> 45)) + relate(nodes("Donald"), nodes("Mickey"), "REL", Map("blocked" -> true, "likesLevel" -> 2)) + relate(nodes("Mickey"), nodes("Minnie"), "REL", Map("likesLevel" -> 25)) + relate(nodes("Minnie"), nodes("Daisy"), "REL", Map("likesLevel" -> 20)) + relate(nodes("Donald"), nodes("Pluto")) + relate(nodes("Pluto"), nodes("Minnie")) + + nodes + } } diff --git a/community/cypher/cypher-compiler-2.3/src/main/scala/org/neo4j/cypher/internal/compiler/v2_3/commands/expressions/ShortestPathExpression.scala b/community/cypher/cypher-compiler-2.3/src/main/scala/org/neo4j/cypher/internal/compiler/v2_3/commands/expressions/ShortestPathExpression.scala index 214ee63c21505..b243eccf3da54 100644 --- a/community/cypher/cypher-compiler-2.3/src/main/scala/org/neo4j/cypher/internal/compiler/v2_3/commands/expressions/ShortestPathExpression.scala +++ b/community/cypher/cypher-compiler-2.3/src/main/scala/org/neo4j/cypher/internal/compiler/v2_3/commands/expressions/ShortestPathExpression.scala @@ -32,13 +32,14 @@ import org.neo4j.cypher.internal.frontend.v2_3.symbols._ import org.neo4j.graphalgo.GraphAlgoFactory import org.neo4j.graphalgo.impl.path.ShortestPath.ShortestPathPredicate import org.neo4j.graphdb._ +import org.neo4j.function.{Predicate => KernelPredicate} import org.neo4j.kernel.Traversal import scala.collection.JavaConverters._ import scala.collection.Map case class ShortestPathExpression(shortestPathPattern: ShortestPath, predicates: Seq[Predicate] = Seq.empty) extends Expression with PathExtractor { - val pathPattern:Seq[Pattern] = Seq(shortestPathPattern) + val pathPattern: Seq[Pattern] = Seq(shortestPathPattern) val pathIdentifiers = Set(shortestPathPattern.pathName, shortestPathPattern.relIterator.getOrElse("")) @@ -53,10 +54,10 @@ case class ShortestPathExpression(shortestPathPattern: ShortestPath, predicates: private def getMatches(ctx: ExecutionContext)(implicit state: QueryState): Any = { val start = getEndPoint(ctx, shortestPathPattern.left) val end = getEndPoint(ctx, shortestPathPattern.right) - val expander: Expander = addPredicates(ctx, makeRelationshipTypeExpander()) + val (expander, nodePredicates) = addPredicates(ctx, makeRelationshipTypeExpander()) val shortestPathPredicate = createShortestPathPredicate(ctx, predicates) val shortestPathStrategy = if (shortestPathPattern.single) - new SingleShortestPathStrategy(expander, shortestPathPattern.allowZeroLength, shortestPathPattern.maxDepth.getOrElse(Int.MaxValue), shortestPathPredicate) + new SingleShortestPathStrategy(expander, shortestPathPattern.allowZeroLength, shortestPathPattern.maxDepth.getOrElse(Int.MaxValue), shortestPathPredicate, nodePredicates) else new AllShortestPathsStrategy(expander, shortestPathPattern.allowZeroLength, shortestPathPattern.maxDepth.getOrElse(Int.MaxValue), shortestPathPredicate) @@ -67,7 +68,8 @@ case class ShortestPathExpression(shortestPathPattern: ShortestPath, predicates: * accepting or disqualifying it as appropriate. */ private def createShortestPathPredicate(incomingCtx: ExecutionContext, predicates: Seq[Predicate])(implicit state: QueryState): ShortestPathPredicate = new ShortestPathPredicate { - override def test(path: Path): Boolean = if (predicates.isEmpty) true else { + override def test(path: Path): Boolean = if (predicates.isEmpty) true + else { incomingCtx += shortestPathPattern.pathName -> path incomingCtx += shortestPathPattern.relIterator.get -> path.relationships() @@ -88,29 +90,29 @@ case class ShortestPathExpression(shortestPathPattern: ShortestPath, predicates: def rewrite(f: (Expression) => Expression): Expression = f(ShortestPathExpression(shortestPathPattern.rewrite(f))) - def calculateType(symbols: SymbolTable) = if (shortestPathPattern.single) CTPath else CTCollection(CTPath) + def calculateType(symbols: SymbolTable) = if (shortestPathPattern.single) CTPath else CTCollection(CTPath) def symbolTableDependencies = shortestPathPattern.symbolTableDependencies + shortestPathPattern.left.name + shortestPathPattern.right.name - private def propertyExistsExpander(name: String) = new org.neo4j.function.Predicate[PropertyContainer] { + private def propertyExistsExpander(name: String) = new KernelPredicate[PropertyContainer] { override def test(t: PropertyContainer): Boolean = { t.hasProperty(name) } } - private def propertyNotExistsExpander(name: String) = new org.neo4j.function.Predicate[PropertyContainer] { + private def propertyNotExistsExpander(name: String) = new KernelPredicate[PropertyContainer] { override def test(t: PropertyContainer): Boolean = { !t.hasProperty(name) } } - private def cypherPositivePredicatesAsExpander(incomingCtx: ExecutionContext, name: String, predicate: Predicate)(implicit state: QueryState) = new org.neo4j.function.Predicate[PropertyContainer] { + private def cypherPositivePredicatesAsExpander(incomingCtx: ExecutionContext, name: String, predicate: Predicate)(implicit state: QueryState) = new KernelPredicate[PropertyContainer] { override def test(t: PropertyContainer): Boolean = { predicate.isTrue(incomingCtx += (name -> t)) } } - private def cypherNegativePredicatesAsExpander(incomingCtx: ExecutionContext, name: String, predicate: Predicate)(implicit state: QueryState) = new org.neo4j.function.Predicate[PropertyContainer] { + private def cypherNegativePredicatesAsExpander(incomingCtx: ExecutionContext, name: String, predicate: Predicate)(implicit state: QueryState) = new KernelPredicate[PropertyContainer] { override def test(t: PropertyContainer): Boolean = { !predicate.isTrue(incomingCtx += (name -> t)) } @@ -132,6 +134,24 @@ case class ShortestPathExpression(shortestPathPattern: ShortestPath, predicates: } } + private def addAllOrNoneNodeExpander(ctx: ExecutionContext, currentExpander: Expander, all: Boolean, + predicate: Predicate, relName: String, + currentNodePredicates: Seq[KernelPredicate[PropertyContainer]]) + (implicit state: QueryState): (Expander, Seq[KernelPredicate[PropertyContainer]]) = { + val filter = predicate match { + case PropertyExists(_, propertyKey) => + if (all) propertyExistsExpander(propertyKey.name) + else propertyNotExistsExpander(propertyKey.name) + case Not(PropertyExists(_, propertyKey)) => + if (all) propertyNotExistsExpander(propertyKey.name) + else propertyExistsExpander(propertyKey.name) + case _ => + if (all) cypherPositivePredicatesAsExpander(ctx, relName, predicate) + else cypherNegativePredicatesAsExpander(ctx, relName, predicate) + } + (currentExpander.addNodeFilter(filter), currentNodePredicates :+ filter) + } + private def makeRelationshipTypeExpander() = if (shortestPathPattern.relTypes.isEmpty) { Traversal.expanderForAllTypes(toGraphDb(shortestPathPattern.dir)) } else { @@ -140,18 +160,24 @@ case class ShortestPathExpression(shortestPathPattern: ShortestPath, predicates: } } - private def addPredicates(ctx: ExecutionContext, relTypeAndDirExpander: Expander)(implicit state: QueryState): Expander = if (predicates.isEmpty) relTypeAndDirExpander - else - predicates.foldLeft(relTypeAndDirExpander) { - case (currentExpander, predicate) => - predicate match { - case NoneInCollection(RelationshipFunction(_), symbolName, innerPredicate) if doesNotDependOnFullPath(innerPredicate) => - addAllOrNoneRelationshipExpander(ctx, currentExpander, all = false, innerPredicate, symbolName) - case AllInCollection(RelationshipFunction(_), symbolName, innerPredicate) if doesNotDependOnFullPath(innerPredicate) => - addAllOrNoneRelationshipExpander(ctx, currentExpander, all = true, innerPredicate, symbolName) - case _ => currentExpander - } - } + private def addPredicates(ctx: ExecutionContext, relTypeAndDirExpander: Expander)(implicit state: QueryState): + (Expander, Seq[KernelPredicate[PropertyContainer]]) = + if (predicates.isEmpty) (relTypeAndDirExpander, Seq()) + else + predicates.foldLeft((relTypeAndDirExpander, Seq[KernelPredicate[PropertyContainer]]())) { + case ((currentExpander, currentNodePredicates: Seq[KernelPredicate[PropertyContainer]]), predicate) => + predicate match { + case NoneInCollection(RelationshipFunction(_), symbolName, innerPredicate) if doesNotDependOnFullPath(innerPredicate) => + (addAllOrNoneRelationshipExpander(ctx, currentExpander, all = false, innerPredicate, symbolName), currentNodePredicates) + case AllInCollection(RelationshipFunction(_), symbolName, innerPredicate) if doesNotDependOnFullPath(innerPredicate) => + (addAllOrNoneRelationshipExpander(ctx, currentExpander, all = true, innerPredicate, symbolName), currentNodePredicates) + case NoneInCollection(NodesFunction(_), symbolName, innerPredicate) if doesNotDependOnFullPath(innerPredicate) => + addAllOrNoneNodeExpander(ctx, currentExpander, all = false, innerPredicate, symbolName, currentNodePredicates) + case AllInCollection(NodesFunction(_), symbolName, innerPredicate) if doesNotDependOnFullPath(innerPredicate) => + addAllOrNoneNodeExpander(ctx, currentExpander, all = true, innerPredicate, symbolName, currentNodePredicates) + case _ => (currentExpander, currentNodePredicates) + } + } private def doesNotDependOnFullPath(predicate: Predicate): Boolean = { (predicate.symbolTableDependencies intersect pathIdentifiers).isEmpty @@ -164,8 +190,17 @@ trait ShortestPathStrategy { def findResult(start: Node, end: Node): Any } -class SingleShortestPathStrategy(expander: Expander, allowZeroLength: Boolean, depth: Int, predicate: ShortestPathPredicate) extends ShortestPathStrategy { - private val finder = GraphAlgoFactory.shortestPath(expander, depth, predicate) +class SingleShortestPathStrategy(expander: Expander, allowZeroLength: Boolean, depth: Int, predicate: ShortestPathPredicate, + filters: Seq[KernelPredicate[PropertyContainer]]) extends ShortestPathStrategy { + private val finder = new org.neo4j.graphalgo.impl.path.ShortestPath(depth, expander, predicate) { + protected override def filterNextLevelNodes(nextNode: Node): Node = { + if (filters.isEmpty) + nextNode + else + if (filters.forall(filter => filter test nextNode)) nextNode + else null + } + } def findResult(start: Node, end: Node): Path = { val result = finder.findSinglePath(start, end) diff --git a/community/graph-algo/src/main/java/org/neo4j/graphalgo/impl/path/ShortestPath.java b/community/graph-algo/src/main/java/org/neo4j/graphalgo/impl/path/ShortestPath.java index 623f2ea2a8c6d..0fce2164a4626 100644 --- a/community/graph-algo/src/main/java/org/neo4j/graphalgo/impl/path/ShortestPath.java +++ b/community/graph-algo/src/main/java/org/neo4j/graphalgo/impl/path/ShortestPath.java @@ -316,7 +316,7 @@ private class DirectionData extends PrefetchingIterator private void prepareNextLevel() { - Collection nodesToIterate = new ArrayList( filterNextLevelNodes( this.nextNodes ) ); + Collection nodesToIterate = new ArrayList<>( this.nextNodes ); this.nextNodes.clear(); this.lastPath.setLength( currentDepth ); this.nextRelationships = new NestingIterator( nodesToIterate.iterator() ) @@ -342,27 +342,32 @@ protected Node fetchNextOrNull() { return null; } - lastMetadata.rels++; Node result = nextRel.getOtherNode( this.lastPath.endNode() ); - LevelData levelData = this.visitedNodes.get( result ); - boolean createdLevelData = false; - if ( levelData == null ) - { - levelData = new LevelData( nextRel, this.currentDepth ); - this.visitedNodes.put( result, levelData ); - createdLevelData = true; - } - if ( this.currentDepth == levelData.depth && !createdLevelData ) - { - levelData.addRel( nextRel ); - } - // Was this level data created right now, i.e. have we visited this node before? - // In that case don't add it as next node to traverse - if ( createdLevelData ) + + if ( filterNextLevelNodes( result ) != null ) { - this.nextNodes.add( result ); - return result; + lastMetadata.rels++; + + LevelData levelData = this.visitedNodes.get( result ); + boolean createdLevelData = false; + if ( levelData == null ) + { + levelData = new LevelData( nextRel, this.currentDepth ); + this.visitedNodes.put( result, levelData ); + createdLevelData = true; + } + if ( this.currentDepth == levelData.depth && !createdLevelData ) + { + levelData.addRel( nextRel ); + } + // Was this level data created right now, i.e. have we visited this node before? + // In that case don't add it as next node to traverse + if ( createdLevelData ) + { + this.nextNodes.add( result ); + return result; + } } } } @@ -475,9 +480,11 @@ public Iterator iterator() } } - protected Collection filterNextLevelNodes( Collection nextNodes ) + protected Node filterNextLevelNodes( Node nextNode ) { - return nextNodes; + // We need to be able to override this method from Cypher, so it must exist in this concrete class. + // And we also need it to do nothing but still work when not overridden. + return nextNode; } // Many long-lived instances diff --git a/community/graph-algo/src/test/java/org/neo4j/graphalgo/impl/path/TestShortestPath.java b/community/graph-algo/src/test/java/org/neo4j/graphalgo/impl/path/TestShortestPath.java index bd2e0daf6ad5d..96200600c044f 100644 --- a/community/graph-algo/src/test/java/org/neo4j/graphalgo/impl/path/TestShortestPath.java +++ b/community/graph-algo/src/test/java/org/neo4j/graphalgo/impl/path/TestShortestPath.java @@ -24,8 +24,8 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Set; @@ -48,12 +48,11 @@ import org.neo4j.kernel.Traversal; import static common.Neo4jAlgoTestCase.MyRelTypes.R1; -import static common.SimpleGraphBuilder.KEY_ID; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import static java.util.Arrays.asList; @@ -340,19 +339,22 @@ public void makeSureDescentStopsWhenPathIsFound() throws Exception final PathFinder finder = new ShortestPath( 100, Traversal.expanderForAllTypes( Direction.OUTGOING ) ) { @Override - protected Collection filterNextLevelNodes( Collection nextNodes ) + protected Node filterNextLevelNodes( Node nextNode ) { - for ( final Node node : nextNodes ) + if ( !allowedNodes.contains( nextNode ) ) { - if ( !allowedNodes.contains( node ) ) - { - fail( "Node " + node.getProperty( KEY_ID ) + " shouldn't be expanded" ); - } + return null; } - return nextNodes; + return nextNode; } }; - finder.findAllPaths( a, c ); + Iterator paths = finder.findAllPaths( a, c ).iterator(); + for ( int i = 0; i < 4; i++ ) + { + Path aToBToC = paths.next(); + assertPath( aToBToC, a, b, c ); + } + assertFalse( "should only have contained four paths", paths.hasNext() ); } @Test