# Policy Iteration

Once a policy has been improved using its value function to yield a better policy, we can then compute a new value function and improve again to yield an even better policy. We can thus obtain a sequence of monotonically improving policies and value functions.

We need the policy evaluation and policy improvement functions:

In [1]:
import numpy as np

def policy_evaluation(pi, P, gamma=1.0, theta=1e-10):
    prev_V = np.zeros(len(P))
    while True:
        V = np.zeros(len(P))
        for s in range(len(P)):
            for prob, next_state, reward, done in P[s][pi(s)]:
                V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
        if np.max(np.abs(prev_V - V)) < theta:
            break
        prev_V = V.copy()
    return V

In [2]:
def policy_improvement(V, P, gamma=1.0):
    Q = np.zeros((len(P), len(P[0])))
    for s in range(len(P)):
        for a in range(len(P[s])):
            for prob, next_state, reward, done in P[s][a]:
                Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
    new_pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return new_pi

In [3]:
def policy_iteration(P, gamma=1.0, theta=1e-10):
    
    # create a random policy: create a list of random actions 
    # and then map them to states
    random_actions = np.random.choice(tuple(P[0].keys()), len(P))
    pi = lambda s: {s:a for s, a in enumerate(random_actions)}[s]
    
    while True:
        # keep a copy of the policy before modify it
        old_pi = {s:pi(s) for s in range(len(P))}
        
        # get the state-value function of the policy
        V = policy_evaluation(pi, P, gamma, theta)
        
        # get an improved policy
        pi = policy_improvement(V, P, gamma)
        
        # if it’s different, we do it all over again
        if old_pi == {s:pi(s) for s in range(len(P))}:
            break
    
    # if it’s not, we break out of the loop and return 
    # the optimal policy and the optimal state-value function 
    return V, pi

Let’s try it on the taxi environment.

In [4]:
import gym

env = gym.make('Taxi-v3')
P = env.env.P
init_state, _ = env.reset()

In [5]:
V_best_p, pi_best_p = policy_iteration(P, gamma=0.99)

We can print the state value function of the policy:

In [6]:
def print_state_value_function(V, P, n_cols=5, prec=3, title='State-value function:'):
    print(title)
    for s in range(len(P)):
        v = V[s]
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(3), '{}'.format(np.round(v, prec)).rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [7]:
print_state_value_function(V_best_p, P, prec=4)

State-value function:
| 000   18.8 | 001 9.6221 | 002 14.1188 | 003 10.7294 | 004 1.1532 |
| 005 9.6221 | 006 1.1532 | 007 4.2495 | 008 9.6221 | 009 5.3025 |
| 010 14.1188 | 011 6.3662 | 012  3.207 | 013 5.3025 | 014  3.207 |
| 015 10.7294 | 016   20.0 | 017 10.7294 | 018 15.2715 | 019 11.8478 |
| 020 17.612 | 021 8.5258 | 022 12.9776 | 023 9.6221 | 024 2.1749 |
| 025 10.7294 | 026 2.1749 | 027 5.3025 | 028 8.5258 | 029 4.2495 |
| 030 12.9776 | 031 5.3025 | 032 4.2495 | 033 6.3662 | 034 4.2495 |
| 035 11.8478 | 036   18.8 | 037 11.8478 | 038 14.1188 | 039 12.9776 |
| 040 11.8478 | 041  3.207 | 042 7.4406 | 043 4.2495 | 044 7.4406 |
| 045 16.4359 | 046 7.4406 | 047 10.7294 | 048 7.4406 | 049  3.207 |
| 050 11.8478 | 051 4.2495 | 052 5.3025 | 053 7.4406 | 054 5.3025 |
| 055 12.9776 | 056 12.9776 | 057 17.612 | 058 12.9776 | 059 14.1188 |
| 060 10.7294 | 061 2.1749 | 062 6.3662 | 063  3.207 | 064 8.5258 |
| 065 17.612 | 066 8.5258 | 067 11.8478 | 068 6.3662 | 069 2.1749 |
| 070 10.7294 | 

We can also print the policy, the probability of success and the mean return:

In [8]:
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^', 'P', 'D'), n_cols=5, title='Policy:'):
    print(title)
    arrs = {k:v for k,v in enumerate(action_symbols)}
    for s in range(len(P)):
        a = pi(s)
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(3), arrs[a].rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [9]:
import random

def probability_success(env, pi, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; # env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        done, steps = False, 0
        while not done and steps < max_steps:
            state, _, done, _, h = env.step(pi(state))
            steps += 1
        results.append(done)
    return np.sum(results)/len(results)

In [10]:
def mean_return(env, pi, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; # env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        done, steps = False, 0
        results.append(0.0)
        while not done and steps < max_steps:
            state, reward, done, _, _ = env.step(pi(state))
            results[-1] += reward
            steps += 1
    return np.mean(results)

In [11]:
print_policy(pi_best_p, P)

ps = probability_success(env, pi_best_p)*100
mr = mean_return(env, pi_best_p)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 000      P | 001      P | 002      P | 003      P | 004      < |
| 005      < | 006      < | 007      < | 008      < | 009      < |
| 010      < | 011      < | 012      < | 013      < | 014      < |
| 015      < | 016      D | 017      < | 018      < | 019      < |
| 020      ^ | 021      ^ | 022      ^ | 023      ^ | 024      < |
| 025      < | 026      < | 027      < | 028      < | 029      < |
| 030      < | 031      < | 032      < | 033      < | 034      < |
| 035      < | 036      ^ | 037      < | 038      < | 039      < |
| 040      < | 041      < | 042      < | 043      < | 044      > |
| 045      > | 046      > | 047      > | 048      < | 049      < |
| 050      < | 051      < | 052      < | 053      < | 054      < |
| 055      < | 056      < | 057      > | 058      < | 059      < |
| 060      < | 061      < | 062      < | 063      < | 064      > |
| 065      > | 066      > | 067      > | 068      < | 069      < |
| 070      < | 071      < | 072      < | 073      < | 