In [None]:
import random
from operator import itemgetter
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.table import Table

# Policy Iteration

Implementing policy evaluation and policy improvement for a deterministic policy.

In [None]:
def is_outside(i):
    """Check if the we are still in the grid
    
    Arguments:
        i {int} -- Index of the row or column
    
    Returns:
        is_out -- True if the index takes us outside the grid, else False
    """    
    if (i < 0) or (i >= GRID_SIZE):
        return True

def is_terminal(state):
    """Check if the state is a terminal state
    
    Arguments:
        state {tuple} -- A tuple (x, y) containing the position in the grid
    
    Returns:
        bool -- True if state is terminal, else False
    """    
    x, y = state
    return (x == 0 and y == 0) or (x==GRID_SIZE-1 and y == GRID_SIZE-1)

def take_action(state, action):
    """Take an action from a given state
    
    Arguments:
        state {tuple} -- A tuple (x, y) containing the position in the grid
        action {str} -- An action from a possible set of actions
    
    Returns:
        (state, reward) -- A tuple (x, y) denoting the new state and a reward 
    """    
    x, y = state
    if is_terminal(state):
        return (x, y), 0
    dy, dx = ACTIONS[action]
    x_new, y_new = x+dx, y+dy
    if not is_outside(x_new):
        x = x_new
    if not is_outside(y_new):
        y = y_new
    return (x, y), REWARD

def draw_grid(grid):
    """Draw a table from the grid
    
    Arguments:
        grid {np.ndarray} -- A Grid to be displayed
    
    Returns:
        fig -- The figure containing the table
    """    
    fig, ax = plt.subplots()
    ax.set_axis_off()
    tb = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = grid.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    # Add cells
    for (i, j), val in np.ndenumerate(grid):
        tb.add_cell(i, j, width, height, text=val,
                    loc='center', facecolor='white')

        # Row and column labels...
    for i in range(len(grid)):
        tb.add_cell(i, -1, width, height, text=i+1, loc='right',
                    edgecolor='none', facecolor='none')
        tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
                    edgecolor='none', facecolor='none')
    ax.add_table(tb)
    return fig

def init_policy(GRID_SIZE):
    """Generates a random deterministic policy
    
    Arguments:
        GRID_SIZE {int} -- Size of the NxN grid world
    
    Returns:
        policy -- A random deterministic policy
    """    
    policy = np.random.choice(['←', '→', '↑', '↓'], size=(GRID_SIZE, GRID_SIZE))
    policy[0, 0] = 'x'
    policy[GRID_SIZE-1, GRID_SIZE-1] = 'x'
    return policy

## Define the parameters for the Grid and initialize the parameters

In [None]:
GRID_SIZE = 10
GRID = np.zeros((GRID_SIZE, GRID_SIZE)) 
ACTIONS = {
    "←": [-1, 0],
    "↑": [0, -1],
    "→": [1, 0],
    "↓": [0, 1]
}
POLICY = init_policy(GRID_SIZE)
REWARD = -1
GAMMA = 1

### Display the initial random policy

In [None]:
fig = draw_grid(POLICY)
fig.savefig('./images/before_iteration.png')

# Policy Evaluation
Given a policy $\pi$, evaluate how good the policy is. That is, compute the value function for the given policy $\pi$

In [None]:
def policy_evaluation(grid, policy, gamma=1.0):
    """Evaluate the given policy for the grid
    
    Arguments:
        grid {np.ndarray} -- The grid containing the value of the states
        policy {np.ndarray} -- The deterministic policy which shows actions to take in each state
        gamma {float} -- The discounting factor
    """    
    iter_count = 0
    while True:
        iter_count += 1
        old_value = grid.copy()
        delta = 0
        for i in range(grid.shape[0]):
            for j in range(grid.shape[0]):
                current_state = (i, j)
                action = policy[current_state]

                new_state, reward = take_action(current_state, action)
                value = reward + gamma*GRID[new_state]

                grid[current_state] = value
                delta = max(delta, np.abs(old_value[current_state] - grid[current_state]))
        if delta < 1e-6 or iter_count == 1000:
            return

# Policy Improvement
Given a policy $\pi$, improve the policy using the Policy Improvement Theorem.

For each state $s$, take actions $a$ and then follow $pi$ afterwards.

Take the action which gives the maximum value from the above computation.

If the new maximum value is greater than $v_{\pi}(s)$, update the policy $\pi(s)$ with the action that gave the maximum value.

In [None]:
def policy_improvement(grid, policy, actions, gamma):
    """Improve the policy using policy improvement theorem
    
    Arguments:
        grid {np.ndarray} -- The grid containing the value of the states
        policy {np.ndarray} -- The deterministic policy which shows actions to take in each state
        actions {dict} -- A dictionary showing how to change the state based on the action
        gamma {float} -- The discounting factor
    
    Returns:
        bool -- True if the policy cannot be further improved else False
    """    
    policy_stable = True
    for i in range(grid.shape[0]):
        for j in range(grid.shape[0]):
            current_state = (i, j)
            current_value = grid[current_state]
            old_action = policy[current_state]
            values = []
            for action in actions:
                new_state, reward = take_action(current_state, action)
                value = reward + gamma*grid[new_state]
                values.append([action, value])
            max_action, max_value = max(values, key=itemgetter(1))
            if max_value > current_value:
                policy[current_state] = max_action
            if old_action != policy[current_state]:
                policy_stable = False
    return policy_stable

# Perform Policy Iteration

Evaluate -> Improve -> Evaluate -> Improve ...

In [None]:
max_iter = 10000
for i in range(max_iter):
    policy_evaluation(GRID, POLICY, gamma=GAMMA)
    converged = policy_improvement(GRID, POLICY, ACTIONS, gamma=GAMMA)
    if converged:
        print(f'Stable Policy Found! Took {i+1} iterations.')
        break
else:
    print('Maximum iterations exceeded')

In [None]:
fig = draw_grid(POLICY)
fig.savefig('images/after_iteration.png')