# Value Iteration

We can merge a truncated policy-evaluation step and a policy improvement into the same equation. Instead of improving the policy (by taking the argmax to get a better policy and then evaluating this improved policy to obtain a value function again), we directly calculate the maximum (max, instead of argmax) value across the actions to be used for the next sweep over the states

In [1]:
import numpy as np

def value_iteration(P, gamma=1.0, theta=1e-10):
    
    # first thing is to initialize a state-value function.
    V = np.zeros(len(P))
    
    # We get in this loop and initialize a Q-function to zero.
    while True:
        Q = np.zeros((len(P), len(P[0])))
        
        # then, for every transition of every action in every 
        # state, we...
        for s in range(len(P)):
            for a in range(len(P[s])):
                for prob, next_state, reward, done in P[s][a]:
                    
                    # ...calculate the action-value function
                    # notice, using V, which is the old truncated estimate
                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
        
        # After each sweep over the state space, make sure 
        # the state-value function keeps changing. 
        # Otherwise, we found the optimal V-function and should break out
        if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
            break
            
        # we don’t need a separate policy-improvement phase. 
        # It’s not a direct replacement, but instead a combination of 
        # improvement and evaluation    
        V = np.max(Q, axis=1)
    
    # only at the end do we extract the optimal policy and return it 
    # along with the optimal state-value function.
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return V, pi

And again, we can solve the Frozen Lake problem:

In [2]:
import gym

env = gym.make('Taxi-v3')
P = env.env.P
init_state, _ = env.reset()

In [3]:
V_best, pi_best = value_iteration(env.env.P, gamma=0.99)

And print policy, state value function and other stuff:

In [11]:
def print_policy(pi, P, action_symbols=('<', 'v', '>', '^', 'P', 'D'), n_cols=5, title='Policy:'):
    print(title)
    arrs = {k:v for k,v in enumerate(action_symbols)}
    for s in range(len(P)):
        a = pi(s)
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(3), arrs[a].rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [5]:
import random

def probability_success(env, pi, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; # env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        done, steps = False, 0
        while not done and steps < max_steps:
            state, _, done, _, h = env.step(pi(state))
            steps += 1
        results.append(done)
    return np.sum(results)/len(results)

In [6]:
def mean_return(env, pi, n_episodes=100, max_steps=200):
    random.seed(123); np.random.seed(123) ; # env.seed(123)
    results = []
    for _ in range(n_episodes):
        state, _ = env.reset()
        done, steps = False, 0
        results.append(0.0)
        while not done and steps < max_steps:
            state, reward, done, _, _ = env.step(pi(state))
            results[-1] += reward
            steps += 1
    return np.mean(results)

In [12]:
print_policy(pi_best, P)

ps = probability_success(env, pi_best)*100
mr = mean_return(env, pi_best)

print('Reaches goal {:.2f}%. Obtains an average undiscounted return of {:.4f}.'.format(ps,mr))

Policy:
| 000      P | 001      P | 002      P | 003      P | 004      < |
| 005      < | 006      < | 007      < | 008      < | 009      < |
| 010      < | 011      < | 012      < | 013      < | 014      < |
| 015      < | 016      D | 017      < | 018      < | 019      < |
| 020      ^ | 021      ^ | 022      ^ | 023      ^ | 024      < |
| 025      < | 026      < | 027      < | 028      < | 029      < |
| 030      < | 031      < | 032      < | 033      < | 034      < |
| 035      < | 036      ^ | 037      < | 038      < | 039      < |
| 040      < | 041      < | 042      < | 043      < | 044      > |
| 045      > | 046      > | 047      > | 048      < | 049      < |
| 050      < | 051      < | 052      < | 053      < | 054      < |
| 055      < | 056      < | 057      > | 058      < | 059      < |
| 060      < | 061      < | 062      < | 063      < | 064      > |
| 065      > | 066      > | 067      > | 068      < | 069      < |
| 070      < | 071      < | 072      < | 073      < | 

In [15]:
def print_state_value_function(V, P, n_cols=5, prec=3, title='State-value function:'):
    print(title)
    for s in range(len(P)):
        v = V[s]
        print("| ", end="")
        if np.all([done for action in P[s].values() for _, _, _, done in action]):
            print("".rjust(9), end=" ")
        else:
            print(str(s).zfill(3), '{}'.format(np.round(v, prec)).rjust(6), end=" ")
        if (s + 1) % n_cols == 0: print("|")

In [18]:
print_state_value_function(V_best, P, prec=3)

State-value function:
| 000   18.8 | 001  9.622 | 002 14.119 | 003 10.729 | 004  1.153 |
| 005  9.622 | 006  1.153 | 007  4.249 | 008  9.622 | 009  5.303 |
| 010 14.119 | 011  6.366 | 012  3.207 | 013  5.303 | 014  3.207 |
| 015 10.729 | 016   20.0 | 017 10.729 | 018 15.272 | 019 11.848 |
| 020 17.612 | 021  8.526 | 022 12.978 | 023  9.622 | 024  2.175 |
| 025 10.729 | 026  2.175 | 027  5.303 | 028  8.526 | 029  4.249 |
| 030 12.978 | 031  5.303 | 032  4.249 | 033  6.366 | 034  4.249 |
| 035 11.848 | 036   18.8 | 037 11.848 | 038 14.119 | 039 12.978 |
| 040 11.848 | 041  3.207 | 042  7.441 | 043  4.249 | 044  7.441 |
| 045 16.436 | 046  7.441 | 047 10.729 | 048  7.441 | 049  3.207 |
| 050 11.848 | 051  4.249 | 052  5.303 | 053  7.441 | 054  5.303 |
| 055 12.978 | 056 12.978 | 057 17.612 | 058 12.978 | 059 14.119 |
| 060 10.729 | 061  2.175 | 062  6.366 | 063  3.207 | 064  8.526 |
| 065 17.612 | 066  8.526 | 067 11.848 | 068  6.366 | 069  2.175 |
| 070 10.729 | 071  3.207 | 072  6.366 |

In [24]:
taxi_row, taxi_col, pass_loc, dest_idx = env.decode(499)
print(f'Row: {taxi_row}, Col: {taxi_col}, Passenger location: {pass_loc}, Destination index: {dest_idx}')

Row: 4, Col: 4, Passenger location: 4, Destination index: 3
