In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
if 'google.colab'in sys.modules:
  !pip install scikit-decide[all]

In [None]:
!(rm -rf scikit-maze/; git clone https://github.com/galleon/scikit-maze.git)

In [None]:
sys.path.insert(0,'./scikit-maze')

In [None]:
from enum import Enum
from typing import Any, List, NamedTuple, Optional

from skdecide import DeterministicPlanningDomain, Space, Value
from skdecide.builders.domain import UnrestrictedActions, Renderable
from skdecide.utils import rollout, match_solvers, load_registered_solver
from skdecide.hub.space.gym import ListSpace, EnumSpace, MultiDiscreteSpace
from skdecide.hub.solver.lazy_astar import LazyAstar

from utils import Maze

In [None]:
from PIL import Image

In [None]:
import io
import ipywidgets as widgets

## Define Action & State spaces

In [None]:
class State(NamedTuple):
  x: int
  y: int

class Action(Enum):
  up = 0
  down = 1
  left = 2
  right = 3

## Define a base domain

In [None]:
class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable):
  T_state = State  # Type of states
  T_observation = T_state  # Type of observations
  T_event = Action  # Type of events
  T_value = float  # Type of transition values (rewards or costs)
  T_predicate = bool  # Type of logical checks
  T_info = None  # Type of additional information in environment outcome

## Create the maze domain

In [None]:
from math import sqrt

class MazeDomain(D):

  def __init__(self, start, end, maze, image_widget=None):
    self.start = start
    self.end = end
    self.maze = maze
    self.image_widget = image_widget
    #if self.image_widget:
    #    self.image_widget.layout = widgets.Layout(width='200px')

  def _get_next_state(self, memory: State, action: Action) -> State:
    # Move agent according to action (except if bumping into a wall)
    next_x, next_y = memory.x, memory.y
    if action == Action.up:
      next_x -= 1
    if action == Action.down:
      next_x += 1
    if action == Action.left:
       next_y -= 1
    if action == Action.right:
       next_y += 1
    return State(next_x, next_y) if self.maze.is_an_empty_cell(next_x, next_y) else memory

  def _get_transition_value(self, memory: State, action: Action, next_state: Action = None) -> \
      Value[D.T_value]:
    # Set cost to 1 when moving (energy cost) and to 2 when bumping into a wall (damage cost)
    return Value(cost=1 if next_state != memory else 2)

  def _get_initial_state_(self) -> State:
    # Set the start position as initial state
    return self.start

  def _get_goals_(self) -> Space[D.T_observation]:
    # Set the end position as goal
    return ListSpace([self.end])

  def _is_terminal(self, state: State) -> D.T_agent[D.T_predicate]:
    # Stop an episode only when goal reached
    return self._is_goal(state)

  def _get_action_space_(self) -> Space[D.T_event]:
    # Define action space
    return EnumSpace(Action)

  def _get_observation_space_(self) -> Space[D.T_observation]:
    # Define observation space
    num_rows = self.maze.height
    num_cols = self.maze.width
    return MultiDiscreteSpace([num_rows, num_cols])

  def _render_from(self, memory: State, **kwargs: Any) -> Any:
    buff = io.BytesIO()
    maze_ = self.maze.get_image(memory.x, memory.y).repeat(4, 0).repeat(4, 1)
    img = Image.fromarray(maze_)
    img.save(buff, format='png')
    return buff.getvalue()
            
  def heuristic(self, s: State) -> Value:
     return Value(cost=sqrt((self.end.x - s.x)**2 + (self.end.y - s.y)**2))
    
  def state_features(self, s: State) -> List[float]:
    return [s.x, s.y]

## Define a Maze

In [None]:
height, width = 50, 50

maze = Maze(width, height)

domain_factory = lambda: MazeDomain(State(1, 1), State(width-1, height-1), maze)

## Render the maze

In [None]:
widgets.Image(value=domain_factory()._render_from(State(1, 1)))

In [None]:
widgets.Image(value=domain_factory()._render_from(State(49, 49)))

## Let solve with A*

Let's try to use a first solver named A. A (pronounced "A-star") is a graph traversal and path search algorithm, which is often used in many fields of computer science due to its completeness, optimality, and optimal efficiency.

One major practical drawback is its  𝑂(𝑏𝑑)  space complexity, as it stores all generated nodes in memory.

In [None]:
import time

out = widgets.Output()
img = widgets.Image(format='png', layout=widgets.Layout(max_width='300px'))
display(widgets.VBox([img, out]))

domain = domain_factory()

# Check that we can solve the Maze with LazyAstar
assert LazyAstar.check_domain(domain)

# All good, let's use LazyAstar
with LazyAstar() as solver:
  # Let's solve the domain
  MazeDomain.solve_with(solver, domain_factory)
    
  # Now let's see the solution
  for i_episode in range(1):
    # Initialize episode
    solver.reset()
    observation = domain.reset()
            
    # Let's define maximum number of steps
    step, max_steps = 1, 1000

    while step <= max_steps:

      if isinstance(domain, Renderable):
        img.value = domain._render_from(observation)
    
      action = solver.sample_action(observation)
    
      outcome = domain.step(action)
      observation = outcome.observation

      termination = domain._is_terminal(observation)
      if termination:
        with out:
          print(f'Episode {i_episode + 1} terminated after {step + 1} steps.')
        break

      time.sleep(0.1)
      step += 1
        
      if isinstance(domain, Renderable):
        img.value = domain._render_from(observation)
  with out:    
    print(f'The goal was {"" if domain.is_goal(observation) else " not"} reached in episode {i_episode + 1}.')