In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import gym

import gym_gridworld
import dynamic_programming

In [21]:
def print_policy(policy, terminal_states):
    arrows = {0: '\u2191', 1:'\u2192' , 2:'\u2193' , 3:'\u2190'}
    terminal_state = '\u25A0'
    n = int(np.sqrt(len(policy)))
    for i in range(n):
        string = ""
        for j in range(n):
            ij = i*n + j
            if ij in terminal_states:
                string += terminal_state
            else:
                string += arrows[policy[ij]]
        print(string)
    
def print_V_PI(V, PI, terminal_states):
    n = np.sqrt(len(V)).astype(int)
    print("\nV\n", np.array(V).round(1).reshape(n,n))
    print("\nPI")
    print_policy(PI, terminal_states)
    
def compute_policy_iteration(env, dicount_factor, modified_max_k=np.Inf):
    V, PI, k = dynamic_programming.policy_iteration(env, dicount_factor, modified_max_k = modified_max_k)
    print("POLICY ITERATION" if  modified_max_k==np.Inf else "MODIFIED POLICY ITERATION")
    print(f"Policy found in {len(k)} iterations, where each policy evaluation lasted for k = {k}")
    print_V_PI(V, PI, env.terminal_states)
    
def compute_value_iteration(env, dicount_factor):
    V, PI, k = dynamic_programming.value_iteration(env, dicount_factor)
    print("VALUE ITERATION")
    print(f"Policy found in {k} iterations")
    print_V_PI(V, PI, env.terminal_states)


env = gym.make("gridworld-v0", n=10, p_action_works=0.9)
dicount_factor = 0.9

compute_policy_iteration(env, dicount_factor)
compute_policy_iteration(env, dicount_factor, modified_max_k=2)
compute_value_iteration(env, dicount_factor)

POLICY ITERATION
Policy found in 11 iterations, where each policy evaluation lasted for k = [67, 17, 15, 13, 12, 7, 2, 1, 1, 1, 1]

V
 [[0.  4.2 4.8 5.5 6.1 6.9 7.8 8.8 9.9 0. ]
 [3.4 3.9 4.4 4.9 5.5 6.2 7.  7.8 8.8 9.9]
 [3.1 3.5 3.9 4.4 4.9 5.5 6.2 7.  7.8 8.8]
 [2.8 3.1 3.5 3.9 4.4 4.9 5.5 6.2 7.  7.8]
 [2.5 2.8 3.1 3.5 3.9 4.4 4.9 5.5 6.2 6.9]
 [2.2 2.5 2.8 3.1 3.5 3.9 4.4 4.9 5.5 6.1]
 [2.  2.2 2.5 2.8 3.1 3.5 3.9 4.4 4.9 5.5]
 [1.8 2.  2.2 2.5 2.8 3.1 3.5 3.9 4.4 4.8]
 [1.6 1.7 2.  2.2 2.5 2.8 3.1 3.5 3.9 4.3]
 [1.4 1.6 1.8 2.  2.2 2.5 2.8 3.1 3.5 3.8]]

PI
■→→→→→→→→■
→→→→→→→→→↑
→→→→→→→→↑↑
↑→→→→→→↑↑↑
↑↑→→→→↑↑↑↑
↑↑↑↑→↑↑↑↑↑
↑↑↑→→↑↑↑↑↑
↑↑→→→↑↑↑↑↑
↑→→→→→↑↑↑↑
→→→→→→→→↑↑
MODIFIED POLICY ITERATION
Policy found in 11 iterations, where each policy evaluation lasted for k = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

V
 [[0.  4.2 4.8 5.5 6.1 6.9 7.8 8.8 9.9 0. ]
 [3.4 3.9 4.4 4.9 5.5 6.2 7.  7.8 8.8 9.9]
 [3.1 3.5 3.9 4.4 4.9 5.5 6.2 7.  7.8 8.8]
 [2.8 3.1 3.5 3.9 4.4 4.9 5.5 6.2 7.  7.8]
 [2.5 2.8