# Assignment #1: Implement 8-puzzle solver (using A* search)

### Why `Graph Search`, not `Tree Search`?

The 8 puzzle state space is a graph, not tree, because the same board state/setting 
can be reached through multiple variant sequences of moves. And in practice it is possible to play game and after arbitrary N moves come back to initial state - or any other state we have already been in - which would mean that we traversed in a cycle (*which is not possible in tree by definition*)
- **tree search** has no memory of visited states, which means we can revisit the same board
  infinitely, looping forever and/or wasting time on already discovered paths.
- **graph search** maintains a `visited` set (closed set)

### Heuristic Functions

#### Hamming distance
Count how many tiles are not in their goal position (with blank excluded)

#### Manhattan distance
For each tile, sum the row distance + column distance to its goal position

## Pre-Implementation: Imports and setup

In [38]:
import heapq
import copy
import time
import random
from enum import Enum

## Pre-Implementation: Utils and Log helpers

All output goes through `log()` so indentation and spacing is consistent everywhere.
The `after=True` flag appends `\n` after the log output (so that way we dont have to do `print()` everywhere after using log method)

In [39]:
class HeuristicType(Enum):
    HAMMING   = 'hamming'
    MANHATTAN = 'manhattan'

LOG_INDENTATION = "  "

def log(*args, indent: int = 1, after: bool = False, **kwargs):
    print(LOG_INDENTATION * indent + " ".join(str(a) for a in args), **kwargs)
    if after:
        print()

def logSection(title: str, after: bool = False):
    log("=" * 50, indent=0)
    log(title, indent=1)
    log("=" * 50, indent=0, after=after)

## Implementation: `Board` Class

The purpose of a `Board` class is to capture state of the board at a certain step/time.

#### Handles all board related logic: 
- heuristics
- neighbor generation
- solvability checking
- board state display

In [40]:
class Board:
    def __init__(self, state: list[list[int]], size: int = 3):
        self.state = state
        self.size = size
        self.blankCellPosition: tuple[int, int] = self._findBlankCellPosition()

    @property
    def manhattan(self) -> int:
        '''
        manhattan distance: 
        - distance in terms of steps (horizontal + vertical) 
        node needs to pass to reach goal position
        - resulting h score is sum of all those distances

        in comparison to hamming method, manhattan gives us more information
        which is good for heuristic a* search and because it allows us 
        to visit less nodes to find best path
        '''
        distance = 0
        for row in range(self.size):
            for column in range(self.size):
                cell = self.state[row][column]
                if cell != 0:
                    # cell value(number) gets mapped to 0-based idx (thus cell-1) 
                    # and using divmod we get values of goal row, column
                    # exmaple: cell=8, size=3; index=(8-1)=7 
                    # -> (7//3, 7%1) -> (row: 2, column: 1) <-- goal position for cell 8
                    goalRow, goalColumn = divmod(cell-1, self.size)
                    distance += abs(row - goalRow) + abs(column - goalColumn)
        return distance

    @property
    def hamming(self) -> int:
        '''
        hamming distance - number of positions where the values differ.
        so the algorithm for case of board cells/tiles (where board - 2D list of integers) is:
        - for position [row, column] compare current state cell with goal state cell at given position
        - - board[row][column] != goal[row][column]
        - if mismatch -> increase counter
        - the resulting counter will be the h score for our search algorithm
        '''
        count = 0
        for row in range(self.size):
            for column in range(self.size):
                cell = self.state[row][column]
                if cell != 0:  # skip blank
                    goalRow, goalCol = divmod(cell - 1, self.size)
                    if row != goalRow or column != goalCol:
                        count += 1
        return count

    @property
    def neighbors(self) -> list:
        '''
        return list of all valid (Board, actionLog) pairs reachable in one move

        #NOTE why deepcopy:
        - python lists are references - without copying, swapping tiles on
        newState would also mutate self.state (same object in memory)
        - deepcopy creates a fully independent 2D list so the original is safe

        #NOTE on actionLog:
        - it does not play any role in implementation of solution
        - included purely for log purposes (so we can reassure algorithm is working properly)
        '''
        blankCellRow, blankCellColumn = self.blankCellPosition
        results = []
        directions = {
            'UP':    (blankCellRow - 1, blankCellColumn),
            'DOWN':  (blankCellRow + 1, blankCellColumn),
            'LEFT':  (blankCellRow, blankCellColumn - 1),
            'RIGHT': (blankCellRow, blankCellColumn + 1),
        }
        for direction, (newRow, newColumn) in directions.items():
            if 0 <= newRow < self.size and 0 <= newColumn < self.size:
                newState = copy.deepcopy(self.state)
                # swap blank with adjacent tile
                newState[blankCellRow][blankCellColumn], newState[newRow][newColumn] = \
                    newState[newRow][newColumn], newState[blankCellRow][blankCellColumn]
                movedCell = self.state[newRow][newColumn]
                logAction = f"move tile {movedCell} {direction}"
                results.append((Board(newState, self.size), logAction))
        return results

    @property
    def isSolved(self) -> bool:
        '''
        if stringified version of the board
        looks like numbers from 1 to size^2 in asc order with 0 (free cell) in the end
        then we reached the goal state
        exmaple: 3x3 solved = "123456780"
        '''
        N = self.size * self.size
        return str(self) == ''.join(map(str, range(1, N))) + '0'

    @property
    def isSolvable(self) -> bool:
        '''
        solvability check by doing inversion count
    
        - for odd number dimensions (size % 2 == 1):
        - - even inversions -> solvable
        - - odd -> unsolvable

        - for even number dimensions (size % 2 == 0):
        - - apart from main part of algorith (as in odd size case)
        the blank row position (counted from bottom) also matters 
        (so sum of inversions + blankRowFromBottom should be odd for it to be solvable):
        - - inversions EVEN and blank on ODD row from bottom  -> solvable
        - - inversions ODD and blank on EVEN row from bottom -> solvable
        - - else (both are even) -> unsolvable

        #NOTE: the sum() block is a compact double loop:
          for every pair (i,j) where j > i, count +1 if cellList[i] > cellList[j]
        '''
        cellList = [tile for row in self.state for tile in row if tile != 0]
        inversions = sum(
            1
            for i in range(len(cellList))
            for j in range(i + 1, len(cellList))
            if cellList[i] > cellList[j]
        )
        if self.size % 2 == 1:
            return inversions % 2 == 0
        else:
            blankRowFromBottom = self.size - self.blankCellPosition[0]
            return (inversions + blankRowFromBottom) % 2 == 1

    def _findBlankCellPosition(self) -> tuple[int, int] | None:
        for row in range(self.size):
            for column in range(self.size):
                if self.state[row][column] == 0:
                    return (row, column)

    def print(self, label: str = "", indent: int = 1):
        pad = LOG_INDENTATION * indent
        cellWidth = max(len(str(self.size * self.size)), 2)
        
        top = "┌" + ("─" * (cellWidth + 2) + "┬") * (self.size - 1) + "─" * (cellWidth + 2) + "┐"
        mid = "├" + ("─" * (cellWidth + 2) + "┼") * (self.size - 1) + "─" * (cellWidth + 2) + "┤"
        bottom = "└" + ("─" * (cellWidth + 2) + "┴") * (self.size - 1) + "─" * (cellWidth + 2) + "┘"
        
        if label: 
            print(pad + f"[{label}]")
        print(pad + top)
        for i, row in enumerate(self.state):
            cells = "│".join(
                f" {str(cell).center(cellWidth)} " if cell != 0
                else f" {'*'.center(cellWidth)} "
                for cell in row
            )
            print(pad + "│" + cells + "│")
            print(pad + (mid if i < self.size - 1 else bottom))

    def __str__(self):
        return ''.join(str(cell) for row in self.state for cell in row)

## Implementation: `Node` Class

#### `Node` class is our main entity in implementation of search algorithm, as both minheap and visited set store and work with `Node` class instances. Each Node object stores its own snapshot of a `Board` so that way we can embed game state into the logic of the a* search algorithm. 
Each node represents one state in the search tree and stores:
- `g` - how many moves were made to get here
- `h` - the heuristic estimate from here to goal
- `f = g + h` - total priority score that will be used in minheap (lower = popped first / explored first)
- `parent` - the node we came from (used to reconstruct the solution path)
- `actionLog` - what move created this node (for display / log purposes only)

In [41]:
class Node:
    def __init__(
        self,
        board,
        parent=None,
        actionLog=None,
        heuristic: HeuristicType = HeuristicType.MANHATTAN
    ):
        self.board = board
        self.parent: Node = parent
        self.actionLog: str = actionLog
        self.g = 0 if parent is None else parent.g + 1
        self.h = board.manhattan if heuristic == HeuristicType.MANHATTAN else board.hamming
        self.f = self.g + self.h

    @property
    def path(self) -> list:
        '''
        reconstruct path from start to this node:
        - walk backwards through parent pointers
        - reverse the collected list
        '''
        node, pathList = self, []
        while node:
            pathList.append(node)
            node = node.parent
        return list(reversed(pathList))

    @property
    def isSolved(self) -> bool:
        return self.board.isSolved

    def __lt__(self, other):
        '''
        less-than (<) operator to compare nodes
        (for choosing cheapest path in a* search / storing in min-heap)
        '''
        return self.f < other.f

## Implementation: `SolutionResult` Class

A simple container returned by `solve()`. Keeps the return type clean

#### Additionally: `display()` handles all formatted output:
- customizable to display or not display solution steps (by default steps are always logged - `showSteps: bool = True`)

In [42]:
class SolutionResult:
    def __init__(self,
        solvable: bool = False,
        visitedNodeCount: int = 0,
        timeMs: float = 0,
        pathList: list = None,
    ):
        self.solvable = solvable
        self.solutionPath: list = pathList or []
        self.visitedNodeCount = visitedNodeCount
        self.timeMs = timeMs

    def display(self, showSteps: bool = True):
        if not self.solvable:
            log(f"[UNSOLVABLE] - impossible to reach the goal", after=True)
            return

        moves = len(self.solutionPath) - 1
        log(f"[SOLVED] in {moves} moves | {self.visitedNodeCount} nodes visited | {self.timeMs:.3f}ms", after=True)

        if not showSteps: return

        for i, node in enumerate(self.solutionPath):
            if i == 0:
                label = "initial state"
            elif i == moves:
                label = f"GOAL STATE: step {i} - {node.actionLog}"
            else:
                label = f"step {i} - {node.actionLog}"
            node.board.print(label)
            log(f"g={node.g}  h={node.h}  f={node.f}", indent=3, after=True)

## Implementation: A* Search algorithm

- check solvability - skip search if unsolvable
- push start node onto the minheap (priority queue sorted by f)
- main loop:
- - pop node with lowest f = g + h
- - if already in visited -> skip (means found a cheaper path than this earlier)
- - add to visited
- - if its the goal -> get and return path (trace parent pointers back) -> done
- - generate neighbors, compute their f, push onto heap
- if heap empties with no goal found -> unsolvable

In [43]:
def solve(board: Board, heuristic: HeuristicType = HeuristicType.MANHATTAN) -> SolutionResult:
    """
    main method to solve the 8 puzzle using a* graph search
    """
    def _now() -> float:
        return time.perf_counter()

    def _ms(t: float, precision: int = 3) -> float:
        return round(t * 1000, precision)

    if not board.isSolvable:
        return SolutionResult()

    startNode = Node(board, heuristic=heuristic)
    nodeMinHeap: list[Node] = []
    heapq.heappush(nodeMinHeap, startNode)

    # graph search: visited set prevents re-processing the same board state
    visited = set()
    visitedCount = 0
    t0 = _now()

    while nodeMinHeap:
        current: Node = heapq.heappop(nodeMinHeap)

        if str(current.board) in visited:
            continue
        visited.add(str(current.board))
        visitedCount += 1

        if current.isSolved:
            return SolutionResult(
                solvable=True,
                visitedNodeCount=visitedCount,
                timeMs=_ms(_now() - t0),
                pathList=current.path
            )

        for possibleState, actionLog in current.board.neighbors:
            if str(possibleState) not in visited:
                child = Node(possibleState, parent=current,
                             actionLog=actionLog, heuristic=heuristic)
                heapq.heappush(nodeMinHeap, child)

    return SolutionResult(visitedNodeCount=visitedCount, timeMs=_ms(_now() - t0))

## Testing: `Tester` Class

Handles test case generation and execution
- **`randomSolvableBoard`** - shuffles from the goal state using random legal moves.
This guarantees solvability: legal moves can never produce an unsolvable state.
- **`randomUnsolvableBoard`** - takes a solvable board and swaps two non blank tiles.
Swapping any two tiles flips the inversion count parity from even to odd -> guaranteed unsolvable.

In [44]:
class Tester:
    @staticmethod
    def randomSolvableBoard(size: int = 3, randMoveCount: int = 200) -> Board:
        """
        generate a random solvable board:
        - starting from the goal state
        - making random valid moves
        this way we can guarantee solvability - we can never shuffle 
        into an unsolvable state by making valid/legal moves
        """
        state = [[(row * size) + col + 1 for col in range(size)]
                 for row in range(size)]
        state[size - 1][size - 1] = 0
        board = Board(state, size)
        for _ in range(randMoveCount):
            board, _ = random.choice(board.neighbors)
        return board

    @staticmethod
    def randomUnsolvableBoard(size: int = 3) -> Board:
        """
        generate a guaranteed unsolvable board by:
        - generating a solvable board
        - swapping two non-blank tiles thus flipping inversion parity -> unsolvable
        - - if solvable guaranteed to have even inversion now its odd
        """
        board = Tester.randomSolvableBoard(size)
        state = copy.deepcopy(board.state)
        nonBlanks = [
            (row, col)
            for row in range(size)
            for col in range(size)
            if state[row][col] != 0
        ]
        (r1, c1), (r2, c2) = nonBlanks[0], nonBlanks[1]
        state[r1][c1], state[r2][c2] = state[r2][c2], state[r1][c1]
        return Board(state, size)

    @staticmethod
    def runCase(
        label: str,
        board: Board,
        heuristic: HeuristicType = HeuristicType.MANHATTAN,
        showSteps: bool = True
    ):
        logSection(f"{label}  [{heuristic.value}]")
        board.print("initial board", indent=1)
        solve(board, heuristic).display(showSteps=showSteps)

    @staticmethod
    def runAll():
        logSection("[TEST] random 3x3 test cases", after=True)
        for i in range(3):
            Tester.runCase(f"random solvable #{i+1}",   Tester.randomSolvableBoard(size=3))
        for i in range(3):
            Tester.runCase(f"random unsolvable #{i+1}", Tester.randomUnsolvableBoard(size=3))

        logSection("bigger board tests (steps hidden for brevity)", after=True)
        for size in [4, 5]:
            for i in range(3):
                label = f"random solvable {size}x{size} #{i+1}"
                board = Tester.randomSolvableBoard(size=size, randMoveCount=30)
                Tester.runCase(label, board, showSteps=False)
                Tester.runCase(label, board, heuristic=HeuristicType.HAMMING, showSteps=False)
            for i in range(3):
                label = f"random unsolvable {size}x{size} #{i+1}"
                board = Tester.randomUnsolvableBoard(size=size)
                Tester.runCase(label, board, showSteps=False)

        logSection("[RESULTS] heuristics comparison (same 3x3 puzzle)")
        testBoard = Board([[1, 2, 5], [3, 4, 0], [6, 7, 8]])
        for h in HeuristicType:
            result = solve(testBoard, heuristic=h)
            log(f"{h.value:<12} -> {result.visitedNodeCount:>5} nodes | {result.timeMs:.3f}ms")

## Test: Fixed test cases

Running 3 solvable and 3 unsolvable (using the boards shown in the slides)

In [45]:
# 3 solvables
Tester.runCase("solvable #1", Board([[1, 0, 3], [4, 2, 5], [7, 8, 6]]))
Tester.runCase("solvable #2", Board([[1, 2, 5], [3, 4, 0], [6, 7, 8]]))
Tester.runCase("solvable #3", Board([[0, 1, 3], [4, 2, 5], [7, 8, 6]]))

# 3 unsolvables
Tester.runCase("unsolvable #1", Board([[1, 2, 3], [4, 5, 6], [8, 7, 0]]))
Tester.runCase("unsolvable #2", Board([[1, 2, 3], [4, 5, 6], [8, 0, 7]]))
Tester.runCase("unsolvable #3", Board([[2, 1, 3], [4, 5, 6], [7, 8, 0]]))

  solvable #1  [manhattan]
  [initial board]
  ┌────┬────┬────┐
  │ 1  │ *  │ 3  │
  ├────┼────┼────┤
  │ 4  │ 2  │ 5  │
  ├────┼────┼────┤
  │ 7  │ 8  │ 6  │
  └────┴────┴────┘
  [SOLVED] in 3 moves | 4 nodes visited | 0.428ms

  [initial state]
  ┌────┬────┬────┐
  │ 1  │ *  │ 3  │
  ├────┼────┼────┤
  │ 4  │ 2  │ 5  │
  ├────┼────┼────┤
  │ 7  │ 8  │ 6  │
  └────┴────┴────┘
      g=0  h=3  f=3

  [step 1 - move tile 2 DOWN]
  ┌────┬────┬────┐
  │ 1  │ 2  │ 3  │
  ├────┼────┼────┤
  │ 4  │ *  │ 5  │
  ├────┼────┼────┤
  │ 7  │ 8  │ 6  │
  └────┴────┴────┘
      g=1  h=2  f=3

  [step 2 - move tile 5 RIGHT]
  ┌────┬────┬────┐
  │ 1  │ 2  │ 3  │
  ├────┼────┼────┤
  │ 4  │ 5  │ *  │
  ├────┼────┼────┤
  │ 7  │ 8  │ 6  │
  └────┴────┴────┘
      g=2  h=1  f=3

  [GOAL STATE: step 3 - move tile 6 DOWN]
  ┌────┬────┬────┐
  │ 1  │ 2  │ 3  │
  ├────┼────┼────┤
  │ 4  │ 5  │ 6  │
  ├────┼────┼────┤
  │ 7  │ 8  │ *  │
  └────┴────┴────┘
      g=3  h=0  f=3

  solvable #2  [manhattan]
  [init

## Test: Random Cases + Heuristic Comparison

Running random cases and comparing heuristic algorithms performance

In [46]:
Tester.runAll()

  [TEST] random 3x3 test cases

  random solvable #1  [manhattan]
  [initial board]
  ┌────┬────┬────┐
  │ 7  │ 3  │ 4  │
  ├────┼────┼────┤
  │ 5  │ *  │ 6  │
  ├────┼────┼────┤
  │ 2  │ 8  │ 1  │
  └────┴────┴────┘
  [SOLVED] in 22 moves | 545 nodes visited | 34.679ms

  [initial state]
  ┌────┬────┬────┐
  │ 7  │ 3  │ 4  │
  ├────┼────┼────┤
  │ 5  │ *  │ 6  │
  ├────┼────┼────┤
  │ 2  │ 8  │ 1  │
  └────┴────┴────┘
      g=0  h=14  f=14

  [step 1 - move tile 3 UP]
  ┌────┬────┬────┐
  │ 7  │ *  │ 4  │
  ├────┼────┼────┤
  │ 5  │ 3  │ 6  │
  ├────┼────┼────┤
  │ 2  │ 8  │ 1  │
  └────┴────┴────┘
      g=1  h=15  f=16

  [step 2 - move tile 4 RIGHT]
  ┌────┬────┬────┐
  │ 7  │ 4  │ *  │
  ├────┼────┼────┤
  │ 5  │ 3  │ 6  │
  ├────┼────┼────┤
  │ 2  │ 8  │ 1  │
  └────┴────┴────┘
      g=2  h=14  f=16

  [step 3 - move tile 6 DOWN]
  ┌────┬────┬────┐
  │ 7  │ 4  │ 6  │
  ├────┼────┼────┤
  │ 5  │ 3  │ *  │
  ├────┼────┼────┤
  │ 2  │ 8  │ 1  │
  └────┴────┴────┘
      g=3  h=15  f=1