Skip to content

Commit

Permalink
better type hinting using generics
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Rialland committed Jan 9, 2023
1 parent 98895e9 commit a5ce186
Showing 1 changed file with 34 additions and 46 deletions.
80 changes: 34 additions & 46 deletions astar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,38 @@

from abc import ABC, abstractmethod
from heapq import heappush, heappop
from typing import Iterable, Union
from typing import Iterable, Union, TypeVar, Generic

__author__ = "Julien Rialland"
__copyright__ = "Copyright 2012-2022, J.Rialland"
__copyright__ = "Copyright 2012-2023, J.Rialland"
__license__ = "BSD"
__version__ = "0.94"
__version__ = "0.95"
__maintainer__ = __author__
__email__ = "".join(
map(
chr,
[
106,
117,
108,
105,
101,
110,
46,
114,
105,
97,
108,
108,
97,
110,
100,
64,
103,
109,
97,
105,
108,
46,
99,
111,
109,
],

__email__ = "nc"

try:
from cryptography.fernet import Fernet

__email__ = (
Fernet(b"nSR9r0WZ-SdMuyPGfLFyHT4Vu0pdfMTudDOcb9FrubE=")
.decrypt(
b"gAAAAABjvCkTri-mqyMjuR41C4WmejRO9ohHVX_6dXdva98OAIiHC8yRc4RDAXCZcoxz-Nnv12qO9BV4nQdJlx-gcK-eOws9hBXyKeFWUGkhp2YEaeYjrFs="
)
.decode("utf-8")
)
)
except Exception as e:
pass

__status__ = "Production"

# infinity as a constant
Infinite = float("inf")

# introduce generic type
T = TypeVar("T")

class AStar(ABC):
class AStar(ABC, Generic[T]):

__slots__ = ()

Expand All @@ -58,7 +44,7 @@ class SearchNode:
__slots__ = ("data", "gscore", "fscore", "closed", "came_from", "out_openset")

def __init__(
self, data, gscore: float = Infinite, fscore: float = Infinite
self, data: T, gscore: float = Infinite, fscore: float = Infinite
) -> None:
self.data = data
self.gscore = gscore
Expand All @@ -77,7 +63,7 @@ def __missing__(self, k):
return v

@abstractmethod
def heuristic_cost_estimate(self, current, goal):
def heuristic_cost_estimate(self, current: T, goal: T) -> float:
"""
Computes the estimated (rough) distance between a node and the goal.
The second parameter is always the goal.
Expand All @@ -86,7 +72,7 @@ def heuristic_cost_estimate(self, current, goal):
raise NotImplementedError

@abstractmethod
def distance_between(self, n1, n2):
def distance_between(self, n1: T, n2: T) -> float:
"""
Gives the real distance between two adjacent nodes n1 and n2 (i.e n2
belongs to the list of n1's neighbors).
Expand All @@ -95,22 +81,22 @@ def distance_between(self, n1, n2):
"""

@abstractmethod
def neighbors(self, node):
def neighbors(self, node: T) -> Iterable[T]:
"""
For a given node, returns (or yields) the list of its neighbors.
This method must be implemented in a subclass.
"""
raise NotImplementedError

def is_goal_reached(self, current, goal):
def is_goal_reached(self, current: T, goal: T) -> bool:
"""
Returns true when we can consider that 'current' is the goal.
The default implementation simply compares `current == goal`, but this
method can be overwritten in a subclass to provide more refined checks.
"""
return current == goal

def reconstruct_path(self, last, reversePath=False)->Iterable:
def reconstruct_path(self, last: SearchNode, reversePath=False) -> Iterable[T]:
def _gen():
current = last
while current:
Expand All @@ -122,14 +108,16 @@ def _gen():
else:
return reversed(list(_gen()))

def astar(self, start, goal, reversePath: bool = False) -> Union[Iterable, None]:
def astar(
self, start: T, goal: T, reversePath: bool = False
) -> Union[Iterable[T], None]:
if self.is_goal_reached(start, goal):
return [start]
searchNodes = AStar.SearchNodeDict()
startNode = searchNodes[start] = AStar.SearchNode(
start, gscore=0.0, fscore=self.heuristic_cost_estimate(start, goal)
)
openSet:list = []
openSet: list = []
heappush(openSet, startNode)
while openSet:
current = heappop(openSet)
Expand Down Expand Up @@ -161,14 +149,14 @@ def astar(self, start, goal, reversePath: bool = False) -> Union[Iterable, None]


def find_path(
start,
goal,
start: T,
goal: T,
neighbors_fnct,
reversePath=False,
heuristic_cost_estimate_fnct=lambda a, b: Infinite,
distance_between_fnct=lambda a, b: 1.0,
is_goal_reached_fnct=lambda a, b: a == b,
):
) -> Union[Iterable[T], None]:
"""A non-class version of the path finding algorithm"""

class FindPath(AStar):
Expand Down

0 comments on commit a5ce186

Please sign in to comment.