In [1]:
import gym
import numpy as np

In [2]:
MAP =  [
    'S-------',
    '--------',
    '---H----',
    '-----H--',
    '---H----',
    '-HH---H-',
    '-H--H-H-',
    '---H---G',
]
MAP_SIZE = (8, 8)
MAP_STRING = ''.join(MAP)
ACTION_MAPPING = {0: '←', 1: '↓', 2: '→', 3: '↑'}

ENV = gym.make('FrozenLake8x8-v0')
gym.envs.toy_text.frozen_lake.MAPS['8x8'] = MAP
ENV = gym.make('FrozenLake8x8-v0', is_slippery=False)

ALL_STATE = range(ENV.nS)
ALL_ACTION = range(ENV.nA)

In [3]:
def print_state_value_func(V: np.ndarray, precision=2):
    rounded = np.round_(V, precision).reshape(MAP_SIZE)
    print(' V(s):\n', rounded, '\n')


def print_policy(policy: np.ndarray):
    greedy_policy = np.argmax(policy, axis=1)
    what_2_print = list()
    for idx, action in enumerate(greedy_policy):
        if MAP_STRING[idx] == 'H':
            character = '□'
        else:
            character = ACTION_MAPPING[action]
        what_2_print.append(character)

    what_2_print = np.array(what_2_print).reshape(MAP_SIZE)
    print(' Policy:\n', what_2_print, '\n')

In [4]:
def evaluate_policy(env,
                    policy,
                    discount_factor = 0.9,
                    theta = 1e-6,
                    max_iteration = 9999) -> np.ndarray:
    print('     POLICY EVALUATION: START!\n')
    # State value function
    V = np.zeros(env.nS)

    for i in range(1, max_iteration + 1):
        is_converged = True

        for state in ALL_STATE:
            state_value = 0
            # For all actions that can be selected by the policy under the current state
            for action, action_prob in enumerate(policy[state]):
                for state_prob, next_state, reward, terminated in env.P[state][action]:
                    state_value += action_prob * state_prob * (reward + discount_factor * V[next_state])

            if abs(V[state] - state_value) > theta:
                is_converged = False
            V[state] = state_value

        if is_converged:
            break
    print_state_value_func(V)
    print(f'     POLICY EVALUATION: finished after ({i}) iterations\n')
    return V

In [5]:
def policy_iteration(env,
                     discount_factor = 0.9,
                     max_iteration = 9999) -> (np.ndarray, np.ndarray):
    print('POLICY ITERATION: START!\n')
    # Init policy with equal prob for all actions
    policy = np.ones([env.nS, env.nA]) / env.nA
    
    for i in range(1, max_iteration + 1):
        is_stable = True
        
        #First step is evaluating current policy
        V = evaluate_policy(env, policy)

        for state in ALL_STATE:
            current_action = np.argmax(policy[state])
            
            # See if can find any action that is better than current action
            Q = np.zeros(env.nA)
            for action in ALL_ACTION:
                for prob, next_state, reward, terminated in env.P[state][action]:
                    Q[action] += prob * (reward + discount_factor * V[next_state])

            best_action = np.argmax(Q)

            if current_action != best_action:
                is_stable = False
            
            # Update current policy greedily
            policy[state] = np.zeros(env.nA)
            policy[state][best_action] = 1.0
            
        print_policy(policy)
        print('======================================================\n')

        if is_stable:
            break    
    print(f'POLICY ITERATION: finished after ({i}) iterations\n')
    return policy, V

In [8]:
policy, V = policy_iteration(ENV)

POLICY ITERATION: START!

     POLICY EVALUATION: START!

 V(s):
 [[0.   0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.01 0.02]
 [0.   0.   0.   0.   0.   0.01 0.01 0.04]
 [0.   0.   0.   0.   0.   0.01 0.   0.12]
 [0.   0.   0.   0.   0.   0.03 0.   0.36]
 [0.   0.   0.   0.   0.04 0.12 0.36 0.  ]] 

     POLICY EVALUATION: finished after (35) iterations

 Policy:
 [['→' '→' '→' '→' '→' '→' '↓' '↓']
 ['→' '→' '→' '→' '→' '→' '↓' '↓']
 ['→' '↑' '↑' '□' '→' '→' '↓' '↓']
 ['→' '→' '→' '→' '↓' '□' '→' '↓']
 ['↑' '↑' '↑' '□' '→' '→' '→' '↓']
 ['↑' '□' '□' '→' '→' '↓' '□' '↓']
 ['↑' '□' '→' '↑' '□' '↓' '□' '↓']
 ['→' '→' '↑' '□' '→' '→' '→' '←']] 


     POLICY EVALUATION: START!

 V(s):
 [[0.25 0.28 0.31 0.35 0.39 0.43 0.48 0.53]
 [0.28 0.31 0.35 0.39 0.43 0.48 0.53 0.59]
 [0.25 0.28 0.31 0.   0.48 0.53 0.59 0.66]
 [0.35 0.39 0.43 0.48 0.53 0.   0.66 0.73]
 [0.31 0.35 0.39 0.   0.5