In [111]:
import numpy as np
import matplotlib.pyplot as plt
from enum import IntEnum

In [216]:
STATE_PROBS = [0.9, 0.075, 0.025] # prob of accessible grid, prob of inaccessible grid, prob of loser grid
REWARDS = [0, -2, -1, 1] # accessible, inaccessible, loser, winner grid rewards
UNKNOWN_POLICY = -2 # the policy is unknown for now, the policies are going to be determined after creating the gridworld
ACTIONS = [0, 1, 2, 3] # up, down, left, right
ROW_SIZE = 10
COLUMN_SIZE = 10
THRESHOLD = 1e-4
DISCOUNT_FACTOR = 0.9

In [217]:
class Gridworld:
  def __init__(self, row_size, column_size, start_position):
    assert(row_size * column_size >= 10) # creating a complex gridworld
    self.gridworld, self.winner_position = self.create_gridworld(row_size, column_size)
    row, column = start_position
    assert(self.gridworld[row, column][0] == 0) # the starting position should be accessible and should not cause the termination of the game
    self.start_position = start_position
    self.position = start_position
    self.available_moves = self.check_available_moves(row_size, column_size)
    self.create_policy(row_size, column_size)
    self.print_gridworld(row_size, column_size)

  def create_gridworld(self, row_size, column_size):
    out = np.empty((row_size, column_size), dtype=object)
    total_number_of_grids = row_size * column_size
    number_of_accessible_grid = int(total_number_of_grids * STATE_PROBS[0])
    number_of_inaccessible_grid = int(total_number_of_grids * STATE_PROBS[1])
    # We subtract the number of winner grid which is 1.
    number_of_loser_grid = total_number_of_grids - number_of_accessible_grid - number_of_inaccessible_grid - 1
    # create a distribution of states
    new_state_probs = np.array([number_of_accessible_grid, number_of_inaccessible_grid, number_of_loser_grid]) / (total_number_of_grids - 1)
    # determine the winner state for policy creation
    winner_position_row = np.random.choice(row_size - 1)
    winner_position_column = np.random.choice(column_size - 1)
    state = np.random.choice(REWARDS[:-1], p=new_state_probs, size=(ROW_SIZE, COLUMN_SIZE))
    out[winner_position_row, winner_position_column] = [REWARDS[-1], UNKNOWN_POLICY]
    for i in range(row_size):
      for j in range(column_size):
        if i == winner_position_row and j == winner_position_column:
          continue
        out[i, j] = [state[i, j], UNKNOWN_POLICY]
    return out, (winner_position_row, winner_position_column)

  def game_over(self):
    row, column = self.starting_position
    return self.gridworld[row, column][0] % 2 == 1 # whether we lost or won the game terminates

  # helper method which doesn't take extreme cases into consideration
  def move_calculate_position(self, position, action):
    if action == 0:
      return [position[0] - 1, position[1]]
    elif action == 1:
      return [position[0] + 1, position[1]]
    elif action == 2:
      return [position[0], position[1] - 1]
    elif action == 3:
      return [position[0], position[1] + 1]

  def move_simulation(self, position, action):
    # check if we step into the invalid grid
    row, column = position
    can_move = self.gridworld[row, column][0] != -2
    not_going_to_move_up_or_down = (row == 0 and action == 0) or (row == ROW_SIZE - 1 and action == 1)
    not_going_to_move_left_or_right = (column == 0 and action == 2) or (column == COLUMN_SIZE - 1 and action == 3)
    can_move = can_move and not(not_going_to_move_up_or_down or not_going_to_move_left_or_right)
    if can_move:
      row, column = self.move_calculate_position(position, action)
      # check if we step into the invalid grid
      can_move = can_move and self.gridworld[row, column][0] != -2
    return can_move, (row, column)

  def create_transition_probs(self, position, action):
    # check if we step into the invalid grid
    row, column = position
    can_move = self.gridworld[row, column][0] != -2
    not_going_to_move_up_or_down = (row == 0 and action == 0) or (row == ROW_SIZE - 1 and action == 1)
    not_going_to_move_left_or_right = (column == 0 and action == 2) or (column == COLUMN_SIZE - 1 and action == 3)
    can_move = can_move and not(not_going_to_move_up_or_down or not_going_to_move_left_or_right)
    if can_move:
      row, column = self.move_calculate_position(position, action)
      # check if we step into the invalid grid
      can_move = can_move and self.gridworld[row, column][0] != -2
    return can_move, (row, column)

  def move(self, action):
    self.position = list(self.move_simulation(self.position, action))

  def check_available_moves(self, row_size, column_size):
    moves = np.empty((row_size, column_size), dtype=object)
    for i in range(row_size):
      for j in range(column_size):
        move_list = []
        if self.gridworld[i, j][0] != 0:
          moves[i, j] = move_list
          continue # non-playable grid
        if self.move_simulation([i, j], 1)[0]:
          move_list.append(1)
        if self.move_simulation([i, j], 0)[0]:
          move_list.append(0)
        if self.move_simulation([i, j], 2)[0]:
          move_list.append(2)
        if self.move_simulation([i, j], 3)[0]:
          move_list.append(3)
        moves[i, j] = move_list
    return moves

  def create_policy(self, row_size, column_size):
    for i in range(row_size):
      for j in range(column_size):
        if self.gridworld[i, j][0] != 0:
          continue # non-playable grid
        self.gridworld[i, j][1] = np.random.choice(ACTIONS)

  def print_gridworld(self, row_size, column_size):
    for i in range(row_size):
      for j in range(column_size):
        grid = self.gridworld[i, j]
        if grid[0] == 0:
          print("[  ]" + "\t", end="")
        elif grid[0] == -2:
          print("[xx]" + "\t", end="")
        elif grid[0] == -1:
          print("[-1]" + "\t", end="")
        elif grid[0] == 1:
          print("[+1]" + "\t", end="")
      print("\n")

  def print_policy(self, row_size, column_size):
    for i in range(row_size):
      for j in range(column_size):
        grid = self.gridworld[i, j]
        print(str(ACTIONS[grid[1]]) + "\t" if grid[0] >= 0 else "X\t", end="")
      print("\n")

In [218]:
gridworld = Gridworld(ROW_SIZE, COLUMN_SIZE, [ROW_SIZE - 1, 0]);

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[xx]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[xx]	[-1]	

[  ]	[xx]	[  ]	[  ]	[  ]	[  ]	[  ]	[-1]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[-1]	[  ]	[  ]	[  ]	

[xx]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[xx]	[  ]	[xx]	[  ]	[  ]	[  ]	[  ]	[xx]	[+1]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[xx]	[  ]	[  ]	



In [219]:
gridworld.print_policy(ROW_SIZE, COLUMN_SIZE)

3	2	0	2	0	2	3	2	X	1	

2	1	0	1	0	1	1	1	X	X	

3	X	0	2	3	0	1	X	0	3	

3	1	0	1	2	0	3	1	3	1	

2	0	3	1	3	2	X	0	2	1	

X	0	0	2	0	0	1	3	1	2	

2	0	3	1	3	2	0	3	3	1	

3	2	0	3	3	2	3	2	3	2	

X	1	X	0	1	1	1	X	2	1	

3	1	3	2	1	0	0	X	0	2	



In [220]:
def create_transition_probs_rewards_table(gridworld):
  transition_probs = {}
  rewards = {}
  for i in range(ROW_SIZE):
    for j in range(COLUMN_SIZE):
      state = (i, j)
      if gridworld.gridworld[state][0] == 0:
        for action in ACTIONS:
          can_move, next_state = gridworld.create_transition_probs(state, action)
          if can_move:
            transition_probs[(state, action, next_state)] = 1
            rewards[(state, action, next_state)] = gridworld.gridworld[next_state][1]

  return transition_probs, rewards

In [222]:
def deterministic_policy_evaluation(gridworld, value_function):
  it = 0
  while True:
    error = 0
    for i in range(ROW_SIZE):
      for j in range(COLUMN_SIZE):
        state = (i, j)
        old_value = value_function[state]
        new_value = 0
        grid = gridworld.gridworld[i, j]
        if grid[0] == 0:
          for action in ACTIONS:
            for k in range(ROW_SIZE):
              for l in range(COLUMN_SIZE):
                next_state = (k, l)
                action_prob = 1 if action == grid[1] else 0
                reward = rewards.get((state, action, next_state), 0)
                new_value += action_prob * transition_probs.get((state, action, next_state), 0) * (reward + DISCOUNT_FACTOR * value_function[next_state])
          value_function[state] = new_value
          error = max(error, np.abs(old_value - new_value))
    print(f"Iteration: {it + 1}, Error: {error}")
    it += 1

    if error < THRESHOLD:
      break

  return value_function

In [223]:
transition_probs, rewards = create_transition_probs_rewards_table(gridworld)

In [224]:
value_function = {}
it = 0
for i in range(ROW_SIZE):
    for j in range(COLUMN_SIZE):
      state = (i, j)
      value_function[state] = 0

while True:
  value_function = deterministic_policy_evaluation(gridworld, value_function)
  is_policy_converged = True
  for i in range(ROW_SIZE):
      for j in range(COLUMN_SIZE):
        state = (i, j)
        grid = gridworld.gridworld[state]
        if grid[0] == 0:
          old_action = grid[1]
          new_action = None
          best_value = float('-inf')

          for action in ACTIONS:
              v = 0
              for k in range(ROW_SIZE):
                for l in range(COLUMN_SIZE):
                    next_state = (k, l)
                    reward = rewards.get((state, action, next_state), 0)
                    v += transition_probs.get((state, action, next_state), 0) * (reward + DISCOUNT_FACTOR * value_function[next_state])

              if v > best_value:
                best_value = v
                new_action = action

          grid[1] = new_action
          if new_action != old_action:
            is_policy_converged = False

  if is_policy_converged:
    break

  print(f"Iteration all: {it + 1}")
  it += 1

Iteration: 1, Error: 6.32
Iteration: 2, Error: 4.32
Iteration: 3, Error: 3.888000000000001
Iteration: 4, Error: 3.499200000000002
Iteration: 5, Error: 3.149280000000001
Iteration: 6, Error: 2.5509168000000004
Iteration: 7, Error: 2.295825120000001
Iteration: 8, Error: 1.8596183472000014
Iteration: 9, Error: 1.5062908612319994
Iteration: 10, Error: 1.2200955975979202
Iteration: 11, Error: 0.9882774340543126
Iteration: 12, Error: 0.8005047215839998
Iteration: 13, Error: 0.6484088244830382
Iteration: 14, Error: 0.5252111478312571
Iteration: 15, Error: 0.4254210297433225
Iteration: 16, Error: 0.34459103409209035
Iteration: 17, Error: 0.27911873761459205
Iteration: 18, Error: 0.22608617746782045
Iteration: 19, Error: 0.18312980374892973
Iteration: 20, Error: 0.14833514103663958
Iteration: 21, Error: 0.12015146423967238
Iteration: 22, Error: 0.0973226860341434
Iteration: 23, Error: 0.07883137568764909
Iteration: 24, Error: 0.06385341430699398
Iteration: 25, Error: 0.051721265588668075
Iterat

In [225]:
def print_value_function(value_function):
  values = np.zeros((ROW_SIZE, COLUMN_SIZE))
  for key, value in value_function.items():
    i, j = key
    values[i, j] = value
  for i in range(ROW_SIZE):
      for j in range(COLUMN_SIZE):
        print(f"{values[i, j]:.2f}" + "\t", end="")
      print("\n")

In [226]:
gridworld.print_gridworld(ROW_SIZE, COLUMN_SIZE)

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[xx]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[xx]	[-1]	

[  ]	[xx]	[  ]	[  ]	[  ]	[  ]	[  ]	[-1]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[-1]	[  ]	[  ]	[  ]	

[xx]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	

[xx]	[  ]	[xx]	[  ]	[  ]	[  ]	[  ]	[xx]	[+1]	[  ]	

[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[  ]	[xx]	[  ]	[  ]	



In [227]:
gridworld.print_policy(ROW_SIZE, COLUMN_SIZE)

1	2	2	2	2	3	2	2	X	0	

1	2	2	1	1	0	0	0	X	X	

1	X	3	3	1	2	1	X	1	1	

0	2	2	3	1	1	3	3	1	2	

0	2	2	3	1	2	X	1	1	1	

X	0	1	1	1	1	3	1	1	1	

1	3	3	1	1	2	3	3	1	2	

3	3	3	3	0	2	2	3	0	2	

X	0	X	0	0	0	0	X	2	0	

3	3	3	0	0	0	0	X	3	0	



In [228]:
print_value_function(value_function)

29.00	29.10	28.19	25.37	24.83	25.26	24.74	25.26	0.00	0.00	

30.00	29.00	27.10	26.22	26.91	24.74	25.26	24.74	0.00	0.00	

30.00	0.00	26.22	26.91	26.57	26.91	26.16	0.00	27.48	25.73	

30.00	30.00	28.00	26.57	27.30	26.57	25.73	27.48	27.20	27.48	

30.00	29.00	26.10	27.30	27.00	27.30	0.00	30.00	28.00	27.20	

0.00	26.10	28.20	28.00	30.00	29.00	30.00	30.00	30.00	28.00	

26.67	28.20	28.00	30.00	30.00	30.00	30.00	30.00	30.00	30.00	

26.30	27.00	30.00	30.00	30.00	30.00	29.00	30.00	30.00	30.00	

0.00	26.30	0.00	30.00	30.00	29.00	29.10	0.00	0.00	29.00	

25.00	26.67	26.30	27.00	28.00	27.10	27.19	0.00	26.39	27.10	

