In [1]:
from copy import deepcopy
from collections import defaultdict
from pprint import pprint
import sys
if "../" not in sys.path:
    sys.path.append("../") 

import numpy as np
from envs.classic_gridworld import *

In [2]:
def policy_evaluation(policy, env, gamma=1., theta=1e-5):
    V = np.zeros(env.nS)
    while True:
        delta = 0
        for state in range(env.nS):
            v = 0
            # enumerate over possible actions specified by the policy
            for action, action_prob in enumerate(policy[state]):
                # enumerate over next possible states
                for prob, next_state, reward, done in env.P[state][action]:
                    v += action_prob *  prob * (reward + gamma * V[next_state])
            delta = max(delta, np.abs(v - V[state]))
            V[state] = v
        if delta < theta:
            break
    return V

# 3x4 Gridworld

In [3]:
env = ClassicGridEnv3x4()
optimal_3x4_policy = np.array([
    [0.0, 0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0, 0.0],
    [0.0, 0.0, 0.0, 1.0],
    [0.0, 0.0, 0.0, 1.0],
    [0.0, 0.0, 0.0, 1.0],
    [0.0, 0.0, 0.0, 1.0],
    [0.0, 0.0, 0.0, 1.0],
    [1.0, 0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0],
])
# policy_evaluation(optimal_3x4_policy, env, gamma=0.9)

In [27]:
V = np.zeros(env.nS)
V[3] = 1
V[7] = -1
V[2] = 0.72
v = 0
gamma = 0.9
for prob, next_state, reward, done in env.P[6][3]:
    print(prob, next_state, reward)
    v += prob * (gamma * V[next_state] )
v

0.1 7 -1.0
0.8 2 0.0
0.1 6 0.0


0.42840000000000006

In [15]:
env.P[2]

{0: [(0.1, 2, 0.0, False), (0.8, 1, 0.0, False), (0.1, 6, 0.0, False)],
 1: [(0.1, 1, 0.0, False), (0.8, 6, 0.0, False), (0.1, 3, 1.0, True)],
 2: [(0.1, 6, 0.0, False), (0.8, 3, 1.0, True), (0.1, 2, 0.0, False)],
 3: [(0.1, 3, 1.0, True), (0.8, 2, 0.0, False), (0.1, 1, 0.0, False)]}

In [None]:
0.8 * 0.9 + 0.1 * 0.8 *

In [4]:
print('State 3: Goal')
pprint(env.P[3])
print()
print('State 7: Pit')
pprint(env.P[7])
print()
print('State 5: Obstacle')
pprint(env.P[5])

State 3: Goal
{0: [(1.0, 3, 0, True)],
 1: [(1.0, 3, 0, True)],
 2: [(1.0, 3, 0, True)],
 3: [(1.0, 3, 0, True)]}

State 7: Pit
{0: [(1.0, 7, 0, True)],
 1: [(1.0, 7, 0, True)],
 2: [(1.0, 7, 0, True)],
 3: [(1.0, 7, 0, True)]}

State 5: Obstacle
{0: [(0.0, 5, 0, False)],
 1: [(0.0, 5, 0, False)],
 2: [(0.0, 5, 0, False)],
 3: [(0.0, 5, 0, False)]}


In [5]:
env.P[7]

{0: [(1.0, 7, 0, True)],
 1: [(1.0, 7, 0, True)],
 2: [(1.0, 7, 0, True)],
 3: [(1.0, 7, 0, True)]}

# 4x4 Gridworld

In [6]:
# create environment
env = ClassicGridEnv4x4()

env.reset()
# look at environment dynamics
print('State 3')
pprint(env.P[3])
print('Goal')
pprint(env.P[0])

0
0
0
0
15
15
15
15
State 3
{0: [(1.0, 2, -1.0, False)],
 1: [(1.0, 7, -1.0, False)],
 2: [(1.0, 3, -1.0, False)],
 3: [(1.0, 3, -1.0, False)]}
Goal
{0: [(1.0, 0, 0, True)],
 1: [(1.0, 0, 0, True)],
 2: [(1.0, 0, 0, True)],
 3: [(1.0, 0, 0, True)]}


  isd /= isd.sum()
  return (csprob_n > np_random.rand()).argmax()


In [9]:
random_policy = np.ones([env.nS, env.nA]) / env.nA
policy_evaluation(random_policy, env).reshape((4,4))

array([[  0.        , -12.99993311, -18.99990384, -20.99989416],
       [-12.99993311, -16.99991792, -18.99991088, -18.9999119 ],
       [-18.99990384, -18.99991088, -16.9999248 , -12.99994386],
       [-20.99989416, -18.9999119 , -12.99994386,   0.        ]])

# 5x4 Gridworld

In [None]:
env = ClassicGridEnv5x4Static()
random_policy = np.ones([env.nS, env.nA]) / env.nA
pprint(policy_evaluation(random_policy, env))

In [None]:
print('State 12')
pprint(env.P[12])
print('State 13')
pprint(env.P[13])
print('State 17')
pprint(env.P[17])

In [None]:
env = ClassicGridEnv5x4Dynamic()
random_policy = np.ones([env.nS, env.nA]) / env.nA
policy_evaluation(random_policy, env)


In [None]:
print('State 12')
pprint(env.P[12])
print('State 13')
pprint(env.P[13])
print('State 17')
pprint(env.P[17])

In [11]:
0.8*(-0.02 + 0.9 *1)

0.7040000000000001

In [30]:
for i in range(10):
    if i == 6:
        continue
    print(i)


0
1
2
3
4
5
7
8
9
