In [None]:
import import_ipynb
from search_env_base import SearchEnv

In [None]:
class EightTilesEnv(SearchEnv):
    """
    Environment for the 8-tiles sliding puzzle (3x3).
    State is a tuple of 9 integers (0 for the blank, 1-8 for tiles), row-major order.
    Actions are moving the blank (0) up, down, left, or right by swapping with an adjacent tile.
    The cost of each move is 1.
    The goal state is (1,2,3,4,5,6,7,8,0).
    """
    def __init__(self, start_state=None, seed=None):
        if start_state is None:
            # Generate a random solvable 8-tile state using a seed for reproducibility
            import random
            if not seed is None:
                self.seed = 42  # default seed if not set externally
            rng = random.Random(seed)
            def is_solvable(state):
                inv = 0
                s = [x for x in state if x != 0]
                for i in range(len(s)):
                    for j in range(i+1, len(s)):
                        if s[i] > s[j]:
                            inv += 1
                return inv % 2 == 0
            tiles = list(range(9))
            while True:
                rng.shuffle(tiles)
                if is_solvable(tiles):
                    start_state = tuple(tiles)
                    break
        goal_state = (1,2,3,4,5,6,7,8,0)
        super().__init__(start_state, goal_state)
        self.state = start_state

    def get_reachable_states(self, state=None):
        """
        Return a list of valid next states reachable from the current state by sliding the blank.
        """
        if state is None:
            state = self.state
        idx = state.index(0)
        row, col = divmod(idx, 3)
        moves = []
        directions = [(-1,0), (1,0), (0,-1), (0,1)]  # up, down, left, right
        for dr, dc in directions:
            new_row, new_col = row + dr, col + dc
            if 0 <= new_row < 3 and 0 <= new_col < 3:
                new_idx = new_row * 3 + new_col
                new_state = list(state)
                # Swap blank with the adjacent tile
                new_state[idx], new_state[new_idx] = new_state[new_idx], new_state[idx]
                moves.append(tuple(new_state))
        return moves

    def step(self, action):
        """
        Action is the next state to move to (must be in get_reachable_states).
        Returns: next_state, reward, done, info
        """
        if action not in self.get_reachable_states(self.state):
            reward = -1.0  # Penalty for invalid move
            done = False
            info = {'invalid': True}
            return self.state, reward, done, info
        self.state = action
        done = self.is_goal(self.state)
        reward = 1.0 if done else 0.0
        info = {'reachable': self.get_reachable_states(self.state)}
        return self.state, reward, done, info

    def reset(self):
        self.state = self.start_state
        return self.state

    def render(self, mode='human'):
        s = self.state
        print(f"{s[0]} {s[1]} {s[2]}")
        print(f"{s[3]} {s[4]} {s[5]}")
        print(f"{s[6]} {s[7]} {s[8]}")

    def cost(self, from_state, to_state):
        """Return the cost of moving from from_state to to_state (always 1.0 for valid moves)."""
        return 1.0

    def is_goal(self, state):
        """Return True if the given state is the goal state."""
        return state == self.goal_state