# Policy Iteration

Policy Iteration is a DP Algorithm to find the optimal policy.

In [1]:
import numpy as np
import matplotlib.pyplot as plt

## Grid World

In [4]:
m, n = 4, 4
r, c = 0, 0
S = [(r, c) for i in range(m) for c in range(n)]
A = ['north', 'south', 'east', 'west']
acount = len(A)
terminal = [(0,0),(3,3)]
gamma = 1

## Transitions

In [5]:
def transition(s, a, s_next):
    r, c = s
    if a == 'north':
        r = r if r == 0 else r - 1
    if a == 'south':
        r = r if r == m - 1 else r + 1
    if a == 'west':
        c = c if c == 0 else c - 1
    if a == 'east':
        c = c if c == n - 1 else c + 1
    prob = 1 if (r, c) == s_next else 0
    reward = -1
    return prob, reward

We are using equiprobable policy. 

## Policy

In [6]:
pi = {s: {a: 0.25 for a in A} for s in S}
pi[(0, 2)]['east'] ## Example to show the equiprobable policy.

0.25

We will look at two kinds of updates - synchronous and in-place.

## Synchronous Policy Evaluation

In [13]:
def to_numpy(v, m, n):
    vnum = np.zeros((m, n))
    for s in v:
        vnum[s] = v[s]
    return vnum

theta = 1e-3
delta = theta + 1
v = {s: 0 for s in S}
iter = 0

while delta >= theta:
    iter += 1
    v_next = {s: 0 for s in S}
    delta = 0
    for s in S:
        if s in terminal:
            continue
        for a in A:
            for s_next in S:
                prob, reward = transition(s, a, s_next)
                v_next[s] += pi[s][a] * prob * (reward + gamma * v[s_next])
        delta = max(delta, abs(v[s] - v_next[s]))
    v = v_next.copy()
v = {s: np.round(v[s]) for s in S}

print(f'Number of Iterations for Convergence: {iter}')
print(f'Value Function Rounded Off to Nearest Integer')
to_numpy(v, m, n)

Number of Iterations for Convergence: 295
Value Function Rounded Off to Nearest Integer


array([[ 0., nan, nan, nan],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.]])

## In-Place Policy Evaluation

In [14]:
def evaluate(pi):
    theta = 1e-3
    delta = theta + 1
    v = {s: 0 for s in S}
    iter = 0
    while delta >= theta:
        iter += 1
        delta = 0
        for s in S:
            vs = 0
        if s in terminal:
            continue
        for a in A:
            for s_next in S:
                prob, reward = transition(s, a, s_next)
                vs += pi[s][a] * prob * (reward + gamma * v[s_next])
        delta = max(delta, abs(v[s] - vs))
        v[s] = vs
    return v, iter
v, iter = evaluate(pi)
v = {s: np.round(v[s]) for s in S}
print(f'Number of Iterations for Convergence: {iter}')
print("Value Funciton Rounded Off to Nearest Integer")
to_numpy(v, m, n)

Number of Iterations for Convergence: 1024
Value Funciton Rounded Off to Nearest Integer


array([[ 0.,  0.,  0., nan],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.]])

## Policy Improvement

In [16]:
def improve(v):
    pi_next = {s: {a: 0 for a in A} for s in S}
    q = {s: {a: 0 for a in A} for s in S}
    for s in S:
        for a in A:
            for s_next in S:
                prob, reward = transition(s, a, s_next)
                q[s][a] += prob * (reward + gamma * v[s_next])
        amax = max([(a, q[s][a]) for a in A], key = lambda x: x[1])[0]
        pi_next[s][amax] = 1
    return pi_next
pi_next = improve(v)

## Policy Iteration

In [None]:
pi_next = {s: {a: 1/acount for a in A} for s in S}
pi = {s: {a: 1 if a == 'north' else 0 for a in A} for s in S}

iter = 0
while not np.array_equal(pi, pi_next):
    pi = pi_next.copy()
    v, _ = evaluate(pi)
    pi_next = improve(v)
    iter += 1

print(f'Number of Iterations for Convergence: {iter}')
to_numpy(v, m, n)

## Optimal Policies

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [6, 6]
ext = 0.1
arrow_map = {'north': (0, ext), 'south': (0, -ext), 'east': (ext, 0), 'west': (-ext, 0)}
q = {s: {a: 0 for a in A} for s in S}
for s in S:
    if s in terminal:
        continue
    for a in A:
        for s_next in S:
            prob, reward = transition(s, a, s_next)
            q[s][a] += prob * (reward + gamma * v[s_next])
    maxq = max([q[s][a] for a in A])
    opt_acts = [a for a in A if q[s][a] == maxq]
    for a in opt_acts:
        r, c = s
        x, y = c + 0.5, m - 1 - r + 0.5,
        plt.arrow(x, y, arrow_map[a][0], arrow_map[a][1], head_width = 0.1, head_length = 0.1)
        plt.text(x - 0.4, y - 0.4, f'${v[s]}$')

plt.grid(True)
plt.xlim([0, n])
plt.ylim([0, m])
plt.xticks(np.arrange(0, n, 1), np.arrange(1, n+1))
plt.yticks(np.arrange(0, m, 1), np.arrange(1, m+1))