# Solving GridWorld using Dynamic Programming

In [2]:
import numpy as np
import sys
from env.gridworld import GridworldEnv
import matplotlib.pyplot as plt

In [3]:
env = GridworldEnv()

In [8]:
print("Action Space for GridWorld : {}".format(env.action_space))
print("Number of States in GridWorld : {}".format(env.observation_space)) 

Action Space for GridWorld : Discrete(4)
Number of States in GridWorld : Discrete(16)


## Policy iteration 

In [9]:
def policy_eval(policy, env, discount_factor=1.0, theta=1e-4):
    """
    function to evaluate any policy
    
    input: policy to be evaluated, gym env object
    output: value function 
    
    """
    
    V = np.zeros(env.nS) #start with zero value
    
    while True:
        change = 0
        for s in range(env.nS):
            v = 0
            for a,action_prob in enumerate(policy[s]):
                for prob,next_state,reward,done in env.P[s][a]:
                    v+=action_prob*prob*(reward+discount_factor*V[next_state])
            change+=abs(V[s]-v)
            V[s]=v
            
        if(change<theta):
            #if no change
            break

    return np.array(V)

In [10]:
# evaluate a random policy in gridworld

random_policy = np.ones([env.nS, env.nA]) / env.nA   #25% in all directions
v = policy_eval(random_policy, env)
v.reshape(env.shape)

array([[  0.        , -13.99994072, -19.99991477, -21.99990619],
       [-13.99994072, -17.99992725, -19.99992101, -19.99992192],
       [-19.99991477, -19.99992101, -17.99993335, -13.99995024],
       [-21.99990619, -19.99992192, -13.99995024,   0.        ]])

In [11]:
def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):
    """
    function to improve policy greedily
    
    input: gym env, policy evaluation function
    output: learned policy, value function
    
    """
    
    def one_step_lookahead(state,v):
        value = np.zeros(env.nA)
        for action in range(env.nA):
            for prob,next_state,reward,done in env.P[state][action]:
                value[action]+=prob*(reward+v[next_state]*discount_factor)
        
        return value
    
    
    policy = np.ones([env.nS, env.nA]) / env.nA #start with random policy
    
    i=0
    while True:
        i+=1
        flag = True
        value = policy_eval(policy,env)
        for s in range(env.nS):
            a_index = np.argmax(policy[s])
            a_greedy = np.argmax(one_step_lookahead(s,value))
            
            if(a_index!=a_greedy):
                flag = False
            policy[s] = np.zeros(env.nA)
            policy[s][a_greedy]=1.0
        
        if flag:
            break
            
    print("policy learnt in {} policy iterations".format(i))
    
    return policy, policy_eval(policy,env)

In [15]:
policy, v= policy_improvement(env)

print("")
print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.argmax(policy, axis=1).reshape(env.shape))
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")

policy learnt in 3 policy iterations

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]

Reshaped Grid Value Function:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]



## Value iteration 

In [11]:
def value_iteration(env,discount_factor = 1.0):
    
    def one_step_lookahead(state,v):
        value = np.zeros(env.nA)
        for action in range(env.nA):
            for prob,next_state,reward,done in env.P[state][action]:
                value[action]+=prob*(reward+v[next_state]*discount_factor)

        return value
    
    value = np.zeros(env.nS)
    policy = np.zeros([env.nS,env.nA])
    
    while True:
        change = 0
        for s in range(env.nS):
            v = np.max(one_step_lookahead(s,value))
            change+=abs(v-value[s])
            value[s]=v
        if change<1e-4:
            break
    
    for s in range(env.nS):
        a_greedy = np.argmax(one_step_lookahead(s,value))
        policy[s][a_greedy] = 1.0

    return policy,value

In [12]:
%%time
policy, v = value_iteration(env)
print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.argmax(policy, axis=1).reshape(env.shape))
print("")

print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")


Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]

Reshaped Grid Value Function:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]

CPU times: user 22.1 ms, sys: 2.68 ms, total: 24.8 ms
Wall time: 20.9 ms
