In [21]:
import numpy as np
import sys

if "../" not in sys.path:
  sys.path.append("../") 
from env.gridworld import GridworldEnv

In [28]:
env = GridworldEnv()

# State transition probabilities. We assume full knowledge of the MDP
P = env.P

# Initialize value function to all 0s.
V = np.zeros(env.nS)

# Equiprobable policy.
policy = np.ones([env.nS, env.nA]) / env.nA

theta = 0.000001
discount = 1.0

In [29]:
def evaluate_policy(policy, V):
    
    while True:
        delta = 0

        for s in range(env.nS):
            v = 0
            for a, action_prob in enumerate(policy[s]):
                for prob, next_state, reward, done in P[s][a]:
                    v += action_prob * prob * (reward + discount * V[next_state])
            delta = max(delta, np.abs(v-V[s]))
            V[s] = v

        # If update is small, end cycle.
        if delta < theta:
            break
    return V

In [30]:
def one_step_lookahead(state, V):
        """
        Helper function to calculate the value for all action in a given state.
        
        Args:
            state: The state to consider (int)
            V: The value to use as an estimator, Vector of length env.nS
        
        Returns:
            A vector of length env.nA containing the expected value of each action.
        """
        A = np.zeros(env.nA)
        for a in range(env.nA):
            for prob, next_state, reward, done in env.P[state][a]:
                A[a] += prob * (reward + discount * V[next_state])
        return A

while True:
    
    # Policy evaluation
    V = evaluate_policy(policy, V)
    stable = True
        
    # Policy improvement
    for s in range(env.nS):
        
        choosen_a = np.argmax(policy[s])
        
        action_values = one_step_lookahead(s, V)
        best_a = np.argmax(action_values)
        
        if best_a != choosen_a:
            stable = False
        policy[s] = np.eye(env.nA)[best_a]
        
    if stable:
        break
        
policy

array([[1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 1., 0., 0.],
       [1., 0., 0., 0.]])