## Policy Improvement

It takes the state-value function of the policy we want to improve, the MDP and gamma (optionally)

In [26]:
import numpy as np

def policy_improvement(V, P, gamma=1.0):
    
    # initialize the Q-function to zero (technically, we
    # can initialize these randomly, but let’s keep things simple).
    Q = np.zeros((len(P), len(P[0])))
    
    # loop through the states, actions, and transitions.
    for s in range(len(P)):
        for a in range(len(P[s])):
            for prob, next_state, reward, done in P[s][a]:
                
                # use those values to calculate the Q-function.
                Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
                
    # obtain a new, greedy policy by taking the argmax of the Q-function            
    new_pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    
    return new_pi

We can consider the "careful" policy on the frozen-lake environment and its state value function:

In [27]:
import gym

env = gym.make('FrozenLake-v1')
P = env.env.P
init_state, _ = env.reset()
goal_state = 15

In [28]:
LEFT, DOWN, RIGHT, UP = range(4)

careful_pi = lambda s: {
    0:LEFT, 1:UP, 2:UP, 3:UP,
    4:LEFT, 5:LEFT, 6:UP, 7:LEFT,
    8:UP, 9:DOWN, 10:LEFT, 11:LEFT,
    12:LEFT, 13:RIGHT, 14:RIGHT, 15:LEFT
}[s]

In [29]:
import numpy as np

def policy_evaluation(pi, P, gamma=1.0, theta=1e-10):
    prev_V = np.zeros(len(P))
    while True:
        V = np.zeros(len(P))
        for s in range(len(P)):
            for prob, next_state, reward, done in P[s][pi(s)]:
                V[s] += prob * (reward + gamma * prev_V[next_state] * (not done))
        if np.max(np.abs(prev_V - V)) < theta:
            break
        prev_V = V.copy()
    return V

In [30]:
V = policy_evaluation(careful_pi, P, gamma=0.99)

Now we can try to improve the policy:

In [31]:
careful_plus_pi = policy_improvement(V, P, gamma=0.99)

We can show the improved policy, its probability of success and its main return using simulation:

In [32]:
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^'), n_cols=4, title='Policy:'):
    print(title)
    arrs = {k:v for k,v in enumerate(action_symbols)}
    for s in range(len(P)):
        a = pi(s)
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(2), arrs[a].rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [33]:
import random

def probability_success(env, pi, goal_state, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; # env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        done, steps = False, 0
        while not done and steps < max_steps:
            state, _, done, _, h = env.step(pi(state))
            steps += 1
        results.append(state == goal_state)
    return np.sum(results)/len(results)

In [34]:
def mean_return(env, pi, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; # env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        done, steps = False, 0
        results.append(0.0)
        while not done and steps < max_steps:
            state, reward, done, _, _ = env.step(pi(state))
            results[-1] += reward
            steps += 1
    return np.mean(results)

In [35]:
print_policy(careful_plus_pi, P)

ps = probability_success(env, careful_plus_pi, goal_state=goal_state)*100
mr = mean_return(env, careful_plus_pi)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      < | 01      ^ | 02      ^ | 03      ^ |
| 04      < |           | 06      < |           |
| 08      ^ | 09      v | 10      < |           |
|           | 13      > | 14      v |           |
Reaches goal 85.00%. Obtains an average undiscounted return of 0.7900.


The new policy is better than the original policy. This is great! 

Is there a better policy than this one? We can try to improve the careful-plus policy:

In [36]:
V = policy_evaluation(careful_plus_pi, P, gamma=0.99)
careful_plus_plus_pi = policy_improvement(V, P, gamma=0.99)

In [37]:
print_policy(careful_plus_plus_pi, P)

ps = probability_success(env, careful_plus_plus_pi, goal_state=goal_state)*100
mr = mean_return(env, careful_plus_plus_pi)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      < | 01      ^ | 02      ^ | 03      ^ |
| 04      < |           | 06      < |           |
| 08      ^ | 09      v | 10      < |           |
|           | 13      > | 14      v |           |
Reaches goal 80.00%. Obtains an average undiscounted return of 0.8100.


There’s no improvement this time. The careful-plus policy is an optimal policy of the frozen-lake environment.

Even if we start with an **adversarial policy** (designed to perform poorly), alternating policy evaluation and improvement would still end up with an optimal policy:

In [38]:
adversarial_pi = lambda s: {
    0:UP, 1:UP, 2:UP, 3:UP,
    4:UP, 5:LEFT, 6:UP, 7:LEFT,
    8:LEFT, 9:LEFT, 10:LEFT, 11:LEFT,
    12:LEFT, 13:LEFT, 14:LEFT, 15:LEFT
}[s]

In [39]:
print_policy(adversarial_pi, P)

ps = probability_success(env, adversarial_pi, goal_state=goal_state)*100
mr = mean_return(env, adversarial_pi)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      ^ | 01      ^ | 02      ^ | 03      ^ |
| 04      ^ |           | 06      ^ |           |
| 08      < | 09      < | 10      < |           |
|           | 13      < | 14      < |           |
Reaches goal 0.00%. Obtains an average undiscounted return of 0.0000.


In [40]:
V = policy_evaluation(adversarial_pi, P, gamma=0.99)
adversarial_pi_2 = policy_improvement(V, P, gamma=0.99)

In [41]:
print_policy(adversarial_pi_2, P)

ps = probability_success(env, adversarial_pi_2, goal_state=goal_state)*100
mr = mean_return(env, adversarial_pi_2)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      < | 01      < | 02      < | 03      < |
| 04      < |           | 06      < |           |
| 08      < | 09      < | 10      < |           |
|           | 13      < | 14      v |           |
Reaches goal 0.00%. Obtains an average undiscounted return of 0.0000.


In [42]:
V = policy_evaluation(adversarial_pi_2, P, gamma=0.99)
adversarial_pi_3 = policy_improvement(V, P, gamma=0.99)

In [43]:
print_policy(adversarial_pi_3, P)

ps = probability_success(env, adversarial_pi_3, goal_state=goal_state)*100
mr = mean_return(env, adversarial_pi_3)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      < | 01      v | 02      > | 03      ^ |
| 04      < |           | 06      < |           |
| 08      < | 09      v | 10      < |           |
|           | 13      v | 14      > |           |
Reaches goal 0.00%. Obtains an average undiscounted return of 0.0000.


In [44]:
V = policy_evaluation(adversarial_pi_3, P, gamma=0.99)
adversarial_pi_4 = policy_improvement(V, P, gamma=0.99)

In [45]:
print_policy(adversarial_pi_4, P)

ps = probability_success(env, adversarial_pi_4, goal_state=goal_state)*100
mr = mean_return(env, adversarial_pi_4)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      v | 01      > | 02      > | 03      ^ |
| 04      < |           | 06      < |           |
| 08      v | 09      v | 10      < |           |
|           | 13      > | 14      > |           |
Reaches goal 16.00%. Obtains an average undiscounted return of 0.2200.


In [46]:
V = policy_evaluation(adversarial_pi_4, P, gamma=0.99)
adversarial_pi_5 = policy_improvement(V, P, gamma=0.99)

In [47]:
print_policy(adversarial_pi_5, P)

ps = probability_success(env, adversarial_pi_5, goal_state=goal_state)*100
mr = mean_return(env, adversarial_pi_5)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 00      < | 01      ^ | 02      > | 03      ^ |
| 04      < |           | 06      < |           |
| 08      ^ | 09      v | 10      < |           |
|           | 13      > | 14      v |           |
Reaches goal 71.00%. Obtains an average undiscounted return of 0.8000.
