# Value iteration on gridworld
This notebook shows how to use value iteration on the gridworld environment

In [1]:
# Set relative path to parent directory
import sys, os
sys.path.insert(0, os.path.abspath('..'))

In [2]:
# Import and set up environment
from environments.gridWorld import gridWorld
env = gridWorld()

## Value iteration algorithm

In [3]:
import copy
import numpy as np

def value_iteration(mdp, epsilon, gamma):
    V = dict()
    # Initialize utilities to zero
    for s in env.states():
        V.update({s: 0.0})
    while(True):
        V_prev = copy.deepcopy(V)
        delta = 0
        for s in mdp.states():
            lst = []
            for a in mdp.actions(s):
                value_sum = 0
                for s_next in mdp.states():
                    value_sum += mdp.transition_probability(s_next, s, a)*V_prev[s_next]
                lst.append(value_sum)
            V[s] = mdp.reward(s) + (0 if (lst == []) else gamma*np.max(lst))
            if np.abs(V[s] - V_prev[s]) > delta:
                delta = np.abs(V[s] - V_prev[s])
        #print(delta)
        if delta < epsilon*(1 - gamma)/gamma or (gamma == 1 and delta  < epsilon):
            return V

## Finding the policy given the value function

In [4]:
def policy(mdp, V):
    PI = dict()
    for s in mdp.states():
        a_lst = []
        v_lst = []
        for a in mdp.actions(s):
            value_sum = 0
            for s_next in mdp.states():
                value_sum += mdp.transition_probability(s_next, s, a)*V[s_next]
            v_lst.append(value_sum)
            a_lst.append(a)
        if len(a_lst) > 0:
            PI.update({s: a_lst[np.argmax(v_lst)]})
    return PI

## Calling the function
We now call the value itteration function and the policy finding function and visualize optimal the value function and policy

In [5]:
V = value_iteration(env, 1e-3, 1)

for y in range(env.board_mask.shape[0]):
    for x in range(env.board_mask.shape[1]):
        try:
            print('{0:.3f}'.format(V[(y, x)]), end='\t')
        except:
            print('x', end='\t')
    print('')

0.812	0.868	0.918	1.000	
0.762	x	0.660	-1.000	
0.705	0.655	0.611	0.387	


In [6]:
P = policy(env, V)

for y in range(env.board_mask.shape[0]):
    for x in range(env.board_mask.shape[1]):
        try:
            print(P[(y, x)], end = '\t')
        except:
            print('x', end='\t')
    print('')

R	R	R	x	
U	x	U	x	
U	L	L	L	
