From f3f67ba22603fd004c62a80a119fa72c9c8e1946 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sat, 18 Jan 2025 11:19:15 +0100 Subject: [PATCH 1/3] Add path-based overrides. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sometimes neighbor algorithms depend on where you're coming from. Trains are not able to do 90° turns. Straight roads or rails make for more speed. Thus this patch allows overriding neighbor and cost estimates that take the path-so-far into account. TODO: add testcases ... --- astar/__init__.py | 58 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/astar/__init__.py b/astar/__init__.py index 5c878bc..6b0dee5 100644 --- a/astar/__init__.py +++ b/astar/__init__.py @@ -63,6 +63,13 @@ def remove(self, item: SNType) -> None: self.sortedlist.remove(item) item.in_openset = False + item = self.heap.pop() + if idx < len(self.heap): + self.heap[idx] = item + # Fix heap invariants + heapq._siftup(self.heap, idx) + heapq._siftdown(self.heap, 0, idx) + def __len__(self) -> int: return len(self.sortedlist) @@ -78,29 +85,48 @@ def heuristic_cost_estimate(self, current: T, goal: T) -> float: """ Computes the estimated (rough) distance between a node and the goal. The second parameter is always the goal. + This method must be implemented in a subclass. """ raise NotImplementedError - @abstractmethod def distance_between(self, n1: T, n2: T) -> float: """ Gives the real distance between two adjacent nodes n1 and n2 (i.e n2 belongs to the list of n1's neighbors). n2 is guaranteed to belong to the list returned by the call to neighbors(n1). - This method must be implemented in a subclass. + + This method (or "path_distance_between") must be implemented in a subclass. """ + raise NotImplementedError + + def path_distance_between(self, n1: SearchNode[T], n2: T) -> float: + """ + Gives the real distance between the node n1 and its neighbor n2. + n2 is guaranteed to belong to the list returned by the call to + path_neighbors(n1). + + Calls "distance_between"`by default. + """ + return self.distance_between(n1.data, n2) - @abstractmethod def neighbors(self, node: T) -> Iterable[T]: """ For a given node, returns (or yields) the list of its neighbors. - This method must be implemented in a subclass. + + This method (or "path_neighbors") must be implemented in a subclass. """ raise NotImplementedError + def path_neighbors(self, node: SearchNode[T]) -> Iterable[T]: + """ + For a given node, returns (or yields) the list of its reachable neighbors. + Calls "neighbors" by default. + """ + return self.neighbors(node.data) + def _neighbors(self, current: SearchNode[T], search_nodes: SearchNodeDict[T]) -> Iterable[SearchNode]: - return (search_nodes[n] for n in self.neighbors(current.data)) + return (search_nodes[n] for n in self.path_neighbors(current)) def is_goal_reached(self, current: T, goal: T) -> bool: """ @@ -147,25 +173,29 @@ def astar( if neighbor.closed: continue - tentative_gscore = current.gscore + self.distance_between( - current.data, neighbor.data + gscore = current.gscore + self.path_distance_between( + current, neighbor.data ) - if tentative_gscore >= neighbor.gscore: + if gscore >= neighbor.gscore: continue - neighbor_from_openset = neighbor.in_openset + fscore = gscore + self.heuristic_cost_estimate( + neighbor.data, goal + ) + + if neighbor.in_openset: + if neighbor.fscore < fscore: + # the new path to this node isn't better + continue - if neighbor_from_openset: # we have to remove the item from the heap, as its score has changed openSet.remove(neighbor) # update the node neighbor.came_from = current - neighbor.gscore = tentative_gscore - neighbor.fscore = tentative_gscore + self.heuristic_cost_estimate( - neighbor.data, goal - ) + neighbor.gscore = gscore + neighbor.fscore = fscore openSet.push(neighbor) From ab01fc84b7b8cca96ff4f891568f77ad53b74d78 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sat, 18 Jan 2025 16:53:03 +0100 Subject: [PATCH 2/3] Use both nodes in path_neighbor, add a cache It's more regular and the cache can be used to directly store relevant information in the destination node which otherwise would have to be recalculated. --- astar/__init__.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/astar/__init__.py b/astar/__init__.py index 6b0dee5..8559dc6 100644 --- a/astar/__init__.py +++ b/astar/__init__.py @@ -15,7 +15,7 @@ class SearchNode(Generic[T]): """Representation of a search node""" - __slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset") + __slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset", "cache") def __init__( self, data: T, gscore: float = infinity, fscore: float = infinity @@ -26,6 +26,7 @@ def __init__( self.closed = False self.in_openset = False self.came_from: Union[None, SearchNode[T]] = None + self.cache: Any = None def __lt__(self, b: "SearchNode[T]") -> bool: """Natural order is based on the fscore value & is used by heapq operations""" @@ -63,13 +64,6 @@ def remove(self, item: SNType) -> None: self.sortedlist.remove(item) item.in_openset = False - item = self.heap.pop() - if idx < len(self.heap): - self.heap[idx] = item - # Fix heap invariants - heapq._siftup(self.heap, idx) - heapq._siftdown(self.heap, 0, idx) - def __len__(self) -> int: return len(self.sortedlist) @@ -100,7 +94,7 @@ def distance_between(self, n1: T, n2: T) -> float: """ raise NotImplementedError - def path_distance_between(self, n1: SearchNode[T], n2: T) -> float: + def path_distance_between(self, n1: SearchNode[T], n2: SearchNode[T]) -> float: """ Gives the real distance between the node n1 and its neighbor n2. n2 is guaranteed to belong to the list returned by the call to @@ -108,7 +102,7 @@ def path_distance_between(self, n1: SearchNode[T], n2: T) -> float: Calls "distance_between"`by default. """ - return self.distance_between(n1.data, n2) + return self.distance_between(n1.data, n2.data) def neighbors(self, node: T) -> Iterable[T]: """ @@ -173,9 +167,7 @@ def astar( if neighbor.closed: continue - gscore = current.gscore + self.path_distance_between( - current, neighbor.data - ) + gscore = current.gscore + self.path_distance_between(current, neighbor) if gscore >= neighbor.gscore: continue From ad394616b1efa9e7f24f4c89fda5e751ad7c0b27 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Sat, 18 Jan 2025 16:54:26 +0100 Subject: [PATCH 3/3] README: Document path_* methods and the cache. --- README.rst | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index 47f17bd..a2956ad 100644 --- a/README.rst +++ b/README.rst @@ -34,7 +34,7 @@ The `astar` library only requires the following property from these objects: For the default implementation of `is_goal_reached`, the objects must be comparable for same-ness (i.e. implement `__eq__`). -A simple way to achieve this, is to use simple objects based on strings, +A simple way to achieve this is to use simple objects based on strings, floats, integers, tuples. [`dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) objects declared with `@dataclass(frozen=True)` directly implement `__hash__` @@ -54,7 +54,14 @@ For a given node, returns (or yields) the list of its neighbors. This is the method that one would provide in order to give to the algorithm the description of the graph to use during for computation. -This method must be implemented in a subclass. +Alternately, your override method may be named "path\_neighbors". Instead of +your node, this method receives a "SearchNode" object whose "came_from" +attribute points to the previous node; your node is in its "data" attribute. +You might want to use this if your path is directional, like the track of a +train that can't do 90° turns. + +One of these methods must be implemented in a subclass. + distance\_between ~~~~~~~~~~~~~~~~~ @@ -68,7 +75,14 @@ Gives the real distance/cost between two adjacent nodes n1 and n2 (i.e n2 belongs to the list of n1's neighbors). n2 is guaranteed to belong to the list returned by a call to neighbors(n1). -This method must be implemented in a subclass. +Alternately, you may override "path\_distance\_between". The arguments +will be a "SearchNode", as in "path\_neighbors". You might want to use this +if your distance measure should include the path's attainable speed, the +kind and number of turns on it, or similar. You can use the nodes' "cache" +attributes to store some data, to speed up calculation. + +One of these methods must be implemented in a subclass. + heuristic\_cost\_estimate ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -82,7 +96,7 @@ Computes the estimated (rough) distance/cost between a node and the goal. The first argument is the start node, or any node that have been returned by a call to the neighbors() method. -This method is used to give to the algorithm an hint about the node he +This method is used to give to the algorithm an hint about the node it may try next during search. This method must be implemented in a subclass. @@ -92,7 +106,6 @@ is\_goal\_reached .. code:: py - def is_goal_reached(self, current, goal) This method shall return a truthy value when the goal is 'reached'. By