In [44]:
from pprint import pprint
import numpy as np
import gym
from gym.envs.toy_text import frozen_lake

In [45]:
ENV = gym.make('FrozenLake-v0')

MAP =  ['S-H-',
        '----',
        '--H-',
        'H--G']
frozen_lake.MAPS['4x4'] = MAP
ENV = gym.make('FrozenLake-v0', is_slippery=False)

ACTION_MAPPING = { 0: '←', 1: '↓', 2: '→', 3: '↑'}

In [46]:
def print_state_value_func(V):
    print(' V(s):')
    print(np.round_(V, 2).reshape(8, 8), '\n')

def print_policy(policy: np.ndarray):
    print(' POLICY: ')
    temp_policy = np.argmax(policy, axis = 1)
    temp_map = ''.join(MAP)
    string_map = list()
    for idx, action in enumerate(temp_policy):
        if temp_map[idx] == 'H':
            string_map.append('□')
        else:
            string_map.append(ACTION_MAPPING[action])
    string_map = np.array(string_map).reshape((4, 4))
    print(string_map, '\n')

In [47]:
def sample_episode(gym_env, policy):
    episode = list()
    state = gym_env.reset()
    terminated = False
    # Sample an episode
    while not terminated:
        action_list_prob = policy[state].tolist()
        action = np.random.choice(range(gym_env.nA), p=action_list_prob)

        next_state, reward, terminated, info = gym_env.step(action)
        episode.append((state, action, reward))
        state = next_state
    return episode

In [55]:
def evaluate_policy_using_Monte_Carlo(gym_env,
                                      policy,
                                      MC_type = 'first_visit',
                                      discount_factor = 1.0,
                                      incremental = False,
                                      max_iter = 29999):
    print('START EVALUATING POLICY: ')
    Q = np.zeros((gym_env.nS, gym_env.nA))
    N = np.zeros_like(Q)
    G = np.zeros_like(Q)

    for _ in range(max_iter):
        episode = sample_episode(gym_env, policy)
        
        # Calulate G and N for Monte Carlo first visit
        if MC_type == 'first_visit':
            visited = np.zeros_like(Q, dtype=int)
            for idx, (state, action, _) in enumerate(episode):
                if visited[state][action] == 1:
                    continue
                N[state][action] += 1
                for i, (_, _, reward) in enumerate(episode[idx:]):
                    G[state][action] += (discount_factor**i) * reward
                visited[state][action] == 1
                
        # Calulate G and N for Monte Carlo every visit
        elif MC_type == 'every_visit':
            for idx, (state, action, _) in enumerate(episode):
                N[state][action] += 1
                for i, (_, _, reward) in enumerate(episode[idx:]):
                    G[state][action] += (discount_factor**i) * reward
        else:
            assert False

    if incremental:
        alpha = np.divide(1, N, out = np.zeros_like(N), where = N!=0)
        Q += alpha * (G - V)
    else:
        Q = np.divide(G, N, out = np.zeros_like(N), where = N!=0)

    return Q

In [56]:
def policy_improvement_loop(gym_env,
                            MC_type = 'first_visit',
                            discount_factor = 1.0,
                            incremental = False,
                            max_iter = 9999):
    
    policy = np.ones([gym_env.nS, gym_env.nA]) / gym_env.nA
    policy_greedy = np.copy(policy)
    
    for k in range(1, max_iter + 1):
        is_stable = True
        Q = evaluate_policy_using_Monte_Carlo(gym_env, policy, MC_type, discount_factor, incremental)
        # print('Q: \n', Q)

        eps_greedy = 1.0 / k
        
        for state in range(gym_env.nS):
            current_action_prob = policy[state]
            
            if np.random.rand() < eps_greedy:
                policy[state] = policy_greedy[state]
            else:
                best_action = np.argmax(Q[state])
                policy[state] = np.zeros(gym_env.nA)
                policy[state][best_action] = 1.0
                
            if not np.all(np.equal(policy[state], current_action_prob)):
                    is_stable = False

        print_policy(policy)
#         if is_stable:
#                 print(f'POLICY ITERATION: converged after {i} iterations\n')
#                 return policy, Q
    
    print(f'POLICY ITERATION:  reached max number of iterations ({max_iter})\n')
    return policy, Q

In [57]:
policy, Q = policy_improvement_loop(ENV,
                                    MC_type = 'first_visit',
                                    incremental = False,
                                    max_iter = 20)
# print_policy(policy)

START EVALUATING POLICY: 
 POLICY: 
[['←' '←' '□' '←']
 ['←' '←' '←' '←']
 ['←' '←' '□' '←']
 ['□' '←' '←' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '↓' '□' '←']
 ['→' '←' '←' '←']
 ['←' '↓' '□' '←']
 ['□' '←' '←' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '↓' '□' '↓']
 ['←' '↓' '→' '↓']
 ['←' '↓' '□' '←']
 ['□' '→' '→' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['←' '←' '□' '←']
 ['←' '↓' '←' '←']
 ['←' '↓' '□' '←']
 ['□' '←' '→' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '←' '□' '←']
 ['←' '↓' '←' '←']
 ['←' '←' '□' '←']
 ['□' '→' '←' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '←' '□' '←']
 ['→' '↓' '←' '←']
 ['←' '↓' '□' '←']
 ['□' '→' '→' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '←' '□' '←']
 ['←' '↓' '←' '←']
 ['←' '↓' '□' '←']
 ['□' '←' '→' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '←' '□' '←']
 ['→' '↓' '←' '←']
 ['←' '↓' '□' '←']
 ['□' '→' '→' '←']] 

START EVALUATING POLICY: 
 POLICY: 
[['↓' '←' '□' '←']
 ['→' '↓' '←' '←']
 ['←' 