In [1]:
from pprint import pprint
import numpy as np
import gym
from gym.envs.toy_text import frozen_lake

In [28]:
ENV = gym.make('FrozenLake-v0')

MAP =  ['S-------',
        '--------',
        '--------',
        '--------',
        '----H---',
        '--------',
        '-H------',
        '---H---G']
frozen_lake.MAPS['8x8'] = MAP
ENV = gym.make('FrozenLake8x8-v0', is_slippery=False)

ACTION_MAPPING = { 0: '←', 1: '↓', 2: '→', 3: '↑'}

In [29]:
def print_state_value_func(V):
    print(' V(s):')
    print(np.round_(V, 2).reshape(8, 8), '\n')

def print_policy(policy: np.ndarray):
    print(' POLICY: ')
    temp_policy = np.argmax(policy, axis = 1)
    temp_map = ''.join(MAP)
    string_map = list()
    for idx, action in enumerate(temp_policy):
        if temp_map[idx] == 'H':
            string_map.append('□')
        else:
            string_map.append(ACTION_MAPPING[action])
    string_map = np.array(string_map).reshape((8, 8))
    print(string_map, '\n')

In [44]:
def evaluate_policy_MC_first_visit(gym_env,
                                   discount_factor = 1.0,
                                   max_iter = 9999,
                                   incremental = False):
    print('     POLICY EVALUATION: MONTE CARLO FIRST VISIT!\n')
    # Init policy with equal prob for all actions
    policy = np.ones([gym_env.nS, gym_env.nA]) / gym_env.nA

    V = np.zeros(gym_env.nS)
    N = np.zeros_like(V, dtype=int)
    G = np.zeros_like(V)

    for _ in range(max_iter):
        episode = list()
        state_2_first_idx = dict()
        state = gym_env.reset()
        counter = 0
        
        while True:
            if state not in state_2_first_idx.keys():
                state_2_first_idx[state] = counter
                
            action_list_prob = policy[state].tolist()
            action = np.random.choice(range(gym_env.nA), p=action_list_prob)
            
            next_state, reward, terminated, info = gym_env.step(action)
            episode.append((state, action, reward))
            state = next_state
            
            if terminated:
                break
            counter += 1

        for state, first_idx in state_2_first_idx.items():
            N[state] += 1
            for i, (_, _, reward) in enumerate(episode[first_idx:]):
                G[state] += (discount_factor**i) * reward

    for state in range(gym_env.nS):
        if N[state] != 0:
            if incremental:
                alpha = 1 / N[state]
                V[state] += alpha * (G[state] - V[state])
            else:
                V[state] = float(G[state] / N[state])

#     print(f'POLICY EVALUATION: reached max number of iterations ({max_iter})\n')
    return V

In [45]:
def evaluate_policy_MC_every_visit(gym_env,
                                   discount_factor = 1.0,
                                   max_iter = 9999,
                                   incremental = False):
    print('     POLICY EVALUATION: MONTE CARLO EVERY VISIT!\n')
    # Init policy with equal prob for all actions
    policy = np.ones([gym_env.nS, gym_env.nA]) / gym_env.nA

    V = np.zeros(gym_env.nS)
    N = np.zeros_like(V, dtype=int)
    G = np.zeros_like(V)

    for _ in range(max_iter):
        episode = list()
        state = gym_env.reset()
        terminated = False
        
        while not terminated:                
            action_list_prob = policy[state].tolist()
            action = np.random.choice(range(gym_env.nA), p=action_list_prob)
            
            next_state, reward, terminated, info = gym_env.step(action)
            episode.append((state, action, reward))
            state = next_state

        for idx, (state, _, _) in enumerate(episode):
            N[state] += 1
            for i, (_, _, reward) in enumerate(episode[idx:]):
                G[state] += (discount_factor**i) * reward

    for state in range(gym_env.nS):
        if N[state] != 0:
            if incremental:
                alpha = 1 / N[state]
                V[state] += alpha * (G[state] - V[state])
            else:
                V[state] = float(G[state] / N[state])

#     print(f'POLICY EVALUATION: reached max number of iterations ({max_iter})\n')
    return V

In [46]:
V1 = evaluate_policy_MC_first_visit(ENV)
print_state_value_func(V1)

     POLICY EVALUATION: MONTE CARLO FIRST VISIT!

 V(s):
[[0.07 0.08 0.08 0.09 0.1  0.12 0.13 0.14]
 [0.07 0.07 0.08 0.09 0.1  0.13 0.15 0.15]
 [0.06 0.07 0.07 0.08 0.1  0.13 0.16 0.18]
 [0.05 0.06 0.06 0.07 0.07 0.14 0.19 0.22]
 [0.04 0.04 0.05 0.05 0.   0.15 0.25 0.31]
 [0.02 0.03 0.04 0.07 0.12 0.23 0.35 0.43]
 [0.01 0.   0.03 0.06 0.16 0.31 0.47 0.63]
 [0.01 0.   0.01 0.   0.16 0.36 0.62 0.  ]] 



In [47]:
V2 = evaluate_policy_MC_every_visit(ENV)
print_state_value_func(V2)

     POLICY EVALUATION: MONTE CARLO EVERY VISIT!

 V(s):
[[0.07 0.07 0.08 0.09 0.1  0.11 0.12 0.13]
 [0.06 0.07 0.07 0.08 0.09 0.11 0.14 0.15]
 [0.06 0.06 0.06 0.07 0.08 0.12 0.16 0.18]
 [0.05 0.05 0.05 0.05 0.07 0.13 0.18 0.22]
 [0.04 0.04 0.04 0.04 0.   0.15 0.23 0.29]
 [0.02 0.03 0.04 0.06 0.11 0.23 0.33 0.4 ]
 [0.01 0.   0.03 0.06 0.16 0.31 0.46 0.61]
 [0.02 0.02 0.02 0.   0.17 0.39 0.65 0.  ]] 



In [48]:
V3 = evaluate_policy_MC_first_visit(ENV, incremental = True)
print_state_value_func(V3)

     POLICY EVALUATION: MONTE CARLO FIRST VISIT!

 V(s):
[[0.07 0.07 0.08 0.09 0.11 0.12 0.13 0.14]
 [0.07 0.07 0.08 0.09 0.1  0.12 0.14 0.16]
 [0.06 0.06 0.07 0.08 0.09 0.12 0.16 0.18]
 [0.05 0.05 0.06 0.06 0.07 0.13 0.19 0.23]
 [0.03 0.04 0.05 0.04 0.   0.14 0.25 0.3 ]
 [0.02 0.02 0.04 0.06 0.1  0.23 0.35 0.44]
 [0.01 0.   0.03 0.06 0.15 0.31 0.48 0.64]
 [0.01 0.01 0.01 0.   0.17 0.36 0.62 0.  ]] 



In [49]:
V4 = evaluate_policy_MC_every_visit(ENV, incremental = True)
print_state_value_func(V4)

     POLICY EVALUATION: MONTE CARLO EVERY VISIT!

 V(s):
[[0.07 0.07 0.07 0.08 0.1  0.11 0.12 0.14]
 [0.06 0.07 0.07 0.08 0.1  0.11 0.14 0.15]
 [0.06 0.06 0.07 0.08 0.09 0.12 0.15 0.18]
 [0.05 0.05 0.06 0.06 0.08 0.13 0.18 0.22]
 [0.04 0.04 0.04 0.04 0.   0.16 0.26 0.29]
 [0.02 0.03 0.04 0.07 0.1  0.23 0.36 0.42]
 [0.01 0.   0.03 0.07 0.17 0.31 0.48 0.63]
 [0.01 0.   0.01 0.   0.16 0.34 0.58 0.  ]] 

