In [4]:
from grid_world import negative_grid, standard_grid
from iterative_policy_evaluation import print_values, print_policy 
import numpy as np 
import operator 

In [17]:
gamma = 0.9
alpha = 0.1
possible_actions = ('U', 'D', 'L', 'R')

def max_dict(dictionary): 
    '''
    Returns the argmax key and the corresponding value 
    '''
    return max(dictionary.items(), key=operator.itemgetter(1))

def random_action(a, epsilon = 0.1): 
    '''
    Epsilon greedy policy 
    '''
    p = np.random.random()
    return a if p < (1 - epsilon) else np.random.choice(possible_actions)

In [23]:
# initialize grid 
grid = negative_grid(step_cost=-0.1)

# initialize Q 
Q = {}
for state in grid.all_states(): 
    Q[state] = {}
    for action in possible_actions: 
        Q[state][action] = 0
        
# let's also keep track of how many times Q[s] has been updated
update_counts = {}
update_counts_sa = {}

for s in grid.all_states():
    update_counts_sa[s] = {}
    for a in possible_actions:
        update_counts_sa[s][a] = 1.0

In [33]:
# the SARSA algorithm 

NUM_EPISODES = 10000
t = 1

for ep in range(NUM_EPISODES): 
    if ep % 100 == 0: 
        t += 0.02 
        
    # initialize starting state 
    state = (2, 0)
    grid.set_state(state)
    
    # pick action 
    action, _ = max_dict(Q[s])
    action = random_action(action, epsilon = 0.5/t)
    
    while not grid.game_over(): 
        # make action 
        reward = grid.move(action)
        
        # store s' and a' 
        state_prime = grid.current_state()
        action_prime, _ = max_dict(Q[state_prime])
        action_prime = random_action(action_prime, epsilon = 0.5/t)
        
        alph = alpha / update_counts_sa[state][action]
        Q[state][action] += alph * (reward + gamma * Q[state_prime][action_prime] - Q[state][action])
        
        update_counts[state] = update_counts.get(state, 0) + 1
        update_counts_sa[state][action] += 0.005
        
        state, action = state_prime, action_prime 
        
        
policy = {}
V = {}
for state in grid.actions.keys(): 
    policy[state], V[state] = max_dict(Q[state])
    
    
print("values:")
print_values(V, grid)
print("policy:")
print_policy(policy, grid)

values:
---------------------------
 0.53| 0.75| 1.00| 0.00|
---------------------------
 0.34| 0.00| 0.75| 0.00|
---------------------------
 0.17| 0.22| 0.46| 0.22|
policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |
