In [1]:
from env import GridWorld
env = GridWorld()

In [2]:
V_init = {s: 0.0 for s in env.get_states()}

def value_iteration(V, theta=0.001):
    """Value Iteration"""
    print(f"\n{'='*60}")
    print(f"Value Iteration")
    print(f"{'='*60}\n")
    
    iteration = 0
    while True:
        iteration += 1
        delta = 0
        
        for state in env.get_states():
            if env.is_terminal(state):
                continue
            v_old = V[state]
            action2values = {}
            for action in env.actions:
                next_state = env.get_next_state(state, action)
                reward = env.get_reward(state, action)
                action2values[action] = reward + env.gamma * V[next_state]
            V[state] = max(action2values.values())
            delta = max(delta, abs(v_old - V[state]))
        
        # Iteration별 출력
        print(f"[Iteration {iteration}]  delta = {delta:.6f}")
        print_V(V)
        print()
        
        if delta < theta:
            break
    
    # Policy 추출
    policy = {}
    for state in env.get_states():
        if env.is_terminal(state):
            continue
        action2values = {}
        for action in env.actions:
            next_state = env.get_next_state(state, action)
            reward = env.get_reward(state, action)
            action2values[action] = reward + env.gamma * V[next_state]
        policy[state] = max(action2values, key=action2values.get)
    
    # 최종 결과
    print(f"{'='*60}")
    print("Final Policy")
    print(f"{'='*60}")
    print_policy(policy)
    
    return policy, V


def print_V(V):
    """Value Function만 출력"""
    for r in range(env.rows - 1, -1, -1):
        row = []
        for c in range(env.cols):
            if (r, c) == env.wall:
                row.append("  WALL ")
            else:
                row.append(f"{V[(r,c)]:6.2f}")
        print("  ".join(row))


def print_policy(policy):
    """Policy만 출력"""
    for r in range(env.rows - 1, -1, -1):
        row = []
        for c in range(env.cols):
            if (r, c) == env.wall:
                row.append(" W ")
            elif env.is_terminal((r, c)):
                row.append(" T ")
            else:
                row.append(f" {policy[(r,c)]} ")
        print("  ".join(row))


# 실행
policy, V = value_iteration(V_init.copy())


Value Iteration

[Iteration 1]  delta = 1.000000
 -0.10   -0.10    1.00    0.00
 -0.10    WALL    -0.10    0.00
 -0.10   -0.10   -0.10   -0.10

[Iteration 2]  delta = 0.900000
 -0.19    0.80    1.00    0.00
 -0.19    WALL     0.80    0.00
 -0.19   -0.19   -0.19   -0.19

[Iteration 3]  delta = 0.810000
  0.62    0.80    1.00    0.00
 -0.27    WALL     0.80    0.00
 -0.27   -0.27    0.62    0.46

[Iteration 4]  delta = 0.729000
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
 -0.34    0.46    0.62    0.46

[Iteration 5]  delta = 0.656100
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46

[Iteration 6]  delta = 0.000000
  0.62    0.80    1.00    0.00
  0.46    WALL     0.80    0.00
  0.31    0.46    0.62    0.46

Final Policy
 R    R    R    T 
 U    W    U    T 
 U    R    U    L 
