In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
if 'google.colab'in sys.modules:
    !pip install scikit-decide[all]

In [4]:
!(rm -rf scikit-maze/; git clone https://github.com/galleon/scikit-maze.git)

Cloning into 'scikit-maze'...
remote: Enumerating objects: 93, done.[K
remote: Counting objects: 100% (93/93), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 93 (delta 35), reused 78 (delta 20), pack-reused 0[K
Unpacking objects: 100% (93/93), done.


In [5]:
sys.path.insert(0,'./scikit-maze')

In [6]:
from enum import Enum
from typing import Any, List, NamedTuple, Optional

from skdecide import DeterministicPlanningDomain, Space, Value
from skdecide.builders.domain import UnrestrictedActions, Renderable
from skdecide.utils import rollout, match_solvers, load_registered_solver
from skdecide.hub.space.gym import ListSpace, EnumSpace, MultiDiscreteSpace
from skdecide.hub.solver.lazy_astar import LazyAstar

from utils import Maze

In [7]:
from PIL import Image

In [8]:
import io
import ipywidgets as widgets

In [9]:
#buff = io.BytesIO()
#img = Image.fromarray(im_array)
#img.save(buff, format='png')

In [10]:
# widgets.Image(value=buff.getvalue(), format='png', layout=widgets.Layout(width='200px'))
#widgets.Image(value=buff.getvalue(), layout=widgets.Layout(width='200px'))

## Define Action & State spaces

In [11]:
class State(NamedTuple):
    x: int
    y: int

class Action(Enum):
    up = 0
    down = 1
    left = 2
    right = 3

## Define a base domain

In [12]:
class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable):
    T_state = State  # Type of states
    T_observation = T_state  # Type of observations
    T_event = Action  # Type of events
    T_value = float  # Type of transition values (rewards or costs)
    T_predicate = bool  # Type of logical checks
    T_info = None  # Type of additional information in environment outcome

## Create the maze domain

In [18]:
from math import sqrt

class MazeDomain(D):

    def __init__(self, start, end, maze, image_widget=None):
        self.start = start
        self.end = end
        self.maze = maze
        self.image_widget = image_widget
        #if self.image_widget:
        #    self.image_widget.layout = widgets.Layout(width='200px')

    def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state:
        # Move agent according to action (except if bumping into a wall)
        next_x, next_y = memory.x, memory.y
        if action == Action.up:
            next_x -= 1
        if action == Action.down:
            next_x += 1
        if action == Action.left:
            next_y -= 1
        if action == Action.right:
            next_y += 1
        return State(next_x, next_y) if self.maze.is_an_empty_cell(next_x, next_y) else memory

    def _get_transition_value(self, memory: D.T_state, action: D.T_event, next_state: Optional[D.T_state] = None) -> \
            Value[D.T_value]:
        # Set cost to 1 when moving (energy cost) and to 2 when bumping into a wall (damage cost)
        return Value(cost=1 if next_state != memory else 2)

    def _get_initial_state_(self) -> D.T_state:
        # Set the start position as initial state
        return self.start

    def _get_goals_(self) -> Space[D.T_observation]:
        # Set the end position as goal
        return ListSpace([self.end])

    def _is_terminal(self, state: D.T_state) -> D.T_agent[D.T_predicate]:
        # Stop an episode only when goal reached
        return self._is_goal(state)

    def _get_action_space_(self) -> Space[D.T_event]:
        # Define action space
        return EnumSpace(Action)

    def _get_observation_space_(self) -> Space[D.T_observation]:
        # Define observation space
        num_rows = self.maze.height
        num_cols = self.maze.width
        return MultiDiscreteSpace([num_rows, num_cols])

    def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any:
        #  display maze in an image widget
        #with io.BytesIO() as buff:
        #    img = Image.fromarray(maze.get_image().repeat(4, 0).repeat(4, 1))
        #    img.save(buff, format='png')
        #    self.image_widget.value = buff.getvalue()
        buff = io.BytesIO()
        maze_ = self.maze.get_image(memory.x, memory.y).repeat(4, 0).repeat(4, 1)
        img = Image.fromarray(maze_)
        img.save(buff, format='png')
        self.image_widget.value = buff.getvalue()
        
        
    def heuristic(self, s: D.T_state) -> Value:
        return Value(cost=sqrt((self.end.x - s.x)**2 + (self.end.y - s.y)**2))
    
    def state_features(self, s: D.T_state) -> List[float]:
        return [s.x, s.y]

## Define a Maze

In [19]:
height, width = 50, 50

domain_factory = lambda: MazeDomain(State(1, 1), State(height-2, width-2), maze=Maze(width, height))

## Let solve with A*

Let's try to use a first solver named A. A (pronounced "A-star") is a graph traversal and path search algorithm, which is often used in many fields of computer science due to its completeness, optimality, and optimal efficiency.

One major practical drawback is its  𝑂(𝑏𝑑)  space complexity, as it stores all generated nodes in memory.

In [20]:
solver = LazyAstar()
solver._initialize()
MazeDomain.solve_with(solver, domain_factory)

<skdecide.hub.solver.lazy_astar.lazy_astar.LazyAstar at 0x7f0aa51cb100>

## Rollout

In [21]:
max_steps = (height-2)*(width-2)

output = widgets.Image(format='png')

#output.value = buff.getvalue()

output

Image(value=b'')

In [22]:
domain = MazeDomain(State(1, 1), State(height-2, width-2), maze=Maze(width, height), image_widget=output)
rollout(domain, solver, max_steps=max_steps, max_framerate=80, verbose=False)

KeyError: State(x=1, y=1)