In [113]:
from collections import namedtuple
from random import choice
from tqdm.auto import tqdm
import numpy as np
import heapq
from collections import namedtuple

In [114]:
PUZZLE_DIM = 3
action = namedtuple('Action', ['pos1', 'pos2'])

## Helper functions

In [115]:
def available_actions(state: np.ndarray) -> list['Action']:
    x, y = [int(_[0]) for _ in np.where(state == 0)]
    actions = list()
    if x > 0:
        actions.append(action((x, y), (x - 1, y)))
    if x < PUZZLE_DIM - 1:
        actions.append(action((x, y), (x + 1, y)))
    if y > 0:
        actions.append(action((x, y), (x, y - 1)))
    if y < PUZZLE_DIM - 1:
        actions.append(action((x, y), (x, y + 1)))
    return actions



def do_action(state: np.ndarray, action: 'Action') -> np.ndarray:
    new_state = state.copy()
    new_state[action.pos1], new_state[action.pos2] = new_state[action.pos2], new_state[action.pos1]
    return new_state

def manhattan_distance(state, goal):
    return sum(abs(x1 - x2) + abs(y1 - y2)
               for num in range(1, PUZZLE_DIM**2)
               for x1, y1 in [np.where(state == num)]
               for x2, y2 in [np.where(goal == num)])


def linear_conflict(state, goal):
    conflict = 0

    # Check row conflicts
    for row in range(PUZZLE_DIM):
        row_values = state[row, :]
        goal_row_values = goal[row, :]
        for i in range(PUZZLE_DIM):
            if row_values[i] in goal_row_values:
                for j in range(i + 1, PUZZLE_DIM):
                    if (row_values[j] in goal_row_values and
                        goal_row_values.tolist().index(row_values[i]) >
                        goal_row_values.tolist().index(row_values[j])):
                        conflict += 2

    # Check column conflicts
    for col in range(PUZZLE_DIM):
        col_values = state[:, col]
        goal_col_values = goal[:, col]
        for i in range(PUZZLE_DIM):
            if col_values[i] in goal_col_values:
                for j in range(i + 1, PUZZLE_DIM):
                    if (col_values[j] in goal_col_values and
                        goal_col_values.tolist().index(col_values[i]) >
                        goal_col_values.tolist().index(col_values[j])):
                        conflict += 2

    return conflict + manhattan_distance(state, goal)


def reconstruct_path(node):
    path = []
    while node:
        path.append(np.array(node.state))  # Convert tuples in array numpy
        node = node.parent
    return path[::-1]  # Return inverted path


## Initial generator

In [116]:
RANDOMIZE_STEPS = 100_000
state = np.array([i for i in range(1, PUZZLE_DIM**2)] + [0]).reshape((PUZZLE_DIM, PUZZLE_DIM))
for r in tqdm(range(RANDOMIZE_STEPS), desc='Randomizing'):
    state = do_action(state, choice(available_actions(state)))
state

Randomizing: 100%|██████████| 100000/100000 [00:00<00:00, 174975.77it/s]


array([[8, 2, 3],
       [4, 5, 7],
       [6, 1, 0]])

In [117]:
goal_state = np.array([i for i in range(1, PUZZLE_DIM**2)] + [0]).reshape((PUZZLE_DIM, PUZZLE_DIM))

class Node:
    def __init__(self, state, g, h, parent=None):
        self.state = state  # Actual state
        self.g = g  # Actual cost
        self.h = h  # heuristic value
        self.parent = parent  # Parent node

    def __lt__(self, other):
        return (self.g + self.h) < (other.g + other.h)      # Compare nodes by f = g + h


In [118]:
def enhanced_heuristic(state, goal):
    return manhattan_distance(state, goal) + linear_conflict(state, goal)

def a_star(initial_state, goal_state):
    calls = 0
    open_list = []
    closed_set = {}  # To memorize states already visited
    h = enhanced_heuristic(initial_state, goal_state)
    root = Node(state=tuple(map(tuple, initial_state)), g=0, h=h, parent=None)

    heapq.heappush(open_list, (root.g + root.h, root))  # Push root node in open_list

    while open_list:
        _, current = heapq.heappop(open_list)

        if current.state == tuple(map(tuple, goal_state)):  # Goal state reached
            print(f"\nNumber of calls to heuristic function: {calls}")
            return reconstruct_path(current)

        if current.state in closed_set and closed_set[current.state] <= current.g:
            continue  # Skip if state already visited with lower cost

        closed_set[current.state] = current.g 

        current_state_ndarray = np.array(current.state)
        for act in available_actions(current_state_ndarray):
            new_state = do_action(current_state_ndarray, act)
            new_state_tuple = tuple(map(tuple, new_state))

            if new_state_tuple in closed_set and closed_set[new_state_tuple] <= current.g + 1:
                continue  # Skip worst states

            g = current.g + 1
            h = enhanced_heuristic(new_state, goal_state)
            calls += 1
            new_node = Node(state=new_state_tuple, g=g, h=h, parent=current)

            # Avoid to add a node with the same state and a higher cost
            if any(node.state == new_state_tuple and node.g <= g for _, node in open_list):
                continue

            heapq.heappush(open_list, (g + h, new_node))

    return None


In [119]:
random_result = a_star(state, goal_state)
if random_result:
    print(f"\nPath found! Length: {len(random_result) - 1}")
    for step in random_result:
        print("\n", step)
else:
    print("\nNo solution found.")


Number of calls to heuristic function: 352

Path found! Length: 26

 [[8 2 3]
 [4 5 7]
 [6 1 0]]

 [[8 2 3]
 [4 5 0]
 [6 1 7]]

 [[8 2 3]
 [4 0 5]
 [6 1 7]]

 [[8 2 3]
 [4 1 5]
 [6 0 7]]

 [[8 2 3]
 [4 1 5]
 [0 6 7]]

 [[8 2 3]
 [0 1 5]
 [4 6 7]]

 [[8 2 3]
 [1 0 5]
 [4 6 7]]

 [[8 2 3]
 [1 5 0]
 [4 6 7]]

 [[8 2 0]
 [1 5 3]
 [4 6 7]]

 [[8 0 2]
 [1 5 3]
 [4 6 7]]

 [[0 8 2]
 [1 5 3]
 [4 6 7]]

 [[1 8 2]
 [0 5 3]
 [4 6 7]]

 [[1 8 2]
 [5 0 3]
 [4 6 7]]

 [[1 0 2]
 [5 8 3]
 [4 6 7]]

 [[1 2 0]
 [5 8 3]
 [4 6 7]]

 [[1 2 3]
 [5 8 0]
 [4 6 7]]

 [[1 2 3]
 [5 8 7]
 [4 6 0]]

 [[1 2 3]
 [5 8 7]
 [4 0 6]]

 [[1 2 3]
 [5 0 7]
 [4 8 6]]

 [[1 2 3]
 [5 7 0]
 [4 8 6]]

 [[1 2 3]
 [5 7 6]
 [4 8 0]]

 [[1 2 3]
 [5 7 6]
 [4 0 8]]

 [[1 2 3]
 [5 0 6]
 [4 7 8]]

 [[1 2 3]
 [0 5 6]
 [4 7 8]]

 [[1 2 3]
 [4 5 6]
 [0 7 8]]

 [[1 2 3]
 [4 5 6]
 [7 0 8]]

 [[1 2 3]
 [4 5 6]
 [7 8 0]]
