In [1]:
import numpy as np
import math

In [2]:
def compute_next_state(current, action, dim):
    x, y = current
    coord = None
    
    if action == 'up':
        coord = (x-1, y) if x-1 >= 0 else (x, y)
    elif action == 'down':
        coord = (x+1, y) if x+1 < dim else (x, y)
    elif action == 'left':
        coord = (x, y-1) if y-1 >= 0 else (x, y)
    elif action == 'right':
        coord = (x, y+1) if y+1 < dim else (x, y)
    
    return coord

In [3]:
'''
This is the algorithm from Chapter 4.1 Policy Evaluation policy
Iterative Policy Evaluation
'''
def eval_policy(grid, grid_policy, theta):
    # 1. init the values
    terminal_states = [ (0,0), (3, 3) ]

    while True: 
        delta = 0.0
        state_iter = np.nditer(grid, flags=['multi_index'])

        for item in state_iter:
            old_value = item
            x, y = state_iter.multi_index

            if (x, y) in terminal_states:
                continue
            
            policy = grid_policy[(x, y)]
            value = 0.0
            for action, prob in policy.items():
                next_state_x, next_state_y = compute_next_state((x, y), action, 4)
                value += prob * 1 * (-1 + grid[next_state_x][next_state_y])

            delta = max(delta, abs(value - old_value))
            grid[x][y] = value

        if delta < theta: 
            break
    return grid

In [4]:
def compute_max_action(action_value_map): 
    max_value = max(action_value_map.values())
    
    max_actions = []
    for k, v in action_value_map.items():
        if v == max_value:
            max_actions.append(k)
    
    for k, v in action_value_map.items():
        if k in max_actions:
            action_value_map[k] = 1.0/len(max_actions)
        else:
            action_value_map[k] = 0
    return action_value_map

In [5]:
def printif(state, msg):
    if state == (1, 2):
        print(msg)

In [6]:
def improve_policy(grid, grid_policy):
    state_iter = np.nditer(grid, flags=['multi_index'])
    terminal_states = [ (0,0), (3, 3) ]
    
    for value in state_iter:
        state = state_iter.multi_index
        
        if state in terminal_states:
            continue
        
        # policy for this state
        old_policy = grid_policy[state]
        
        action_value_map = {}
        
        for action, prob in old_policy.items():
            _x, _y = compute_next_state(state, action, 4)
            action_value = 1 * (-1 + grid[_x][_y])
            action_value_map[action] = action_value
            
        new_policy = compute_max_action(action_value_map)
        
        # this is slightly different from the algorithm in the book. In the book, it compares original action
        # vs the new action which is the argmax of Q. Since we are evaluating for stochastic case here, we have to
        # compare original policy vs updated policy for this particular state.
        
        if old_policy != new_policy:
            # we need to update our policy
            grid_policy[state] = new_policy
            return False
    return True

In [7]:
# initialize the values
grid = np.zeros((4, 4))
theta = 0.05

policy = { 
    'up': 0.25,
    'down': 0.25,
    'left': 0.25,
    'right': 0.25
}

# create policy map for the whole grid
grid_policy = {}
state_iter = np.nditer(grid, flags=['multi_index'])
for value in state_iter:
    state = state_iter.multi_index
    grid_policy[state] = policy

policy_stable = False

while not policy_stable:
    # eval
    grid = eval_policy(grid, grid_policy, theta)
    
    # improve
    policy_stable = improve_policy(grid, grid_policy)

In [8]:
grid_policy

{(0, 0): {'up': 0.25, 'down': 0.25, 'left': 0.25, 'right': 0.25},
 (0, 1): {'up': 0, 'down': 0, 'left': 1.0, 'right': 0},
 (0, 2): {'up': 0, 'down': 0, 'left': 1.0, 'right': 0},
 (0, 3): {'up': 0, 'down': 0.5, 'left': 0.5, 'right': 0},
 (1, 0): {'up': 1.0, 'down': 0, 'left': 0, 'right': 0},
 (1, 1): {'up': 0.5, 'down': 0, 'left': 0.5, 'right': 0},
 (1, 2): {'up': 0.25, 'down': 0.25, 'left': 0.25, 'right': 0.25},
 (1, 3): {'up': 0, 'down': 1.0, 'left': 0, 'right': 0},
 (2, 0): {'up': 1.0, 'down': 0, 'left': 0, 'right': 0},
 (2, 1): {'up': 0.25, 'down': 0.25, 'left': 0.25, 'right': 0.25},
 (2, 2): {'up': 0, 'down': 0.5, 'left': 0, 'right': 0.5},
 (2, 3): {'up': 0, 'down': 1.0, 'left': 0, 'right': 0},
 (3, 0): {'up': 0.5, 'down': 0, 'left': 0, 'right': 0.5},
 (3, 1): {'up': 0, 'down': 0, 'left': 0, 'right': 1.0},
 (3, 2): {'up': 0, 'down': 0, 'left': 0, 'right': 1.0},
 (3, 3): {'up': 0.25, 'down': 0.25, 'left': 0.25, 'right': 0.25}}