# Gridworld Example

Rules:
1. If you are at $A$, you get a reward of +10 irrespective of the action only if you land in $A'$
2. If you are at $B$, you get a reward of +5 irrespective of the action only if you land in $B'$
3. If you are at a boundary except for $A$ and $B$, you get -1 if you take an action that takes you out of bounds, but you remain in the same state
4. Any other step has zero reward

In [27]:
import numpy as np
from numba import jit

In [486]:
# actions, rewards, (*state)
n_actions = 4
n_rewards = 4
state_size = 25

# p(s',r | a, s)
p_gridworld = np.zeros((state_size, n_rewards, n_actions, state_size))
p_gridworld.shape

(25, 4, 4, 25)

In [489]:
lower_bound = 0
upper_bound = np.sqrt(state_size).astype(int) - 1

In [503]:
rewards = np.array([0, 5, 10, -1])
reward_map = {r: ix for ix, r in enumerate(rewards)}

actions = ["up", "right", "down", "left"]
actions_ix_map = {a: ix for ix, a in enumerate(rewards)}

action_map = {
    "up": np.array([-1, 0]),
    "right": np.array([0, 1]),
    "down": np.array([1, 0]),
    "left": np.array([0, -1])
}

# mapping from special states to rewards
special_map = {
    1: 10,
    3: 5
}

# mapping from special states to terminal states
special_state_map = {
    1: 21,
    3: 13
}

In [504]:
def get_pos(ix):
    col = ix % 5
    row = ix // 5
    state = np.asarray([row, col])
    return state

def get_state(position):
    row, col = position
    return 5 * row + col

def move(state, action):
    position = get_pos(state)
    new_position = position + action_map[action]
    return new_position

def is_out_of_bounds(position, lb=0, ub=4):
    return (new_pos < lb).any() or (new_pos > ub).any()

In [519]:
# p(s',r | a, s)
p_gridworld = np.zeros((state_size, n_rewards, n_actions, state_size))
curr_pos = get_pos(s)
for s in range(state_size):
    curr_pos = get_pos(s)
    for r in reward_map:
        for action in action_map:
            a_pos = action_ix_map[action]
            r_pos = reward_map[r]
            new_pos = move(s, action)
            new_state = get_state(new_pos)

            val = 0
            if s in special_states:
                if r == special_map[s]:
                    val = 1
                new_state = special_state_map[s]
                new_pos = get_pos(new_state)
            elif is_out_of_bounds(new_pos):
                if r == -1:
                    val = 1
                new_pos = curr_pos
                new_state = s
            elif r == 0:
                val = 1
            
            if val == 1 and r == 10:
                pass
                print(f"{r=:2}, {action=:5}, {curr_pos} -> {new_pos}")
            p_gridworld[new_state, r_pos, a_pos, s] = val
p_gridworld = p_gridworld / p_gridworld.sum(axis=0, keepdims=True).sum(axis=1, keepdims=True)

r=10, action=up   , [0 1] -> [4 1]
r=10, action=right, [0 1] -> [4 1]
r=10, action=down , [0 1] -> [4 1]
r=10, action=left , [0 1] -> [4 1]


In [514]:
# Σ_{s', r, a} r * p(s', r | a, s)
b = (p_gridworld * rewards[None, :, None, None]).sum(axis=0).sum(axis=0).sum(axis=0) / 4
b

array([-0.5 , 10.  , -0.25,  5.  , -0.5 , -0.25,  0.  ,  0.  ,  0.  ,
       -0.25, -0.25,  0.  ,  0.  ,  0.  , -0.25, -0.25,  0.  ,  0.  ,
        0.  , -0.25, -0.5 , -0.25, -0.25, -0.25, -0.5 ])

In [515]:
γ = 0.9
I = np.eye(state_size)
A = I - γ / 4 * p_gridworld.sum(axis=1).sum(axis=1)
A.shape

(25, 25)

In [516]:
np.linalg.solve(A, b).reshape(5, 5).round(1)

array([[-0.7, 10.2,  0.2,  5.5, -0.4],
       [ 0.5,  1.3,  1.8,  2.5,  1.1],
       [ 2.2,  3.3,  4.1,  8.1,  3. ],
       [ 4.8,  7.1,  5.2,  4.3,  2.2],
       [ 8.5, 18.2,  7.5,  3.5,  1.4]])