In [10]:
import numpy as np

def policy_evaluation(S, A, SR, P, pi, theta=1e-6, gamma=1.0):
    V = np.zeros((GRID_SIZE, GRID_SIZE), dtype=float)

    while True:
        delta = 0.0
        for s in S:
            v = V[s]
            V[s] = sum(pi(s, a) * sum(P(s, a, s_prime, r) * (r + gamma * V[s_prime]) for s_prime, r in SR(s, a)) for a in A)
            delta = max(delta, abs(v - V[s]))

        if delta < theta: break

    return V


def pi(s, a):
    return .25

def get_next_state_and_reward(s, a):
    if any((s[0] == t[0] and s[1] == t[1]) for t in TERMINAL_STATES):
        return s, 0
    
    next_state = (s[0] + a[0], s[1] + a[1])
    if next_state[0] < 0 or next_state[0] >= GRID_SIZE or next_state[1] < 0 or next_state[1] >= GRID_SIZE:
        next_state = s
    
    return next_state, -1

def transition_prob(s, a, s_prime, r):
    actual_next_state, actual_reward = get_next_state_and_reward(s, a)
    return 1.0 if (s_prime == actual_next_state and r == actual_reward) else 0.0

def SR(s, a):
    next_state, reward = get_next_state_and_reward(s, a)
    return [(next_state, reward)]

GRID_SIZE = 4
TERMINAL_STATES = [(0, 0), (3, 3)]
A = [
    (0, 1),
    (0, -1),
    (1, 0),
    (-1, 0)
]
S = [
    (i, j)
    for i in range(GRID_SIZE)
    for j in range(GRID_SIZE)
]


V = policy_evaluation(S, A, SR, P=transition_prob, pi=pi, theta=1e-6, gamma=1.0).round(2)
print(V)

[[  0. -14. -20. -22.]
 [-14. -18. -20. -20.]
 [-20. -20. -18. -14.]
 [-22. -20. -14.   0.]]


In [11]:
# what is q_pi(11, down)?
# 11 is state (2, 3), down is action (1, 0), reward is -1
state = (2, 3)
action = (1, 0)
s_prime, reward = get_next_state_and_reward(state, action)
gamma = 1.0
action_value = reward + gamma * V[s_prime[0], s_prime[1]]
print(action_value)

-1.0


In [12]:
# what is q_pi(7, down)?
# 7 is state (1, 3), down is action (1, 0), reward is -1
state = (1, 3)
action = (1, 0)
s_prime, reward = get_next_state_and_reward(state, action)
gamma = 1.0
action_value = reward + gamma * V[s_prime[0], s_prime[1]]
print(action_value)

-15.0
