Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[A-star] Added expansion pruning via cutoff if cutoff is provided #7073

Merged
31 changes: 28 additions & 3 deletions networkx/algorithms/shortest_paths/astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@nx._dispatch(edge_attrs="weight", preserve_node_attrs="heuristic")
def astar_path(G, source, target, heuristic=None, weight="weight"):
def astar_path(G, source, target, heuristic=None, weight="weight", *, cutoff=None):
"""Returns a list of nodes in a shortest path between source and target
using the A* ("A-star") algorithm.

Expand Down Expand Up @@ -49,6 +49,15 @@ def astar_path(G, source, target, heuristic=None, weight="weight"):
dictionary of edge attributes for that edge. The function must
return a number or None to indicate a hidden edge.

cutoff : float, optional
If this is provided, the search will be bounded to this value. I.e. if
the evaluation function surpasses this value for a node n, the node will not
be expanded further and will be ignored. More formally, let h'(n) be the
heuristic function, and g(n) be the cost of reaching n from the source node. Then,
if g(n) + h'(n) > cutoff, the node will not be explored further.
Note that if the heuristic is inadmissible, it is possible that paths
are ignored even though they satisfy the cutoff.

Raises
------
NetworkXNoPath
Expand Down Expand Up @@ -152,14 +161,20 @@ def heuristic(u, v):
continue
else:
h = heuristic(neighbor, target)

if cutoff and ncost + h > cutoff:
continue

enqueued[neighbor] = ncost, h
push(queue, (ncost + h, next(c), neighbor, ncost, curnode))

raise nx.NetworkXNoPath(f"Node {target} not reachable from {source}")


@nx._dispatch(edge_attrs="weight", preserve_node_attrs="heuristic")
def astar_path_length(G, source, target, heuristic=None, weight="weight"):
def astar_path_length(
G, source, target, heuristic=None, weight="weight", *, cutoff=None
):
"""Returns the length of the shortest path between source and target using
the A* ("A-star") algorithm.

Expand Down Expand Up @@ -195,6 +210,16 @@ def astar_path_length(G, source, target, heuristic=None, weight="weight"):
positional arguments: the two endpoints of an edge and the
dictionary of edge attributes for that edge. The function must
return a number or None to indicate a hidden edge.

cutoff : float, optional
If this is provided, the search will be bounded to this value. I.e. if
the evaluation function surpasses this value for a node n, the node will not
be expanded further and will be ignored. More formally, let h'(n) be the
heuristic function, and g(n) be the cost of reaching n from the source node. Then,
if g(n) + h'(n) > cutoff, the node will not be explored further.
Note that if the heuristic is inadmissible, it is possible that paths
are ignored even though they satisfy the cutoff.

Raises
------
NetworkXNoPath
Expand All @@ -210,5 +235,5 @@ def astar_path_length(G, source, target, heuristic=None, weight="weight"):
raise nx.NodeNotFound(msg)

weight = _weight_function(G, weight)
path = astar_path(G, source, target, heuristic, weight)
path = astar_path(G, source, target, heuristic, weight, cutoff=cutoff)
return sum(weight(u, v, G[u][v]) for u, v in zip(path[:-1], path[1:]))
38 changes: 38 additions & 0 deletions networkx/algorithms/shortest_paths/tests/test_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,44 @@ def test_astar_nopath(self):
with pytest.raises(nx.NodeNotFound):
nx.astar_path(self.XG, "s", "moon")

def test_astar_cutoff(self):
with pytest.raises(nx.NetworkXNoPath):
# optimal path_length in XG is 9
nx.astar_path(self.XG, "s", "v", cutoff=8.0)
anders-rydbirk marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(nx.NetworkXNoPath):
nx.astar_path_length(self.XG, "s", "v", cutoff=8.0)

def test_astar_admissible_heuristic_with_cutoff(self):
heuristic_values = {"s": 36, "y": 4, "x": 0, "u": 0, "v": 0}

def h(u, v):
return heuristic_values[u]

assert nx.astar_path_length(self.XG, "s", "v") == 9
assert nx.astar_path_length(self.XG, "s", "v", heuristic=h) == 9
assert nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=12) == 9
assert nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=9) == 9
with pytest.raises(nx.NetworkXNoPath):
nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=8)

def test_astar_inadmissible_heuristic_with_cutoff(self):
heuristic_values = {"s": 36, "y": 14, "x": 10, "u": 10, "v": 0}

def h(u, v):
return heuristic_values[u]

# optimal path_length in XG is 9. This heuristic gives over-estimate.
assert nx.astar_path_length(self.XG, "s", "v", heuristic=h) == 10
assert nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=15) == 10
with pytest.raises(nx.NetworkXNoPath):
nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=9)
with pytest.raises(nx.NetworkXNoPath):
nx.astar_path_length(self.XG, "s", "v", heuristic=h, cutoff=12)

def test_astar_cutoff2(self):
assert nx.astar_path(self.XG, "s", "v", cutoff=10.0) == ["s", "x", "u", "v"]
assert nx.astar_path_length(self.XG, "s", "v") == 9

def test_cycle(self):
C = nx.cycle_graph(7)
assert nx.astar_path(C, 0, 3) == [0, 1, 2, 3]
Expand Down