3-2 FrozenLake에서 정책 반복

In [1]:
import gymnasium as gym
import numpy as np

In [None]:
def policy_iteration(env, gamma=0.9):
    V = np.zeros(env.observation_space.n)  # Initialize value function to zeros
    #pi = [0 for i in range(env.observation_space.n)]  # Initialize policy to zeros
    pi = np.zeros(env.observation_space.n) # Initialize policy to zeros
    while True:
        # E (policy evaluation)
        while True: 
            delta = 0  # Track maximum change in value
            for state in range(env.observation_space.n):  # For each state
                action = pi[state]  # Action from current policy
                state_prob, next_state, reward, done = env.P[state][action][0]
                # Update value using Bellman equation
                v = reward + gamma * V[next_state]
                # Update maximum change
                delta = max(delta, abs(v - V[state]))
                V[state] = v  # Update value function
            # Stop when value function changes very little
            if delta < 1e-8:
                break

        # I (Policy improvement)
        policy_stable = True
        for state in range(env.observation_space.n):
            old_action = pi[state]
            q = np.zeros(env.action_space.n)
            # Choose the action with the highest value
            for action in range(env.action_space.n):
                state_prob, next_state, reward, done = env.P[state][action][0]
                q[action] = reward + gamma * V[next_state]
            pi[state] = np.argmax(q)
            if old_action != pi[state]:
                policy_stable = False
        if policy_stable:
            break
    return V, pi

env = gym.make('FrozenLake-v1', render_mode='ansi', is_slippery=False)
V, pi = policy_iteration(env)
print('최적정책\n', pi.astype(int).reshape(4,4))
print('최적 가치함수\n', np.round(V.reshape(4,4),4))


최적정책
 [[1 2 1 0]
 [1 0 1 0]
 [2 1 1 0]
 [0 2 2 0]]
최적 가치함수
[[0.5905, 0.6561, 0.729 , 0.6561],
 [0.6561, 0.    , 0.81  , 0.    ],
 [0.729 , 0.81  , 0.9   , 0.    ],
 [0.    , 0.9   , 1.    , 0.    ]]
