# Policy Evaluation

In [3]:
import numpy as np

## Gridworld Example

In [29]:
action_map = {
    "up": np.array([-1, 0]),
    "right": np.array([0, 1]),
    "down": np.array([1, 0]),
    "left": np.array([0, -1])
}


def get_pos(ix, length):
    """
    Get the position in a square Gridworld
    """
    col = ix % length
    row = ix // length
    state = np.asarray([row, col])
    return state

In [110]:
class Grid:
    def __init__(self, length):
        self.length = length
        self.action_map = {
            "up": np.array([-1, 0]),
            "right": np.array([0, 1]),
            "down": np.array([1, 0]),
            "left": np.array([0, -1])
        }
        self.action_ix = {action: ix for ix, action in enumerate(self.action_map)}

    def get_pos(self, ix):
        """
        Compute cartesian coordinates of the griworld given an index
        """
        col = ix % self.length
        row = ix // self.length
        position = np.asarray([row, col])
        return position
    
    def get_ix(self, position):
        row, col = position
        ix = self.length * row + col
        return ix
    
    def move(self, ix, action):
        position = self.get_pos(ix)
        new_position = position + self.action_map[action]
        new_ix = self.get_ix(new_position)
        return new_ix
    
class GridV2(Grid):
    def __init__(self, length, terminal_states):
        super().__init__(length)
        self.terminal_states = terminal_states
    
    def _check_out_of_bounds(self, state):
        return (state < 0) | (state >= self.length ** 2)
    
    def move(self, state, action):
        if state in self.terminal_states:
            return state
        
        new_state = super().move(state, action)
        is_oob = self._check_out_of_bounds(new_state)
        return state * is_oob + new_state * (1 - is_oob)

In [111]:
terminal_states = np.array([0, 15])
gridworld = GridV2(l, terminal_states)

In [180]:
n_actions = 4
n_rewards = 2
state_size = 4 ** 2

# s', r | s, a
p_gridworld = np.zeros((state_size, n_rewards, state_size, n_actions))
p_gridworld.shape

(16, 2, 16, 4)

In [181]:
gridworld.move(0, "left")

0

In [182]:
p_gridworld[:, 1, 9].sum(axis=1).reshape(4, 4)

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [183]:
actions = ["up", "right", "down", "left"]
for s in np.arange(state_size):
    for a in actions:
        s_prime = gridworld.move(s, a)
        r = 0 if s in terminal_states else -1
        a_ix = gridworld.action_ix[a]
        p_gridworld[s_prime, r, s, a_ix] = r
        print(f"p({s_prime:2}, {r:2} | {s:2}, {a:5}) = 1")

p( 0,  0 |  0, up   ) = 1
p( 0,  0 |  0, right) = 1
p( 0,  0 |  0, down ) = 1
p( 0,  0 |  0, left ) = 1
p( 1, -1 |  1, up   ) = 1
p( 2, -1 |  1, right) = 1
p( 5, -1 |  1, down ) = 1
p( 0, -1 |  1, left ) = 1
p( 2, -1 |  2, up   ) = 1
p( 3, -1 |  2, right) = 1
p( 6, -1 |  2, down ) = 1
p( 1, -1 |  2, left ) = 1
p( 3, -1 |  3, up   ) = 1
p( 4, -1 |  3, right) = 1
p( 7, -1 |  3, down ) = 1
p( 2, -1 |  3, left ) = 1
p( 0, -1 |  4, up   ) = 1
p( 5, -1 |  4, right) = 1
p( 8, -1 |  4, down ) = 1
p( 3, -1 |  4, left ) = 1
p( 1, -1 |  5, up   ) = 1
p( 6, -1 |  5, right) = 1
p( 9, -1 |  5, down ) = 1
p( 4, -1 |  5, left ) = 1
p( 2, -1 |  6, up   ) = 1
p( 7, -1 |  6, right) = 1
p(10, -1 |  6, down ) = 1
p( 5, -1 |  6, left ) = 1
p( 3, -1 |  7, up   ) = 1
p( 8, -1 |  7, right) = 1
p(11, -1 |  7, down ) = 1
p( 6, -1 |  7, left ) = 1
p( 4, -1 |  8, up   ) = 1
p( 9, -1 |  8, right) = 1
p(12, -1 |  8, down ) = 1
p( 7, -1 |  8, left ) = 1
p( 5, -1 |  9, up   ) = 1
p(10, -1 |  9, right) = 1
p(13, -1 |  

In [251]:
rewards = np.array([0, -1])
vk = np.zeros(state_size)

for _ in range(10):
    single_reward = (rewards + vk[:, None])
    vk = np.einsum("ijkl,ij->k", p_gridworld, single_reward) / 4
    print(vk.reshape(4, 4), end="\n" * 2)

[[0. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 0.]]

[[0.   0.25 0.   0.  ]
 [0.25 0.   0.   0.  ]
 [0.   0.   0.   0.25]
 [0.   0.   0.25 0.  ]]

[[0.     0.9375 0.9375 0.9375]
 [1.     0.875  1.     0.9375]
 [0.9375 1.     0.875  1.    ]
 [0.9375 0.9375 0.9375 0.    ]]

[[0.       0.3125   0.046875 0.046875]
 [0.3125   0.015625 0.09375  0.03125 ]
 [0.03125  0.09375  0.015625 0.3125  ]
 [0.046875 0.046875 0.3125   0.      ]]

[[0.         0.90625    0.875      0.890625  ]
 [0.9765625  0.796875   0.97265625 0.87890625]
 [0.87890625 0.97265625 0.796875   0.9765625 ]
 [0.890625   0.875      0.90625    0.        ]]

[[0.         0.35546875 0.08886719 0.09472656]
 [0.35839844 0.04296875 0.16308594 0.0703125 ]
 [0.0703125  0.16308594 0.04296875 0.35839844]
 [0.09472656 0.08886719 0.35546875 0.        ]]

[[0.         0.87817383 0.82446289 0.84692383]
 [0.94799805 0.73999023 0.9387207  0.82836914]
 [0.82836914 0.9387207  0.73999023 0.94799805]
 [0.84692383 0.82446289 0.87817383 0.  