In [23]:
import gym
import numpy as np
from IPython import display
import matplotlib.pyplot as plt
import time
%matplotlib inline

In [24]:
env = gym.make('FrozenLake-v0')
env = env.unwrapped
print(env)

<FrozenLakeEnv<FrozenLake-v0>>


In [25]:
print("number of states: ", env.observation_space.n)
print("number of actions: ", env.action_space.n)

number of states:  16
number of actions:  4


In [26]:
env.render()


[41mS[0mFFF
FHFH
FFFH
HFFG


In [27]:
def compute_value_function(policy, gamma=1.0):
    
    #초기 value table zero
    value_table = np.zeros(env.nS)
    
    threshold = 1e-10
    
    while True:
        
        updated_value_table = np.copy(value_table)
        
        #각 state에 대해서 policy에 따라 action 선택
        # 그리고 value table 계산
        for state in range(env.nS):
            action = policy[state]
            
            # 선택된 action에 따라 value table 생성
            for trans_prob, next_state, reward_prob, _ in env.P[state][action]:
                value_table[state] = sum([trans_prob * (reward_prob + gamma * updated_value_table[next_state])])
                
                
        if(np.sum(np.fabs(updated_value_table - value_table)) <= threshold):
            break
    
    return value_table
            

In [28]:
def extract_policy(value_table, gamma = 1.0):
    
    # 초기 policy 설정
    policy = np.zeros(env.observation_space.n, np.int32)
    
    for state in range(env.observation_space.n):
        
        # Q table 초기화
        Q_table = np.zeros(env.action_space.n)
        
        #Q value 계산 (모든 state에 대해서)
        for action in range(env.action_space.n):
            for next_sr in env.P[state][action]:
                trans_prob, next_state, reward_prob, _ = next_sr
                Q_table[action] += (trans_prob * (reward_prob + gamma * value_table[next_state]))
                
        # Maximum Q value 선택 (최적의 state에서)
        policy[state] = np.argmax(Q_table)
    
    return policy

In [29]:
def policy_iteration(env,gamma = 1.0):
    
    # 초기 policy
    old_policy = np.zeros(env.observation_space.n, np.int32)
    no_of_iterations = 200000
    
    for i in range(no_of_iterations):
    
        # value function 계산
        new_value_function = compute_value_function(old_policy, gamma)
        
        # new policy 추출(value function으로 부터)
        new_policy = extract_policy(new_value_function, gamma)
        
        if(np.all(old_policy == new_policy)):
            print('Policy-iteration converge at step.'%(i+1))
            break
            
        old_policy = new_policy
        
    return new_policy

In [30]:
def play(env, optimal_policy, max_step=1000):
    state = 0
    for i in range(max_step):
        env.render()
        time.sleep(1)
        display.clear_output(wait=True)
        display.display(plt.gcf())
        state, _, done, _ = env.step(optimal_policy[state])
        
        if done:
            env.render()
            break;

In [31]:
optimal_policy = policy_iteration(env)
print(optimal_policy)

KeyboardInterrupt: 

In [None]:
play(env,optimal_policy)