# Informed Search

In [1]:
from state import State
from node import Node, find_state
from heapq import heappush, heappop, heapify
from typing import List, Tuple, Dict, Any, Callable

In [None]:
def a_star_graph_search(initial_state: State, h: Callable) -> Tuple[Node, int]:
    # create the initial node from the initial state
    initial_node = Node(initial_state, None, 0, 0)  
    # calculate the initial f value (g + h)
    f_value = 0 + h(initial_state)                  
    # initialize the frontier with the initial node
    frontier = [(f_value, initial_node)]            
    # heapify the frontier for efficient access
    # each member of the frontier is a tuple (f_value, node)
    heapify(frontier)                               
                                                    
    # initialize the explored set to keep track of visited nodes
    explored = set()                                
    # initialize number of nodes visited
    n_visits = 0                                    

    # while there are nodes in the frontier
    while frontier:                                 
        n_visits += 1                               
        # pop the node with the lowest f_value from the frontier
        f_value, node = heappop(frontier)           
        # add the node to the explored set
        explored.add(node.state)                    
        # if the node is a goal node
        if node.state.is_goal():                    
            # return the node and the number of nodes visited
            return node, n_visits                   
        else:
            for child_state, step_cost in node.state.successors():
                # if the child state is not in the explored set
                if child_state not in explored:         
                    # check if the child state is already in the frontier
                    idx, existing_node = find_state(child_state, frontier)          
                    # create a child node from the current node
                    child_node = Node(child_state, node, node.path_cost + step_cost, 
                                      node.depth + 1)  
                    # calculate the f value for the child node
                    f_value = child_node.path_cost + h(child_state)                 
                    # if the child state is not in the frontier
                    if existing_node is None:                                       
                        # push the child node into the frontier
                        heappush(frontier, (f_value, child_node))                   
                    # if the child state is in the frontier with a higher path cost
                    elif child_node.path_cost < existing_node.path_cost:            
                        # update the frontier with the new child node
                        frontier[idx] = (f_value, child_node)                       
                        # heapify the frontier to maintain the heap property  
                        heapify(frontier)                                           
    # if no goal node is found, return None and number of nodes visited        
    return None, n_visits                               


In [3]:
from eight_puzzle import EightPuzzleState

def h_misplaced_tiles(state: EightPuzzleState) -> int:
    """Heuristic function that counts the number of misplaced tiles."""
    # flatten the board [[7,2,4],[5,0,6],[8,3,1]] -> [7,2,4,5,0,6,8,3,1]
    flat_board = [tile for row in state.board for tile in row]
    # count the number of tiles that are not in their goal position
    # the goal position for tile i is at index i-1, so we check if tile != 0 and tile != i + 1
    # we ignore the empty tile (0) since it does not count as misplaced
    return sum(1 for i, tile in enumerate(flat_board) if tile != 0 and tile != i + 1)

In [4]:
initial_state = EightPuzzleState([[7, 2, 4], [5, 0, 6], [8, 3, 1]], (1, 1))
print("Heuristic of the initial state =", h_misplaced_tiles(initial_state))

Heuristic of the initial state = 6


In [9]:
goal_node, n_visits = a_star_graph_search(initial_state, h_misplaced_tiles)  
if goal_node:                                       
    path = []                                       
    node = goal_node
    while node:
        path.append(node.state)
        node = node.parent
    path.reverse()
    print(f"Path:")
    for i, state in enumerate(path):
        print(f"Step {i}:")
        print(state)
        print()
    print(f"Path cost: {goal_node.path_cost}")
    print(f"Number of nodes visited: {n_visits}")

Path:
Step 0:
7 2 4
5 0 6
8 3 1

Step 1:
7 2 4
5 3 6
8 0 1

Step 2:
7 2 4
5 3 6
8 1 0

Step 3:
7 2 4
5 3 0
8 1 6

Step 4:
7 2 4
5 0 3
8 1 6

Step 5:
7 2 4
0 5 3
8 1 6

Step 6:
0 2 4
7 5 3
8 1 6

Step 7:
2 0 4
7 5 3
8 1 6

Step 8:
2 4 0
7 5 3
8 1 6

Step 9:
2 4 3
7 5 0
8 1 6

Step 10:
2 4 3
7 0 5
8 1 6

Step 11:
2 4 3
7 1 5
8 0 6

Step 12:
2 4 3
7 1 5
0 8 6

Step 13:
2 4 3
0 1 5
7 8 6

Step 14:
2 4 3
1 0 5
7 8 6

Step 15:
2 0 3
1 4 5
7 8 6

Step 16:
0 2 3
1 4 5
7 8 6

Step 17:
1 2 3
0 4 5
7 8 6

Step 18:
1 2 3
4 0 5
7 8 6

Step 19:
1 2 3
4 5 0
7 8 6

Step 20:
1 2 3
4 5 6
7 8 0

Path cost: 20
Number of nodes visited: 3667
