# DAT257x: Reinforcement Learning Explained

## Lab 4: Dynamic Programming

### Exercise 4.2 Policy Iteration

Implement the algorithm for Policy Iteration in the code cell below.  

Note that there is a subtle difference between the algorithm for Policy Evaluation, which assumes the policy is stochastic, and the Policy Evaluation step for the Policy Iteration algorithm, which assumes the policy is deterministic.  This means that you cannot directly call your previous code, but you can reuse large pieces of it for the Policy Evaluation step.


In [1]:
import tester       # required for testing and grading your code

def policy_iteration(state_count, gamma, theta, get_available_actions, get_transitions):
    """
    This function computes the optimal value function and policy for the specified MDP, using the Policy Iteration algorithm.
    'state_count' is the total number of states in the MDP. States are represented as 0-relative numbers.
    'gamma' is the MDP discount factor for rewards.
    'theta' is the small number threshold to signal convergence of the value function (see Iterative Policy Evaluation algorithm).
    'get_available_actions' returns a list of the MDP available actions for the specified state parameter.
    'get_transitions' is the MDP state / reward transiton function.  It accepts two parameters, state and action, and returns
        a list of tuples, where each tuple is of the form: (next_state, reward, probabiliity).  
    """
    pi = state_count*[0]
    
    # init with a policy with first avail action for each state
    for s in range(state_count):
        avail_actions = get_available_actions(s)
        pi[s] = avail_actions[0][0]
        
    # insert code here to iterate using policy evaluation and policy improvement (see Policy Iteration algorithm)
    
    #Loop for test policy stable
    while True:
        V = state_count*[0]                # init all state value estimates to 0
        #Loop for policy evaluation
        while True:
            Delta = 0

            #Loop 4 each state in S
            for s in range(state_count):
                # Initialize a variable to update the value        
                v = 0 
                # Get the policy action
                #a = pi[s]
                for a in pi[s]:
                    # Next state, reward summation
                    for next_state,reward,prob_s_sprim in get_transitions(s,a):
                        # Update the value function for the second array
                        v += prob_s_sprim * (reward + gamma * V[next_state])
                # Update Delta
                Delta = max(Delta,abs(v-V[s]))
                #Update the value function. In the next iteration V[s] must be updated 4 fast convergence
                V[s] = v

            # loop until delta threshold is reached
            if Delta<theta:
                break
            
        
        
        # Policy evaluation
        policy_stable = True
        
        # For each state in S:
        for s in range(state_count):
            old_action = pi[s]
            # For all actions:
            action_vals = len(avail_actions)*[0]
            
            for indx,a in enumerate(get_available_actions(s)):
                # For all next_state,reward transitions calculate the argmax of all actions
                for next_state,reward,prob_s_sprim in get_transitions(s,a):
                    action_vals[indx] += prob_s_sprim * (reward + gamma * V[next_state])
                    
            m = max(action_vals)
            pi[s] = avail_actions[action_vals.index(m)]
                
            if old_action is not pi[s]:
                    policy_stable = False       
        
        if policy_stable == True:
            break
    
    
    
    return (V, pi)        # return both the final value function and the final policy

tester.policy_iteration_test(policy_iteration)  


Testing: Policy Iteration
passed test: return value is tuple
passed test: length of tuple = 2
passed test: v is list of length=15
passed test: values of v elements
passed test: pi is list of length=15
passed test: values of pi elements
PASSED: Policy Iteration passcode = 9970-010
