In [0]:
          _____                _____                    _____                    _____                    _____                    _____          
         /\    \              /\    \                  /\    \                  /\    \                  /\    \                  /\    \         
        /::\    \            /::\    \                /::\    \                /::\    \                /::\    \                /::\    \        
       /::::\    \           \:::\    \              /::::\    \              /::::\    \              /::::\    \               \:::\    \       
      /::::::\    \           \:::\    \            /::::::\    \            /::::::\    \            /::::::\    \               \:::\    \      
     /:::/\:::\    \           \:::\    \          /:::/\:::\    \          /:::/\:::\    \          /:::/\:::\    \               \:::\    \     
    /:::/__\:::\    \           \:::\    \        /:::/__\:::\    \        /:::/__\:::\    \        /:::/__\:::\    \               \:::\    \    
    \:::\   \:::\    \          /::::\    \      /::::\   \:::\    \      /::::\   \:::\    \      /::::\   \:::\    \              /::::\    \   
  ___\:::\   \:::\    \        /::::::\    \    /::::::\   \:::\    \    /::::::\   \:::\    \    /::::::\   \:::\    \    ____    /::::::\    \  
 /\   \:::\   \:::\    \      /:::/\:::\    \  /:::/\:::\   \:::\    \  /:::/\:::\   \:::\____\  /:::/\:::\   \:::\    \  /\   \  /:::/\:::\    \ 
/::\   \:::\   \:::\____\    /:::/  \:::\____\/:::/  \:::\   \:::\____\/:::/  \:::\   \:::|    |/:::/  \:::\   \:::\____\/::\   \/:::/  \:::\____\
\:::\   \:::\   \::/    /   /:::/    \::/    /\::/    \:::\  /:::/    /\::/   |::::\  /:::|____|\::/    \:::\  /:::/    /\:::\  /:::/    \::/    /
 \:::\   \:::\   \/____/   /:::/    / \/____/  \/____/ \:::\/:::/    /  \/____|:::::\/:::/    /  \/____/ \:::\/:::/    /  \:::\/:::/    / \/____/ 
  \:::\   \:::\    \      /:::/    /                    \::::::/    /         |:::::::::/    /            \::::::/    /    \::::::/    /          
   \:::\   \:::\____\    /:::/    /                      \::::/    /          |::|\::::/    /              \::::/    /      \::::/____/           
    \:::\  /:::/    /    \::/    /                       /:::/    /           |::| \::/____/               /:::/    /        \:::\    \           
     \:::\/:::/    /      \/____/                       /:::/    /            |::|  ~|                    /:::/    /          \:::\    \          
      \::::::/    /                                    /:::/    /             |::|   |                   /:::/    /            \:::\    \         
       \::::/    /                                    /:::/    /              \::|   |                  /:::/    /              \:::\____\        
        \::/    /                                     \::/    /                \:|   |                  \::/    /                \::/    /        
         \/____/                                       \/____/                  \|___|                   \/____/                  \/____/         

# Policy Evaluation Exercise

## Dependency installation

If current option doesn't work, try to execute commented lines.

In [1]:
# !python -m pip install -e git+https://github.com/star-ai/rl-environments.git#egg=rlenvs
# !python -m pip install gym
!pip install -e git+https://github.com/star-ai/rl-environments.git#egg=rlenvs
!pip install gym

Obtaining rlenvs from git+https://github.com/star-ai/rl-environments.git#egg=rlenvs
  Updating ./src/rlenvs clone
Installing collected packages: rlenvs
  Found existing installation: rlenvs 0.1
    Can't uninstall 'rlenvs'. No files were found to uninstall.
  Running setup.py develop for rlenvs
Successfully installed rlenvs


## Import dependecies

In [0]:
from IPython.core.debugger import set_trace
import numpy as np
import pprint

# Import below can all of a sudden break
# NOTE: if running locally, remove src.rlenvs from import
from src.rlenvs.rlenvs.envs.gridworld import GridworldEnv

## Create Environment

OpenAI env. 

- **`env.P`** represents the transition probabilities of the environment.
- **`env.P[s][a]`** is a list of transition tuples `(prob, next_state, reward, done).`
- **`env.nS`** is a number of states in the environment. 
- **`env.nA`** is a number of actions in the environment.

In [0]:
env = GridworldEnv()

## What we implement

$$
v(s) = R_s + \gamma \sum_{s' \in S} P_{ss'}V(s')
$$

## Implementation

### Function to calculate state value

In [0]:
def calculate_state_value(policy, state, env, V, discount_factor):
    """"
    Calculate state value given policy, state, and current state value function.

    Args:
      policy: Policy - [S, A] matrix of probabilities of action A given state S
      env: Environment.
        env.P[s][a] return list of transition tuples (transition_probability, 
          next_state, reward, done).
      V: current state value function, V[s] return value for state s.
    """
    v = 0
    # Look at the possible next actions
    for a, action_prob in enumerate(policy[state]):
        # For each action, look at the possible next states...
        for prob, next_state, reward, done in env.P[state][a]:
            # Calculate the expected value
            v += action_prob * prob * (reward + discount_factor * V[next_state])
    return v

### Function to calculate all states values and maximum change between current state values and new state values

The change is required to identify if the State-value function is converged to the optimal.

In [0]:
def run_full_sweep(policy, env, V, discount_factor):
    """
    Run a full sweep over states.
    """
    new_V = np.zeros(env.nS)
    delta = 0
    # For each state, perform a "full backup"
    for s in range(env.nS):
        v = calculate_state_value(policy, s, env, V, discount_factor)

        # How much our value function changed (across any states)
        delta = max(delta, np.abs(v - V[s]))
        new_V[s] = v
    return new_V, delta

### Function to evaluate given policy

In [0]:
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    """
    Evaluate a policy given an environment and a full description of the 
    environment's dynamics.
    
    Args:
        policy: [S, A] shaped matrix representing the policy.
        env: OpenAI env. env.P represents the transition probabilities of the 
          environment.
            env.P[s][a] is a list of transition tuples (prob, next_state, reward,
              done).
            env.nS is a number of states in the environment. 
            env.nA is a number of actions in the environment.
        theta: We stop evaluation once our value function change is less than 
          theta for all states.
        discount_factor: Gamma discount factor.
    
    Returns:
        Vector of length env.nS representing the value function.
    """
    
    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    while True:
        V, delta = run_full_sweep(policy, env, V, discount_factor)
        print(V)
        print(delta)
        # Stop evaluating once our value function change is below a threshold
        if delta < theta:
            break
    return np.array(V)

## Test Run

### Initialise random policy (All actions have equal probabilty)

In [0]:
random_policy = np.ones([env.nS, env.nA]) / env.nA

### Evaluate current policy

### Print results

In [10]:
pp = pprint.PrettyPrinter(indent=2)
print("Value Function:")
pp.pprint(np.reshape(v, (4, 4)))

Value Function:
array([[  0.        , -13.99989315, -19.99984167, -21.99982282],
       [-13.99989315, -17.99986052, -19.99984273, -19.99984167],
       [-19.99984167, -19.99984273, -17.99986052, -13.99989315],
       [-21.99982282, -19.99984167, -13.99989315,   0.        ]])


#### Simple test

In [11]:
# Test: Make sure the evaluated policy is what we expected
expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14,
                       -22, -20, -14, 0])
print('Expected')
pp.pprint(np.reshape(expected_v, (4,4)))
print()
np.testing.assert_array_almost_equal(v, expected_v, decimal=2)
print('Test passed')

Expected
array([[  0, -14, -20, -22],
       [-14, -18, -20, -20],
       [-20, -20, -18, -14],
       [-22, -20, -14,   0]])

Test passed
