In [None]:
# This class models the maze environment as a grid, where a pirate searches for treasure.

import numpy as np

VISITED_FLAG = 0.85   # Mark visited squares with ~85% gray.
PIRATE_FLAG = 0.45    # Pirate's current position is represented by ~45% gray.

# Movement options for the agent
MOVE_LEFT  = 0
MOVE_UP    = 1
MOVE_RIGHT = 2
MOVE_DOWN  = 3

class PirateTreasureMaze:
    """
    The maze is stored as a 2D numpy array with float values:
    - 1.0 → open cell
    - 0.0 → blocked cell
    Pirate starting position is given as (row, col), defaults to (0,0).
    """

    def __init__(self, grid, start=(0,0)):
        self._grid = np.array(grid)
        rows, cols = self._grid.shape
        self.goal = (rows-1, cols-1)   # Treasure located in the bottom-right corner

        # Collect all open cells except the treasure cell
        self.open_cells = [(r, c) for r in range(rows) for c in range(cols) if self._grid[r, c] == 1.0]
        if self.goal in self.open_cells:
            self.open_cells.remove(self.goal)

        # Validation checks
        if self._grid[self.goal] == 0.0:
            raise ValueError("Invalid grid: treasure cell is blocked.")
        if start not in self.open_cells:
            raise ValueError("Invalid start: pirate must begin in an open cell.")

        self.reset(start)

    def reset(self, start):
        """ Reset the maze with pirate at a starting position. """
        self.pirate = start
        self.grid = np.copy(self._grid)
        row, col = start
        self.grid[row, col] = PIRATE_FLAG
        self.state = (row, col, 'start')
        self.min_reward = -0.5 * self.grid.size  # cap on how negative rewards can go
        self.total_reward = 0
        self.visited = set()

    def update_state(self, action):
        """Update pirate’s position depending on action validity."""
        rows, cols = self.grid.shape
        prow, pcol, status = pirate_row, pirate_col, state_status = self.state

        if self.grid[pirate_row, pirate_col] > 0.0:
            self.visited.add((pirate_row, pirate_col))  # mark cell as visited

        valid_moves = self.valid_moves()

        if not valid_moves:
            state_status = 'blocked'
        elif action in valid_moves:
            state_status = 'valid'
            if action == MOVE_LEFT:
                pcol -= 1
            elif action == MOVE_UP:
                prow -= 1
            elif action == MOVE_RIGHT:
                pcol += 1
            elif action == MOVE_DOWN:
                prow += 1
        else:
            state_status = 'invalid'  # attempt to move somewhere impossible

        self.state = (prow, pcol, state_status)

    def get_reward(self):
        """Determine the reward based on current pirate state."""
        prow, pcol, status = self.state
        rows, cols = self.grid.shape

        if (prow, pcol) == (rows-1, cols-1):   # reached treasure
            return 1.0
        if status == 'blocked':
            return self.min_reward - 1
        if (prow, pcol) in self.visited:
            return -0.25
        if status == 'invalid':
            return -0.75
        if status == 'valid':
            return -0.05

    def act(self, action):
        """Perform an action, update state, and return environment feedback."""
        self.update_state(action)
        reward = self.get_reward()
        self.total_reward += reward
        outcome = self.check_status()
        view = self.observe()
        return view, reward, outcome

    def observe(self):
        """Return current environment as a flattened array."""
        canvas = self.render()
        return canvas.reshape((1, -1))

    def render(self):
        """Return visual grid: open = white, blocked = black, pirate = gray."""
        canvas = np.copy(self.grid)
        rows, cols = self.grid.shape

        # reset display
        for r in range(rows):
            for c in range(cols):
                if canvas[r, c] > 0.0:
                    canvas[r, c] = 1.0

        # draw pirate
        row, col, _ = self.state
        canvas[row, col] = PIRATE_FLAG
        return canvas

    def check_status(self):
        """Check if the game is ongoing, won, or lost."""
        if self.total_reward < self.min_reward:
            return 'lose'
        prow, pcol, _ = self.state
        rows, cols = self.grid.shape
        if (prow, pcol) == (rows-1, cols-1):
            return 'win'
        return 'running'

    def valid_moves(self, position=None):
        """Return list of possible moves from a given position (or current)."""
        if position is None:
            row, col, _ = self.state
        else:
            row, col = position

        actions = [MOVE_LEFT, MOVE_UP, MOVE_RIGHT, MOVE_DOWN]
        rows, cols = self.grid.shape

        # Boundary conditions
        if row == 0: actions.remove(MOVE_UP)
        if row == rows-1: actions.remove(MOVE_DOWN)
        if col == 0: actions.remove(MOVE_LEFT)
        if col == cols-1: actions.remove(MOVE_RIGHT)

        # Blocked cell conditions
        if row > 0 and self.grid[row-1, col] == 0.0: actions.remove(MOVE_UP)
        if row < rows-1 and self.grid[row+1, col] == 0.0: actions.remove(MOVE_DOWN)
        if col > 0 and self.grid[row, col-1] == 0.0: actions.remove(MOVE_LEFT)
        if col < cols-1 and self.grid[row, col+1] == 0.0: actions.remove(MOVE_RIGHT)

        return actions
