# Import

In [1]:
from collections import deque
from copy import deepcopy
from graphviz import Digraph
import heapq
import random

# Class Puzzle

In [2]:
class Puzzle:
    def __init__(self, state, action=None, parent=None, g=0, h=0):
        self.state = state
        self.id = str(self.state)
        self.action = action
        self.parent = parent
        self.g = g
        self.h = h
        self.f = g + h

    def __eq__(self, other):
        return self.state == other.state

    def __hash__(self):
        return hash(str(self.state))

    def __lt__(self, other):
        return self.f < other.f

    def __str__(self):
        return "\n".join("".join(str(e) for e in r) for r in self.state)

    @staticmethod
    def get_pos(state, val):
        for i in range(3):
            for j in range(3):
                if state[i][j] == val:
                    return i, j
        return None

    @staticmethod
    def check_neighbor(state, a, b):
        pos_a, pos_b = Puzzle.get_pos(state, a), Puzzle.get_pos(state, b)
        if not pos_a or not pos_b:
            return False
        return (pos_a[0] == pos_b[0] and abs(pos_a[1] - pos_b[1]) == 1) or \
               (pos_a[1] == pos_b[1] and abs(pos_a[0] - pos_b[0]) == 1)

    @staticmethod
    def swap(state, a, b):
        a_i, a_j = Puzzle.get_pos(state, a)
        b_i, b_j = Puzzle.get_pos(state, b)
        state[a_i][a_j], state[b_i][b_j] = state[b_i][b_j], state[a_i][a_j]

    def get_dest_pos(self, action, pi, pj):
        return {
            'L': (pi, pj + 1),
            'R': (pi, pj - 1),
            'U': (pi + 1, pj),
            'D': (pi - 1, pj),
        }.get(action, (pi, pj))

    def get_successor(self, action, state):
        pi, pj = Puzzle.get_pos(state, 0)
        ni, nj = self.get_dest_pos(action, pi, pj)
        if 0 <= ni < 3 and 0 <= nj < 3:
            state[pi][pj], state[ni][nj] = state[ni][nj], 0
            return state
        return None

    def get_successors(self):
        was_13 = Puzzle.check_neighbor(self.state, 1, 3)
        was_24 = Puzzle.check_neighbor(self.state, 2, 4)
        successors = []

        for act in ['L', 'R', 'U', 'D']:
            new_state = self.get_successor(act, deepcopy(self.state))
            if new_state is None:
                continue
            if Puzzle.check_neighbor(new_state, 1, 3) and not was_13:
                Puzzle.swap(new_state, 1, 3)
            if Puzzle.check_neighbor(new_state, 2, 4) and not was_24:
                Puzzle.swap(new_state, 2, 4)
            successors.append(Puzzle(new_state, act, self))

        return successors

    def get_id(self):
        return self.id

    def get_action(self):
        return self.action

    def get_solution_path(self):
        path, node = [], self
        while node.parent:
            path.append(node.action)
            node = node.parent
        return path[::-1]

    def draw(self, dot):
        label = self.get_id()
        flat = [x for row in self.state for x in row]
        tile = lambda x: " " if x == 0 else str(x)
        table = f'''<
            <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
            <TR><TD>{tile(flat[0])}</TD><TD>{tile(flat[1])}</TD><TD>{tile(flat[2])}</TD></TR>
            <TR><TD>{tile(flat[3])}</TD><TD>{tile(flat[4])}</TD><TD>{tile(flat[5])}</TD></TR>
            <TR><TD>{tile(flat[6])}</TD><TD>{tile(flat[7])}</TD><TD>{tile(flat[8])}</TD></TR>
            </TABLE>>'''
        dot.node(label, table, shape="plaintext")
        if self.parent:
            dot.edge(self.parent.get_id(), self.get_id(), label=self.get_action())

In [3]:
class PuzzleAgent:
    @classmethod
    def solve(cls, initial_state, goal_states, heuristic_func, graph_depth=20):
        dot = Digraph()
        explored = set()
        drawn = set()

        puzzle = Puzzle(initial_state, g=0, h=min(heuristic_func(initial_state, goal) for goal in goal_states))
        open_set = [puzzle]
        heapq.heapify(open_set)

        while open_set:
            curr = heapq.heappop(open_set)
            if curr.state in goal_states:
                node = curr
                while node:
                    if node.get_id() not in drawn:
                        node.draw(dot)
                        drawn.add(node.get_id())
                    node = node.parent
                return {"goal_node": curr, "cost": curr.g, "actions": curr.get_solution_path()}, dot

            explored.add(str(curr.state))
            if curr.g < graph_depth and curr.get_id() not in drawn:
                curr.draw(dot)
                drawn.add(curr.get_id())

            for n in curr.get_successors():
                if str(n.state) in explored:
                    continue
                if any(str(n.state) == str(x.state) for x in open_set):
                    continue
                n.g = curr.g + 1
                n.h = min(heuristic_func(n.state, goal) for goal in goal_states)
                n.f = n.g + n.h
                heapq.heappush(open_set, n)

        return None, dot

In [4]:
def h_manhattan(state, goal):
    distance = 0
    for i in range(1, 9):
        pos1, pos2 = Puzzle.get_pos(state, i), Puzzle.get_pos(goal, i)
        if pos1 and pos2:
            distance += abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
    return distance

In [5]:
def h_near_goal(state, goal, n=2):
    count = 0
    for i in range(1, 9):
        pos1, pos2 = Puzzle.get_pos(state, i), Puzzle.get_pos(goal, i)
        if pos1 and pos2:
            distance = abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
            if distance <= n:
                count += 1
    return count

In [10]:
goal_states = [
    [[1, 2, 3], [4, 5, 6], [7, 8, 0]],
    [[8, 7, 6], [5, 4, 3], [2, 1, 0]],
    [[1, 2, 0], [3, 4, 5], [6, 7, 8]],
    [[8, 7, 0], [6, 5, 4], [3, 2, 1]]
]

In [None]:
tiles = list(range(9))
random.shuffle(tiles)
initial_state = [tiles[i:i+3] for i in range(0, 9, 3)]

result, dot = PuzzleAgent.solve(initial_state, goal_states, h_near_goal)
if result:
    print("Found goal!")
    print("Initial:")
    for row in initial_state:
        print(row)
    print("Cost:", result["cost"])
    print("Path:", result["actions"])
    dot
else:
    print("No solution found.")
