In [4]:
ALL_ACTIONS = ("U","D","R","L")
class Grid:
    def __init__(self,dimension,start) -> None:
        rows , cols = dimension
        i , j = start
        self.cols = cols
        self.rows = rows
        self.i = i
        self.j = j
    def set(self,rewards,actions):
        self.rewards = rewards
        self.actions = actions
    def set_state(self,point):
        self.i = point[0]
        self.j = point[1]
    def get_current_state(self):
        return (self.i,self.j)
    def is_end(self,state):
        return state not in self.actions
    def get_next_state(self,state,action):
        (i,j) = state
        if action == 'U':
                i -= 1
        elif action == 'D':
            i += 1
        elif action == 'R':
            j += 1
        elif action == 'L':
            j -= 1
        return (i,j)
    def move(self, action):
    # check if legal move first
        if action in self.actions[(self.i, self.j)]:
            if action == 'U':
                self.i -= 1
            elif action == 'D':
                self.i += 1
            elif action == 'R':
                self.j += 1
            elif action == 'L':
                self.j -= 1
        return self.rewards.get((self.i, self.j), 0)
    def get_all_states(self):
        return set(self.actions.keys()) | set(self.rewards.keys())
    
    

In [5]:
def standard_grid():

    # .  .  .  1
    # .  x  . -1
    # s  .  .  .
    grid = Grid((3, 4), (2, 0))
    rewards = {(0, 3): 1, (1, 3): -1}
    actions = {
        (0, 0): ('D', 'R'),
        (0, 1): ('L', 'R'),
        (0, 2): ('L', 'D', 'R'),
        (1, 0): ('U', 'D'),
        (1, 2): ('U', 'D', 'R'),
        (2, 0): ('U', 'R'),
        (2, 1): ('L', 'R'),
        (2, 2): ('L', 'R', 'U'),
        (2, 3): ('L', 'U'),
        }
    grid.set(rewards,actions)
    return grid

In [6]:
grid  = standard_grid()

In [143]:
grid.actions.keys()

dict_keys([(0, 0), (0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1), (2, 2), (2, 3)])

In [189]:
import numpy as np
def play_game(grid, policy, max_steps=20):

    start_states = list(grid.actions.keys())
    start_idx = np.random.choice(len(start_states))
    grid.set_state(start_states[start_idx])

    s = grid.get_current_state()
    a = np.random.choice(ALL_ACTIONS) # first action is uniformly random
    states_actions_rewards = [(s, a, 0)]
    for _ in range(max_steps):
        r = grid.move(a)
        s = grid.get_current_state()

        if grid.is_end(s):
            states_actions_rewards.append((s, None, r))
            break
        else:
            a = policy[s]
            states_actions_rewards.append((s, a, r))


    G = 0
    GAMMA= 0.9
    states_actions_returns = []
    first = True
    for s, a, r in reversed(states_actions_rewards):
        if first:
            first = False
        else:
            states_actions_returns.append((s, a, G))
        G = r + GAMMA*G
    states_actions_returns.reverse() # we want it to be in order of state visited
    return states_actions_returns

In [191]:
policy = {
(2, 0): 'U',
(1, 0): 'U',
(0, 0): 'D',
(0, 1): 'R',
(0, 2): 'L',
(1, 2): 'D',
(2, 1): 'R',
(2, 2): 'R',
(2, 3): 'U',
}
play_game(grid,policy)


[((2, 3), 'D', -0.9), ((2, 3), 'U', -1.0)]

In [192]:
# initialize Q(s,a) and returns
Q = {}
states = grid.get_all_states()
for s in states:
    if s in grid.actions: # not a terminal state
        Q[s] = {}
        for a in ALL_ACTIONS:
            Q[s][a] = 0 # needs to be initialized to something so we can argmax it
    else:
          # terminal state or state we can't otherwise get to
        pass
    
Q

{(0, 1): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (1, 2): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (2, 1): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (0, 0): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (2, 0): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (2, 3): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (0, 2): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (2, 2): {'U': 0, 'D': 0, 'R': 0, 'L': 0},
 (1, 0): {'U': 0, 'D': 0, 'R': 0, 'L': 0}}

In [51]:
def max_dict(d):

    max_key = None
    max_val = float('-inf')
    for k, v in d.items():
        if v > max_val:
            max_val = v
            max_key = k
    return max_key, max_val

In [193]:
learning_rate = 0.1
states = grid.get_all_states()
policy = {
(2, 0): 'U',
(1, 0): 'D',
(0, 0): 'R',
(0, 1): 'R',
(0, 2): 'L',
(1, 2): 'D',
(2, 1): 'R',
(2, 2): 'R',
(2, 3): 'U',
}
for t in range(5000):
    states_action_returns = play_game(grid,policy)
    print(len(states_action_returns),states_action_returns)
    seen_states_action = set()
    for s, a, G in states_action_returns:    
            sa = (s,a)
            if sa not in seen_states_action:
                old_q = Q[s][a]
                Q[s][a] = old_q + learning_rate * (G - old_q)
                seen_states_action.add(sa)
    for s in policy.keys():
        valid_actions = grid.actions[s]
        max_value = 0 
        for action in list(valid_actions):
            if Q[s][action] > max_value:
                max_value = Q[s][action]
                policy[s] = action
            

20 [((2, 0), 'L', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0)]
20 [((2, 0), 'L', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0)]
20 [((1, 0), 'L', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0), 'D', 0.0), ((2, 0), 'U', 0.0), ((1, 0),

 [((2, 1), 'R', 0.7290000000000001), ((2, 2), 'U', 0.81), ((1, 2), 'U', 0.9), ((0, 2), 'R', 1.0)]
6 [((2, 0), 'D', 0.5904900000000002), ((2, 0), 'U', 0.6561000000000001), ((1, 0), 'U', 0.7290000000000001), ((0, 0), 'R', 0.81), ((0, 1), 'R', 0.9), ((0, 2), 'R', 1.0)]
3 [((2, 2), 'U', 0.81), ((1, 2), 'U', 0.9), ((0, 2), 'R', 1.0)]
4 [((0, 1), 'L', 0.7290000000000001), ((0, 0), 'R', 0.81), ((0, 1), 'R', 0.9), ((0, 2), 'R', 1.0)]
4 [((2, 3), 'L', 0.7290000000000001), ((2, 2), 'U', 0.81), ((1, 2), 'U', 0.9), ((0, 2), 'R', 1.0)]
4 [((0, 0), 'U', 0.7290000000000001), ((0, 0), 'R', 0.81), ((0, 1), 'R', 0.9), ((0, 2), 'R', 1.0)]
6 [((2, 1), 'L', 0.5904900000000002), ((2, 0), 'U', 0.6561000000000001), ((1, 0), 'U', 0.7290000000000001), ((0, 0), 'R', 0.81), ((0, 1), 'R', 0.9), ((0, 2), 'R', 1.0)]
3 [((0, 2), 'D', 0.81), ((1, 2), 'U', 0.9), ((0, 2), 'R', 1.0)]
3 [((0, 2), 'L', 0.81), ((0, 1), 'R', 0.9), ((0, 2), 'R', 1.0)]
3 [((0, 1), 'U', 0.81), ((0, 1), 'R', 0.9), ((0, 2), 'R', 1.0)]
3 [((0, 2),

In [55]:
def print_values(V, g):
    for i in range(g.rows):
        print("---------------------------")
        for j in range(g.cols):
            v = V.get((i, j), 0)
            if v >= 0:
                print(" %.2f|" % v, end="")
            else:
                print("%.2f|" % v, end="")  
        print("")


def print_policy(P, g):
    for i in range(g.rows):
        print("---------------------------")
        for j in range(g.cols):
            a = P.get((i, j), " ")
            print("  %s  |" % a, end="")
        print("")



In [194]:
print_policy(policy,grid)

---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |


In [195]:
Q

{(0, 1): {'U': 0.8099991791689896,
  'D': 0.8099978812916876,
  'R': 0.8999999999999995,
  'L': 0.7289999570420034},
 (1, 2): {'U': 0.8999999999999995,
  'D': 0.7289998620737583,
  'R': -0.9999999570420033,
  'L': 0.8099986099154762},
 (2, 1): {'U': 0.6560997116304921,
  'D': 0.6560998891179021,
  'R': 0.7289999999999996,
  'L': 0.5904895153074968},
 (0, 0): {'U': 0.7289999570420034,
  'D': 0.6560996819932486,
  'R': 0.8099999999999996,
  'L': 0.728999563776747},
 (2, 0): {'U': 0.6560999999999997,
  'D': 0.5904891791689897,
  'R': 0.6560996148584901,
  'L': 0.5904891791689897},
 (2, 3): {'U': -0.9999999469654363,
  'D': 0.6560993210906714,
  'R': 0.656099597717926,
  'L': 0.7289999999999996},
 (0, 2): {'U': 0.8999988740315358,
  'D': 0.8099987489239286,
  'R': 0.9999999999999996,
  'L': 0.8099998631085209},
 (2, 2): {'U': 0.8099999999999996,
  'D': 0.7289996983674162,
  'R': 0.6560972017720217,
  'L': 0.6560982708738611},
 (1, 0): {'U': 0.7289999999999996,
  'D': 0.5904899522688928,
  