In [1]:
from env import GridWorld
from agent import Agent

In [2]:
env = GridWorld()

In [3]:
V = {s: 0.0 for s in env.get_states()}  # Value function
policy = {s: 'R' for s in env.get_states()}  # Initial policy

def policy_evaluation_async(threshold=0.001):
    """Asynchronous Policy Evaluation"""
    while True:
        delta = 0
        for state in env.get_states():
            if env.is_terminal(state):
                continue
            
            v = V[state]
            action = policy[state]
            
            # Bellman equation for policy evaluation
            next_state = env.get_next_state(state, action)
            reward = env.get_reward(state, action)
            V[state] = reward + env.gamma * V[next_state]
            
            delta = max(delta, abs(v - V[state]))
        
        if delta < threshold:
            break

def policy_improvement():
    """Policy Improvement"""
    policy_stable = True
    
    for state in env.get_states():
        if env.is_terminal(state):
            continue
        
        old_action = policy[state]
        
        # 모든 행동에 대해 Q값 계산
        action_values = {}
        for action in env.actions:
            next_state = env.get_next_state(state, action)
            reward = env.get_reward(state, action)
            action_values[action] = reward + env.gamma * V[next_state]
        
        # 최선의 행동 선택
        policy[state] = max(action_values, key=action_values.get)
        
        if old_action != policy[state]:
            policy_stable = False
    
    return policy_stable

def policy_iteration():
    """Asynchronous Policy Iteration"""
    iteration = 0
    while True:
        iteration += 1
        print(f"\n{'='*50}")
        print(f"Iteration {iteration}")
        print(f"{'='*50}")
        
        # Policy Evaluation
        policy_evaluation_async()
        
        # Policy Improvement
        policy_stable = policy_improvement()
        
        print_results()
        
        if policy_stable:
            print(f"\n✅ Policy converged after {iteration} iterations!")
            break

def print_results():
    """결과 출력"""
    print("\nValue Function:")
    for r in range(env.rows - 1, -1, -1):  # 2, 1, 0 순서 (상하반전)
        row_values = []
        for c in range(env.cols):
            if (r, c) == env.wall:
                row_values.append("  WALL ")
            else:
                row_values.append(f"{V[(r,c)]:6.2f}")
        print("  ".join(row_values))
    
    print("Policy:")
    for r in range(env.rows - 1, -1, -1):  # 2, 1, 0 순서 (상하반전)
        row_policy = []
        for c in range(env.cols):
            if (r, c) == env.wall:
                row_policy.append(" W ")
            elif env.is_terminal((r, c)):
                row_policy.append(" T ")
            else:
                row_policy.append(f" {policy[(r,c)]} ")
        print("  ".join(row_policy))


In [4]:
policy_iteration()


Iteration 1

Value Function:
  0.62    0.80    1.00    0.00
 -0.99    WALL    -1.00    0.00
 -0.99   -0.99   -0.99   -0.99
Policy:
 R    R    R    T 
 U    W    U    T 
 U    U    D    D 

Iteration 2

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31   -0.99   -0.99   -0.99
Policy:
 R    R    R    T 
 U    W    U    T 
 U    L    U    D 

Iteration 3

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.18    0.62   -0.99
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

Iteration 4

Value Function:
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46
Policy:
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 

✅ Policy converged after 4 iterations!
