In [1]:
from env import GridWorld
env = GridWorld()

In [11]:
V_init = {s: 0.0 for s in env.get_states()}  # Value function
policy_init = {s: 'R' for s in env.get_states()}  # Initial policy

def policy_evaluation(policy, V, in_place, max_sweep, threshold=0.001):
    """
    Policy Evaluation
    
    Args:
        in_place (bool) : True = asyncronous update, False : syncronous update
        max_sweep (bool) : True = sweeps until convergence, False : one sweep
        threshold (float) : convergence threshold
    """
    while True:
        
        if not in_place:
            V_old = V.copy()

        delta = 0
        for state in env.get_states():
            if env.is_terminal(state):
                continue
            
            action = policy[state]
            next_state = env.get_next_state(state, action)
            reward = env.get_reward(state, action)

            if in_place:
                # Asynchronous: Ï¶âÏãú ÏóÖÎç∞Ïù¥Ìä∏ (ÌòÑÏû¨ V ÏÇ¨Ïö©)
                v_old = V[state]
                V[state] = reward + env.gamma * V[next_state]
            else:
                # Synchronous: old V ÏÇ¨Ïö©Ìï¥ÏÑú new V Í≥ÑÏÇ∞
                v_old = V_old[state]
                V[state] = reward + env.gamma * V_old[next_state]

            delta = max(delta, abs(v_old - V[state]))
                
        # Ï¢ÖÎ£å Ï°∞Í±¥
        if max_sweep:
            if delta < threshold:
                break
        else:
            break
    return V

def policy_improvement(policy, V):
    """Policy Improvement"""
    is_convergent = True
    
    for state in env.get_states():
        if env.is_terminal(state):
            continue
        
        old_action = policy[state]
        
        # Î™®Îì† ÌñâÎèôÏóê ÎåÄÌï¥ QÍ∞í Í≥ÑÏÇ∞
        action_values = {}
        for action in env.actions:
            next_state = env.get_next_state(state, action)
            reward = env.get_reward(state, action)
            action_values[action] = reward + env.gamma * V[next_state]
        
        # ÏµúÏÑ†Ïùò ÌñâÎèô ÏÑ†ÌÉù
        policy[state] = max(action_values, key=action_values.get)
        
        if old_action != policy[state]:
            is_convergent = False
    
    return policy, is_convergent


def print_results(policy, V):
    """Í≤∞Í≥º Ï∂úÎ†•"""
    print("\nValue Function:")
    for r in range(env.rows - 1, -1, -1):  # 2, 1, 0 ÏàúÏÑú (ÏÉÅÌïòÎ∞òÏ†Ñ)
        row_values = []
        for c in range(env.cols):
            if (r, c) == env.wall:
                row_values.append("  WALL ")
            else:
                row_values.append(f"{V[(r,c)]:6.2f}")
        print("  ".join(row_values))
    
    print("Policy:")
    for r in range(env.rows - 1, -1, -1):  # 2, 1, 0 ÏàúÏÑú (ÏÉÅÌïòÎ∞òÏ†Ñ)
        row_policy = []
        for c in range(env.cols):
            if (r, c) == env.wall:
                row_policy.append(" W ")
            elif env.is_terminal((r, c)):
                row_policy.append(" T ")
            else:
                row_policy.append(f" {policy[(r,c)]} ")
        print("  ".join(row_policy))


def policy_iteration(policy, V, in_place, max_sweep):
    """Policy Iteration"""
    iteration = 0

    print(f"\n{'='*60}")
    print(f"üéØ POLICY ITERATION  ‚îÇ  {"Asyncronous" if in_place else "Syncronous"} + {"Full sweep" if max_sweep else "single sweep"}")
    print(f"{'='*60}\n")
    
    while True:
        iteration += 1
        
        # IterationÏùÄ Í∞ÑÏÜåÌôî
        print(f"\n[Iteration {iteration}]")
        
        V = policy_evaluation(policy, V, in_place=in_place, max_sweep=max_sweep)
        policy, is_convergent = policy_improvement(policy, V)
        
        print_results(policy, V)
        
        if is_convergent:
            print(f"\n‚úÖ Policy converged after {iteration} iterations!")
            break


In [12]:
# syncronous, full sweep
policy_iteration(policy_init, V_init, in_place=True, max_sweep=True)


üéØ POLICY ITERATION  ‚îÇ  Asyncronous + Full sweep


[Iteration 1]

Value Function:
  0.62    0.80    1.00    0.00
 -0.99    WALL    -1.00    0.00
 -0.99   -0.99   -0.99   -0.99
Policy:
 R    R    R    T 
 U    W    U    T 
 U    U    D    D 

[Iteration 2]

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31   -0.99   -0.99   -0.99
Policy:
 R    R    R    T 
 U    W    U    T 
 U    L    U    D 

[Iteration 3]

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.18    0.62   -0.99
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

[Iteration 4]

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

‚úÖ Policy converged after 4 iterations!


In [13]:
# asyncronous, full sweep
policy_iteration(policy_init, V_init, in_place=False, max_sweep=True)


üéØ POLICY ITERATION  ‚îÇ  Syncronous + Full sweep


[Iteration 1]

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

‚úÖ Policy converged after 1 iterations!


In [14]:
# syncronous, single sweep
policy_iteration(policy_init, V_init, in_place=True, max_sweep=False)


üéØ POLICY ITERATION  ‚îÇ  Asyncronous + single sweep


[Iteration 1]

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

‚úÖ Policy converged after 1 iterations!


In [15]:
# asyncronous, single sweep
policy_iteration(policy_init, V_init, in_place=False, max_sweep=False)


üéØ POLICY ITERATION  ‚îÇ  Syncronous + single sweep


[Iteration 1]

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

‚úÖ Policy converged after 1 iterations!
