In [1]:
import numpy as np
from frozenlake import FrozenLakeEnv

In [2]:
# init the environment
env = FrozenLakeEnv(map_name="8x8",is_slippery=True)

In [3]:
# check number of states and actions
nb_states = env.observation_space.n
nb_actions = env.action_space.n

In [4]:
def compute_value(env, policy, discount_factor=1.0, threshold=1e-20):
    
    # initialize value table randomly
    value_table = np.zeros((nb_states, 1))
    
    while True:
        
        new_value_table = np.copy(value_table)
        for state in range(nb_states):
            
            action = int(policy[state])
            for next_state_parameters in env.P[state][action]:
                
                transition_prob, next_state, reward_prob, _ = next_state_parameters
                # apply Bellman equation
                value_table[state] = transition_prob * (reward_prob + discount_factor * new_value_table[next_state])
                
        if (np.sum(np.fabs(new_value_table - value_table)) <= threshold): break
    
    return value_table

In [5]:
def extract_policy(value_table, discount_factor=1.0):
    
    # initialize policy randomly
    policy = np.zeros(env.observation_space.n)
    for state in range(env.observation_space.n):
        
        Q_table = np.zeros(env.action_space.n)
        for action in range(env.action_space.n):
            
            for next_state_parameters in env.P[state][action]:
                transition_prob, next_state, reward_prob, _ = next_state_parameters
                # apply Bellman equation
                Q_table[action] += (transition_prob * (reward_prob + discount_factor * value_table[next_state]))
        
        policy[state] = np.argmax(Q_table)
        
    return policy

In [6]:
NB_ITERATIONS =  200000

In [7]:
random_policy = np.zeros((nb_states, 1))

from tqdm import tqdm_notebook
for i in tqdm_notebook(range(NB_ITERATIONS)):
    
    new_value_table = compute_value(env, random_policy)
    new_policy = extract_policy(new_value_table)
    random_policy = new_policy

HBox(children=(IntProgress(value=0, max=200000), HTML(value='')))




In [8]:
print('Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):')
print(np.reshape(new_policy, (env.action_space.n, -1)))

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0.]]


---