In [1]:
import numpy as np

In [2]:
ACTION_SPACE = ('U', 'D', 'L', 'R')

In [3]:
class Grid:
  def __init__(self, rows, cols, start):
    self.rows = rows
    self.cols = cols
    self.i = start[0]
    self.j = start[1]

  def set(self, rewards, actions):
    # rewards : dict of (i, j) -> reward
    # actions : dist of (i, j) -> list of possible actions
    self.rewards = rewards
    self.actions = actions

  def set_state(self, s):
    self.i = s[0]
    self.j = s[1]

  def current_state(self):
    return (self.i, self.j)

  def is_terminal(self, s):
    return s not in self.actions

  def reset(self):
    # put agent back in start position
    self.i = 2
    self.j = 0
    return (self.i, self.j)

  def get_next_state(self, s, a):
    i, j = s[0], s[1]

    if a in self.actions[(i, j)]:
      if a == 'U':
        i -= 1
      elif a == 'D':
        i += 1
      elif a == 'L':
        j -= 1
      elif a == 'R':
        j += 1

    return i, j

  def move(self, action):
    if action in self.actions[(self.i, self.j)]:
      if action == 'U':
        self.i -= 1
      elif action == 'D':
        self.i += 1
      elif action == 'L':
        self.j -= 1
      elif action == 'R':
        self.j += 1

    return self.rewards.get((self.i, self.j), 0)

  def undo_move(self, action):
    if action == 'U':
      self.i += 1
    elif action == 'D':
      self.i -= 1
    elif action == 'L':
      self.j += 1
    elif action == 'R':
      self.j -= 1

  def game_over(self):
    return (self.i, self.j) not in self.actions

  def all_states(self):
    return set(self.actions.keys()) | set(self.rewards.keys())

In [4]:
def standard_grid():
  g = Grid(3, 4, (2, 0))
  rewards = {(0, 3): 1, (1, 3): -1}
  actions = {
      (0, 0): ('D', 'R'),
      (0, 1): ('L', 'R'),
      (0, 2): ('L', 'D', 'R'),
      (1, 0): ('U', 'D'),
      (1, 2): ('U', 'D', 'R'),
      (2, 0): ('U', 'R'),
      (2, 1): ('L', 'R'),
      (2, 2): ('L', 'R', 'U'),
      (2, 3): ('L', 'U'),
  }
  g.set(rewards, actions)
  return g

In [5]:
SMALL_ENOUGH = 1e-3 # threshold for convergence

In [12]:
def print_values(V, g):
  for i in range(g.rows):
    print("---------------------------")
    for j in range(g.cols):
      v = V.get((i, j), 0)
      if v >= 0:
        print(" %.2f|" % v, end="")
      else:
        print("%.2f|" % v, end="")
    print("")

In [7]:
def print_policy(P, g):
  for i in range(g.rows):
    print("---------------------------")
    for j in range(g.cols):
      a = P.get((i, j), ' ')
      print("  %s  |" % a, end="")
    print("")

In [8]:
transition_probs = {}
# transition_probs[(s, a, s')] = p(s'| s, a)
# any key not present has probability zero

rewards = {}
# key -> (s, a, s')

grid = standard_grid()
for i in range(grid.rows):
  for j in range(grid.cols):
    s = (i, j)
    if not grid.is_terminal(s):
      for a in ACTION_SPACE:
        s2 = grid.get_next_state(s, a)
        transition_probs[(s, a, s2)] = 1
        if s2 in grid.rewards:
          rewards[(s, a, s2)] = grid.rewards[s2]

In [9]:
# fixed policy
policy = {
    (2, 0): 'U',
    (1, 0): 'U',
    (0, 0): 'R',
    (0, 1): 'R',
    (0, 2): 'R',
    (1, 2): 'U',
    (2, 1): 'R',
    (2, 2): 'U',
    (2, 3): 'L'
}

print_policy(policy, grid)

---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |


In [10]:
gamma = 0.9 # discount factor

In [13]:
# initialize V(s) = 0
V = {}
for s in grid.all_states():
  V[s] = 0

# repeat until convergence
it = 0
while True:
  biggest_change = 0
  for s in grid.all_states():
    if not grid.is_terminal(s):
      old_v = V[s]
      new_v = 0 # will accumulate the answer
      for a in ACTION_SPACE:
        for s2 in grid.all_states():

          # action probability is deterministic
          action_prob = 1 if policy.get(s) == a else 0

          r = rewards.get((s, a, s2), 0)
          new_v += action_prob * transition_probs.get((s, a, s2), 0) * (r + gamma * V[s2])

      V[s] = new_v
      biggest_change = max(biggest_change, np.abs(old_v - new_v))

  print("iter:", it, "biggest_change:", biggest_change)
  print_values(V, grid)
  it += 1

  if biggest_change <= SMALL_ENOUGH:
    break

iter: 0 biggest_change: 1.0
---------------------------
 0.00| 0.00| 1.00| 0.00|
---------------------------
 0.00| 0.00| 0.00| 0.00|
---------------------------
 0.00| 0.00| 0.00| 0.00|
iter: 1 biggest_change: 0.9
---------------------------
 0.81| 0.90| 1.00| 0.00|
---------------------------
 0.73| 0.00| 0.90| 0.00|
---------------------------
 0.00| 0.00| 0.81| 0.00|
iter: 2 biggest_change: 0.7290000000000001
---------------------------
 0.81| 0.90| 1.00| 0.00|
---------------------------
 0.73| 0.00| 0.90| 0.00|
---------------------------
 0.66| 0.73| 0.81| 0.73|
iter: 3 biggest_change: 0
---------------------------
 0.81| 0.90| 1.00| 0.00|
---------------------------
 0.73| 0.00| 0.90| 0.00|
---------------------------
 0.66| 0.73| 0.81| 0.73|
