# Example 3.5: Gridworld

<img src="figures/chap.03.05.example3.5.gridworld.png" width="70%">

* 각 cell에서 취할 수 있는 actions: `north`, `south`, `east`, and `west`
* deterministic action
* 경계를 벗어나는 action에 대해서는 state가 변하지 않고 reward `-1`을 받음
* 나머지 다른 action(state `A`와 `B`에 있을때를 빼고)에 대해서는 reward `0`을 받음
* state `A`에서는 어떤 action을 하던지 state `A'`으로 가고 reward `+10`을 받음
* state `B`에서는 어떤 action을 하던지 state `B'`으로 가고 reward `+5`을 받음

In [1]:
import typing

import numpy as np
np.set_printoptions(precision=1)

### Grid world state index

|   |   |   |   |   |
|----|----|----|----|----|
| 0,0  | 0,1 | 0,2 | 0,3 | 0,4 |
| 1,0  | 1,1 | 1,2 | 1,3 | 1,4 |
| 2,0  | 2,1 | 2,2 | 2,3 | 2,4 |
| 3,0  | 3,1 | 3,2 | 3,3 | 3,4 |
| 4,0  | 4,1 | 4,2 | 4,3 | 4,4 |

## Bellman equation

$$v_{\pi}(s) = \sum_{a} \pi(a | s)
\sum_{s', r} p(s', r | s, a)
\left[ r + \gamma v_{\pi}(s') \right]$$

* $p(s', r | s, a)$: deterministic
  * $p(s'=(4,1), \, r=10 \ | \ s=(0,1), \, a=\textrm{'north'}) = 1$
    * state $A$에서 'north' 방향으로 움직였을 때 다음 state가 $(4,1)$ 이고 reward 10을 받을 확률은 1
  * $p(s'=(2,1), \, r=0 \ | \ s=(1,1), \, a=\textrm{'north'}) = 0$
    * state $(1,1)$에서 'north' 방향으로 움직였을 때 다음 state가 $(2,1)$ 이고 reward 0을 받을 확률은 0
    * deterministic 이라 $s'=(0,1)$만 허용됨

#### `self.M`

* linear equation의 계수를을 모아놓은 matrix

$$ M V = R$$

$$ \left[ \begin{array}{cccc}
w_{0,0} & w_{0,1} & \cdots  & w_{0,24} \\
w_{1,0} & w_{1,1} & \cdots  & w_{1,24} \\
\vdots & \vdots & \vdots & \vdots \\
w_{24,0} & w_{24,1} & \cdots  & w_{24,24}
\end{array} \right]
\left[ \begin{array}{c}
v_{\pi}(s_{(0, 0)}) \\
v_{\pi}(s_{(0, 1)}) \\
\vdots \\
v_{\pi}(s_{(4, 4)}) \\
\end{array} \right]
= \left[ \begin{array}{c}
\frac{1}{4} R_{s_{(0, 0)}} \\
\frac{1}{4} R_{s_{(0, 1)}} \\
\vdots \\
\frac{1}{4} R_{s_{(4, 4)}} \\
\end{array} \right]
$$

where $R_{s_{(0, 0)}} = r_{a=\textrm{'north'}} + r_{a=\textrm{'south'}} + r_{a=\textrm{'east'}} + r_{a=\textrm{'west'}}$

In [2]:
class GridWorld(object):
  
  def __init__(self, L: int = 5) -> None:
    self.L = L
    self.N = self.L * self.L
    self.actions_list = ['north', 'south', 'east', 'west']
    self.rewards_list = [10, 5, 0, -1]
    self.p = np.zeros((self.N, len(self.rewards_list),
                       self.N, len(self.actions_list)))
    
    # set state2ij and ij2state
    self.state2ij = np.zeros(shape=(self.N, 2), dtype=np.int64)
    self.ij2state = {}
    for i in range(self.L):
      for j in range(self.L):
        self.state2ij[i * self.L + j] = (i, j)
        self.ij2state[(i, j)] = i * self.L + j
        
    self.A = (0, 1)  # special site
    self.B = (0, 3)  # special site
    self.A_ = (4, 1)  # special site
    self.B_ = (2, 3)  # special site
    
    # assign transition matrix p describing model dynamics
    for s in range(self.N):
      if self.is_A_or_B(s):
        continue
      for a_idx, a in enumerate(self.actions_list):
        for s_ in range(self.N):
          for r_idx, r in enumerate(self.rewards_list):
            expected_s, expected_r = self.step(s, a)

            if expected_s == s_ and expected_r == r:
              # print(s_, r_idx, s, a_idx)
              self.p[s_, r_idx, s, a_idx] = 1.
        
  def north(self, i: int, j: int) -> typing.Tuple[typing.Tuple[int, int], bool]:
    if i == 0:
      return (i, j), False
    else:
      return (i - 1, j), True

  def south(self, i: int, j: int) -> typing.Tuple[typing.Tuple[int, int], bool]:
    if i == self.L - 1:
      return (i, j), False
    else:
      return (i + 1, j), True

  def east(self, i: int, j: int) -> typing.Tuple[typing.Tuple[int, int], bool]:
    if j == self.L - 1:
      return (i, j), False
    else:
      return (i, j + 1), True

  def west(self, i: int, j: int) -> typing.Tuple[typing.Tuple[int, int], bool]:
    if j == 0:
      return (i, j), False
    else:
      return (i, j - 1), True
  
  def transition(self, state: int, action: str) -> typing.Tuple[int, bool]:
    i, j = self.state2ij[state]
    if action == 'north':
      new_state, is_moving = self.north(i, j)
    elif action == 'south':
      new_state, is_moving = self.south(i, j)
    elif action == 'east':
      new_state, is_moving = self.east(i, j)
    elif action == 'west':
      new_state, is_moving = self.west(i, j)

    return self.ij2state[new_state], is_moving
  
  def is_A_or_B(self, state: int) -> bool:
    i, j = self.state2ij[state]
    if self.A == (i, j):
      state_A_ = self.ij2state[self.A_]
      self.p[state_A_, 0, state, :] = 1.
      return True
    if self.B == (i, j):
      state_B_ = self.ij2state[self.B_]
      self.p[state_B_, 1, state, :] = 1.
      return True
    
    return False
  
  def dynamics(self, state_: int, reward_idx: int,
               state: int, action_idx: int) -> float:
    return self.p[state_, reward_idx, state, action_idx]
  
  def step(self, state: int, action: str) -> typing.Tuple[int, int]:
    assert action in self.actions_list
    i, j = self.state2ij[state]
    if self.A == (i, j):
      return self.ij2state[self.A_], 10
    if self.B == (i, j):
      return self.ij2state[self.B_], 5

    new_state, is_moving = self.transition(state, action)
    if is_moving:
      return new_state, 0
    else:
      return new_state, -1

In [3]:
class Agent(object):
  
  def __init__(self) -> None:
    self.action_list = ['north', 'south', 'east', 'west']
    
  def policy(self, state: int, action: str) -> float:
    return 0.25
  
  def action(self, state: int) -> str:
    probabilities = np.array([self.policy(state, action) for action in self.action_list])
    assert np.sum(probabilities) == 1
    
    return np.random.choice(self.action_list, p=probabilities)

In [4]:
model = GridWorld()
agent = Agent()

In [5]:
## Linear equations
L = model.L
N = model.N
w = np.zeros((N, N))
b = np.zeros(N)
gamma = 0.9

In [6]:
for s in range(N):
  for s_ in range(N):
    for a_idx, a in enumerate(model.actions_list):
      for r_idx, r in enumerate(model.rewards_list):
        coeff = agent.policy(s, a) * model.dynamics(s_, r_idx, s, a_idx)
        w[s, s_] += coeff * gamma
        b[s] += coeff * r

In [7]:
identity_matrix = np.eye(N)
w = w - identity_matrix
b = -b

In [8]:
x = np.linalg.solve(w, b)
x = x.reshape(L, L)
print(x)

[[ 3.3  8.8  4.4  5.3  1.5]
 [ 1.5  3.   2.3  1.9  0.5]
 [ 0.1  0.7  0.7  0.4 -0.4]
 [-1.  -0.4 -0.4 -0.6 -1.2]
 [-1.9 -1.3 -1.2 -1.4 -2. ]]


### Results

<img src="figures/chap.03.05.example3.5.gridworld.png" width="70%">