Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Loading…

Able to pass in InitialBranchState to TraversalAStar #604

Merged
merged 1 commit into from

3 participants

@tinwelint
Collaborator

No description provided.

@jexp
Collaborator

ci-bot please retest

@simpsonjulian
@jexp jexp merged commit f78135f into neo4j:master
@tinwelint tinwelint deleted the tinwelint:stateful-astar branch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
This page is out of date. Refresh to see the latest.
View
29 community/graph-algo/src/main/java/org/neo4j/graphalgo/impl/path/TraversalAStar.java
@@ -20,7 +20,9 @@
package org.neo4j.graphalgo.impl.path;
import static org.neo4j.graphdb.traversal.Evaluators.includeWhereEndNodeIs;
+import static org.neo4j.graphdb.traversal.InitialBranchState.NO_STATE;
import static org.neo4j.helpers.collection.IteratorUtil.firstOrNull;
+import static org.neo4j.kernel.StandardExpander.toPathExpander;
import static org.neo4j.kernel.Traversal.traversal;
import java.util.Iterator;
@@ -33,7 +35,9 @@
import org.neo4j.graphalgo.impl.util.StopAfterWeightIterator;
import org.neo4j.graphdb.Direction;
import org.neo4j.graphdb.Node;
+import org.neo4j.graphdb.PathExpander;
import org.neo4j.graphdb.RelationshipExpander;
+import org.neo4j.graphdb.traversal.InitialBranchState;
import org.neo4j.graphdb.traversal.TraversalBranch;
import org.neo4j.graphdb.traversal.TraversalDescription;
import org.neo4j.graphdb.traversal.TraversalMetadata;
@@ -55,21 +59,36 @@
private final EstimateEvaluator<Double> estimateEvaluator;
- public TraversalAStar( RelationshipExpander expander, CostEvaluator<Double> costEvaluator,
- EstimateEvaluator<Double> estimateEvaluator )
+ @SuppressWarnings( "unchecked" )
+ public <T> TraversalAStar( PathExpander<T> expander,
+ CostEvaluator<Double> costEvaluator, EstimateEvaluator<Double> estimateEvaluator )
+ {
+ this( expander, NO_STATE, costEvaluator, estimateEvaluator );
+ }
+
+ public <T> TraversalAStar( PathExpander<T> expander, InitialBranchState<T> initialState,
+ CostEvaluator<Double> costEvaluator, EstimateEvaluator<Double> estimateEvaluator )
{
- this.traversalDescription = traversal().uniqueness(
- Uniqueness.NONE ).expand( expander );
this.costEvaluator = costEvaluator;
this.estimateEvaluator = estimateEvaluator;
+ this.traversalDescription = traversal().uniqueness( Uniqueness.NONE ).expand( expander, initialState );
+ }
+
+ @SuppressWarnings( "unchecked" )
+ public TraversalAStar( RelationshipExpander expander, CostEvaluator<Double> costEvaluator,
+ EstimateEvaluator<Double> estimateEvaluator )
+ {
+ this( toPathExpander( expander ), costEvaluator, estimateEvaluator );
}
+ @Override
public Iterable<WeightedPath> findAllPaths( Node start, final Node end )
{
lastTraverser = traversalDescription.order(
new SelectorFactory( end ) ).evaluator( includeWhereEndNodeIs( end ) ).traverse( start );
return new Iterable<WeightedPath>()
{
+ @Override
public Iterator<WeightedPath> iterator()
{
return new StopAfterWeightIterator( lastTraverser.iterator(), costEvaluator );
@@ -77,6 +96,7 @@ public TraversalAStar( RelationshipExpander expander, CostEvaluator<Double> cost
};
}
+ @Override
public WeightedPath findSinglePath( Node start, Node end )
{
return firstOrNull( findAllPaths( start, end ) );
@@ -104,6 +124,7 @@ Double f()
return this.estimateH + this.wayLengthG;
}
+ @Override
public int compareTo( PositionData o )
{
return f().compareTo( o.f() );
View
176 community/graph-algo/src/test/java/org/neo4j/graphalgo/path/TestAStar.java
@@ -23,46 +23,92 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import static org.neo4j.graphalgo.CommonEvaluators.doubleCostEvaluator;
+import static org.neo4j.graphdb.Direction.OUTGOING;
+import static org.neo4j.kernel.Traversal.expanderForAllTypes;
+import static org.neo4j.kernel.Traversal.pathExpanderForAllTypes;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.Map;
import java.util.Set;
import org.junit.Ignore;
import org.junit.Test;
-import org.neo4j.graphalgo.CommonEvaluators;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
import org.neo4j.graphalgo.EstimateEvaluator;
import org.neo4j.graphalgo.GraphAlgoFactory;
import org.neo4j.graphalgo.PathFinder;
import org.neo4j.graphalgo.WeightedPath;
+import org.neo4j.graphalgo.impl.path.TraversalAStar;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Path;
+import org.neo4j.graphdb.PathExpander;
import org.neo4j.graphdb.Relationship;
-import org.neo4j.kernel.Traversal;
+import org.neo4j.graphdb.traversal.BranchState;
+import org.neo4j.graphdb.traversal.InitialBranchState;
+import org.neo4j.helpers.collection.MapUtil;
import common.Neo4jAlgoTestCase;
+@RunWith( Parameterized.class )
public class TestAStar extends Neo4jAlgoTestCase
{
- static EstimateEvaluator<Double> ESTIMATE_EVALUATOR = new EstimateEvaluator<Double>()
+ @Test
+ public void wikipediaExample() throws Exception
{
- public Double getCost( Node node, Node goal )
- {
- double dx = (Double) node.getProperty( "x" )
- - (Double) goal.getProperty( "x" );
- double dy = (Double) node.getProperty( "y" )
- - (Double) goal.getProperty( "y" );
- double result = Math.sqrt( Math.pow( dx, 2 ) + Math.pow( dy, 2 ) );
- return result;
- }
- };
+ /* GIVEN
+ *
+ * (start)---2--->(d)
+ * \ \
+ * 1.5 .\
+ * v 3
+ * (a)-\ v
+ * -2-\ (e)
+ * ->(b) \
+ * / \
+ * /-- 2
+ * /-3- v
+ * v --4------->(end)
+ * (c)------/
+ */
+ Node start = graph.makeNode( "start", "x", 0d, "y", 0d );
+ graph.makeNode( "a", "x", 0.3d, "y", 1d );
+ graph.makeNode( "b", "x", 2d, "y", 2d );
+ graph.makeNode( "c", "x", 0d, "y", 3d );
+ graph.makeNode( "d", "x", 2d, "y", 0d );
+ graph.makeNode( "e", "x", 3d, "y", 1.5d );
+ Node end = graph.makeNode( "end", "x", 3.3d, "y", 2.8d );
+ graph.makeEdge( "start", "a", "length", 1.5d );
+ graph.makeEdge( "a", "b", "length", 2d );
+ graph.makeEdge( "b", "c", "length", 3d );
+ graph.makeEdge( "c", "end", "length", 4d );
+ graph.makeEdge( "start", "d", "length", 2d );
+ graph.makeEdge( "d", "e", "length", 3d );
+ graph.makeEdge( "e", "end", "length", 2d );
- private PathFinder<WeightedPath> newFinder()
- {
- return GraphAlgoFactory.aStar( Traversal.expanderForAllTypes(),
- CommonEvaluators.doubleCostEvaluator( "length" ), ESTIMATE_EVALUATOR );
+ // WHEN
+ WeightedPath path = finder.findSinglePath( start, end );
+
+ // THEN
+ assertPathDef( path, "start", "d", "e", "end" );
}
+ /**
+ * <pre>
+ * 01234567
+ * +-------->x A - C: 10
+ * 0|A C A - B: 2 (x2)
+ * 1| B B - C: 6
+ * V
+ * y
+ * </pre>
+ */
@Test
public void testSimplest()
{
@@ -74,15 +120,14 @@ public void testSimplest()
Relationship relBC = graph.makeEdge( "B", "C", "length", 3d );
Relationship relAC = graph.makeEdge( "A", "C", "length", 10d );
- PathFinder<WeightedPath> astar = newFinder();
int counter = 0;
- for ( WeightedPath path : astar.findAllPaths( nodeA, nodeC ) )
+ for ( WeightedPath path : finder.findAllPaths( nodeA, nodeC ) )
{
assertEquals( (Double)5d, (Double)path.weight() );
assertPath( path, nodeA, nodeB, nodeC );
counter++;
}
-// assertEquals( 2, counter );
+ assertEquals( 1, counter );
}
/**
@@ -108,8 +153,7 @@ public void canGetMultiplePathsInTriangleGraph() throws Exception
Relationship expectedSecond = graph.makeEdge( "B", "C", "length", 6d );
graph.makeEdge( "A", "C", "length", 10d );
- PathFinder<WeightedPath> algo = newFinder();
- Iterator<WeightedPath> paths = algo.findAllPaths( nodeA, nodeC ).iterator();
+ Iterator<WeightedPath> paths = finder.findAllPaths( nodeA, nodeC ).iterator();
for ( int foundCount = 0; foundCount < 2; foundCount++ )
{
assertTrue( "expected more paths (found: " + foundCount + ")", paths.hasNext() );
@@ -163,13 +207,11 @@ public void canGetMultiplePathsInASmallRoadNetwork() throws Exception
graph.makeEdge( "C", "F", "length", 12d );
graph.makeEdge( "A", "F", "length", 25d );
- PathFinder<WeightedPath> algo = newFinder();
-
// Try the search in both directions.
for ( Node[] nodes : new Node[][] { { nodeA, nodeF }, { nodeF, nodeA } } )
{
int found = 0;
- Iterator<WeightedPath> paths = algo.findAllPaths( nodes[0], nodes[1] ).iterator();
+ Iterator<WeightedPath> paths = finder.findAllPaths( nodes[0], nodes[1] ).iterator();
for ( int foundCount = 0; foundCount < 2; foundCount++ )
{
assertTrue( "expected more paths (found: " + foundCount + ")", paths.hasNext() );
@@ -192,4 +234,88 @@ else if ( path.length() != found && path.length() == 4 )
assertFalse( "expected at most two paths", paths.hasNext() );
}
}
+
+ @SuppressWarnings( { "rawtypes", "unchecked" } )
+ @Test
+ public void canUseBranchState() throws Exception
+ {
+ // This test doesn't use the predefined finder, which only means an unnecessary instantiation
+ // if such an object. And this test will be run twice (once for each finder type in data()).
+
+ Node nodeA = graph.makeNode( "A", "x", 0d, "y", 0d );
+ Node nodeB = graph.makeNode( "B", "x", 2d, "y", 1d );
+ Node nodeC = graph.makeNode( "C", "x", 7d, "y", 0d );
+ graph.makeEdge( "A", "B", "length", 2d );
+ graph.makeEdge( "A", "B", "length", 2d );
+ graph.makeEdge( "B", "C", "length", 3d );
+ graph.makeEdge( "A", "C", "length", 10d );
+
+ final Map<Node, Double> seenBranchStates = new HashMap<Node, Double>();
+ PathExpander<Double> expander = new PathExpander<Double>()
+ {
+ @Override
+ public Iterable<Relationship> expand( Path path, BranchState<Double> state )
+ {
+ double newState = state.getState();
+ if ( path.length() > 0 )
+ {
+ newState += (Double) path.lastRelationship().getProperty( "length" );
+ state.setState( newState );
+ }
+ seenBranchStates.put( path.endNode(), newState );
+
+ return path.endNode().getRelationships( OUTGOING );
+ }
+
+ @Override
+ public PathExpander<Double> reverse()
+ {
+ throw new UnsupportedOperationException();
+ }
+ };
+
+ double initialStateValue = 0D;
+ PathFinder<WeightedPath> traversalFinder = new TraversalAStar( expander,
+ new InitialBranchState.State( initialStateValue, initialStateValue ),
+ doubleCostEvaluator( "length" ), ESTIMATE_EVALUATOR );
+ WeightedPath path = traversalFinder.findSinglePath( nodeA, nodeC );
+ assertEquals( (Double) 5.0D, (Double) path.weight() );
+ assertPathDef( path, "A", "B", "C" );
+ assertEquals( MapUtil.<Node,Double>genericMap( nodeA, 0D, nodeB, 2D, nodeC, 5D ), seenBranchStates );
+ }
+
+ static EstimateEvaluator<Double> ESTIMATE_EVALUATOR = new EstimateEvaluator<Double>()
+ {
+ @Override
+ public Double getCost( Node node, Node goal )
+ {
+ double dx = (Double) node.getProperty( "x" )
+ - (Double) goal.getProperty( "x" );
+ double dy = (Double) node.getProperty( "y" )
+ - (Double) goal.getProperty( "y" );
+ double result = Math.sqrt( Math.pow( dx, 2 ) + Math.pow( dy, 2 ) );
+ return result;
+ }
+ };
+
+ @Parameters
+ public static Collection<Object[]> data()
+ {
+ return Arrays.asList( new Object[][]
+ {
+ {
+ GraphAlgoFactory.aStar( expanderForAllTypes(), doubleCostEvaluator( "length" ), ESTIMATE_EVALUATOR )
+ },
+ {
+ new TraversalAStar( pathExpanderForAllTypes(), doubleCostEvaluator( "length" ), ESTIMATE_EVALUATOR )
+ }
+ } );
+ }
+
+ private final PathFinder<WeightedPath> finder;
+
+ public TestAStar( PathFinder<WeightedPath> finder )
+ {
+ this.finder = finder;
+ }
}
Something went wrong with that request. Please try again.