In [1]:
class GridWorldMDP:
    def __init__(self, grid, terminals, transition_probs=None, rewards=None, discount=0.9):
        self.grid = grid
        self.terminals = set(terminals)
        self.transition_probs = transition_probs or {a: 1 for a in ['up', 'down', 'left', 'right']}
        self.rewards = rewards or {}
        self.discount = discount
        self.states = set()
        self._init_states()

    def _init_states(self):
        for row in range(len(self.grid)):
            for col in range(len(self.grid[row])):
                if self.grid[row][col] == 'o':
                    self.states.add((row, col))

    def _is_valid_state(self, state):
        row, col = state
        return 0 <= row < len(self.grid) and 0 <= col < len(self.grid[0]) and self.grid[row][col] == 'o'

    def get_transition_states_and_probs(self, state, action):
        if state in self.terminals:
            return [(state, 1.0)]
        movements = {'up': (-1, 0), 'down': (1, 0), 'left': (0, -1), 'right': (0, 1)}
        next_state = tuple(map(sum, zip(state, movements[action])))
        if self._is_valid_state(next_state):
            return [(next_state, self.transition_probs[action])]
        return [(state, 1.0)]

    def get_reward(self, state, action, next_state):
        return self.rewards.get(next_state, 0)


In [2]:
grid = [['o', 'o', 'o', 'o'],
        ['o', 'x', 'o', 'x'],
        ['o', 'o', 'o', 'o']]
terminals = [(2, 3)]
mdp = GridWorldMDP(grid, terminals)

In [3]:
test_state = (1, 2)
test_action = 'up'
transitions = mdp.get_transition_states_and_probs(test_state, test_action)
reward = mdp.get_reward(test_state, test_action, transitions[0][0])
print("Transitions:", transitions)
print("Reward:", reward)

Transitions: [((0, 2), 1)]
Reward: 0
