In [1]:
from pprint import pprint
import numpy as np
import gym
from gym.envs.toy_text import frozen_lake

In [12]:
ENV = gym.make('FrozenLake-v0')

MAP =  ['S-H-',
        '----',
        '-H--',
        'H--G']
frozen_lake.MAPS['4x4'] = MAP
ENV = gym.make('FrozenLake-v0', is_slippery=False)

ACTION_MAPPING = { 0: '←', 1: '↓', 2: '→', 3: '↑'}

In [13]:
def print_policy(policy: np.ndarray):
    print(' POLICY: ')
    temp_policy = np.argmax(policy, axis = 1)
    temp_map = ''.join(MAP)
    string_map = list()
    for idx, action in enumerate(temp_policy):
        if temp_map[idx] == 'H':
            string_map.append('□')
        else:
            string_map.append(ACTION_MAPPING[action])
    string_map = np.array(string_map).reshape((4, 4))
    print(string_map, '\n')

In [34]:
def policy_improvement_loop(env,
                            alpha = 0.5,
                            discount_factor = 1.0,
                            max_iter = 9999):
    print('     POLICY ITERATION: TEMPORAL DIFFERENCE')
    # Init policy with equal prob for all actions
    policy = np.ones([env.nS, env.nA]) / env.nA
    policy_greedy = np.copy(policy)
    policy_eye = np.eye(env.nA)

    Q = np.zeros((env.nS, env.nA))
    
    def choose_action(state, env = env, policy = policy):
            return np.random.choice(
                range(env.nA),
                p=policy[state].tolist()
            )

    for k in range(1, max_iter + 1):
        eps_greedy = 1.0 / (k + 1)
        state = env.reset()
        action = choose_action(state)
        terminated = False
        
        while not terminated:
            
            next_state, reward, terminated, _ = env.step(action)
            next_action = choose_action(next_state)

            TD_target = reward + discount_factor * Q[next_state, next_action]
            Q[state, action] += alpha * (TD_target - Q[state, action])
            
            if np.random.rand() < eps_greedy:
                policy[state] = policy_greedy[state]

            elif Q[state].sum() > 0:
                best_action = np.argmax(Q[state])
                policy[state] = policy_eye[best_action]
            
            state, action = next_state, next_action
                        
    return policy, Q

In [35]:
policy, Q = policy_improvement_loop(
    ENV,
    alpha = 0.8,
    max_iter = 99999
)
print_policy(policy)
print(Q)

     POLICY ITERATION: TEMPORAL DIFFERENCE
 POLICY: 
[['→' '↓' '□' '↑']
 ['↓' '→' '→' '↓']
 ['↑' '□' '←' '↓']
 ['□' '←' '→' '←']] 

[[0.11270269 0.13188468 1.         0.16674304]
 [0.27685638 1.         0.         0.13857343]
 [0.         0.         0.         0.        ]
 [0.         0.24926189 0.09819022 0.35596515]
 [0.21558285 0.22938956 0.21064738 0.20233738]
 [0.25487798 0.         1.         0.16340782]
 [0.09583959 0.0688883  1.         0.        ]
 [0.37013619 1.         0.03815121 0.48328533]
 [0.         0.         0.         0.23682306]
 [0.         0.         0.         0.        ]
 [0.         0.2        0.13521778 0.34444148]
 [0.67608889 1.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.8        1.         0.407296  ]
 [0.         0.         0.         0.        ]]
