# Policy Iteration
## solution

In [1]:
# Adapted from: https://github.com/lazyprogrammer/machine_learning_examples/tree/master/rl
# and then from: https://github.com/omerbsezer/Reinforcement_learning_tutorial_with_demo

In [2]:
import numpy as np
from gridWorldGame import standard_grid, negative_grid,print_values, print_policy

SMALL_ENOUGH = 1e-3
GAMMA = 0.9
ALL_POSSIBLE_ACTIONS = ('U', 'D', 'L', 'R')
# this grid gives you a reward of -0.1
# to find a shorter path to the goal, use negative grid
grid = negative_grid()
print("rewards:")
print_values(grid.rewards, grid)

rewards:
---------------------------
-0.10|-0.10|-0.10| 1.00|
---------------------------
-0.10| 0.00|-0.10|-1.00|
---------------------------
-0.10|-0.10|-0.10|-0.10|


In [3]:
# state -> action
# choose an action and update randomly 
policy = {}
for s in grid.actions.keys():
  policy[s] = np.random.choice(ALL_POSSIBLE_ACTIONS)

In [4]:
# initial policy
print("initial policy:")
print_policy(policy, grid)

initial policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  D  |     |
---------------------------
  D  |  L  |  R  |  D  |


In [5]:
# initialize V(s) - value function
V = {}
states = grid.all_states()
for s in states:
  # V[s] = 0
  if s in grid.actions:
    V[s] = np.random.random()
  else:
    # terminal state
    V[s] = 0

# initial value for all states in grid
# print(V)
print_values(V, grid)

---------------------------
 0.21| 0.38| 0.20| 0.00|
---------------------------
 0.21| 0.00| 0.22| 0.00|
---------------------------
 0.13| 0.67| 0.53| 0.66|


In [6]:
iteration=0
# repeat until convergence
# when policy does not change, it will finish
while True:
  iteration+=1
  print("values %d: " % iteration)
  print_values(V, grid)
  print("policy %d: " % iteration)
  print_policy(policy, grid)
  print('\n\n')

  # policy evaluation step
  while True:
    biggest_change = 0
    for s in states:
      old_v = V[s]

      # V(s) only has value if it's not a terminal state
      if s in policy:
        a = policy[s]
        grid.set_state(s)
        r = grid.move(a) # reward
        next_state = grid.current_state() # s' 
        V[s] = r + GAMMA * V[next_state]
        biggest_change = max(biggest_change, np.abs(old_v - V[s]))

    if biggest_change < SMALL_ENOUGH:
      break

  # policy improvement step
  is_policy_converged = True
  for s in states:
    if s in policy:
      old_a = policy[s]
      new_a = None
      best_value = float('-inf')
      # loop through all possible actions to find the best current action
      for a in ALL_POSSIBLE_ACTIONS:
        grid.set_state(s)
        r = grid.move(a)
        next_state = grid.current_state() 
        v = r + GAMMA * V[next_state]
        if v > best_value:
          best_value = v
          new_a = a
      policy[s] = new_a
      if new_a != old_a:
        is_policy_converged = False

  if is_policy_converged:
    break

values 1: 
---------------------------
 0.21| 0.38| 0.20| 0.00|
---------------------------
 0.21| 0.00| 0.22| 0.00|
---------------------------
 0.13| 0.67| 0.53| 0.66|
policy 1: 
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  D  |     |
---------------------------
  D  |  L  |  R  |  D  |



values 2: 
---------------------------
 0.62| 0.80| 1.00| 0.00|
---------------------------
 0.46| 0.00|-0.99| 0.00|
---------------------------
-0.99|-0.99|-0.99|-0.99|
policy 2: 
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  R  |  D  |



values 3: 
---------------------------
 0.62| 0.80| 1.00| 0.00|
---------------------------
 0.46| 0.00| 0.80| 0.00|
---------------------------
 0.31|-0.99|-0.99|-0.99|
policy 3: 
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------

In [7]:
print("final values:")
print_values(V, grid)
print("final policy:")
print_policy(policy, grid)

final values:
---------------------------
 0.62| 0.80| 1.00| 0.00|
---------------------------
 0.46| 0.00| 0.80| 0.00|
---------------------------
 0.31| 0.46| 0.62| 0.46|
final policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |
