In [1]:
import numpy as np

In [2]:
ACTION_SPACE = ('U', 'D', 'L', 'R')

In [3]:
class WindyGrid:
  def __init__(self, rows, cols, start):
    self.rows = rows
    self.cols = cols
    self.i = start[0]
    self.j = start[1]

  def set(self, rewards, actions, probs):
    self.rewards = rewards
    self.actions = actions
    self.probs = probs

  def set_state(self, s):
    self.i = s[0]
    self.j = s[1]

  def current_state(self):
    return (self.i, self.j)

  def is_terminal(self, s):
    return s not in self.actions

  def move(self, action):
    s = (self.i, self.j)
    a = action

    next_state_probs = self.probs[(s, a)]
    next_states = list(next_state_probs.keys())
    next_probs = list(next_state_probs.values())
    s2 = np.random.choice(next_states, p=next_probs)

    # update the current state
    self.i, self.j = s2

    return self.rewards.get(s2, 0)

  def game_over(self):
    return (self.i, self.j) not in self.actions

  def all_states(self):
    return set(self.actions.keys()) | set(self.rewards.keys())

In [4]:
def windy_grid():
  g = WindyGrid(3, 4, (2, 0))
  rewards = {(0, 3) : 1, (1, 3) : -1}
  actions = {
      (0, 0) : ('D', 'R'),
      (0, 1) : ('R', 'L'),
      (0, 2) : ('R', 'L', 'D'),
      (1, 0) : ('U', 'D'),
      (1, 2) : ('U', 'R', 'D'),
      (2, 0) : ('U', 'R'),
      (2, 1) : ('L', 'R'),
      (2, 2) : ('L', 'U', 'R'),
      (2, 3) : ('L', 'U')
  }
  # p(s' | s, a) represented as:
  # KEY: (s, a) --> VALUE: {s': p(s' | s, a)}
  probs = {
      ((2, 0), 'U'): {(1, 0): 1.0},
      ((2, 0), 'D'): {(2, 0): 1.0},
      ((2, 0), 'L'): {(2, 0): 1.0},
      ((2, 0), 'R'): {(2, 1): 1.0},
      ((1, 0), 'U'): {(0, 0): 1.0},
      ((1, 0), 'D'): {(2, 0): 1.0},
      ((1, 0), 'L'): {(1, 0): 1.0},
      ((1, 0), 'R'): {(1, 0): 1.0},
      ((0, 0), 'U'): {(0, 0): 1.0},
      ((0, 0), 'D'): {(1, 0): 1.0},
      ((0, 0), 'L'): {(0, 0): 1.0},
      ((0, 0), 'R'): {(0, 1): 1.0},
      ((0, 1), 'U'): {(0, 1): 1.0},
      ((0, 1), 'D'): {(0, 1): 1.0},
      ((0, 1), 'L'): {(0, 0): 1.0},
      ((0, 1), 'R'): {(0, 2): 1.0},
      ((0, 2), 'U'): {(0, 2): 1.0},
      ((0, 2), 'D'): {(1, 2): 1.0},
      ((0, 2), 'L'): {(0, 1): 1.0},
      ((0, 2), 'R'): {(0, 3): 1.0},
      ((2, 1), 'U'): {(2, 1): 1.0},
      ((2, 1), 'D'): {(2, 1): 1.0},
      ((2, 1), 'L'): {(2, 0): 1.0},
      ((2, 1), 'R'): {(2, 2): 1.0},
      ((2, 2), 'U'): {(1, 2): 1.0},
      ((2, 2), 'D'): {(2, 2): 1.0},
      ((2, 2), 'L'): {(2, 1): 1.0},
      ((2, 2), 'R'): {(2, 3): 1.0},
      ((2, 3), 'U'): {(1, 3): 1.0},
      ((2, 3), 'D'): {(2, 3): 1.0},
      ((2, 3), 'L'): {(2, 2): 1.0},
      ((2, 3), 'R'): {(2, 3): 1.0},
      ((1, 2), 'U'): {(0, 2): 0.5, (1, 3): 0.5},
      ((1, 2), 'D'): {(2, 2): 1.0},
      ((1, 2), 'L'): {(1, 2): 1.0},
      ((1, 2), 'R'): {(1, 3): 1.0},
  }
  g.set(rewards, actions, probs)
  return g

In [5]:
from tabulate import tabulate

In [6]:
SMALL_ENOUGH = 1e-3

In [7]:
def print_values(V, g) :
  table = []
  for i in range(g.rows) :
    row = []
    for j in range(g.cols) :
      v = V.get((i, j), 0)
      row.append(v)
    table.append(row)
  print(tabulate(table, tablefmt="grid", floatfmt=".3f"))

def print_policy(P, g) :
  table = []
  for i in range(g.rows) :
    row = []
    for j in range(g.cols) :
      a = P.get((i, j), ' ')
      row.append(a)
    table.append(row)
  print(tabulate(table, tablefmt="grid"))

In [9]:
# transition_probs[(s, a, s')] = p(s' | s, a)
transition_probs = {}

# key -> (s, a, s')
rewards = {}

grid = windy_grid()
for (s, a), v in grid.probs.items():
  for s2, p in v.items():
    transition_probs[(s, a, s2)] = p
    rewards[(s, a, s2)] = grid.rewards.get(s2, 0)

## probabilistic policy
policy = {
    (2, 0) : {'U' : 0.5, 'R' : 0.5},
    (1, 0) : {'U' : 1.0},
    (0, 0) : {'R' : 1.0},
    (0, 1) : {'R' : 1.0},
    (0, 2) : {'R' : 1.0},
    (1, 2) : {'U' : 1.0},
    (2, 1) : {'R' : 1.0},
    (2, 2) : {'U' : 1.0},
    (2, 3) : {'L' : 1.0},
}
print_policy(policy, grid)

+----------------------+------------+------------+------------+
| {'R': 1.0}           | {'R': 1.0} | {'R': 1.0} |            |
+----------------------+------------+------------+------------+
| {'U': 1.0}           |            | {'U': 1.0} |            |
+----------------------+------------+------------+------------+
| {'U': 0.5, 'R': 0.5} | {'R': 1.0} | {'U': 1.0} | {'L': 1.0} |
+----------------------+------------+------------+------------+


In [12]:
# inititalize V(s) = 0
V = {}
for s in grid.all_states():
  V[s] = 0

gamma = 0.9 # discount factor

# repeat until convergence
it = 0
while True:
  biggest_change = 0
  for s in grid.all_states():
    if not grid.is_terminal(s):
      old_v = V[s]
      new_v = 0

      for a in ACTION_SPACE:
        for s2 in grid.all_states():
          action_prob = policy[s].get(a, 0)

          r = rewards.get((s, a, s2), 0)
          new_v += action_prob * transition_probs.get((s, a, s2), 0) * (r + gamma * V[s2])

      V[s] = new_v
      biggest_change = max(biggest_change, np.abs(old_v - V[s]))

  print("iter:", it, "biggest_change:", biggest_change)
  print_values(V, grid)
  it += 1

  if biggest_change < SMALL_ENOUGH:
    break

iter: 0 biggest_change: 1.0
+-------+-------+--------+-------+
| 0.000 | 0.000 |  1.000 | 0.000 |
+-------+-------+--------+-------+
| 0.000 | 0.000 | -0.500 | 0.000 |
+-------+-------+--------+-------+
| 0.000 | 0.000 | -0.450 | 0.000 |
+-------+-------+--------+-------+
iter: 1 biggest_change: 0.9
+--------+--------+--------+--------+
|  0.810 |  0.900 |  1.000 |  0.000 |
+--------+--------+--------+--------+
|  0.729 |  0.000 | -0.050 |  0.000 |
+--------+--------+--------+--------+
| -0.182 | -0.405 | -0.045 | -0.405 |
+--------+--------+--------+--------+
iter: 2 biggest_change: 0.4920750000000001
+-------+--------+--------+--------+
| 0.810 |  0.900 |  1.000 |  0.000 |
+-------+--------+--------+--------+
| 0.729 |  0.000 | -0.050 |  0.000 |
+-------+--------+--------+--------+
| 0.310 | -0.040 | -0.045 | -0.040 |
+-------+--------+--------+--------+
iter: 3 biggest_change: 0
+-------+--------+--------+--------+
| 0.810 |  0.900 |  1.000 |  0.000 |
+-------+--------+--------+----