## Policy Iteration

This notebook gives a preview of finding Optimal Policy through Policy Iteration. The Agent is on a 4*4 grid and its goal is to reach the terminal state marked with solid black fill.
![title](images/gridworld.png)

1.The Agent can take actions in each direction (UP=0, RIGHT=1, DOWN=2, LEFT=3).<br> 
2.Any action that takes an Agent beyond the grid will result in the Agent staying in the same state.<br>
3.Agent recieves a reward of -1 at each step until it reaches the terminal state.<br><br><br>
     Let us try to find a policy that can take our Agent to the terminal state and also compute the Value Function for the same using Policy Iteration method. We would cover this in detail in subsequent module,however the demo is provided now to get an illustration of how an RL problem can be solved.

In [7]:
import numpy as np
from gridWorld import GridWorld

Given a policy, find the worthiness of states.Initialize the worthiness of states by zero

In [8]:
"""
Evaluate a policy given an environment and a full description of the environment's dynamics.
        Arguments:
        policy: [S, A] shaped matrix representing the policy.
        env: OpenAI env. 
        env.numStates : number of states in the environment
        env.numActions: number of actions in the environment
        theta: We stop evaluation once our value function change is less than theta for all states.
        discount_factor: Gamma discount factor.
"""
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    # Start with a random (all 0) value function
    value_fn = np.zeros(env.numStates)
    while True:
        delta =0
    # For each state, perform a "full backup"
        for state in range(env.numStates):
            state_value =0
    # Look at the possible next actions
            for action,action_prob in enumerate(policy[state]):
    # For each action, look at the possible next states...
    # env.model[state][action] is a list of transition tuples (prob, next_state, reward, done)
                for  prob, next_state, reward, done in env.model[state][action]:
                    # Calculate the expected value
                    state_value += action_prob * prob * (reward + discount_factor * value_fn[next_state])
           # How much our value function changed (across any states)
            delta = max(delta, np.abs(state_value - value_fn[state]))
            value_fn[state] = state_value
        # Stop evaluating once our value function change is below a threshold
        if delta < theta:
            break
    return value_fn
            

In [9]:
"""
Iteratively evaluates and improves a policy untill an Optimal Policy is found
Arguments:
    env: The OpenAI environment
    policy_eval_fn: Policy Evaluation function that takes three arguments: policy,env,discount_factor
    discount_factor: gamma discount factor
    
Returns:
    A tuple (policy,value_fn)
    policy is the optimal policy, a matrix of shape [S,A] where each state s contains a valid probability distribution 
    over actions
    value_fn is the value function for the optimal policy
"""

def policy_iteration(env, policy_eval_fn=policy_eval, discount_factor=1.0):
    
    def compute_value_fn_update(state,value_fn):
        value_fn_update = np.zeros(env.numActions)
        for action in range(env.numActions):
            for prob,next_state,reward,done in env.model[state][action]:
                value_fn_update[action] += prob * (reward + discount_factor * value_fn[next_state])
                
        return value_fn_update 
    # Start with a random policy
    policy = np.ones([env.numStates,env.numActions]) /env.numActions
     
 
    
    while True:
        # Evaluate the current policy, calculate the value function, call to policy_eval function
        value_fn = policy_eval_fn(policy, env, discount_factor)
        
        
        policy_stable = True
        
       # Policy Improvement
    
        for state in range(env.numStates):
            # The best action we would take under the currect policy
            chosen_a = np.argmax(policy[state])
            
            # Find the best action 
            # Ties are resolved arbitrarily
            action_values = compute_value_fn_update(state, value_fn)
            best_a = np.argmax(action_values)
            
            # Greedily update the policy
            if chosen_a != best_a:
                policy_stable = False
            policy[state] = np.eye(env.numActions)[best_a]
        
        # If the policy is stable we've found an optimal policy. Return it
        if policy_stable:
            return policy, value_fn
    
    while True:
        delta = 0
        for state in range(env.numStates):
            state_value = 0
            for action,action_prob in enumerate(policy[state]):
                for  prob, next_state, reward, done in env.model[state][action]:
                    state_value += action_prob * prob * (reward + discount_factor * value_fn[next_state])
                delta = max(delta, np.abs(state_value - value_fn[state]))
                value_fn[state] = state_value
        # Stop evaluating once our value function change is below a threshold
        if delta < theta:
            break
    return value_fn
    
    
  

We will learn about Policy Iteration in the subsequent modules, however we can observe that Policy Iteration is able to learn a policy that would take the Agent to the terminal state starting from any internal state.

In [10]:
env = GridWorld()
policy, v = policy_iteration(env)


In [11]:
print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):")
print(np.reshape(np.argmax(policy, axis=1), env.shape))
print("")

Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):
[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]



We also compute the Value Function for each state that corresponds to the number of steps required for the Agent to reach the terminal state since the reward is -1 for each step.

In [12]:
print("Value Function:")
print("Reshaped Grid Value Function:")
print(v.reshape(env.shape))
print("")

Value Function:
Reshaped Grid Value Function:
[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]

