In [1]:
import numpy as np

def get_grid(grid):
    if '\n' in grid:
        return grid
    if grid not in grids:
        raise Exception()
    return grids[grid].strip()

MOVES = {0 : (-1, 0),
         1 : ( 0, 1),
         2 : ()}
def move(state, action_id):
    if action_id == 0:
        return state[0] - 1, state[1]
    elif action_id == 1:
        return state[0], state[1] + 1
    elif action_id == 2:
        return state[0] + 1, state[1]
    elif action_id == 3:
        return state[0], state[1] - 1
    raise Exception()

class Base():

    def __init__(self, *args, params = {}, exceptions = [], **kwargs):
        exceptions   = list(exceptions) + ['self', '__class__']
        self._params = {}
        for k, v in params.items():
            if k in exceptions: continue
            self._params[k] = v
        self.__dict__.update(self._params)
        self.__name__ = params['__class__'].__name__
        super().__init__(*args, **kwargs)

    def __repr__(self):
        name   = self.__name__
        keys   = list(self._params)
        values = [f'"{value}"' if isinstance(value, str) else f'array{value.shape}' if isinstance(value, np.ndarray) else value for value in self._params.values()]
        params = ', '.join(f'{key} = {value}' for key, value in zip(keys, values))
        return f'{name}({params})'

class Mapper(Base, dict):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, params = locals(), exceptions = ['args', 'kwargs'], **kwargs)
        update    = {value : key for key, value in self.items()}
        self.repr = 'Mapper({' + ', '.join([f'{key} : {value}' for key, value in self.items()]) + '})'
        self.update(update)

    def __getitem__(self, key):
        return super().__getitem__(key)

    def __repr__(self):
        return self.repr

class GridWorld(Base):

    def __init__(self, grid : str, start_locations : list, terminal_locations : dict, movement_cost : float = -1., p : float = 1.):
        super().__init__(params = locals())
        self._grid     = get_grid(grid)
        states         = {state : i for i, state in enumerate((i, j) for i, row in enumerate(self._grid.split('\n')) for j, val in enumerate(row) if val == ' ')}
        self._states   = Mapper(states)
        self._start    = [self._states[loc] if isinstance(loc, (int, np.int_)) else loc for loc in start_locations]
        
        self.P = {}
        if p == 1:
            for state in self._states:
                self.P[state] = {}
                

    @property
    def grid(self):
        print(self._grid)

    def reset(self):
        self.loc      = self._start[np.random.choice(len(self._start))]
        self.terminal = False
        return self.loc

    def step(self, action_id : int):
        current   = self.loc
        terminal  = current in self.terminal_locations
        new_state = current if terminal else move(self.loc, action_id)

        if new_state in self._states:
            self.loc = new_state
        r  = self.movement_cost
        if new_state != current and new_state in self.terminal_locations:
            r += self.terminal_locations[new_state]
        return r, self._states[self.loc], terminal, {}

    def current_grid(self, mode = 'ascii'):
        grid = [[*row] for row in self._grid.split('\n')]
        i, j = self.loc
        grid[i][j] = 'A'
        print('\n'.join(''.join(row) for row in grid))
        

four_rooms = """
#############
#     #     #
#     #     #
#           #
#     #     #
#     #     #
### ##### ###
#     #     #
#     #     #
#           #
#     #     #
#     #     #
#############
"""

grids = dict(four_rooms = four_rooms)


In [93]:
gridworld = GridWorld('four_rooms', np.array([0]), {})
gridworld

GridWorld(grid = "four_rooms", start_locations = array(1,), terminal_locations = {}, movement_cost = -1.0, p = 1.0)

In [94]:
gridworld.reset()

(1, 1)

In [95]:
gridworld.current_grid()

#############
#A    #     #
#     #     #
#           #
#     #     #
#     #     #
### ##### ###
#     #     #
#     #     #
#           #
#     #     #
#     #     #
#############


In [105]:
gridworld.step(0)

(-1.0, (1, 1), False, {})

In [79]:
dir(gridworld)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__name__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_grid',
 '_params',
 '_start',
 '_states',
 '_terminal',
 'current_grid',
 'grid',
 'loc',
 'reset',
 'step',
 'terminal']