In [1]:
import numpy as np

In [2]:
N = 4 # grid size
grid = np.zeros((N, N))
S = np.arange(N**2)
terminals = [0, 15]
A = list(map(np.array, [(+1, 0), (-1, 0), (0, +1), (0, -1)]))
gamma = 1 # undiscounted

# transform state to grid cell
def s2g(s):
    assert s in S, 'Invalid state'
    return np.array([s // N, s % N])

# transform grid cell to state
def g2s(g):
    return g[0] * N + g[1]

def pi_equiprobable_random(a, s):
    return 1 / len(A)

In [3]:
def P(s_, s, a):
    if s in terminals:
        return s_ in terminals
    sPa = g2s(np.minimum(np.maximum(s2g(s) + a, 0), N - 1))
    if s_ in terminals:
        return sPa in terminals
    return sPa == s_

def P_pi(s_, s, pi):
    return sum(pi(a, s) * P(s_, s, a) for a in A)

In [4]:
def R(s, a):
    return 0 if s in terminals else -1

def R_pi(s, pi):
    return sum(pi(a, s) * R(s, a) for a in A)

In [6]:
def policy_evaluation(pi, n_updates=10, inplace=False):
    V = np.zeros_like(S, dtype=float)
    for i in range(n_updates):
        V_ = V if inplace else np.copy(V)
        for s in S:
            V_[s] = R_pi(s, pi) + gamma * sum(P_pi(s_, s, pi) * V[s_] for s_ in S)
        V = V_
    return V

In [8]:
# Note that the inplace version converges faster.
print(policy_evaluation(pi_equiprobable_random, n_updates=200, inplace=False))
print(policy_evaluation(pi_equiprobable_random, n_updates=200, inplace=True))

[  0.         -13.99975741 -19.99964052 -21.99959772 -13.99975741
 -17.99968332 -19.99964293 -19.99964052 -19.99964052 -19.99964293
 -17.99968332 -13.99975741 -21.99959772 -19.99964052 -13.99975741
   0.        ]
[  0.         -13.99999963 -19.99999947 -21.99999941 -13.99999963
 -17.99999955 -19.99999951 -19.99999951 -19.99999947 -19.99999951
 -17.99999958 -13.99999969 -21.99999941 -19.99999951 -13.99999969
   0.        ]


In [9]:
def policy_evaluation_threshold(pi, threshold, inplace=False):
    V = np.zeros_like(S, dtype=float)
    n_updates = 0
    while True:
        max_diff = 0
        V_ = V if inplace else np.copy(V)
        for s in S:
            v_old = V_[s]
            V_[s] = R_pi(s, pi) + gamma * sum(P_pi(s_, s, pi) * V[s_] for s_ in S)
            max_diff = max(max_diff, abs(V_[s] - v_old))
        n_updates += 1
        if max_diff < threshold:
            break
        V = V_
    return n_updates

In [10]:
# Note that the inplace version converges faster.
print(policy_evaluation_threshold(pi_equiprobable_random, threshold=1e-4, inplace=False))
print(policy_evaluation_threshold(pi_equiprobable_random, threshold=1e-4, inplace=True))

173
114
