In [1]:
import numpy as np
import pprint
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv

In [2]:
pp = pprint.PrettyPrinter(indent=2)
env = GridworldEnv()

In [19]:
def value_iteration(env, theta=0.0001, discount_factor=1.0):
    """
    Value Iteration Algorithm.
    
    Args:
        env: OpenAI environment. env.P represents the transition probabilities of the environment.
        theta: Stopping threshold. If the value of all states changes less than theta
            in one iteration we are done.
        discount_factor: lambda time discount factor.
        
    Returns:
        A tuple (policy, V) of the optimal policy and the optimal value function.        
    """
    
    
    V = np.zeros(env.nS)
    
    while True:
        delta = 0
        for s in range(env.nS):
            old_v = V[s]
            
            action_values = np.zeros(env.nA)
            for a in range(env.nA):
                for prob, reward, next_state, done in env.P[s][a]:                    
                    # full backup
                    action_values[a] = prob * (reward + discount_factor * V[next_state])
            print(action_values)
            # iterate value function
            V[s] = np.max(action_values)            
            delta = max(delta, np.abs(V[s] - old_v))
        
        if delta < theta:
            print(V)
            # ignoring ties, iterate policy greedily
            break
    
    policy = np.zeros([env.nS, env.nA])
    
    for s in range(env.nS):
        action_values = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, reward, next_state, done in env.P[s][a]:                    
                # full backup
                action_values[a] = prob * (reward + discount_factor * V[next_state])        
        
        best_a = np.argmax(action_values)
        policy[s] = np.eye(env.nA)[best_a]
    # Implement!
    return policy, V

In [20]:
policy, v = value_iteration(env)

print("Policy Probability Distribution:")
print(policy)
print("")

print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

print("Value Function:")
print(v)
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")

[ 0.  0.  0.  0.]
[ 1.  2.  5.  0.]
[ 2.  3.  6.  1.]
[ 3.  3.  7.  2.]
[ 0.  5.  8.  4.]
[ 1.  6.  9.  4.]
[  2.   7.  10.   5.]
[  3.   7.  11.   6.]
[  4.   9.  12.   8.]
[  5.  10.  13.   8.]
[  6.  11.  14.   9.]
[  7.  11.  15.  10.]
[  8.  13.  12.  12.]
[  9.  14.  13.  12.]
[ 10.  15.  14.  13.]
[ 15.  15.  15.  15.]
[ 0.  0.  0.  0.]
[ 16.  17.  20.  15.]
[ 17.  18.  21.  16.]
[ 18.  18.  22.  17.]
[ 15.  20.  23.  19.]
[ 16.  21.  24.  19.]
[ 17.  22.  25.  20.]
[ 18.  22.  26.  21.]
[ 19.  24.  27.  23.]
[ 20.  25.  28.  23.]
[ 21.  26.  29.  24.]
[ 22.  26.  30.  25.]
[ 23.  28.  27.  27.]
[ 24.  29.  28.  27.]
[ 25.  30.  29.  28.]
[ 15.  15.  15.  15.]
[ 0.  0.  0.  0.]
[ 16.  17.  20.  15.]
[ 17.  18.  21.  16.]
[ 18.  18.  22.  17.]
[ 15.  20.  23.  19.]
[ 16.  21.  24.  19.]
[ 17.  22.  25.  20.]
[ 18.  22.  26.  21.]
[ 19.  24.  27.  23.]
[ 20.  25.  28.  23.]
[ 21.  26.  29.  24.]
[ 22.  26.  30.  25.]
[ 23.  28.  27.  27.]
[ 24.  29.  28.  27.]
[ 25.  30.  29.  28.



In [13]:
# Test the value function
expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1,  0])
np.testing.assert_array_almost_equal(v, expected_v, decimal=2)

AssertionError: 
Arrays are not almost equal to 2 decimals

(mismatch 93.75%)
 x: array([  0.,  20.,  21.,  22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,
        30.,  28.,  29.,  30.,  15.])
 y: array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1,  0])