Credit for the Gridworld Environment goes to Denny Britz: [code here](https://github.com/dennybritz/reinforcement-learning/blob/master/lib/envs/gridworld.py).

In [1]:
from britz_gridworld import GridworldEnv
import numpy as np
import io
import sys
from gym.envs.toy_text import discrete

In [2]:
gridworld = GridworldEnv()

In [3]:
class policy_iteration:
    def __init__(self, environment, discount, threshold, maxiter):
        self.environment = environment
        self.discount = discount
        self.num_states = environment.nS
        self.num_actions = environment.nA
        self.threshold = threshold
        self.maxiter = maxiter
        
    def evaluation(self, policy):
        v_0 = np.zeros(self.num_states)

        for i in range(self.maxiter):
            DELTA = 0
            for i in range(self.num_states):
                v_ = 0
                for act, actprob in enumerate(policy[i]):
                    for prob, state, r, over in self.environment.P[i][act]:
                        v_ += actprob * prob * (r + self.discount * v_0[state])

                DELTA = max(DELTA, np.abs(v_ - v_0[i]))
                v_0[i] = v_

            if DELTA < self.threshold: 
                break

        return np.array(v_0)
        
    def lookahead(self, current_state, V):
        A = np.zeros(self.num_actions)

        for acts in range(self.num_actions):
            for prob, next_state, r, over in self.environment.P[current_state][acts]:
                A[acts] += prob * (r + self.discount * V[next_state])
        return A
    
    def improvement(self, evaluation_function):
        policy = np.ones([self.num_states, self.num_actions]) / self.num_actions
        
        for i in range(self.maxiter):
            V_func = evaluation_function(policy)
            stable = True
            
            for state in range(self.num_states):
                picked_action = np.argmax(policy[state])
                
                act_vals = self.lookahead(state, V_func)
                optimal_action = np.argmax(act_vals)
                
                if picked_action is not optimal_action:
                    stable = False
                    
                policy[state] = np.eye(self.num_actions)[optimal_action]
            
            if stable:
                return policy, V_func
            
        return policy, V_func
        

In [4]:
pol_iter = policy_iteration(gridworld, 1.0, 1e-2, 10000)
policy, value_function = pol_iter.improvement(pol_iter.evaluation)

In [5]:
print('The policy')
policy

The policy


array([[1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]])

In [6]:
print('Policy reshaped onto gridworld')
print('0 -> up; 1 -> right; 2 -> down; 3 -> left')
np.argmax(policy, axis=1).reshape(gridworld.shape)

Policy reshaped onto gridworld
0 -> up; 1 -> right; 2 -> down; 3 -> left


array([[0, 3, 3, 2],
       [0, 0, 0, 2],
       [0, 0, 1, 2],
       [0, 1, 1, 0]])

In [7]:
print('Value function')
value_function

Value function


array([ 0., -1., -2., -3., -1., -2., -3., -2., -2., -3., -2., -1., -3.,
       -2., -1.,  0.])

In [8]:
print('Value function reshaped onto gridworld')
value_function.reshape(gridworld.shape)

Value function reshaped onto gridworld


array([[ 0., -1., -2., -3.],
       [-1., -2., -3., -2.],
       [-2., -3., -2., -1.],
       [-3., -2., -1.,  0.]])

In [31]:
import time
from IPython.display import clear_output
gridworld.reset()
TERMINAL = False
while TERMINAL is False:
    print(gridworld._render())
    time.sleep(1.5)
    state = gridworld.s
    if state == 0 or state == 15:
        print('TERMINAL STATE REACHED')
        TERMINAL = True
        break
    action = np.argmax(policy[state])
    gridworld.step(action)
    clear_output(wait=True)
    

T  o  o  o
o  o  o  o
o  o  o  o
o  o  o  x
None
TERMINAL STATE REACHED
