# An Overview of the A* Algorithm
A* is a shortest path algorithm for a weighted graph using a heuristic search function, first published in 1968. Another rather famous shortest path algorithm, Dijkstra's algorithm, is a special case of A* that uses a uniform search strategy. In search theory, A* is what is known as a "best-first search" -- essentially it uses some inherent information embedded in the graph to guide the search from a start node to an end node.

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline

In [None]:
# Visualization Utility

def draw_graph(graph, start, end, path=None):
    """Input:
        - dictionary of node connections and weights (tuple of strings:float)
        - start node (string)
        - end node (string)
        - [optional] path (list of node pairs)
    """
    G = nx.Graph()

    # Create the nodes and edges in the graph
    for k, v in graph.items():
        G.add_edge(k[0], k[1], weight=v)

    # Sets a layout for the graph
    pos = nx.spring_layout(G, seed=7)
    
    # Draw the basic graph
    nx.draw_networkx_nodes(G, pos)
    nx.draw_networkx_edges(G, pos, width=6)

    # Highlight the start and end nodes
    nx.draw_networkx_nodes(G, pos, nodelist=[start, end], node_size=500, node_color="red")

    # Draw the path between the start and end nodes, if available
    if path is not None:
        edgelist = [(path[i],path[i+1]) for i in range(0,len(path)-1)] 
        nx.draw_networkx_edges(G, pos, edgelist=edgelist, width=6, edge_color="r", style="dotted")
    
    # Add labels to everything
    nx.draw_networkx_labels(G, pos, font_size=20, font_family="sans-serif")
    edge_labels = nx.get_edge_attributes(G, "weight")
    nx.draw_networkx_edge_labels(G, pos, edge_labels)

    # Draw the graph
    plt.show()


graph_example = {("A", "B"):0.5, ("A", "C"):0.1, ("C", "D"):0.4,
                 ("A", "D"):0.3, ("C", "E"):0.5, ("B", "D"):0.8}
draw_graph(graph_example, "A", "E", ("A", "C", "E"))

## Algorithm Overview -- A Connected Graph Example
Given a graph, we'd like to systematically find the shortest path between two nodes. To do this, we're going to introduce a heuristic to explore our graph:

$$f = g + h$$

where $g$ is the cumulative "cost to go" from the start node through the path so far to a new node, and $h$ is a heuristic estimate of the cost from the new node to the goal. We use this function to assign values to nodes and paths through the graph, and select what new paths/nodes to explore. In general, the algorithm takes the following form:

*Initialize* an open list with the start node in it, and an empty closed list.

*While* the open list _is not empty_ do the following:
1. Select the node from the open list with the lowest $f$ value and set as the "current node"
2. If the current node is the goal node, search is over! Return the path to the goal.
3. Otherwise, move the current node to the closed list, then for each neighbor of the closed list:
    1. If the neighbor is already in the closed list, skip it.
    2. If the neighbor is not in the open list, compute its $f$ value and add it to the open list; set the parent of the neighbor to the current node.
    3. If the neighbor is already in the open list, update its $f$ value if it is lower than it's previous value.
4. If the open list is empty; no solution.

In [None]:
class Node(object):
    """Creates a node class to track values in the A* algorithm."""

    def __init__(self, name=None, parent=None):
        """Initialize the node"""
        self.name = name
        self.parent = parent
        self.g = 0
        self.h = 0
        self.f = 0
    
    def __eq__(self, other):
        return self.name == other.name
    
    def __repr__(self):
        return f"{self.name} - g: {self.g}, h: {self.h}, f: {self.f}"
    
    def __lt__(self, other):
        return self.f < other.f
    
    def __gt__(self, other):
        return self.f > other.f

In [None]:
def return_path(current_node):
    """Traces a path from a given Node"""
    path = []
    current = current_node
    while current is not None:
        path.append(current.name)
        current = current.parent
    print(path[::-1])
    return path[::-1]

In [None]:
def astar(graph, start, end):
    """From a graph, returns a path from start to end"""
    start_node = Node(name=start, parent=None)
    end_node = Node(name=end, parent=None)

    # start the open and closed lists
    open_list = [start_node]
    closed_list = []

    # run the loop
    while len(open_list) > 0:
        # get the current node (node with smallest f value)
        current_node = open_list[0]
        for n in open_list:
            if n.f < current_node.f:
                current_node = n
        open_list.remove(current_node)
        
        # return the path if current node is the end node
        if current_node == end_node:
            path = return_path(current_node)
            return path
        
        else:
            closed_list.append(current_node)
            neighbors = []
            for k, v in graph.items():
                if current_node.name in k:
                    keys = list(k)
                    keys.remove(current_node.name)
                    neighbor = Node(name=keys[0], parent=current_node)
                    neighbor.g = current_node.g + v
                    neighbor.h = 1
                    neighbor.f = neighbor.g + neighbor.h
                    neighbors.append(neighbor)
            
            for n in neighbors:
                if n in closed_list:
                    pass
                elif n in open_list:
                    for i, node in enumerate(open_list):
                        if node == n:
                            if node.f > n.f:
                                open_list[i] = n
                else:
                    open_list.append(n)


graph_example = {("A", "B"):0.5, ("A", "C"):1.6, ("C", "D"):0.4, ("C", "B"):0.6,
                 ("A", "D"):0.3, ("C", "E"):0.5, ("B", "D"):0.8}
start_node = "A"
end_node = "E"
path = astar(graph_example, start_node, end_node)
draw_graph(graph_example, start_node, end_node, path)

## A* -- In a Maze; Leveraging The Heuristic

In the connected graph example, it is tricky to think about a well-grounded heuristic to help guide search. But in navigation problems, a commonly used heuristic is the "true shortest path" between a node at the goal. This helps bias search towards locations that are physically closer to a goal point. Let's see how this manifests in a maze world. Note that this is adopted from: https://gist.github.com/ryancollingwood/32446307e976a11a1185a5394d6657bc.

In [None]:
def astar_maze(maze, start, end):
    """From a graph, returns a path from start to end"""
    start_node = Node(name=start, parent=None)
    end_node = Node(name=end, parent=None)

    # start the open and closed lists
    open_list = [start_node]
    closed_list = []

    # set the grid locations to search for in the maze (adjacency)
    adjacent_squares = ((0, -1), (0, 1), (-1, 0), (1, 0),)

    # run the loop
    while len(open_list) > 0:
        # get the current node (node with smallest f value)
        current_node = open_list[0]
        for n in open_list:
            if n.f < current_node.f:
                current_node = n
        open_list.remove(current_node)
        
        # return the path if current node is the end node
        if current_node == end_node:
            path = return_path(current_node)
            return path
        
        else:
            closed_list.append(current_node)
            neighbors = []
            for new_position in adjacent_squares: # Adjacent squares
                # Get node position
                node_position = (current_node.name[0] + new_position[0], current_node.name[1] + new_position[1])
                if node_position[0] > (len(maze) - 1) or node_position[0] < 0 or node_position[1] > (len(maze[len(maze)-1]) -1) or node_position[1] < 0:
                    continue
                if maze[node_position[0]][node_position[1]] != 0:
                    continue
                new_node = Node(parent=current_node, name=node_position)
                new_node.g = current_node.g + 1
                new_node.h = ((new_node.name[0] - end_node.name[0]) ** 2) + ((new_node.name[1] - end_node.name[1]) ** 2)
                new_node.f = new_node.g + new_node.h
                neighbors.append(new_node)
            
            for n in neighbors:
                if n in closed_list:
                    pass
                elif n in open_list:
                    for i, node in enumerate(open_list):
                        if node == n:
                            if node.f > n.f:
                                open_list[i] = n
                else:
                    open_list.append(n)

def example(print_maze = True):

    maze = [[0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,] * 2,
            [0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,] * 2,
            [0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,] * 2,
            [0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,] * 2,
            [0,0,0,1,1,0,0,1,1,1,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,] * 2,
            [0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,] * 2,
            [0,0,0,1,0,1,1,1,1,0,1,1,0,0,1,1,1,0,0,0,1,1,1,1,1,1,1,0,0,0,] * 2,
            [0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,0,0,0,0,0,1,1,1,0,] * 2,
            [0,0,0,1,0,1,1,0,1,1,0,1,1,1,0,0,0,0,0,1,0,0,1,1,1,1,1,0,0,0,] * 2,
            [0,0,0,1,0,1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,1,0,1,0,1,1,] * 2,
            [0,0,0,1,0,1,0,1,1,0,1,1,1,1,0,0,1,1,1,1,1,1,1,0,1,0,1,0,0,0,] * 2,
            [0,0,0,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,1,1,1,0,] * 2,
            [0,0,0,1,0,1,1,1,1,0,1,0,0,1,1,1,0,1,1,1,1,0,1,1,1,0,1,0,0,0,] * 2,
            [0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,1,1,] * 2,
            [0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,] * 2,
            [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,] * 2,]
    
    start = (0, 0)
    end = (len(maze)-1, len(maze[0])-1)

    path = astar_maze(maze, start, end)

    if print_maze:
      for step in path:
        maze[step[0]][step[1]] = 2
      
      for row in maze:
        line = []
        for col in row:
          if col == 1:
            line.append("\u2588")
          elif col == 0:
            line.append(" ")
          elif col == 2:
            line.append(".")
        print("".join(line))

example()