In [1]:
from pprint import pprint
import numpy as np
import gym
from gym.envs.toy_text import frozen_lake

In [2]:
ENV = gym.make('FrozenLake-v0')

MAP =  ['S-H-',
        '----',
        '-H--',
        'H--G']
frozen_lake.MAPS['4x4'] = MAP
ENV = gym.make('FrozenLake-v0', is_slippery=False)

ACTION_MAPPING = { 0: '←', 1: '↓', 2: '→', 3: '↑'}

In [3]:
def print_policy(policy: np.ndarray):
    print(' POLICY: ')
    temp_policy = np.argmax(policy, axis = 1)
    temp_map = ''.join(MAP)
    string_map = list()
    for idx, action in enumerate(temp_policy):
        if temp_map[idx] == 'H':
            string_map.append('□')
        else:
            string_map.append(ACTION_MAPPING[action])
    string_map = np.array(string_map).reshape((4, 4))
    print(string_map, '\n')

In [12]:
def policy_improvement_loop(env,
                            alpha = 0.5,
                            discount_factor = 1.0,
                            max_iter = 9999):
    print('     POLICY ITERATION: TEMPORAL DIFFERENCE')
    # Init policy with equal prob for all actions
    policy = np.ones([env.nS, env.nA]) / env.nA
    policy_greedy = np.copy(policy)
    policy_eye = np.eye(env.nA)

    Q1 = np.zeros((env.nS, env.nA))
    Q2 = np.copy(Q1)
    
    def choose_action(state, env = env, policy = policy):
            return np.random.choice(
                range(env.nA),
                p=policy[state].tolist()
            )

    for k in range(1, max_iter + 1):
        eps_greedy = 1.0 / (k + 1)
        state = env.reset()
        terminated = False
        
        while not terminated:
            action = choose_action(state)
            next_state, reward, terminated, _ = env.step(action)
            
            if np.random.rand() < 0.5:
                Q = Q1
            else:
                Q = Q2
            TD_target = reward + discount_factor * np.max(Q[next_state])
            Q[state, action] += alpha * (TD_target - Q[state, action])
            
            if np.random.rand() < eps_greedy:
                policy[state] = policy_greedy[state]

            elif Q1[state].sum() != 0 and Q2[state].sum() != 0:
                best_action = np.argmax(Q1[state] + Q2[state])
                policy[state] = policy_eye[best_action]
            
            state = next_state
                        
    return policy, Q1, Q2

In [14]:
policy, Q1, Q2 = policy_improvement_loop(ENV,
                                    alpha = 0.8,
                                    max_iter = 99999)
print_policy(policy)

     POLICY ITERATION: TEMPORAL DIFFERENCE
 POLICY: 
[['→' '↓' '□' '←']
 ['←' '→' '↓' '←']
 ['←' '□' '→' '↓']
 ['□' '←' '←' '←']] 



In [15]:
print(Q1, '\n')
print(Q2)

[[0.58982157 0.58971439 1.         0.58958947]
 [0.98358239 1.         0.         0.89434164]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.54037442 0.45283988 0.91777606 0.98358925]
 [0.95712881 0.         1.         0.98334377]
 [0.         1.         0.64       0.        ]
 [0.8        0.         0.         0.        ]
 [0.43486544 0.         0.         0.82495398]
 [0.         0.         0.         0.        ]
 [0.         0.         1.         0.8       ]
 [0.992      1.         0.96       0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]
 [0.         0.         0.         0.        ]] 

[[0.992 0.    1.    0.8  ]
 [0.8   1.    0.    0.96 ]
 [0.    0.    0.    0.   ]
 [0.    0.    0.    0.   ]
 [0.    0.    0.    0.   ]
 [0.    0.    1.    0.   ]
 [0.8   1.    0.    0.   ]
 [0.8   0.    0.    0.   ]
 [0.    0.    0.    0.   ]
 [