# DAT257x: Reinforcement Learning Explained

## Lab: Dynamic Programming

### Policy Evaluation

Policy Evaluation calculates the value function for a policy, given the policy and the full definition of the associated Markov Decision Process.  The full definition of an MDP is the set of states, the set of available actions for each state, the set of rewards, the discount factor, and the state/reward transition function.

In [13]:
import numpy as np
import gridworld_mdp as gw   # defines the MDP for a 4x4 gridworld

The gridworld MDP defines the probability of state transitions for our 4x4 gridworld using a "get_transitions()" function.  

Let's try it out now, with state=2 and all defined actions.

In [14]:
# try out the gw.get_transitions(state, action) function

state = 5
actions = gw.get_available_actions(state)

for action in actions:
    transitions = gw.get_transitions(state=state, action=action)

    # examine each return transition (only 1 per call for this MDP)
    for (trans) in transitions:
        next_state, reward, probability = trans    # unpack tuple
        print("transition("+ str(state) + ", " + action + "):", "next_state=", next_state, ", reward=", reward, ", probability=", probability)

transition(5, up): next_state= 1 , reward= -1 , probability= 1
transition(5, down): next_state= 9 , reward= -1 , probability= 1
transition(5, left): next_state= 4 , reward= -1 , probability= 1
transition(5, right): next_state= 6 , reward= -1 , probability= 1


In [15]:
gw.get_transitions(state=5, action='down')[0][0]

9

**Implement the algorithm for Iterative Policy Evaluation using the in-place approach**. In the in-place approach, one array holds the values being estimated for each state and the same array is used for estimates of states needed by the algorithm.

In [16]:
def policy_eval_in_place(state_count, gamma, theta, get_policy, get_transitions):
    """
    This function uses the in-place approach to evaluate the specified policy for the specified MDP:
    
    'state_count' is the total number of states in the MDP. States are represented as 0-relative numbers.
    
    'gamma' is the MDP discount factor for rewards.
    
    'theta' is the small number threshold to signal convergence of the value function (see Iterative Policy Evaluation algorithm).
    
    'get_policy' is the stochastic policy function - it takes a state parameter and returns list of tuples, 
        where each tuple is of the form: (action, probability).  It represents the policy being evaluated.
        
    'get_transitions' is the state/reward transiton function.  It accepts two parameters, state and action, and returns
        a list of tuples, where each tuple is of the form: (next_state, reward, probabiliity).  
         
    """
    V = state_count*[0]
    while True:
        delta = 0
        for s in range(state_count):
            v = V[s]
            tmp_value = 0
            for pi in get_policy(s):
                action, action_probability = pi
                transitions = get_transitions(state=s, action=action)
                for transition in transitions:
                    next_state, reward, probability = transition
                    tmp_value += action_probability*probability*(reward + gamma*V[next_state])
            V[s] = tmp_value
            delta = max(delta, abs(v - V[s]))
        if  delta < theta:
            break
    return V

In [19]:
def get_equal_policy(state):
    # build a simple policy where all 4 actions have the same probability, ignoring the specified state
    policy = ( ("up", .25), ("right", .25), ("down", .25), ("left", .25))
    return policy

In [20]:
n_states = gw.get_state_count()

# test our function
values = policy_eval_in_place(state_count=n_states, gamma=.9, theta=.001, get_policy=get_equal_policy, \
    get_transitions=gw.get_transitions)

In [21]:
a = np.append(values, 0)
np.reshape(a, (4,4))

array([[ 0.        , -5.27590649, -7.12580367, -7.64772992],
       [-5.27590649, -6.60421391, -7.17850791, -7.12638424],
       [-7.12580367, -7.17850791, -6.60467837, -5.27666399],
       [-7.64772992, -7.12638424, -5.27666399,  0.        ]])