Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The `astar` library only requires the following property from these objects:
For the default implementation of `is_goal_reached`, the objects must be
comparable for same-ness (i.e. implement `__eq__`).

A simple way to achieve this, is to use simple objects based on strings,
A simple way to achieve this is to use simple objects based on strings,
floats, integers, tuples.
[`dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
objects declared with `@dataclass(frozen=True)` directly implement `__hash__`
Expand All @@ -54,7 +54,14 @@ For a given node, returns (or yields) the list of its neighbors.
This is the method that one would provide in order to give to the
algorithm the description of the graph to use during for computation.

This method must be implemented in a subclass.
Alternately, your override method may be named "path\_neighbors". Instead of
your node, this method receives a "SearchNode" object whose "came_from"
attribute points to the previous node; your node is in its "data" attribute.
You might want to use this if your path is directional, like the track of a
train that can't do 90° turns.

One of these methods must be implemented in a subclass.


distance\_between
~~~~~~~~~~~~~~~~~
Expand All @@ -68,7 +75,14 @@ Gives the real distance/cost between two adjacent nodes n1 and n2 (i.e
n2 belongs to the list of n1's neighbors). n2 is guaranteed to belong to
the list returned by a call to neighbors(n1).

This method must be implemented in a subclass.
Alternately, you may override "path\_distance\_between". The arguments
will be a "SearchNode", as in "path\_neighbors". You might want to use this
if your distance measure should include the path's attainable speed, the
kind and number of turns on it, or similar. You can use the nodes' "cache"
attributes to store some data, to speed up calculation.

One of these methods must be implemented in a subclass.


heuristic\_cost\_estimate
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -82,7 +96,7 @@ Computes the estimated (rough) distance/cost between a node and the
goal. The first argument is the start node, or any node that have been
returned by a call to the neighbors() method.

This method is used to give to the algorithm an hint about the node he
This method is used to give to the algorithm an hint about the node it
may try next during search.

This method must be implemented in a subclass.
Expand All @@ -92,7 +106,6 @@ is\_goal\_reached

.. code:: py


def is_goal_reached(self, current, goal)

This method shall return a truthy value when the goal is 'reached'. By
Expand Down
54 changes: 38 additions & 16 deletions astar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class SearchNode(Generic[T]):
"""Representation of a search node"""

__slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset")
__slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset", "cache")

def __init__(
self, data: T, gscore: float = infinity, fscore: float = infinity
Expand All @@ -26,6 +26,7 @@ def __init__(
self.closed = False
self.in_openset = False
self.came_from: Union[None, SearchNode[T]] = None
self.cache: Any = None

def __lt__(self, b: "SearchNode[T]") -> bool:
"""Natural order is based on the fscore value & is used by heapq operations"""
Expand Down Expand Up @@ -78,29 +79,48 @@ def heuristic_cost_estimate(self, current: T, goal: T) -> float:
"""
Computes the estimated (rough) distance between a node and the goal.
The second parameter is always the goal.

This method must be implemented in a subclass.
"""
raise NotImplementedError

@abstractmethod
def distance_between(self, n1: T, n2: T) -> float:
"""
Gives the real distance between two adjacent nodes n1 and n2 (i.e n2
belongs to the list of n1's neighbors).
n2 is guaranteed to belong to the list returned by the call to neighbors(n1).
This method must be implemented in a subclass.

This method (or "path_distance_between") must be implemented in a subclass.
"""
raise NotImplementedError

def path_distance_between(self, n1: SearchNode[T], n2: SearchNode[T]) -> float:
"""
Gives the real distance between the node n1 and its neighbor n2.
n2 is guaranteed to belong to the list returned by the call to
path_neighbors(n1).

Calls "distance_between"`by default.
"""
return self.distance_between(n1.data, n2.data)

@abstractmethod
def neighbors(self, node: T) -> Iterable[T]:
"""
For a given node, returns (or yields) the list of its neighbors.
This method must be implemented in a subclass.

This method (or "path_neighbors") must be implemented in a subclass.
"""
raise NotImplementedError

def path_neighbors(self, node: SearchNode[T]) -> Iterable[T]:
"""
For a given node, returns (or yields) the list of its reachable neighbors.
Calls "neighbors" by default.
"""
return self.neighbors(node.data)

def _neighbors(self, current: SearchNode[T], search_nodes: SearchNodeDict[T]) -> Iterable[SearchNode]:
return (search_nodes[n] for n in self.neighbors(current.data))
return (search_nodes[n] for n in self.path_neighbors(current))

def is_goal_reached(self, current: T, goal: T) -> bool:
"""
Expand Down Expand Up @@ -147,25 +167,27 @@ def astar(
if neighbor.closed:
continue

tentative_gscore = current.gscore + self.distance_between(
current.data, neighbor.data
)
gscore = current.gscore + self.path_distance_between(current, neighbor)

if tentative_gscore >= neighbor.gscore:
if gscore >= neighbor.gscore:
continue

neighbor_from_openset = neighbor.in_openset
fscore = gscore + self.heuristic_cost_estimate(
neighbor.data, goal
)

if neighbor.in_openset:
if neighbor.fscore < fscore:
# the new path to this node isn't better
continue

if neighbor_from_openset:
# we have to remove the item from the heap, as its score has changed
openSet.remove(neighbor)

# update the node
neighbor.came_from = current
neighbor.gscore = tentative_gscore
neighbor.fscore = tentative_gscore + self.heuristic_cost_estimate(
neighbor.data, goal
)
neighbor.gscore = gscore
neighbor.fscore = fscore

openSet.push(neighbor)

Expand Down