In [1]:
from env import create_standard_grid, create_custom_grid_1, create_custom_grid_2
from algorithms import iterative_policy_evaluation
from algorithms import compute_policy_from_values


In [2]:
gw = create_custom_grid_1()

In [3]:
policy = {
    (0, 0): 'right', (0, 1): 'right', (0, 2): 'up', (0, 3): 'up',
    (1, 0): 'up', (1, 1): '', (1, 2): '', (1, 3): '',
    (2, 0): 'right', (2, 1): 'right', (2, 2): 'down', (2, 3): 'down',
    (3, 0): 'right', (3, 1): 'right', (3, 2): 'down', (3, 3): 'down'
}

In [34]:
# from page 80 of Sutton and Barto, RL, 2nd. Ed.
def policy_iteration(gw, policy, gamma=0.9, epsilon=0.001):
    while True:
        # perform iterative policy evaluation to update values
        iterative_policy_evaluation(gw, policy, gamma, epsilon)
        # update policy from new values
        new_policy = compute_policy_from_values(gw, gamma)
        # see if policy has changed
        for action in policy:
            if policy[action] == new_policy[action]:
                policy_stable = True
            else:
                policy_stable = False
                break
        # update policy
        policy = new_policy
        # repeat until policy does not change
        if policy_stable == True:
            break

In [35]:
def iterative_policy_evaluation(gw, policy, gamma=0, epsilon=0.001):    #print(gw)
    #print(policy)
    #print(gamma)
    #print(epsilon)
    while True:
        biggest_change = 0
        for node in gw:
            state = node.state
            if not gw.is_terminal(state) and not gw.is_barrier(state):
                # get current (old) value
                old_value = gw.get_value(state)
                # get action from policy
                action = policy[state]
                # get immediate reward for action
                reward = gw.get_reward_for_action(state, action)
                # get value at destination state
                value_at_dest = gw.get_value_at_destination(state, action)
                # compute new value
                #print("state: {}".format(state))
                #print("action: {}".format(action))
                #print("reward: {}".format(reward))
                #print("gamma: {}".format(gamma))
                #print("value_at_dest: {}".format(value_at_dest))
                new_value = (reward + gamma*value_at_dest)
                #print("new_value: {}".format(new_value))
                # set new value for state
                gw.set_value(state, new_value)
                # see if |new_value-old_value| is larger than biggest_change
                biggest_change = max(
                    biggest_change, abs(new_value-old_value))
        # iterated over all states, so see if biggest_change is small enough
        if biggest_change < epsilon:
            break

In [36]:
#### print("")
print("Initial Policy")
gw.print_policy(policy)
print("")

# note: this execution of iterative policy evaluation is not part 
# of the policy iteration algorithm.  It is for the purpose of 
# displaying the values associated with the input policy

iterative_policy_evaluation(gw, policy)
print("Initial Policy Values")
gw.print_values()

# run policy iteration algorithm
policy_iteration(gw, policy)
# compute policy from optimal values
new_policy = compute_policy_from_values(gw)

#print new policy and values
print("") 
print("New Policy")
gw.print_policy(new_policy)
print("")
print("New Policy Values")
gw.print_values()
print("")

Initial Policy
-------------------------------------
|  Right |  Right |   Down |   Down |
-------------------------------------
|  Right |  Right |   Down |   Down |
-------------------------------------
|     Up |        |        |        |
-------------------------------------
|  Right |  Right |     Up |     Up |
-------------------------------------

Initial Policy Values
-------------------------------------
|   0.73 |   0.81 |   0.90 |  -0.90 |
-------------------------------------
|   0.81 |   0.90 |   1.00 |  -1.00 |
-------------------------------------
|   0.73 |   0.00 |   0.00 |   0.00 |
-------------------------------------
|   0.81 |   0.90 |   1.00 |  -1.00 |
-------------------------------------

New Policy
-------------------------------------
|  Right |  Right |   Down |   Left |
-------------------------------------
|  Right |  Right |   Down |   Left |
-------------------------------------
|   Down |        |        |        |
-------------------------------------
