# Utils

In [None]:
def vector_add(a, b):
    """Component-wise addition of two vectors."""
    if not (a and b):
        return a or b
    if hasattr(a, '__iter__') and hasattr(b, '__iter__'):
        assert len(a) == len(b)
        return list(map(vector_add, a, b))
    else:
        return a + b

def isnumber(x):
    """Is x a number?"""
    return hasattr(x, '__int__')

def print_table(table, header=None, sep='', numfmt='{}'):
    """Print a list of lists as a table, so that columns line up nicely.
    header, if specified, will be printed as the first row.
    numfmt is the format for all numbers; you might want e.g. '{:.2f}'.
    (If you want different formats in different columns,
    don't use print_table.) sep is the separator between columns."""
    justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]

    if header:
        table.insert(0, header)

    table = [[numfmt.format(x) if isnumber(x)
                else "###" if x==None
                else x for x in row]
             for row in table]
    sizes = list(
        map(lambda seq: max(map(len, seq)),
            list(zip(*[map(str, row) for row in table]))))

    for row in table:
        print(sep.join(getattr(
            str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))

# MDP

In [None]:
class GridMDP:
    """
    A Markov Decision Process on a two-dimensional grid.
    Attributes:
        grid (list of lists): Reward grid, where None indicates obstacles.
        terminals (set): Terminal states.
        init (tuple): Initial state.
        gamma (float): Discount factor (0 < gamma <= 1).
        rows (int): Number of rows in the grid.
        cols (int): Number of columns in the grid.
        orientations (tuple): Valid directions as unit vectors: (east, north, west, south).
        turns (tuple): Turn directions: (left, right).
    """

    def __init__(self, grid, terminals, init=(1, 1), gamma=0.99):
        # Reverse grid for bottom-to-top indexing
        self.grid = grid[::-1]
        self.rows = len(grid)
        self.cols = len(grid[0])

        # Extract states, reward, and validate input
        self.states = set()
        self.reward = {}
        for y in range(self.rows):
            for x in range(self.cols):
                if self.grid[y][x] is not None:
                    self.states.add((x, y))
                    self.reward[(x, y)] = self.grid[y][x]

        if init not in self.states:
            raise ValueError("Invalid initial state:", init)
        if any(t not in self.states for t in terminals):
            raise ValueError("Invalid terminal states:", terminals)

        self.terminals = terminals
        self.init = init
        self.gamma = gamma
        self.orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]  
        #the 4 variables are transparent to self.orientation, it is like performing
        #self.orientations = [(1, 0), (0, 1), (-1, 0), (0, -1)]
        #EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]
        self.turns = LEFT, RIGHT = (+1, -1)

        # Precompute transition probabilities for efficiency
        self.transitions = {s: self._calculate_T(s) for s in self.states}

    def _calculate_T(self, s):
        """
        Calculate transition probabilities for all actions from a state.

        Args:
        state (tuple): Current state.

         Returns:
            dict: Mapping from action to list of (probability, next_state) pairs.
        """
        transitions = {action: [(0.8, self._go(s, action))]
                   for action in self.orientations}
        for action in transitions:
            transitions[action].append((0.1, self._go(s, self._turn_direction(action, -1))))
            transitions[action].append((0.1, self._go(s, self._turn_direction(action, +1))))
        return transitions

    def _turn_direction(self, direction, turn):
        """
        Turn the given direction by the specified amount.

        Args:
            direction (tuple): Current direction.
            turn (int): direction to turn (left: -1, right: 1).

        Returns:
            tuple: New direction.
        """
        index = self.orientations.index(direction)
        return self.orientations[(index + turn) % len(self.orientations)]

    def _go(self, state, direction):
        """
        Move one step in the given direction, handling boundaries.

        Args:
            state (tuple): Current state.
            direction (tuple): Direction to move.

        Returns:
            tuple: New state.
        """
        new_state = tuple(vector_add(state, direction))
        return new_state if new_state in self.states else state

    def R(self, state):
        """
        Get the reward for a state.

        Args:
            state (tuple): State.

        Returns:
            float: Reward.
        """
        return self.reward[state]

    def T(self, state, action):
        """
        Get the transition probabilities for a state and action.

        Args:
            state (tuple): State.
            action (tuple): Action.

        Returns:
            list: List of (probability, next_state) pairs.
        """
        return self.transitions[state][action] if action else [(0.0, state)]


    def actions(self, state):
        """
        Get the available actions in a state (always oriented actions).

        Args:
            state (tuple): State.

        Returns:
            list: List of actions (possible directions).
        """
        if state in self.terminals:
            return [None]
        else:
            return self.orientations

    def to_grid(self, mapping):
        """
        Convert a mapping from (x, y) to values into a grid representation.

        Args:
            mapping (dict): Mapping from (x, y) to values.

        Returns:
            list of lists: Grid representation.
        """
        return list(reversed([[mapping.get((x, y), None) for x in range(self.cols)]
                              for y in range(self.rows)]))

    def to_arrows(self, policy):
        """
        Convert a policy (mapping from state to action) into a grid showing corresponding arrow directions.

        Args:
            policy (dict): Mapping from state to action.

        Returns:
            list of lists: Grid representation with arrows.
        """
        chars = {(1, 0): " > ", (0, 1): ' ∧ ', (-1, 0): ' < ', (0, -1): ' ∨ ', None: ' G '}
        return self.to_grid({s: chars[a] for (s, a) in policy.items()})



# Environment

In [None]:
grid = [
    [None, None, None, None, None, None, None, None, None, None, None],
    [None, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, None, +5.0, None],
    [None, -0.1, None, None, None, None, None, None, None, -0.1, None],
    [None, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, None],
    [None, -0.1, None, None, None, None, None, None, None, None, None],
    [None, -0.1, None, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, None],
    [None, -0.1, None, None, None, None, None, -0.1, -0.1, -0.1, None],
    [None, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, None],
    [None, None, None, None, None, -0.1, None, -0.1, -0.1, -0.1, None],
    [None, -5.0, -0.1, -0.1, -0.1, -0.1, None, -0.1, -0.1, -0.1, None],
    [None, None, None, None, None, None, None, None, None, None, None]
]
terminals = [(9, 9)]
maze = GridMDP(grid, terminals)

# Q-Values and Best Policy

In [None]:
def q_values(mdp, s, V):
    res = [sum(p*V[si] for p, si in mdp.T(s, a))
            for a in mdp.actions(s)]
    return res


def best_policy(mdp, V):
    """Given an MDP and a utility function U, determine the best policy,
    as a mapping from state to action."""
    pi = {}
    for s in mdp.states:
        if s in mdp.terminals:  # Skip terminal states.
            pi[s] = None
            continue
        qs = q_values(mdp, s, V)
        pi[s]=  mdp.actions(s)[qs.index(max(qs))]
    return pi

# Value Iteration

In [None]:
def value_iteration(self, iterations=20, epsilon=1e-3):
    """
    Perform value iteration algorithm to solve the MDP.

    Args:
        iterations (int): Number of iterations.
    Returns:
        dict: Mapping from state to value.
    """

    V = {s: 0 for s in self.states}
    for _ in range(iterations):
        _V = V.copy()
        delta = 0

        for s in self.states:
            V[s] = self.R(s) + self.gamma * max(q_values(self, s, V))
            delta = max(delta, abs(_V[s]-V[s]))

        if delta <= epsilon * (1 - self.gamma) / self.gamma:
            break
    return V

# Run!

In [None]:
V = value_iteration(maze)
pi = best_policy(maze, V)
print_table(maze.to_arrows(pi))

#################################
### ∨  <  <  <  <  <  < ### G ###
### ∨ ##################### ∧ ###
### >  >  >  >  >  >  >  >  ∧ ###
### ∧ ###########################
### ∧ ### >  >  >  >  ∨  <  < ###
### ∧ ############### ∨  <  < ###
### ∧  <  <  <  <  <  <  <  < ###
############### ∧ ### ∧  <  < ###
### >  >  >  >  ∧ ### ∧  <  < ###
#################################


# Policy Iteration

In [None]:
def policy_evaluation(mdp, pi, V, k=20):
    """Return an updated utility mapping V from each state in the MDP to its
    utility, using an approximation (modified policy iteration)."""
    for i in range(k):
        for s in mdp.states:
            V[s] = mdp.R(s) + mdp.gamma*sum(p*V[si] for p, si in mdp.T(s, pi[s]))
    return V


def policy_iteration(mdp, iterations=10):

    import random
    V = {s: 0 for s in mdp.states}
    pi = {s: random.choice(mdp.actions(s)) for s in mdp.states}

    for _ in range(iterations):
        V = policy_evaluation(mdp, pi, V)
        unchanged = True

        for s in mdp.states:
            qs = q_values(mdp, s, V)
            q_max = max(qs)
            a_max = mdp.actions(s)[qs.index(q_max)]

            if q_max > sum(p*V[si] for p, si in mdp.T(s, pi[s])):
                pi[s] = a_max
                unchanged = False

        if unchanged:
            break

    return pi

# Run

In [None]:
pi = policy_iteration(maze)
print_table(maze.to_arrows(pi))

#################################
### ∨  <  <  <  <  <  < ### G ###
### ∨ ##################### ∧ ###
### >  >  >  >  >  >  >  >  ∧ ###
### ∧ ###########################
### ∧ ### >  >  >  >  ∨  ∨  < ###
### ∧ ############### ∨  <  < ###
### ∧  <  <  <  <  <  <  <  < ###
############### ∧ ### ∧  <  < ###
### >  >  >  >  ∧ ### ∧  <  < ###
#################################
